Skip to content

Commit fe9eee3

Browse files
authored
Add files via upload
first commit
1 parent aac84ea commit fe9eee3

13 files changed

+31193
-0
lines changed

README.md

+319
Large diffs are not rendered by default.

config/cpm-medium.json

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"activation_function": "gelu_new",
3+
"architectures": [
4+
"GPT2LMHeadModel"
5+
],
6+
"attn_pdrop": 0.1,
7+
"bos_token_id": 1,
8+
"embd_pdrop": 0.1,
9+
"eos_token_id": 2,
10+
"initializer_range": 0.02,
11+
"layer_norm_epsilon": 1e-05,
12+
"model_type": "gpt2",
13+
"n_ctx": 1024,
14+
"n_embd": 1024,
15+
"n_head": 16,
16+
"n_layer": 24,
17+
"n_positions": 1024,
18+
"n_special": 0,
19+
"predict_special_tokens": true,
20+
"resid_pdrop": 0.1,
21+
"summary_activation": null,
22+
"summary_first_dropout": 0.1,
23+
"summary_proj_to_labels": true,
24+
"summary_type": "cls_index",
25+
"summary_use_proj": true,
26+
"task_specific_params": {
27+
"text-generation": {
28+
"do_sample": true,
29+
"max_length": 50
30+
}
31+
},
32+
"vocab_size": 30000
33+
}

config/cpm-small.json

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"activation_function": "gelu_new",
3+
"architectures": [
4+
"GPT2LMHeadModel"
5+
],
6+
"attn_pdrop": 0.1,
7+
"bos_token_id": 50256,
8+
"embd_pdrop": 0.1,
9+
"eos_token_id": 50256,
10+
"initializer_range": 0.02,
11+
"layer_norm_epsilon": 1e-05,
12+
"model_type": "gpt2",
13+
"n_ctx": 1024,
14+
"n_embd": 768,
15+
"n_head": 12,
16+
"n_layer": 12,
17+
"n_positions": 1024,
18+
"resid_pdrop": 0.1,
19+
"summary_activation": null,
20+
"summary_first_dropout": 0.1,
21+
"summary_proj_to_labels": true,
22+
"summary_type": "cls_index",
23+
"summary_use_proj": true,
24+
"task_specific_params": {
25+
"text-generation": {
26+
"do_sample": true,
27+
"max_length": 50
28+
}
29+
},
30+
"vocab_size": 30000
31+
}

data_parallel.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
2+
from torch.nn.parallel import DataParallel
3+
import torch
4+
from torch.nn.parallel._functions import Scatter
5+
from torch.nn.parallel.parallel_apply import parallel_apply
6+
7+
def scatter(inputs, target_gpus, chunk_sizes, dim=0):
8+
r"""
9+
Slices tensors into approximately equal chunks and
10+
distributes them across given GPUs. Duplicates
11+
references to objects that are not tensors.
12+
"""
13+
def scatter_map(obj):
14+
if isinstance(obj, torch.Tensor):
15+
try:
16+
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
17+
except:
18+
print('obj', obj.size())
19+
print('dim', dim)
20+
print('chunk_sizes', chunk_sizes)
21+
quit()
22+
if isinstance(obj, tuple) and len(obj) > 0:
23+
return list(zip(*map(scatter_map, obj)))
24+
if isinstance(obj, list) and len(obj) > 0:
25+
return list(map(list, zip(*map(scatter_map, obj))))
26+
if isinstance(obj, dict) and len(obj) > 0:
27+
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
28+
return [obj for targets in target_gpus]
29+
30+
# After scatter_map is called, a scatter_map cell will exist. This cell
31+
# has a reference to the actual function scatter_map, which has references
32+
# to a closure that has a reference to the scatter_map cell (because the
33+
# fn is recursive). To avoid this reference cycle, we set the function to
34+
# None, clearing the cell
35+
try:
36+
return scatter_map(inputs)
37+
finally:
38+
scatter_map = None
39+
40+
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
41+
r"""Scatter with support for kwargs dictionary"""
42+
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
43+
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
44+
if len(inputs) < len(kwargs):
45+
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
46+
elif len(kwargs) < len(inputs):
47+
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
48+
inputs = tuple(inputs)
49+
kwargs = tuple(kwargs)
50+
return inputs, kwargs
51+
52+
class BalancedDataParallel(DataParallel):
53+
def __init__(self, gpu0_bsz, *args, **kwargs):
54+
self.gpu0_bsz = gpu0_bsz
55+
super().__init__(*args, **kwargs)
56+
57+
def forward(self, *inputs, **kwargs):
58+
if not self.device_ids:
59+
return self.module(*inputs, **kwargs)
60+
if self.gpu0_bsz == 0:
61+
device_ids = self.device_ids[1:]
62+
else:
63+
device_ids = self.device_ids
64+
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
65+
66+
# print('len(inputs): ', str(len(inputs)))
67+
# print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))
68+
69+
if len(self.device_ids) == 1:
70+
return self.module(*inputs[0], **kwargs[0])
71+
if self.gpu0_bsz == 0:
72+
replicas = self.replicate(self.module, self.device_ids)
73+
else:
74+
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
75+
76+
# replicas = self.replicate(self.module, device_ids[:len(inputs)])
77+
if self.gpu0_bsz == 0:
78+
replicas = replicas[1:]
79+
80+
# print('replicas:', str(len(replicas)))
81+
82+
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
83+
return self.gather(outputs, self.output_device)
84+
85+
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
86+
return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])
87+
88+
def scatter(self, inputs, kwargs, device_ids):
89+
bsz = inputs[0].size(self.dim)
90+
num_dev = len(self.device_ids)
91+
gpu0_bsz = self.gpu0_bsz
92+
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
93+
if gpu0_bsz < bsz_unit:
94+
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
95+
delta = bsz - sum(chunk_sizes)
96+
for i in range(delta):
97+
chunk_sizes[i + 1] += 1
98+
if gpu0_bsz == 0:
99+
chunk_sizes = chunk_sizes[1:]
100+
else:
101+
return super().scatter(inputs, kwargs, device_ids)
102+
103+
# print('bsz: ', bsz)
104+
# print('num_dev: ', num_dev)
105+
# print('gpu0_bsz: ', gpu0_bsz)
106+
# print('bsz_unit: ', bsz_unit)
107+
# print('chunk_sizes: ', chunk_sizes)
108+
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
109+

dataset.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from torch.utils.data import Dataset
2+
import torch
3+
4+
5+
class CPMDataset(Dataset):
6+
"""
7+
8+
"""
9+
10+
def __init__(self, input_list, max_len):
11+
self.input_list = input_list
12+
self.max_len = max_len
13+
14+
def __getitem__(self, index):
15+
input_ids = self.input_list[index]
16+
input_ids = input_ids[:self.max_len]
17+
input_ids = torch.tensor(input_ids, dtype=torch.long)
18+
return input_ids
19+
20+
def __len__(self):
21+
return len(self.input_list)

generate.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import os
4+
import argparse
5+
from tqdm import trange
6+
from transformers import GPT2LMHeadModel, GPT2Config, CpmTokenizer
7+
from utils import top_k_top_p_filtering, set_logger
8+
from os.path import join, exists
9+
10+
11+
def generate_next_token(input_ids):
12+
"""
13+
对于给定的上文,生成下一个单词
14+
"""
15+
outputs = model(input_ids=input_ids)
16+
logits = outputs.logits
17+
# next_token_logits表示最后一个token的hidden_state对应的prediction_scores,也就是模型要预测的下一个token的概率
18+
next_token_logits = logits[0, -1, :]
19+
next_token_logits = next_token_logits / args.temperature
20+
# 对于<unk>的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
21+
next_token_logits[unk_id] = -float('Inf')
22+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
23+
# torch.multinomial表示从候选集合中选出无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
24+
next_token_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
25+
return next_token_id
26+
27+
28+
def generate(max_len):
29+
# 对title与context进行tokenize
30+
title_ids = tokenizer.encode(title, add_special_tokens=False)
31+
context_ids = tokenizer.encode(context, add_special_tokens=False)
32+
input_ids = title_ids + [sep_id] + context_ids
33+
cur_len = len(input_ids)
34+
last_token_id = input_ids[-1] # 已生成的内容的最后一个token
35+
input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
36+
37+
while True:
38+
next_token_id = generate_next_token(input_ids)
39+
input_ids = torch.cat((input_ids, next_token_id.unsqueeze(0)), dim=1)
40+
cur_len += 1
41+
word = tokenizer.convert_ids_to_tokens(next_token_id.item())
42+
# if cur_len >= max_len:
43+
# break
44+
# 超过最大长度,并且换行
45+
if cur_len >= max_len and last_token_id == 8 and next_token_id == 3:
46+
break
47+
# 超过最大长度,并且生成标点符号
48+
if cur_len >= max_len and word in [".", "。", "!", "!", "?", "?", ",", ","]:
49+
break
50+
# 生成结束符
51+
if next_token_id == eod_id:
52+
break
53+
result = tokenizer.decode(input_ids.squeeze(0))
54+
return result
55+
56+
57+
if __name__ == '__main__':
58+
# 参数设置
59+
parser = argparse.ArgumentParser()
60+
parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
61+
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度')
62+
parser.add_argument('--topk', default=0, type=int, required=False, help='最高几选一')
63+
parser.add_argument('--topp', default=0.85, type=float, required=False, help='最高积累概率')
64+
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, help='重复惩罚参数')
65+
parser.add_argument('--max_len', default=200, type=int, required=False, help='生成的最长长度')
66+
parser.add_argument('--log_path', default='log/generate.log', type=str, required=False, help='日志存放位置')
67+
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
68+
parser.add_argument('--model_path', type=str, default='model/zuowen_epoch40', help='模型存放位置')
69+
# parser.add_argument('--title', type=str, default='徜徉在书籍的阳光世界', help='作文标题')
70+
# parser.add_argument('--context', type=str, default='一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙', help='作文上文')
71+
parser.add_argument('--title', type=str, default='家乡的四季', help='作文标题')
72+
parser.add_argument('--context', type=str, default='家乡的四季,最美不过了', help='作文上文')
73+
args = parser.parse_args()
74+
75+
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
76+
args.cuda = torch.cuda.is_available() and not args.no_cuda # 当用户使用GPU,并且GPU可用时
77+
device = 'cuda:0' if args.cuda else 'cpu'
78+
# device = 'cpu'
79+
80+
# 创建日志对象
81+
logger = set_logger(args.log_path)
82+
83+
# 初始化tokenizer
84+
tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")
85+
eod_id = tokenizer.convert_tokens_to_ids("<eod>") # 文档结束符
86+
sep_id = tokenizer.sep_token_id
87+
unk_id = tokenizer.unk_token_id
88+
89+
# 加载模型
90+
model = GPT2LMHeadModel.from_pretrained(args.model_path)
91+
model.eval()
92+
model = model.to(device)
93+
94+
title = args.title
95+
context = args.context
96+
logger.info("title:{}".format(title))
97+
logger.info("context:{}".format(context))
98+
99+
# 开始生成
100+
result = generate(args.max_len)
101+
result = result.split("<sep>")[1]
102+
logger.info("result:{}\n".format(result))
103+
104+
# 通过控制台循环生成
105+
# print('开始生成,输入CTRL + Z以退出')
106+
# while True:
107+
# try:
108+
# # 用户输入title与context
109+
# title = input("请输入作文标题:")
110+
# context = input("请输入作文起始句子:")
111+
#
112+
# logger.info("title:{}".format(title))
113+
# logger.info("context:{}".format(context))
114+
#
115+
# # 开始生成
116+
# result = generate(args.max_len)
117+
# result = result.split("<sep>")[1]
118+
# logger.info("result:{}\n".format(result))
119+
# break
120+
#
121+
# except KeyboardInterrupt:
122+
# break
123+
124+

0 commit comments

Comments
 (0)