-
Notifications
You must be signed in to change notification settings - Fork 61
/
trainer.lua
151 lines (126 loc) · 4.74 KB
/
trainer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
require 'nn'
require 'optim'
require 'sys'
-- Load Facebook optim package
paths.dofile('Optim.lua')
local trainer = {}
-- This function should be called before any other on the trainer package.
-- Takes as input a torch network, a criterion and the options of the training
function trainer.initialize(network, criterion, options)
local optim_state = {
learningRate = options.lr,
momentum = options.mom,
learningRateDecay = options.lrd,
weightDecay = options.wd,
}
trainer.tensor_type = torch.getdefaulttensortype()
if not options.no_cuda then
trainer.tensor_type = 'torch.CudaTensor'
end
trainer.batch_size = options.bs
trainer.network = network
if criterion then
trainer.criterion = criterion
trainer.optimizer = nn.Optim(network, optim_state)
end
end
-- Main training function.
-- This performs one epoch of training on the network given during
-- initialization using the given dataset.
-- Returns the mean error on the dataset.
function trainer.train(dataset)
if not trainer.optimizer then
error('Trainer not initialized properly. Use trainer.initialize first.')
end
-- do one epoch
print('<trainer> on training set:')
local epoch_error = 0
local nbr_samples = dataset.data:size(1)
local size_samples = dataset.data:size()[dataset.data:dim()]
local time = sys.clock()
-- generate random training batches
local indices = torch.randperm(nbr_samples):long():split(trainer.batch_size)
indices[#indices] = nil -- remove last partial batch
-- preallocate input and target tensors
local inputs = torch.zeros(trainer.batch_size, 3,
size_samples, size_samples,
trainer.tensor_type)
local targets = torch.zeros(trainer.batch_size, 1,
trainer.tensor_type)
for t,ind in ipairs(indices) do
-- get the minibatch
inputs:copy(dataset.data:index(1,ind))
targets:copy(dataset.label:index(1,ind))
epoch_error = epoch_error + trainer.optimizer:optimize(optim.sgd,
inputs,
targets,
trainer.criterion)
-- disp progress
xlua.progress(t*trainer.batch_size, nbr_samples)
end
-- finish progress
xlua.progress(nbr_samples, nbr_samples)
-- time taken
time = sys.clock() - time
time = time / nbr_samples
print("<trainer> time to learn 1 sample = " .. (time*1000) .. 'ms')
print("<trainer> mean error (train set) = " .. epoch_error/nbr_samples)
return epoch_error
end
-- Main testing function.
-- This performs a full test on the given dataset using the network
-- given during the initialization.
-- Returns the mean error on the dataset and the accuracy.
function trainer.test(dataset)
if not trainer.network then
error('Trainer not initialized properly. Use trainer.initialize first.')
end
-- test over given dataset
print('')
print('<trainer> on testing Set:')
local time = sys.clock()
local nbr_samples = dataset.data:size(1)
local size_samples = dataset.data:size()[dataset.data:dim()]
local epoch_error = 0
local correct = 0
local all = 0
-- generate indices and split them into batches
local indices = torch.range(1,nbr_samples):long()
indices = indices:split(trainer.batch_size)
-- preallocate input and target tensors
local inputs = torch.zeros(trainer.batch_size, 3,
size_samples, size_samples,
trainer.tensor_type)
local targets = torch.zeros(trainer.batch_size, 1,
trainer.tensor_type)
for t,ind in ipairs(indices) do
-- last batch may not be full
local local_batch_size = ind:size(1)
-- resize prealocated tensors (should only happen on last batch)
inputs:resize(local_batch_size,3,size_samples,size_samples)
targets:resize(local_batch_size, 1)
inputs:copy(dataset.data:index(1,ind))
targets:copy(dataset.label:index(1,ind))
-- test samples
local scores = trainer.network:forward(inputs)
epoch_error = epoch_error + trainer.criterion:forward(scores,
targets)
local _, preds = scores:max(2)
correct = correct + preds:float():eq(targets:float()):sum()
all = all + preds:size(1)
-- disp progress
xlua.progress(t*trainer.batch_size, nbr_samples)
end
-- finish progress
xlua.progress(nbr_samples, nbr_samples)
-- timing
time = sys.clock() - time
time = time / nbr_samples
print("<trainer> time to test 1 sample = " .. (time*1000) .. 'ms')
print("<trainer> mean error (test set) = " .. epoch_error/nbr_samples)
local accuracy = correct / all
print('accuracy % : ', accuracy * 100)
print('')
return epoch_error, accuracy
end
return trainer