-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathTensor2Tensor模型
96 lines (81 loc) · 2.44 KB
/
Tensor2Tensor模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
Tensor2Tensor
1.模型易于更改,泛化性高
2.利于GPU、TPU加速,分布式计算
应用:
翻译
图像分类
语音识别
以下均为命令行命令
#普通模式
!pip install tensor2tensor
#gpu 模式
!pip install tensor2tensor[tensorflow_gpu]
#需要注意 tensor2tensor 与 tensorflow的版本兼容性
#一般需要在tensorflow 1.6.0以前
#获取数据
!t2t-datagen \
--data_dir=/t2t_data \
--tmp_dir=/tmp/t2t_datagen \
--problem=translate_enzh_wmt32k
# 训练
!t2t-trainer \
--data_dir=/t2t_data \
--problem=translate_enzh_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=/t2t_train
# 解码
!DECODE_FILE=/t2t_data/decode_this.txt
!echo "Hello world" >> $DECODE_FILE
!echo "Goodbye world" >> $DECODE_FILE
!echo -e 'Hallo Welt\nAuf Wiedersehen Welt' > ref-translation.de
!BEAM_SIZE=4
!ALPHA=0.6
!t2t-decoder \
--data_dir=/t2t_data \
--problem=translate_enzh_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=/t2t_train \
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
--decode_from_file=$DECODE_FILE \
--decode_to_file=translation.en
#自定义问题
from tensor2tensor.utils import registry
from tensor2tensor.data_generators import problem, text_problems
#自定义的problem一定要加该装饰器,装饰器将方法变为属性,使其可以直接赋值
@registry.register_problem
class MyProblem(text_problems.Text2TextProblem):
@property
def approx_vocab_size(self):
return 100
@property
def is_generate_per_split(self):
return True
@property
def dataset_splits(self):
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 5,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 5,
}]
def generate_samples(self, data_dir, tmp_dir, dataset_split):
del data_dir
del tmp_dir
del dataset_split
#读取原始的训练样本数据
questions = open("./rawdata/questions.txt", "r")
anwsers = open("./rawdata/anwsers.txt", "r")
comment_list = questions.readlines()
tag_list = anwsers.readlines()
questions.close()
anwsers.close()
for comment, tag in zip(comment_list, tag_list):
comment = comment.strip()
tag = tag.strip()
yield {
"inputs": comment,
"targets": tag
}