-
Notifications
You must be signed in to change notification settings - Fork 57
/
laia-reuse-model
executable file
·59 lines (47 loc) · 1.67 KB
/
laia-reuse-model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env th
require 'laia'
local parser = laia.argparse(){
name = 'laia-reuse-model',
description = 'Loads an existing model, replaces the last fully connected ' ..
'layer with a new one with the given number of output symbols and saves ' ..
'the model to disk.'
}
parser:option(
'--seed -s', 'Seed for random numbers generation.',
0x012345, laia.toint)
-- Arguments
parser:argument('input_file', 'File containing the model to reuse.')
parser:argument(
'output_size',
'Number of output symbols. If you are going to use the CTC ' ..
'loss include one additional element! If 0 the last fully connected layer is kept as given.')
:convert(laia.toint)
:ge(0)
parser:argument(
'output_file', 'Output file to store the model')
-- Register laia.Version options
laia.Version():registerOptions(parser)
-- Register logging options
laia.log.registerOptions(parser)
local opt = parser:parse()
-- Load *BEST* model from the input checkpoint.
local model = laia.Checkpoint():load(opt.input_file):Best():getModel()
assert(model ~= nil, 'No model was found in input file!')
-- Initialize random seeds
laia.manualSeed(opt.seed)
local rnn_units = model:get(model:size()):parameters()[1]:size(2)
--print('rnn_units: ', rnn_units)
if opt.output_size > 0 then
model:remove(model:size())
model:add(nn.Linear(rnn_units, opt.output_size))
model:float()
end
-- Save model to disk
local checkpoint = laia.Checkpoint()
checkpoint:Best():setModel(model)
checkpoint:Last():setModel(model)
checkpoint:save(opt.output_file)
local p, _ = model:getParameters()
laia.log.info('\n' .. model:__tostring__())
laia.log.info('Saved model with %d parameters to %q',
p:nElement(), opt.output_file)