-
Notifications
You must be signed in to change notification settings - Fork 927
/
basic_language_model_gpt2_ml.py
57 lines (44 loc) · 2.95 KB
/
basic_language_model_gpt2_ml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#! -*- coding: utf-8 -*-
# 基本测试:中文GPT2_ML模型
# 介绍链接:https://kexue.fm/archives/7292
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import AutoRegressiveDecoder
from bert4keras.snippets import uniout
config_path = '/root/kg/bert/gpt2_ml/config.json'
checkpoint_path = '/root/kg/bert/gpt2_ml/model.ckpt-100000'
dict_path = '/root/kg/bert/gpt2_ml/vocab.txt'
tokenizer = Tokenizer(
dict_path, token_start=None, token_end=None, do_lower_case=True
) # 建立分词器
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, model='gpt2_ml'
) # 建立模型,加载权重
class ArticleCompletion(AutoRegressiveDecoder):
"""基于随机采样的文章续写
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids = np.concatenate([inputs[0], output_ids], 1)
return self.last_token(model).predict(token_ids)
def generate(self, text, n=1, topp=0.95):
token_ids, _ = tokenizer.encode(text)
results = self.random_sample([token_ids], n, topp=topp) # 基于随机采样
return [text + tokenizer.decode(ids) for ids in results]
article_completion = ArticleCompletion(
start_id=None,
end_id=511, # 511是中文句号
maxlen=256,
minlen=128
)
print(article_completion.generate(u'今天天气不错'))
"""
部分结果:
>>> article_completion.generate(u'今天天气不错')
[u'今天天气不错,可以去跑步。昨晚看了一个关于跑步的纪录片,里面的女主讲述的是一个女孩子的成长,很励志,也很美丽。我也想跑,但是我不知道跑步要穿运动鞋,所以就买了一双运动鞋。这个纪录片是关于运动鞋的,有一 集讲了一个女孩子,从小学开始就没有穿过运动鞋,到了高中才开始尝试跑步。']
>>> article_completion.generate(u'双十一')
[u'双十一马上就要到了!你还在为双11的物流配送而担心吗?你还在为没时间去仓库取货而发愁吗?你还在为不知道怎么买到便宜货而发愁吗?你还在为买不到心仪的产品而懊恼吗?那么,双十一就来了!今天小编带你来看看这些 快递,都是怎么送货的!1. 物流配送快递公司的配送,主要是由快递公司负责,快递公司负责派件,物流服务。']
>>> article_completion.generate(u'科学空间')
[u'科学空间站科学空间站(英文:science space station),是中华人民共和国的一个空间站。该空间站是中国科学院大连物理研究所研制,主要研发和使用中国科学院大连物理研究所的核能动力空间站。科学空间站位于北京市海淀区,距离地面393米,总建筑面积约为1万平方米,总投资约为5亿元人民币。科学空间站于2018年12月26日开始动工,2021年6月建成并投入使用。']
"""