-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathSetup.lua
290 lines (254 loc) · 14 KB
/
Setup.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
require 'logroll'
local _ = require 'moses'
local classic = require 'classic'
local cjson = require 'cjson'
local Setup = classic.class('Setup')
-- Performs global setup
function Setup:_init(arg)
-- Create log10 for Lua 5.2
if not math.log10 then
math.log10 = function(x)
return math.log(x, 10)
end
end
-- Parse command-line options
self.opt = self:parseOptions(arg)
-- Create experiment directory
if not paths.dirp(self.opt.experiments) then
paths.mkdir(self.opt.experiments)
end
paths.mkdir(paths.concat(self.opt.experiments, self.opt._id))
-- Save options for reference
local file = torch.DiskFile(paths.concat(self.opt.experiments, self.opt._id, 'opts.json'), 'w')
file:writeString(cjson.encode(self.opt))
file:close()
-- Set up logging
local flog = logroll.file_logger(paths.concat(self.opt.experiments, self.opt._id, 'log.txt'))
local plog = logroll.print_logger()
log = logroll.combine(flog, plog) -- Global logger
-- Validate command-line options (logging errors)
self:validateOptions()
-- Augment environments to meet spec
self:augmentEnv()
-- Torch setup
log.info('Setting up Torch7')
-- Set number of BLAS threads
torch.setnumthreads(self.opt.threads)
-- Set default Tensor type (float is more efficient than double)
torch.setdefaulttensortype(self.opt.tensorType)
-- Set manual seed
torch.manualSeed(self.opt.seed)
-- Tensor creation function for removing need to cast to CUDA if GPU is enabled
-- TODO: Replace with local functions across codebase
self.opt.Tensor = function(...)
return torch.Tensor(...)
end
-- GPU setup
if self.opt.gpu > 0 then
log.info('Setting up GPU')
cutorch.setDevice(self.opt.gpu)
-- Set manual seeds using random numbers to reduce correlations
cutorch.manualSeed(torch.random())
-- Replace tensor creation function
self.opt.Tensor = function(...)
return torch.CudaTensor(...)
end
end
classic.strict(self)
end
-- Parses command-line options
function Setup:parseOptions(arg)
-- Detect and use GPU 1 by default
local cuda = pcall(require, 'cutorch')
local cmd = torch.CmdLine()
-- Base Torch7 options
cmd:option('-seed', 1, 'Random seed')
cmd:option('-threads', 4, 'Number of BLAS or async threads')
cmd:option('-tensorType', 'torch.FloatTensor', 'Default tensor type')
cmd:option('-gpu', cuda and 1 or 0, 'GPU device ID (0 to disable)')
cmd:option('-cudnn', 'false', 'Utilise cuDNN (if available)')
-- Environment options
cmd:option('-env', 'rlenvs.Catch', 'Environment class (Lua file to be loaded/rlenv)')
cmd:option('-zoom', 1, 'Display zoom (requires QT)')
cmd:option('-game', '', 'Name of Atari ROM (stored in "roms" directory)')
-- Training vs. evaluate mode
cmd:option('-mode', 'train', 'Train vs. test mode: train|eval')
-- State preprocessing options (for visual states)
cmd:option('-height', 0, 'Resized screen height (0 to disable)')
cmd:option('-width', 0, 'Resize screen width (0 to disable)')
cmd:option('-colorSpace', '', 'Colour space conversion (screen is RGB): <none>|y|lab|yuv|hsl|hsv|nrgb')
-- Model options
cmd:option('-modelBody', 'models.Catch', 'Path to Torch nn model to be used as DQN "body"')
cmd:option('-hiddenSize', 512, 'Number of units in the hidden fully connected layer')
cmd:option('-histLen', 4, 'Number of consecutive states processed/used for backpropagation-through-time') -- DQN standard is 4, DRQN is 10
cmd:option('-duel', 'true', 'Use dueling network architecture (learns advantage function)')
cmd:option('-bootstraps', 10, 'Number of bootstrap heads (0 to disable)')
--cmd:option('-bootstrapMask', 1, 'Independent probability of masking a transition for each bootstrap head ~ Ber(bootstrapMask) (1 to disable)')
cmd:option('-recurrent', 'false', 'Use recurrent connections')
-- Experience replay options
cmd:option('-discretiseMem', 'true', 'Discretise states to integers ∈ [0, 255] for storage')
cmd:option('-memSize', 1e6, 'Experience replay memory size (number of tuples)')
cmd:option('-memSampleFreq', 4, 'Interval of steps between sampling from memory to learn')
cmd:option('-memNSamples', 1, 'Number of times to sample per learning step')
cmd:option('-memPriority', '', 'Type of prioritised experience replay: <none>|rank|proportional') -- TODO: Implement proportional prioritised experience replay
cmd:option('-alpha', 0.65, 'Prioritised experience replay exponent α') -- Best vals are rank = 0.7, proportional = 0.6
cmd:option('-betaZero', 0.45, 'Initial value of importance-sampling exponent β') -- Best vals are rank = 0.5, proportional = 0.4
-- Reinforcement learning parameters
cmd:option('-gamma', 0.99, 'Discount rate γ')
cmd:option('-epsilonStart', 1, 'Initial value of greediness ε')
cmd:option('-epsilonEnd', 0.01, 'Final value of greediness ε') -- Tuned DDQN final greediness (1/10 that of DQN)
cmd:option('-epsilonSteps', 1e6, 'Number of steps to linearly decay epsilonStart to epsilonEnd') -- Usually same as memory size
cmd:option('-tau', 30000, 'Steps between target net updates τ') -- Tuned DDQN target net update interval (3x that of DQN)
cmd:option('-rewardClip', 1, 'Clips reward magnitude at rewardClip (0 to disable)')
cmd:option('-tdClip', 1, 'Clips TD-error δ magnitude at tdClip (0 to disable)')
cmd:option('-doubleQ', 'true', 'Use Double Q-learning')
-- Note from Georg Ostrovski: The advantage operators and Double DQN are not entirely orthogonal as the increased action gap seems to reduce the statistical bias that leads to value over-estimation in a similar way that Double DQN does
cmd:option('-PALpha', 0.9, 'Persistent advantage learning parameter α (0 to disable)')
-- Training options
cmd:option('-optimiser', 'rmspropm', 'Training algorithm') -- RMSProp with momentum as found in "Generating Sequences With Recurrent Neural Networks"
cmd:option('-eta', 0.0000625, 'Learning rate η') -- Prioritied experience replay learning rate (1/4 that of DQN; does not account for Duel as well)
cmd:option('-momentum', 0.95, 'Gradient descent momentum')
cmd:option('-batchSize', 32, 'Minibatch size')
cmd:option('-steps', 5e7, 'Training iterations (steps)') -- Frame := step in ALE; Time step := consecutive frames treated atomically by the agent
cmd:option('-learnStart', 50000, 'Number of steps after which learning starts')
cmd:option('-gradClip', 10, 'Clips L2 norm of gradients at gradClip (0 to disable)')
-- Evaluation options
cmd:option('-progFreq', 10000, 'Interval of steps between reporting progress')
cmd:option('-reportWeights', 'false', 'Report weight and weight gradient statistics')
cmd:option('-noValidation', 'false', 'Disable asynchronous agent validation thread') -- TODO: Make behaviour consistent across Master/AsyncMaster
cmd:option('-valFreq', 250000, 'Interval of steps between validating agent') -- valFreq steps is used as an epoch, hence #epochs = steps/valFreq
cmd:option('-valSteps', 125000, 'Number of steps to use for validation')
cmd:option('-valSize', 500, 'Number of transitions to use for calculating validation statistics')
-- Async options
cmd:option('-async', '', 'Async agent: <none>|Sarsa|OneStepQ|NStepQ|A3C') -- TODO: Change names
cmd:option('-rmsEpsilon', 0.1, 'Epsilon for sharedRmsProp')
cmd:option('-entropyBeta', 0.01, 'Policy entropy regularisation β')
-- ALEWrap options
cmd:option('-fullActions', 'false', 'Use full set of 18 actions')
cmd:option('-actRep', 4, 'Times to repeat action') -- Independent of history length
cmd:option('-randomStarts', 30, 'Max number of no-op actions played before presenting the start of each training episode')
cmd:option('-poolFrmsType', 'max', 'Type of pooling over previous emulator frames: max|mean')
cmd:option('-poolFrmsSize', 2, 'Number of emulator frames to pool over')
cmd:option('-lifeLossTerminal', 'true', 'Use life loss as terminal signal (training only)')
cmd:option('-flickering', 0, 'Probability of screen flickering (Catch only)')
-- Experiment options
cmd:option('-experiments', 'experiments', 'Base directory to store experiments')
cmd:option('-_id', '', 'ID of experiment (used to store saved results, defaults to game name)')
cmd:option('-network', '', 'Saved network weights file to load (weights.t7)')
cmd:option('-checkpoint', 'false', 'Checkpoint network weights (instead of saving just latest weights)')
cmd:option('-verbose', 'false', 'Log info for every episode (only in train mode)')
cmd:option('-saliency', '', 'Display saliency maps (requires QT): <none>|normal|guided|deconvnet')
cmd:option('-record', 'false', 'Record screen (only in eval mode)')
local opt = cmd:parse(arg)
-- Process boolean options (Torch fails to accept false on the command line)
opt.cudnn = opt.cudnn == 'true'
opt.duel = opt.duel == 'true'
opt.recurrent = opt.recurrent == 'true'
opt.discretiseMem = opt.discretiseMem == 'true'
opt.doubleQ = opt.doubleQ == 'true'
opt.reportWeights = opt.reportWeights == 'true'
opt.fullActions = opt.fullActions == 'true'
opt.lifeLossTerminal = opt.lifeLossTerminal == 'true'
opt.checkpoint = opt.checkpoint == 'true'
opt.verbose = opt.verbose == 'true'
opt.record = opt.record == 'true'
opt.noValidation = opt.noValidation == 'true'
-- Process boolean/enum options
if opt.colorSpace == '' then opt.colorSpace = false end
if opt.memPriority == '' then opt.memPriority = false end
if opt.async == '' then opt.async = false end
if opt.saliency == '' then opt.saliency = false end
if opt.async then opt.gpu = 0 end -- Asynchronous agents are CPU-only
-- Set ID as env (plus game name) if not set
if opt._id == '' then
local envName = paths.basename(opt.env)
if opt.game == '' then
opt._id = envName
else
opt._id = envName .. '.' .. opt.game
end
end
-- Create one environment to extract specifications
local Env = require(opt.env)
local env = Env(opt)
opt.stateSpec = env:getStateSpec()
opt.actionSpec = env:getActionSpec()
-- Process display if available (can be used for saliency recordings even without QT)
if env.getDisplay then
opt.displaySpec = env:getDisplaySpec()
end
return opt
end
-- Logs and aborts on error
local function abortIf(err, msg)
if err then
log.error(msg)
error(msg)
end
end
-- Validates setup options
function Setup:validateOptions()
-- Check environment state is a single tensor
abortIf(#self.opt.stateSpec ~= 3 or not _.isArray(self.opt.stateSpec[2]), 'Environment state is not a single tensor')
-- Check environment has discrete actions
abortIf(self.opt.actionSpec[1] ~= 'int' or self.opt.actionSpec[2] ~= 1, 'Environment does not have discrete actions')
-- Change state spec if resizing
if self.opt.height ~= 0 then
self.opt.stateSpec[2][2] = self.opt.height
end
if self.opt.width ~= 0 then
self.opt.stateSpec[2][3] = self.opt.width
end
-- Check colour conversions
if self.opt.colorSpace then
abortIf(not _.contains({'y', 'lab', 'yuv', 'hsl', 'hsv', 'nrgb'}, self.opt.colorSpace), 'Unsupported colour space for conversion')
abortIf(self.opt.stateSpec[2][1] ~= 3, 'Original colour space must be RGB for conversion')
-- Change state spec if converting from colour to greyscale
if self.opt.colorSpace == 'y' then
self.opt.stateSpec[2][1] = 1
end
end
-- Check start of learning occurs after at least one minibatch of data has been collected
abortIf(self.opt.learnStart <= self.opt.batchSize, 'learnStart must be greater than batchSize')
-- Check enough validation transitions will be collected before first validation
abortIf(self.opt.valFreq <= self.opt.valSize, 'valFreq must be greater than valSize')
-- Check prioritised experience replay options
abortIf(self.opt.memPriority and not _.contains({'rank', 'proportional'}, self.opt.memPriority), 'Type of prioritised experience replay unrecognised')
abortIf(self.opt.memPriority == 'proportional', 'Proportional prioritised experience replay not implemented yet') -- TODO: Implement
-- Check no prioritized replay is done when bootstrap
abortIf(self.opt.bootstraps > 0 and _.contains({'rank', 'proportional'}, self.opt.memPriority), 'Prioritized experience replay not possible with bootstrap')
-- Check start of learning occurs after at least 1/100 of memory has been filled
abortIf(self.opt.learnStart <= self.opt.memSize/100, 'learnStart must be greater than memSize/100')
-- Check memory size is multiple of 100 (makes prioritised sampling partitioning simpler)
abortIf(self.opt.memSize % 100 ~= 0, 'memSize must be a multiple of 100')
-- Check learning occurs after first progress report
abortIf(self.opt.learnStart < self.opt.progFreq, 'learnStart must be greater than progFreq')
-- Check saliency map options
abortIf(self.opt.saliency and not _.contains({'normal', 'guided', 'deconvnet'}, self.opt.saliency), 'Unrecognised method for visualising saliency maps')
-- Check saliency is valid
abortIf(self.opt.saliency and not self.opt.displaySpec, 'Saliency cannot be shown without env:getDisplay()')
abortIf(self.opt.saliency and #self.opt.stateSpec[2] ~= 3 and (self.opt.stateSpec[2][1] ~= 3 or self.opt.stateSpec[2][1] ~= 1), 'Saliency cannot be shown without visual state')
-- Check async options
if self.opt.async then
abortIf(self.opt.recurrent and self.opt.async ~= 'OneStepQ', 'Recurrent connections only supported for OneStepQ in async for now')
abortIf(self.opt.PALpha > 0, 'Persistent advantage learning not supported in async modes yet')
abortIf(self.opt.bootstraps > 0, 'Bootstrap heads not supported in async mode yet')
abortIf(self.opt.async == 'A3C' and self.opt.duel, 'Dueling networks and A3C are incompatible')
abortIf(self.opt.async == 'A3C' and self.opt.doubleQ, 'Double Q-learning and A3C are incompatible')
abortIf(self.opt.saliency, 'Saliency maps not supported in async modes yet')
end
end
-- Augments environments with extra methods if missing
function Setup:augmentEnv()
local Env = require(self.opt.env)
local env = Env(self.opt)
-- Set up fake training mode (if needed)
if not env.training then
Env.training = function() end
end
-- Set up fake evaluation mode (if needed)
if not env.evaluate then
Env.evaluate = function() end
end
end
return Setup