使用方法请参考https://github.com/pass-lin/bert4keras3
这是一个基于keras的版本的模型
支持tensorflow,jax,pytorch多后端调用
未来还能支持苹果的mlx
理论上numpy也能调用推理
from transformers import AutoTokenizer
from bert4keras3.models import build_transformer_model
import numpy as np
import os
os.environ["KERAS_BACKEND"] = "jax"#选择后端,jax是最优后端
os.environ["FLASH_ATTN"]='0'#是否开启flash attention,这个需要自己去安装
#jax使用flash参考https://github.com/nshepperd/flash_attn_jax/releases这里安装flash
os.environ["ENABLE_LORA"] = "0"#1就是开启lora的
tokenizer = AutoTokenizer.from_pretrained(dict_path)#用hf加载tokenizer
LLAMA = build_transformer_model(
config_path,#config文件
keras_weights_path=weights_path,#weights.h5文件
model='llama',
with_lm=True,
return_keras_model=False,
)
model = LLAMA.model#训练的模型
generate_model=LLAMA.build_cache_model([max_len],end_token=end_token,#end token是你对应模型的end token
progress_print=True,search_mode='topp',k=0.7)#推理模型
inputs = [start_token]+tokenizer.encode('hello world')#start token同理
inputs = np.reshape(inputs,[1,-1])
generate_model.predict(inputs)
model.predict(inputs)
评论