整理一下在keras中使用T5模型的要点,尤其是中文场景下的使用要点。以多国语言版mT5为例。
博客链接:https://kexue.fm/archives/7867
本项目实验环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.9.1
首先,要想办法下载Google开放的权重,最简单的方式,是找一台能科学上网的服务器,在上面安装gsutil,然后执行
gsutil cp -r gs://t5-data/pretrained_models/mt5/small .
T5使用sentencepiece作为tokenizer,mT5的tokenizer模型下载地址为
gsutil cp -r gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model .
笔者精简好的tokenizer文件:sentencepiece_cn.model和sentencepiece_cn_keep_tokens.json
另外,为了方便国内用户,笔者将small版和base版整理分享到百度网盘(mwfc)了。
T5模型的配置文件是gin格式的,这不符合bert4keras的输入,使用者请根据所给的gin和下述模版构建对应的config.json文件。
下面是mT5 small版的参考config.json:
{
"hidden_dropout_prob": 0.1,
"hidden_size": 512,
"initializer_range": 0.02,
"intermediate_size": 1024,
"num_attention_heads": 6,
"attention_head_size": 64,
"num_hidden_layers": 8,
"vocab_size": 250112,
"hidden_act": ["gelu", "linear"]
}
一般要修改的是hidden_size
、intermediate_size
、num_attention_heads
、attention_head_size
和num_hidden_layers
这几个参数。
# 模型路径
config_path = '/root/kg/bert/mt5/mt5_small/t5_config.json'
checkpoint_path = '/root/kg/bert/mt5/mt5_small/model.ckpt-1000000'
spm_path = '/root/kg/bert/mt5/sentencepiece.model'
# 加载分词器
tokenizer = SpTokenizer(spm_path, token_start=None, token_end='</s>')
# 加载模型
t5 = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
model='t5.1.1',
return_keras_model=False,
name='T5',
)
encoder = t5.encoder
decoder = t5.decoder
model = t5.model
# 模型路径
config_path = '/root/kg/bert/mt5/mt5_base/t5_config.json'
checkpoint_path = '/root/kg/bert/mt5/mt5_base/model.ckpt-1000000'
spm_path = '/root/kg/bert/mt5/sentencepiece_cn.model'
keep_tokens_path = '/root/kg/bert/mt5/sentencepiece_cn_keep_tokens.json'
# 加载分词器
tokenizer = SpTokenizer(spm_path, token_start=None, token_end='</s>')
keep_tokens = json.load(open(keep_tokens_path))
# 加载模型
t5 = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
keep_tokens=keep_tokens,
model='t5.1.1',
return_keras_model=False,
name='T5',
)
encoder = t5.encoder
decoder = t5.decoder
model = t5.model
细节请参考:task_autotitle_csl.py。
QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn