-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
317 lines (260 loc) · 13.1 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import os
import numpy as np
import argparse
import time
from tqdm import tqdm
from collections import deque
from collections import defaultdict
# imports from navigation package
from navigation import agent_raycast # agent for the raycast-based environment
from navigation import model_pytorch # pytorch-based model
from navigation import model_tensorflow # tensorflow-based model
from navigation.envs import mlagents # simple environment wrapper
from navigation.dqn.utils import config # config. functionality (load-save)
# logging functionality
import logger
from IPython.core.debugger import set_trace
TEST = True # global variable, set by the argparser
TIME_START = 0 # global variable, set in __main__
RESULTS_FOLDER = 'results' # global variable, where to place the results of training
SEED = 0 # global variable, set by argparser
CONFIG_AGENT = '' # global variable, set by argparser
CONFIG_MODEL = '' # global variable, set by argparser
USE_DOUBLE_DQN = False # global variable, set by argparser
USE_PRIORITIZED_EXPERIENCE_REPLAY = False # global variable, set by argparser
USE_DUELING_DQN = False # global variable, set by argparser
def train( env, agent, sessionId, savefile, resultsFilename, replayFilename ) :
MAX_EPISODES = agent.learningMaxSteps
MAX_STEPS_EPISODE = 1000
LOG_WINDOW_SIZE = 100
_progressbar = tqdm( range( 1, MAX_EPISODES + 1 ), desc = 'Training>', leave = True )
_maxAvgScore = -np.inf
_scores = []
_scoresAvgs = []
_scoresWindow = deque( maxlen = LOG_WINDOW_SIZE )
_stepsWindow = deque( maxlen = LOG_WINDOW_SIZE )
_timeStart = TIME_START
for iepisode in _progressbar :
_state = env.reset( training = True )
_score = 0
_nsteps = 0
while True :
# grab action from dqn agent: runs through model, e-greedy, etc.
_action = agent.act( _state, inference = False )
# apply action in simulator to get the transition
_snext, _reward, _done, _ = env.step( _action )
## env.render()
_transition = ( _state, _action, _snext, _reward, _done )
# send this transition back to the agent (to learn when he pleases)
## set_trace()
agent.step( _transition )
# prepare for next iteration
_state = _snext
_score += _reward
_nsteps += 1
if _done :
break
_scores.append( _score )
_scoresWindow.append( _score )
_stepsWindow.append( _nsteps )
if iepisode >= LOG_WINDOW_SIZE :
_avgScore = np.mean( _scoresWindow )
_avgSteps = np.mean( _stepsWindow )
_scoresAvgs.append( _avgScore )
if _avgScore > _maxAvgScore :
_maxAvgScore = _avgScore
# log resultss
if agent._usePrioritizedExpReplay :
_progressbar.set_description( 'Training> Max-Avg=%.2f, Curr-Avg=%.2f, Curr=%.2f, Eps=%.2f, Beta=%.2f' % (_maxAvgScore, _avgScore, _score, agent.epsilon, agent._rbuffer.beta ) )
else :
_progressbar.set_description( 'Training> Max-Avg=%.2f, Curr-Avg=%.2f, Curr=%.2f, Eps=%.2f' % (_maxAvgScore, _avgScore, _score, agent.epsilon ) )
_progressbar.refresh()
# save trained model
agent.save( savefile )
_timeStop = int( time.time() )
_trainingTime = _timeStop - _timeStart
# save training results for later visualization and analysis
logger.saveTrainingResults( resultsFilename,
sessionId,
_timeStart,
_scores,
_scoresAvgs,
agent.actorModel.losses,
agent.actorModel.bellmanErrors,
agent.actorModel.gradients )
# save replay batch for later visualization and analysis
_minibatch = agent.replayBuffer.sample( 100 )
_ss, _aa, _rr, _ssnext = _minibatch[0], _minibatch[1], _minibatch[2], _minibatch[3]
_q_s_batch = [ agent.actorModel.eval( agent._preprocess( state ) ) \
for state in _ss ]
_replayBatch = { 'states' : _ss, 'actions' : _aa, 'rewards' : _rr, 'nextStates' : _ssnext }
logger.saveReplayBatch( replayFilename,
sessionId,
TIME_START,
_replayBatch,
_q_s_batch )
def test( env, agent ) :
_progressbar = tqdm( range( 1, 10 + 1 ), desc = 'Testing>', leave = True )
for _ in _progressbar :
_state = env.reset( training = False )
_score = 0.0
_goodBananas = 0
_badBananas = 0
while True :
_action = agent.act( _state, inference = True )
_state, _reward, _done, _ = env.step( _action )
if _reward > 0 :
_goodBananas += 1
_progressbar.write( 'Got banana! :D. So far: %d' % _goodBananas )
elif _reward < 0 :
_badBananas += 1
_progressbar.write( 'Got bad banana :/. So far: %d' % _badBananas )
_score += _reward
if _done :
break
_progressbar.set_description( 'Testing> Score=%.2f' % ( _score ) )
_progressbar.refresh()
def experiment( sessionId,
library,
savefile,
resultsFilename,
replayFilename,
agentConfigFilename,
modelConfigFilename ) :
# grab factory-method for the model according to the library requested
_modelBuilder = model_pytorch.DqnModelBuilder if library == 'pytorch' \
else model_tensorflow.DqnModelBuilder
# grab initialization-method for the model according to the library requested
_backendInitializer = model_pytorch.BackendInitializer if library == 'pytorch' \
else model_tensorflow.BackendInitializer
# paths to the environment executables
_bananaExecPath = os.path.join( os.getcwd(), 'executables/Banana_Linux/Banana.x86_64' )
_bananaHeadlessExecPath = os.path.join( os.getcwd(), 'executables/Banana_Linux_NoVis/Banana.x86_64' )
if CONFIG_AGENT != '' :
agent_raycast.AGENT_CONFIG = config.DqnAgentConfig.load( CONFIG_AGENT )
if CONFIG_MODEL != '' :
agent_raycast.MODEL_CONFIG = config.DqnModelConfig.load( CONFIG_MODEL )
# instantiate the environment
_env = mlagents.createDiscreteActionsEnv( _bananaExecPath, seed = SEED )
# set the seed for the agent
agent_raycast.AGENT_CONFIG.seed = SEED
# set improvement flags
agent_raycast.AGENT_CONFIG.useDoubleDqn = USE_DOUBLE_DQN
agent_raycast.AGENT_CONFIG.usePrioritizedExpReplay = USE_PRIORITIZED_EXPERIENCE_REPLAY
agent_raycast.AGENT_CONFIG.useDuelingDqn = USE_DUELING_DQN
_agent = agent_raycast.CreateAgent( agent_raycast.AGENT_CONFIG,
agent_raycast.MODEL_CONFIG,
_modelBuilder,
_backendInitializer )
# save agent and model configurations
config.DqnAgentConfig.save( agent_raycast.AGENT_CONFIG, agentConfigFilename )
config.DqnModelConfig.save( agent_raycast.MODEL_CONFIG, modelConfigFilename )
if not TEST :
train( _env, _agent, sessionId, savefile, resultsFilename, replayFilename )
else :
_agent.load( _savefile )
test( _env, _agent )
if __name__ == '__main__' :
_parser = argparse.ArgumentParser()
_parser.add_argument( 'mode',
help = 'mode of execution (train|test)',
type = str,
choices = [ 'train', 'test' ] )
_parser.add_argument( '--library',
help = 'deep learning library to use (pytorch|tensorflow)',
type = str,
choices = [ 'pytorch','tensorflow' ],
default = 'pytorch' )
_parser.add_argument( '--sessionId',
help = 'identifier of this training run',
type = str,
default = 'banana_simple' )
_parser.add_argument( '--seed',
help = 'random seed for the environment and generators',
type = int,
default = 0 )
_parser.add_argument( '--visual',
help = 'whether or not use the visual-banana environment',
type = str,
default = 'false' )
_parser.add_argument( '--ddqn',
help = 'whether or not to use double dqn (true|false)',
type = str,
default = 'false' )
_parser.add_argument( '--prioritizedExpReplay',
help = 'whether or not to use prioritized experience replay (true|false)',
type = str,
default = 'false' )
_parser.add_argument( '--duelingDqn',
help = 'whether or not to use dueling dqn (true|false)',
type = str,
default = 'false' )
_parser.add_argument( '--configAgent',
help = 'configuration file for the agent (hyperparameters, etc.)',
type = str,
default = '' )
_parser.add_argument( '--configModel',
help = 'configuration file for the model (architecture, etc.)',
type = str,
default = '' )
_args = _parser.parse_args()
# whether or not we are in test mode
TEST = ( _args.mode == 'test' )
# the actual seed for the environment
SEED = _args.seed
# timestamp of the start of execution
TIME_START = int( time.time() )
_sessionfolder = os.path.join( RESULTS_FOLDER, _args.sessionId )
if not os.path.exists( _sessionfolder ) :
os.makedirs( _sessionfolder )
# file where to save the trained model
_savefile = _args.sessionId
_savefile += '_model_'
_savefile += _args.library
_savefile += ( '.pth' if _args.library == 'pytorch' else '.h5' )
_savefile = os.path.join( _sessionfolder, _savefile )
# file where to save the training results statistics
_resultsFilename = os.path.join( _sessionfolder,
_args.sessionId + '_results.pkl' )
# file where to save the replay information (for further extra analysis)
_replayFilename = os.path.join( _sessionfolder,
_args.sessionId + '_replay.pkl' )
# configuration files for this training session
_agentConfigFilename = os.path.join( _sessionfolder, _args.sessionId + '_agentconfig.json' )
_modelConfigFilename = os.path.join( _sessionfolder, _args.sessionId + '_modelconfig.json' )
# whether or not use the visual-banana environment
VISUAL = ( _args.visual.lower() == 'true' )
# DQN improvements options
USE_DOUBLE_DQN = ( _args.ddqn.lower() == 'true' )
USE_PRIORITIZED_EXPERIENCE_REPLAY = ( _args.prioritizedExpReplay.lower() == 'true' )
USE_DUELING_DQN = ( _args.duelingDqn.lower() == 'true' )
# Configuration files with training information (provided by the user)
CONFIG_AGENT = _args.configAgent
CONFIG_MODEL = _args.configModel
print( '#############################################################' )
print( '# #' )
print( '# Environment and agent setup #' )
print( '# #' )
print( '#############################################################' )
print( 'Mode : ', _args.mode )
print( 'Library : ', _args.library )
print( 'SessionId : ', _args.sessionId )
print( 'Savefile : ', _savefile )
print( 'ResultsFilename : ', _resultsFilename )
print( 'ReplayFilename : ', _replayFilename )
print( 'AgentConfigFilename : ', _agentConfigFilename )
print( 'ModelConfigFilename : ', _modelConfigFilename )
print( 'VisualBanana : ', _args.visual )
print( 'DoubleDqn : ', _args.ddqn )
print( 'PrioritizedExpReplay : ', _args.prioritizedExpReplay )
print( 'DuelingDqn : ', _args.duelingDqn )
print( 'Agent config file : ', 'None' if _args.configAgent == '' else _args.configAgent )
print( 'Model config file : ', 'None' if _args.configModel == '' else _args.configModel )
print( '#############################################################' )
experiment( _args.sessionId,
_args.library,
_savefile,
_resultsFilename,
_replayFilename,
_agentConfigFilename,
_modelConfigFilename )