forked from OpenNMT/OpenNMT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.lua
175 lines (142 loc) · 5.48 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
require('onmt.init')
require('tds')
local cmd = onmt.utils.ExtendedCmdLine.new('train.lua')
-- First argument define the model type: seq2seq/lm - default is seq2seq.
local modelType = cmd.getArgument(arg, '-model_type') or 'seq2seq'
local modelClass = onmt.ModelSelector(modelType)
-- Options declaration.
local options = {
{
'-data', '',
[[Path to the data package `*-train.t7` generated by the preprocessing step.]],
{
valid = onmt.utils.ExtendedCmdLine.nonEmpty
}
},
{
'-save_model', '',
[[Model filename (the model will be saved as `<save_model>_epochN_PPL.t7`
where `PPL` is the validation perplexity.]],
{
valid = onmt.utils.ExtendedCmdLine.nonEmpty
}
}
}
cmd:setCmdLineOptions(options, 'Data')
onmt.data.SampledDataset.declareOpts(cmd)
onmt.Model.declareOpts(cmd)
modelClass.declareOpts(cmd)
onmt.train.Optim.declareOpts(cmd)
onmt.train.Trainer.declareOpts(cmd)
onmt.train.Checkpoint.declareOpts(cmd)
onmt.utils.CrayonLogger.declareOpts(cmd)
onmt.utils.Cuda.declareOpts(cmd)
onmt.utils.Logger.declareOpts(cmd)
cmd:text('')
cmd:text('**Other options**')
cmd:text('')
onmt.utils.Memory.declareOpts(cmd)
onmt.utils.Profiler.declareOpts(cmd)
cmd:option('-seed', 3435, [[Random seed.]], {valid=onmt.utils.ExtendedCmdLine.isUInt()})
local opt = cmd:parse(arg)
local function main()
torch.manualSeed(opt.seed)
_G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level)
_G.profiler = onmt.utils.Profiler.new(false)
_G.crayon_logger = onmt.utils.CrayonLogger.new(opt)
onmt.utils.Cuda.init(opt)
onmt.utils.Parallel.init(opt)
local checkpoint, paramChanges
checkpoint, opt, paramChanges = onmt.train.Checkpoint.loadFromCheckpoint(opt)
cmd:logConfig(opt)
_G.logger:info('Training '..modelClass.modelName()..' model')
-- Create the data loader class.
_G.logger:info('Loading data from \'' .. opt.data .. '\'...')
local dataset = torch.load(opt.data, 'binary', false)
-- Keep backward compatibility.
dataset.dataType = dataset.dataType or 'bitext'
-- Check if data type matches the model.
if not modelClass.dataType(dataset.dataType) then
_G.logger:error('Data type: \'' .. dataset.dataType .. '\' does not match model type: \'' .. modelClass.modelName() .. '\'')
os.exit(0)
end
-- record datatype in the options, and preprocessing options if present
opt.data_type = dataset.dataType
opt.preprocess = dataset.opt
local trainData
if opt.sample > 0 then
trainData = onmt.data.SampledDataset.new(dataset.train.src, dataset.train.tgt, opt)
else
trainData = onmt.data.Dataset.new(dataset.train.src, dataset.train.tgt)
end
local validData = onmt.data.Dataset.new(dataset.valid.src, dataset.valid.tgt)
local nTrainBatch, batchUsage = trainData:setBatchSize(opt.max_batch_size, opt.uneven_batches)
validData:setBatchSize(opt.max_batch_size, opt.uneven_batches)
if dataset.dataType ~= 'monotext' then
local srcVocSize
local srcFeatSize = '-'
if dataset.dicts.src then
srcVocSize = dataset.dicts.src.words:size()
srcFeatSize = #dataset.dicts.src.features
else
srcVocSize = '*'..dataset.dicts.srcInputSize
end
local tgtVocSize
local tgtFeatSize = '-'
if dataset.dicts.tgt then
tgtVocSize = dataset.dicts.tgt.words:size()
tgtFeatSize = #dataset.dicts.tgt.features
else
tgtVocSize = '*'..dataset.dicts.tgtInputSize
end
_G.logger:info(' * vocabulary size: source = %s; target = %s',
srcVocSize, tgtVocSize)
_G.logger:info(' * additional features: source = %s; target = %s',
srcFeatSize, tgtFeatSize)
else
_G.logger:info(' * vocabulary size: %d', dataset.dicts.src.words:size())
_G.logger:info(' * additional features: %d', #dataset.dicts.src.features)
end
_G.logger:info(' * maximum sequence length: source = %d; target = %d',
trainData.maxSourceLength, trainData.maxTargetLength)
_G.logger:info(' * number of training sentences: %d', #trainData.src)
_G.logger:info(' * number of batches: %d', nTrainBatch)
_G.logger:info(' - source sequence lengths: %s', opt.uneven_batches and 'variable' or 'equal')
_G.logger:info(' - maximum size: %d', opt.max_batch_size)
_G.logger:info(' - average size: %.2f', #trainData.src / nTrainBatch)
_G.logger:info(' - capacity: %.2f%%', math.ceil(batchUsage * 1000) / 10)
_G.logger:info('Building model...')
local model
-- Build or load model from checkpoint and copy to GPUs.
onmt.utils.Parallel.launch(function(idx)
local _modelClass = onmt.ModelSelector(modelType)
if checkpoint.models then
_G.model = _modelClass.load(opt, checkpoint.models, dataset.dicts, idx > 1)
-- dynamic parameter changes
if not onmt.utils.Table.empty(paramChanges) then
_G.model:changeParameters(paramChanges)
end
else
local verbose = idx == 1
_G.model = _modelClass.new(opt, dataset.dicts, verbose)
end
onmt.utils.Cuda.convert(_G.model)
return idx, _G.model
end, function(idx, themodel)
if idx == 1 then
model = themodel
end
end)
if opt.sample > 0 then
trainData:checkModel(model)
end
-- Define optimization method.
local optimStates = (checkpoint.info and checkpoint.info.optimStates) or nil
local optim = onmt.train.Optim.new(opt, optimStates)
-- Initialize trainer.
local trainer = onmt.train.Trainer.new(opt)
-- Launch training.
trainer:train(model, optim, trainData, validData, dataset, checkpoint.info)
_G.logger:shutDown()
end
main()