-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.lua
68 lines (56 loc) · 1.35 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
--
-- User: peyman
-- Date: 11/25/16
-- Time: 11:42 PM
-- to run: th main.lua
require 'torch'
require 'paths'
require 'xlua'
require 'optim'
virusClassifier = {}
virusClassifier.version = 1
torch.setdefaulttensortype('torch.FloatTensor')
----------------------------------------------------------------------
opt = dofile('opts.lua').parse(arg)
if opt.gpu then
require 'cutorch'
require 'cunn'
end
print('Saving everything to: ' .. opt.save)
os.execute('mkdir -p ' .. opt.save)
paths.dofile('data/data.lua')
paths.dofile('criteria.lua')
paths.dofile('model.lua')
paths.dofile('utils.lua')
if opt.testMode then
paths.dofile('test.lua')
test()
else
paths.dofile('train.lua')
paths.dofile('validate.lua')
epoch = 1
stop=false
-- nEpochs= Number of total epochs to run
for i=1,opt.nEpochs do
train()
stop = validate()
if stop then
break
else
epoch = epoch + 1
end
end
-- save
print("\n\nSaving model at epoch stage: " .. epoch)
local filename = paths.concat(opt.save, opt.network)
os.execute('mkdir -p ' .. sys.dirname(filename))
print('saving final model to ' .. filename)
if bestModel ~= nil then
print('found best model')
net.model = bestModel
end
torch.save(filename, net)
saveWs(net)
end
print('done')
--end