-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdonkeys.lua
142 lines (130 loc) · 4.03 KB
/
donkeys.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
Threads = require "threads"
dofile("display.lua")
function newEpoch()
do
local threadParams = params
donkeys = Threads(
params.nThreads,
function(idx)
tid = idx
dofile("provider.lua")
params = threadParams
prov = Provider.new(tid,params.nThreads,params.cv)
--print(string.format("Initialized thread %d of %d.", tid,params.nThreads))
end
)
end
end
function reset()
local nThreadsResetted = Counter.new()
while true do
donkeys:addjob(
function ()
prov.trainData.currentIdx = 1
prov.testData.currentIdx = 1
prov.trainData.finished = 0
prov.testData.finished = 0
return tid
end,
function (tid)
nThreadsResetted:add(tid)
end
)
if nThreadsResetted:size() == params.nThreads then
break
end
end
end
function train()
local count = 0
local nThreadsFinished = 0
optimState = {
learningRate = params.lr,
beta1 = 0.9,
beta2 = 0.999,
epsilon = 1e-8,
weightDecay = params.weightDecay
}
local epochLosses = {}
reset()
while true do
donkeys:addjob(function()
imgPaths, inputs, targets = prov:getBatch("train")
return tid, prov.trainData.finished, imgPaths,inputs,targets
end,
function(tid, finished, imgPaths, inputs, targets)
if finished <= 1 then
if finished == 1 then nThreadsFinished = nThreadsFinished + 1; print(string.format("Thread %d finished training (total = %d)",tid,nThreadsFinished))
end
if parameters == nil then
if model then parameters, gradParameters = model:getParameters() end
print("Number of parameters ==>",parameters:size())
end
local outputs
local dLoss_dO
local batchLoss
local targetResize
local loss
function feval(x)
if x ~= parameters then parameters:copy(x) end
model:training()
gradParameters:zero()
outputs = model:forward(inputs) -- Only one input for training unlike testing
loss = criterion:forward(outputs,targets)
dLoss_dO = criterion:backward(outputs,targets)
model:backward(inputs,dLoss_dO)
return loss, gradParameters
end
_, batchLoss = optimMethod(feval,parameters,optimState)
epochLosses[#epochLosses+1] = loss
count = targets:size(1) + count
xlua.progress(count,dataSizes.trainCV)
cmTrain:batchAdd(outputs,targets)
if torch.uniform() < 0.1 and params.display == 1 then display(inputs,"train") end
elseif finished == 2 then
end
end
)
if nThreadsFinished == params.nThreads then break end
donkeys:synchronize()
end
local meanLoss = torch.Tensor(epochLosses):mean()
print(string.format("Finished training epoch %d with %d examples seen. Mean loss = %f.",epoch,count,meanLoss))
print(cmTrain)
cmTrain:zero()
end
function test()
local count = 0
local nThreadsFinished = 0
local epochLosses = {}
while true do
donkeys:addjob(function()
imgPaths, inputs, targets = prov:getBatch("test")
return tid, prov.testData.finished, imgPaths,inputs,targets
end,
function(tid, finished, imgPaths, inputs, targets)
if finished <= 1 then
if finished == 1 then nThreadsFinished = nThreadsFinished + 1; print(string.format("Thread %d finished testing (total = %d)",tid,nThreadsFinished))
end
local outputs
local loss
local targetResize
outputs = model:forward(inputs)
loss = criterion:forward(outputs,targets)
count = targets:size(1) + count
xlua.progress(count,dataSizes.testCV)
cmTest:batchAdd(outputs,targets)
epochLosses[#epochLosses+1] = loss
if torch.uniform() < 0.1 and params.display == 1 then display(inputs,"test") end
elseif finished == 2 then
end
end
)
if nThreadsFinished == params.nThreads then break end
donkeys:synchronize()
end
local meanLoss = torch.Tensor(epochLosses):mean()
print(string.format("Finished test epoch %d with %d examples seen. Mean loss = %f.",epoch,count,meanLoss))
print(cmTest)
cmTest:zero()
end