diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 0451791c70..5a3932969e 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -29,8 +29,8 @@ import agent import test_episodes -@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @jax.jit +@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) def gae_advantages( rewards: onp.ndarray, terminal_masks: onp.ndarray, @@ -69,7 +69,7 @@ def gae_advantages( advantages = advantages[::-1] return jnp.array(advantages) -@functools.partial(jax.jit, static_argnums=(6)) +@functools.partial(jax.jit, static_argnums=6) def train_step( optimizer: flax.optim.base.Optimizer, trajectories: Tuple, @@ -207,10 +207,9 @@ def process_experience( returns = advantages + values[:-1, :] # After preprocessing, concatenate data from all agents. trajectories = (states, actions, log_probs, returns, advantages) + trajectory_len = num_agents * actor_steps trajectories = tuple(map( - lambda x: onp.reshape( - x, (num_agents * actor_steps,) + x.shape[2:]), - trajectories)) + lambda x: onp.reshape(x, (trajectory_len,) + x.shape[2:]), trajectories)) return trajectories def train( diff --git a/examples/ppo/ppo_lib_test.py b/examples/ppo/ppo_lib_test.py index c48bbb9e52..5cee68bca8 100644 --- a/examples/ppo/ppo_lib_test.py +++ b/examples/ppo/ppo_lib_test.py @@ -64,17 +64,17 @@ def test_creation(self): game = self.choose_random_game() env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() - self.assertTrue(obs.shape == frame_shape) + self.assertEqual(obs.shape, frame_shape) def test_step(self): frame_shape = (84, 84, 4) game = self.choose_random_game() - env = env_utils.create_env(game, clip_rewards=False) + env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() actions = [1, 2, 3, 0] for a in actions: obs, reward, done, info = env.step(a) - self.assertTrue(obs.shape == frame_shape) + self.assertEqual(obs.shape, frame_shape) self.assertTrue(reward <= 1. and reward >= -1.) self.assertTrue(isinstance(done, bool)) self.assertTrue(isinstance(info, dict)) @@ -95,9 +95,9 @@ def test_model(self): test_batch_size, obs_shape = 10, (84, 84, 4) random_input = onp.random.random(size=(test_batch_size,) + obs_shape) log_probs, values = optimizer.target(random_input) - self.assertTrue(values.shape == (test_batch_size, 1)) + self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = onp.sum(onp.exp(log_probs), axis=1) - self.assertTrue(sum_probs.shape == (test_batch_size, )) + self.assertEqual(sum_probs.shape, (test_batch_size, )) onp_testing.assert_allclose(sum_probs, onp.ones((test_batch_size, )), atol=1e-6)