forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 723
/
ppo2.py
541 lines (472 loc) · 28.9 KB
/
ppo2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
import time
import gym
import numpy as np
import tensorflow as tf
from stable_baselines import logger
from stable_baselines.common import explained_variance, ActorCriticRLModel, tf_util, SetVerbosity, TensorboardWriter
from stable_baselines.common.runners import AbstractEnvRunner
from stable_baselines.common.policies import ActorCriticPolicy, RecurrentActorCriticPolicy
from stable_baselines.common.schedules import get_schedule_fn
from stable_baselines.common.tf_util import total_episode_reward_logger
from stable_baselines.common.math_util import safe_mean
class PPO2(ActorCriticRLModel):
"""
Proximal Policy Optimization algorithm (GPU version).
Paper: https://arxiv.org/abs/1707.06347
:param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
:param gamma: (float) Discount factor
:param n_steps: (int) The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param ent_coef: (float) Entropy coefficient for the loss calculation
:param learning_rate: (float or callable) The learning rate, it can be a function
:param vf_coef: (float) Value function coefficient for the loss calculation
:param max_grad_norm: (float) The maximum value for the gradient clipping
:param lam: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
the number of environments run in parallel should be a multiple of nminibatches.
:param noptepochs: (int) Number of epoch when optimizing the surrogate
:param cliprange: (float or callable) Clipping parameter, it can be a function
:param cliprange_vf: (float or callable) Clipping parameter for the value function, it can be a function.
This is a parameter specific to the OpenAI implementation. If None is passed (default),
then `cliprange` (that is used for the policy) will be used.
IMPORTANT: this clipping depends on the reward scaling.
To deactivate value function clipping (and recover the original PPO implementation),
you have to pass a negative value (e.g. -1).
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param full_tensorboard_log: (bool) enable additional logging when using tensorboard
WARNING: this logging can take a lot of space quickly
:param seed: (int) Seed for the pseudo-random generators (python, numpy, tensorflow).
If None (default), use random seed. Note that if you want completely deterministic
results, you must set `n_cpu_tf_sess` to 1.
:param n_cpu_tf_sess: (int) The number of threads for TensorFlow operations
If None, the number of cpu of the current machine will be used.
"""
def __init__(self, policy, env, gamma=0.99, n_steps=128, ent_coef=0.01, learning_rate=2.5e-4, vf_coef=0.5,
max_grad_norm=0.5, lam=0.95, nminibatches=4, noptepochs=4, cliprange=0.2, cliprange_vf=None,
verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None,
full_tensorboard_log=False, seed=None, n_cpu_tf_sess=None):
self.learning_rate = learning_rate
self.cliprange = cliprange
self.cliprange_vf = cliprange_vf
self.n_steps = n_steps
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.gamma = gamma
self.lam = lam
self.nminibatches = nminibatches
self.noptepochs = noptepochs
self.tensorboard_log = tensorboard_log
self.full_tensorboard_log = full_tensorboard_log
self.action_ph = None
self.advs_ph = None
self.rewards_ph = None
self.old_neglog_pac_ph = None
self.old_vpred_ph = None
self.learning_rate_ph = None
self.clip_range_ph = None
self.entropy = None
self.vf_loss = None
self.pg_loss = None
self.approxkl = None
self.clipfrac = None
self._train = None
self.loss_names = None
self.train_model = None
self.act_model = None
self.value = None
self.n_batch = None
self.summary = None
super().__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs,
seed=seed, n_cpu_tf_sess=n_cpu_tf_sess)
if _init_setup_model:
self.setup_model()
def _make_runner(self):
return Runner(env=self.env, model=self, n_steps=self.n_steps,
gamma=self.gamma, lam=self.lam)
def _get_pretrain_placeholders(self):
policy = self.act_model
if isinstance(self.action_space, gym.spaces.Discrete):
return policy.obs_ph, self.action_ph, policy.policy
return policy.obs_ph, self.action_ph, policy.deterministic_action
def setup_model(self):
with SetVerbosity(self.verbose):
assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the PPO2 model must be " \
"an instance of common.policies.ActorCriticPolicy."
self.n_batch = self.n_envs * self.n_steps
self.graph = tf.Graph()
with self.graph.as_default():
self.set_random_seed(self.seed)
self.sess = tf_util.make_session(num_cpu=self.n_cpu_tf_sess, graph=self.graph)
n_batch_step = None
n_batch_train = None
if issubclass(self.policy, RecurrentActorCriticPolicy):
assert self.n_envs % self.nminibatches == 0, "For recurrent policies, "\
"the number of environments run in parallel should be a multiple of nminibatches."
n_batch_step = self.n_envs
n_batch_train = self.n_batch // self.nminibatches
act_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
n_batch_step, reuse=False, **self.policy_kwargs)
with tf.variable_scope("train_model", reuse=True,
custom_getter=tf_util.outer_scope_getter("train_model")):
train_model = self.policy(self.sess, self.observation_space, self.action_space,
self.n_envs // self.nminibatches, self.n_steps, n_batch_train,
reuse=True, **self.policy_kwargs)
with tf.variable_scope("loss", reuse=False):
self.action_ph = train_model.pdtype.sample_placeholder([None], name="action_ph")
self.advs_ph = tf.placeholder(tf.float32, [None], name="advs_ph")
self.rewards_ph = tf.placeholder(tf.float32, [None], name="rewards_ph")
self.old_neglog_pac_ph = tf.placeholder(tf.float32, [None], name="old_neglog_pac_ph")
self.old_vpred_ph = tf.placeholder(tf.float32, [None], name="old_vpred_ph")
self.learning_rate_ph = tf.placeholder(tf.float32, [], name="learning_rate_ph")
self.clip_range_ph = tf.placeholder(tf.float32, [], name="clip_range_ph")
neglogpac = train_model.proba_distribution.neglogp(self.action_ph)
self.entropy = tf.reduce_mean(train_model.proba_distribution.entropy())
vpred = train_model.value_flat
# Value function clipping: not present in the original PPO
if self.cliprange_vf is None:
# Default behavior (legacy from OpenAI baselines):
# use the same clipping as for the policy
self.clip_range_vf_ph = self.clip_range_ph
self.cliprange_vf = self.cliprange
elif isinstance(self.cliprange_vf, (float, int)) and self.cliprange_vf < 0:
# Original PPO implementation: no value function clipping
self.clip_range_vf_ph = None
else:
# Last possible behavior: clipping range
# specific to the value function
self.clip_range_vf_ph = tf.placeholder(tf.float32, [], name="clip_range_vf_ph")
if self.clip_range_vf_ph is None:
# No clipping
vpred_clipped = train_model.value_flat
else:
# Clip the different between old and new value
# NOTE: this depends on the reward scaling
vpred_clipped = self.old_vpred_ph + \
tf.clip_by_value(train_model.value_flat - self.old_vpred_ph,
- self.clip_range_vf_ph, self.clip_range_vf_ph)
vf_losses1 = tf.square(vpred - self.rewards_ph)
vf_losses2 = tf.square(vpred_clipped - self.rewards_ph)
self.vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
ratio = tf.exp(self.old_neglog_pac_ph - neglogpac)
pg_losses = -self.advs_ph * ratio
pg_losses2 = -self.advs_ph * tf.clip_by_value(ratio, 1.0 - self.clip_range_ph, 1.0 +
self.clip_range_ph)
self.pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
self.approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.old_neglog_pac_ph))
self.clipfrac = tf.reduce_mean(tf.cast(tf.greater(tf.abs(ratio - 1.0),
self.clip_range_ph), tf.float32))
loss = self.pg_loss - self.entropy * self.ent_coef + self.vf_loss * self.vf_coef
tf.summary.scalar('entropy_loss', self.entropy)
tf.summary.scalar('policy_gradient_loss', self.pg_loss)
tf.summary.scalar('value_function_loss', self.vf_loss)
tf.summary.scalar('approximate_kullback-leibler', self.approxkl)
tf.summary.scalar('clip_factor', self.clipfrac)
tf.summary.scalar('loss', loss)
with tf.variable_scope('model'):
self.params = tf.trainable_variables()
if self.full_tensorboard_log:
for var in self.params:
tf.summary.histogram(var.name, var)
grads = tf.gradients(loss, self.params)
if self.max_grad_norm is not None:
grads, _grad_norm = tf.clip_by_global_norm(grads, self.max_grad_norm)
grads = list(zip(grads, self.params))
trainer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_ph, epsilon=1e-5)
self._train = trainer.apply_gradients(grads)
self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac']
with tf.variable_scope("input_info", reuse=False):
tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.rewards_ph))
tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate_ph))
tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph))
tf.summary.scalar('clip_range', tf.reduce_mean(self.clip_range_ph))
if self.clip_range_vf_ph is not None:
tf.summary.scalar('clip_range_vf', tf.reduce_mean(self.clip_range_vf_ph))
tf.summary.scalar('old_neglog_action_probability', tf.reduce_mean(self.old_neglog_pac_ph))
tf.summary.scalar('old_value_pred', tf.reduce_mean(self.old_vpred_ph))
if self.full_tensorboard_log:
tf.summary.histogram('discounted_rewards', self.rewards_ph)
tf.summary.histogram('learning_rate', self.learning_rate_ph)
tf.summary.histogram('advantage', self.advs_ph)
tf.summary.histogram('clip_range', self.clip_range_ph)
tf.summary.histogram('old_neglog_action_probability', self.old_neglog_pac_ph)
tf.summary.histogram('old_value_pred', self.old_vpred_ph)
if tf_util.is_image(self.observation_space):
tf.summary.image('observation', train_model.obs_ph)
else:
tf.summary.histogram('observation', train_model.obs_ph)
self.train_model = train_model
self.act_model = act_model
self.step = act_model.step
self.proba_step = act_model.proba_step
self.value = act_model.value
self.initial_state = act_model.initial_state
tf.global_variables_initializer().run(session=self.sess) # pylint: disable=E1101
self.summary = tf.summary.merge_all()
def _train_step(self, learning_rate, cliprange, obs, returns, masks, actions, values, neglogpacs, update,
writer, states=None, cliprange_vf=None):
"""
Training of PPO2 Algorithm
:param learning_rate: (float) learning rate
:param cliprange: (float) Clipping factor
:param obs: (np.ndarray) The current observation of the environment
:param returns: (np.ndarray) the rewards
:param masks: (np.ndarray) The last masks for done episodes (used in recurent policies)
:param actions: (np.ndarray) the actions
:param values: (np.ndarray) the values
:param neglogpacs: (np.ndarray) Negative Log-likelihood probability of Actions
:param update: (int) the current step iteration
:param writer: (TensorFlow Summary.writer) the writer for tensorboard
:param states: (np.ndarray) For recurrent policies, the internal state of the recurrent model
:return: policy gradient loss, value function loss, policy entropy,
approximation of kl divergence, updated clipping range, training update operation
:param cliprange_vf: (float) Clipping factor for the value function
"""
advs = returns - values
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
td_map = {self.train_model.obs_ph: obs, self.action_ph: actions,
self.advs_ph: advs, self.rewards_ph: returns,
self.learning_rate_ph: learning_rate, self.clip_range_ph: cliprange,
self.old_neglog_pac_ph: neglogpacs, self.old_vpred_ph: values}
if states is not None:
td_map[self.train_model.states_ph] = states
td_map[self.train_model.dones_ph] = masks
if cliprange_vf is not None and cliprange_vf >= 0:
td_map[self.clip_range_vf_ph] = cliprange_vf
if states is None:
update_fac = max(self.n_batch // self.nminibatches // self.noptepochs, 1)
else:
update_fac = max(self.n_batch // self.nminibatches // self.noptepochs // self.n_steps, 1)
if writer is not None:
# run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...)
if self.full_tensorboard_log and (1 + update) % 10 == 0:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, policy_loss, value_loss, policy_entropy, approxkl, clipfrac, _ = self.sess.run(
[self.summary, self.pg_loss, self.vf_loss, self.entropy, self.approxkl, self.clipfrac, self._train],
td_map, options=run_options, run_metadata=run_metadata)
writer.add_run_metadata(run_metadata, 'step%d' % (update * update_fac))
else:
summary, policy_loss, value_loss, policy_entropy, approxkl, clipfrac, _ = self.sess.run(
[self.summary, self.pg_loss, self.vf_loss, self.entropy, self.approxkl, self.clipfrac, self._train],
td_map)
writer.add_summary(summary, (update * update_fac))
else:
policy_loss, value_loss, policy_entropy, approxkl, clipfrac, _ = self.sess.run(
[self.pg_loss, self.vf_loss, self.entropy, self.approxkl, self.clipfrac, self._train], td_map)
return policy_loss, value_loss, policy_entropy, approxkl, clipfrac
def set_env(self,env):
super().set_env(env)
self.n_batch = self.n_envs * self.n_steps
def learn(self, total_timesteps, callback=None, log_interval=1, tb_log_name="PPO2",
reset_num_timesteps=True):
# Transform to callable if needed
self.learning_rate = get_schedule_fn(self.learning_rate)
self.cliprange = get_schedule_fn(self.cliprange)
cliprange_vf = get_schedule_fn(self.cliprange_vf)
new_tb_log = self._init_num_timesteps(reset_num_timesteps)
callback = self._init_callback(callback)
with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
as writer:
self._setup_learn()
t_first_start = time.time()
n_updates = total_timesteps // self.n_batch
callback.on_training_start(locals(), globals())
for update in range(1, n_updates + 1):
assert self.n_batch % self.nminibatches == 0, ("The number of minibatches (`nminibatches`) "
"is not a factor of the total number of samples "
"collected per rollout (`n_batch`), "
"some samples won't be used."
)
batch_size = self.n_batch // self.nminibatches
t_start = time.time()
frac = 1.0 - (update - 1.0) / n_updates
lr_now = self.learning_rate(frac)
cliprange_now = self.cliprange(frac)
cliprange_vf_now = cliprange_vf(frac)
callback.on_rollout_start()
# true_reward is the reward without discount
rollout = self.runner.run(callback)
# Unpack
obs, returns, masks, actions, values, neglogpacs, states, ep_infos, true_reward = rollout
callback.on_rollout_end()
# Early stopping due to the callback
if not self.runner.continue_training:
break
self.ep_info_buf.extend(ep_infos)
mb_loss_vals = []
if states is None: # nonrecurrent version
update_fac = max(self.n_batch // self.nminibatches // self.noptepochs, 1)
inds = np.arange(self.n_batch)
for epoch_num in range(self.noptepochs):
np.random.shuffle(inds)
for start in range(0, self.n_batch, batch_size):
timestep = self.num_timesteps // update_fac + ((epoch_num *
self.n_batch + start) // batch_size)
end = start + batch_size
mbinds = inds[start:end]
slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
mb_loss_vals.append(self._train_step(lr_now, cliprange_now, *slices, writer=writer,
update=timestep, cliprange_vf=cliprange_vf_now))
else: # recurrent version
update_fac = max(self.n_batch // self.nminibatches // self.noptepochs // self.n_steps, 1)
assert self.n_envs % self.nminibatches == 0
env_indices = np.arange(self.n_envs)
flat_indices = np.arange(self.n_envs * self.n_steps).reshape(self.n_envs, self.n_steps)
envs_per_batch = batch_size // self.n_steps
for epoch_num in range(self.noptepochs):
np.random.shuffle(env_indices)
for start in range(0, self.n_envs, envs_per_batch):
timestep = self.num_timesteps // update_fac + ((epoch_num *
self.n_envs + start) // envs_per_batch)
end = start + envs_per_batch
mb_env_inds = env_indices[start:end]
mb_flat_inds = flat_indices[mb_env_inds].ravel()
slices = (arr[mb_flat_inds] for arr in (obs, returns, masks, actions, values, neglogpacs))
mb_states = states[mb_env_inds]
mb_loss_vals.append(self._train_step(lr_now, cliprange_now, *slices, update=timestep,
writer=writer, states=mb_states,
cliprange_vf=cliprange_vf_now))
loss_vals = np.mean(mb_loss_vals, axis=0)
t_now = time.time()
fps = int(self.n_batch / (t_now - t_start))
if writer is not None:
total_episode_reward_logger(self.episode_reward,
true_reward.reshape((self.n_envs, self.n_steps)),
masks.reshape((self.n_envs, self.n_steps)),
writer, self.num_timesteps)
if self.verbose >= 1 and (update % log_interval == 0 or update == 1):
explained_var = explained_variance(values, returns)
logger.logkv("serial_timesteps", update * self.n_steps)
logger.logkv("n_updates", update)
logger.logkv("total_timesteps", self.num_timesteps)
logger.logkv("fps", fps)
logger.logkv("explained_variance", float(explained_var))
if len(self.ep_info_buf) > 0 and len(self.ep_info_buf[0]) > 0:
logger.logkv('ep_reward_mean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buf]))
logger.logkv('ep_len_mean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buf]))
logger.logkv('time_elapsed', t_start - t_first_start)
for (loss_val, loss_name) in zip(loss_vals, self.loss_names):
logger.logkv(loss_name, loss_val)
logger.dumpkvs()
callback.on_training_end()
return self
def save(self, save_path, cloudpickle=False):
data = {
"gamma": self.gamma,
"n_steps": self.n_steps,
"vf_coef": self.vf_coef,
"ent_coef": self.ent_coef,
"max_grad_norm": self.max_grad_norm,
"learning_rate": self.learning_rate,
"lam": self.lam,
"nminibatches": self.nminibatches,
"noptepochs": self.noptepochs,
"cliprange": self.cliprange,
"cliprange_vf": self.cliprange_vf,
"verbose": self.verbose,
"policy": self.policy,
"observation_space": self.observation_space,
"action_space": self.action_space,
"n_envs": self.n_envs,
"n_cpu_tf_sess": self.n_cpu_tf_sess,
"seed": self.seed,
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
}
params_to_save = self.get_parameters()
self._save_to_file(save_path, data=data, params=params_to_save, cloudpickle=cloudpickle)
class Runner(AbstractEnvRunner):
def __init__(self, *, env, model, n_steps, gamma, lam):
"""
A runner to learn the policy of an environment for a model
:param env: (Gym environment) The environment to learn from
:param model: (Model) The model to learn
:param n_steps: (int) The number of steps to run for each environment
:param gamma: (float) Discount factor
:param lam: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
"""
super().__init__(env=env, model=model, n_steps=n_steps)
self.lam = lam
self.gamma = gamma
def _run(self):
"""
Run a learning step of the model
:return:
- observations: (np.ndarray) the observations
- rewards: (np.ndarray) the rewards
- masks: (numpy bool) whether an episode is over or not
- actions: (np.ndarray) the actions
- values: (np.ndarray) the value function output
- negative log probabilities: (np.ndarray)
- states: (np.ndarray) the internal states of the recurrent policies
- infos: (dict) the extra information of the model
"""
# mb stands for minibatch
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [], [], [], [], [], []
mb_states = self.states
ep_infos = []
for _ in range(self.n_steps):
actions, values, self.states, neglogpacs = self.model.step(self.obs, self.states, self.dones) # pytype: disable=attribute-error
mb_obs.append(self.obs.copy())
mb_actions.append(actions)
mb_values.append(values)
mb_neglogpacs.append(neglogpacs)
mb_dones.append(self.dones)
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.env.action_space, gym.spaces.Box):
clipped_actions = np.clip(actions, self.env.action_space.low, self.env.action_space.high)
self.obs[:], rewards, self.dones, infos = self.env.step(clipped_actions)
self.model.num_timesteps += self.n_envs
if self.callback is not None:
# Abort training early
self.callback.update_locals(locals())
if self.callback.on_step() is False:
self.continue_training = False
# Return dummy values
return [None] * 9
for info in infos:
maybe_ep_info = info.get('episode')
if maybe_ep_info is not None:
ep_infos.append(maybe_ep_info)
mb_rewards.append(rewards)
# batch of steps to batch of rollouts
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
mb_actions = np.asarray(mb_actions)
mb_values = np.asarray(mb_values, dtype=np.float32)
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
mb_dones = np.asarray(mb_dones, dtype=np.bool)
last_values = self.model.value(self.obs, self.states, self.dones) # pytype: disable=attribute-error
# discount/bootstrap off value fn
mb_advs = np.zeros_like(mb_rewards)
true_reward = np.copy(mb_rewards)
last_gae_lam = 0
for step in reversed(range(self.n_steps)):
if step == self.n_steps - 1:
nextnonterminal = 1.0 - self.dones
nextvalues = last_values
else:
nextnonterminal = 1.0 - mb_dones[step + 1]
nextvalues = mb_values[step + 1]
delta = mb_rewards[step] + self.gamma * nextvalues * nextnonterminal - mb_values[step]
mb_advs[step] = last_gae_lam = delta + self.gamma * self.lam * nextnonterminal * last_gae_lam
mb_returns = mb_advs + mb_values
mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, true_reward = \
map(swap_and_flatten, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, true_reward))
return mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, mb_states, ep_infos, true_reward
# obs, returns, masks, actions, values, neglogpacs, states = runner.run()
def swap_and_flatten(arr):
"""
swap and then flatten axes 0 and 1
:param arr: (np.ndarray)
:return: (np.ndarray)
"""
shape = arr.shape
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])