-
Notifications
You must be signed in to change notification settings - Fork 28
/
opts.lua
133 lines (122 loc) · 6.5 KB
/
opts.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found here
-- https://github.com/facebook/fb.resnet.torch/blob/master/LICENSE. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Code modified for Shake-Shake by Xavier Gastaldi
--
local M = { }
function M.parse(arg)
local cmd = torch.CmdLine()
cmd:text()
cmd:text('Torch-7 ResNet Training script')
cmd:text('See https://github.com/facebook/fb.resnet.torch/blob/master/TRAINING.md for examples')
cmd:text()
cmd:text('Options:')
------------ General options --------------------
cmd:option('-data', '', 'Path to dataset')
cmd:option('-dataset', 'imagenet', 'Options: imagenet | cifar10 | cifar100 | imagenet32')
cmd:option('-manualSeed', 0, 'Manually set RNG seed')
cmd:option('-nGPU', 1, 'Number of GPUs to use by default')
cmd:option('-backend', 'cudnn', 'Options: cudnn | cunn')
cmd:option('-cudnn', 'fastest', 'Options: fastest | default | deterministic')
cmd:option('-gen', 'gen', 'Path to save generated files')
cmd:option('-precision', 'single', 'Options: single | double | half')
cmd:option('-irun', 1, 'Run index: 1 | 2 | 3 | 4 | ...', 'number')
------------- Data options ------------------------
cmd:option('-nThreads', 3, 'number of data loading threads')
------------- Training options --------------------
cmd:option('-algorithmType', 'SGDW', 'Options: SGD | SGDdec | SGDW | ADAM | ADAMdec | ADAMW')
cmd:option('-LRdec', 'true', 'Whether to decay LR or not')
cmd:option('-Te', 1, 'Initial period of restarts: 1 | 2 | 3 | 4 | ...', 'number')
cmd:option('-Tmult', 2, 'Multiplicative factor of Te at every restart: 1 | 2 | 3 | 4 | ...', 'number')
cmd:option('-nEpochs', 0, 'Number of total epochs to run')
cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)')
cmd:option('-batchSize', 128, 'mini-batch size (1 = pure stochastic)')
cmd:option('-testOnly', 'false', 'Run on validation set only')
cmd:option('-tenCrop', 'false', 'Ten-crop testing')
------------- Checkpointing options ---------------
cmd:option('-save', 'checkpoints', 'Directory in which to save checkpoints')
cmd:option('-resume', 'none', 'Resume from the latest checkpoint in this directory')
---------- Optimization options ----------------------
cmd:option('-LR', 0.1, 'initial learning rate')
cmd:option('-momentum', 0.9, 'momentum')
cmd:option('-weightDecay', 5e-4, 'weight decay')
---------- Model options ----------------------------------
cmd:option('-netType', 'resnet', 'Options: resnet | preresnet')
cmd:option('-depth', 34, 'ResNet depth: 18 | 34 | 50 | 101 | ...', 'number')
cmd:option('-shortcutType', '', 'Options: A | B | C')
cmd:option('-retrain', 'none', 'Path to model to retrain with')
cmd:option('-optimState', 'none', 'Path to an optimState to reload from')
---------- Model options ----------------------------------
cmd:option('-shareGradInput', 'false', 'Share gradInput tensors to reduce memory usage')
cmd:option('-optnet', 'false', 'Use optnet to reduce memory usage')
cmd:option('-resetClassifier', 'false', 'Reset the fully connected layer for fine-tuning')
cmd:option('-nClasses', 0, 'Number of classes in the dataset')
------Shake-Shake------
cmd:option('-widenFactor', 1, 'Widening factor: 1 | 2 | 3 | 4 | ...', 'number')
cmd:option('-lrShape', 'multistep', 'learning rate annealing function, multistep or cosine')
cmd:option('-nCycles', 1, 'number of learning rate annealing cycles')
cmd:option('-forwardShake', 'true', 'Sample random numbers during the forward pass')
cmd:option('-backwardShake', 'true', 'Sample random numbers during the backward pass')
cmd:option('-shakeImage', 'true', 'Use a different random number for each image in the mini-batch')
-----------------------
cmd:text()
local opt = cmd:parse(arg or {})
opt.testOnly = opt.testOnly ~= 'false'
opt.tenCrop = opt.tenCrop ~= 'false'
opt.shareGradInput = opt.shareGradInput ~= 'false'
opt.optnet = opt.optnet ~= 'false'
opt.resetClassifier = opt.resetClassifier ~= 'false'
if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then
cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n')
end
if opt.dataset == 'imagenet' then
-- Handle the most common case of missing -data flag
local trainDir = paths.concat(opt.data, 'train')
if not paths.dirp(opt.data) then
cmd:error('error: missing ImageNet data directory')
elseif not paths.dirp(trainDir) then
cmd:error('error: ImageNet missing `train` directory: ' .. trainDir)
end
-- Default shortcutType=B and nEpochs=90
opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType
opt.nEpochs = opt.nEpochs == 0 and 90 or opt.nEpochs
elseif opt.dataset == 'cifar10' then
-- Default shortcutType=A and nEpochs=164
opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType
opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs
elseif opt.dataset == 'cifar100' then
-- Default shortcutType=A and nEpochs=164
opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType
opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs
elseif opt.dataset == 'imagenet32' then
-- Default shortcutType=A and nEpochs=164
opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType
opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs
else
cmd:error('unknown dataset: ' .. opt.dataset)
end
if opt.precision == nil or opt.precision == 'single' then
opt.tensorType = 'torch.CudaTensor'
elseif opt.precision == 'double' then
opt.tensorType = 'torch.CudaDoubleTensor'
elseif opt.precision == 'half' then
opt.tensorType = 'torch.CudaHalfTensor'
else
cmd:error('unknown precision: ' .. opt.precision)
end
if opt.resetClassifier then
if opt.nClasses == 0 then
cmd:error('-nClasses required when resetClassifier is set')
end
end
if opt.shareGradInput and opt.optnet then
cmd:error('error: cannot use both -shareGradInput and -optnet')
end
return opt
end
return M