Skip to content

Commit

Permalink
Merge pull request #470 from wrzadkow:rl-example-ppo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 334768921
  • Loading branch information
Flax Authors committed Oct 1, 2020
2 parents fed1aaf + f3a9d03 commit 45937af
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
9 changes: 4 additions & 5 deletions examples/ppo/ppo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import agent
import test_episodes

@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1)
@jax.jit
@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1)
def gae_advantages(
rewards: onp.ndarray,
terminal_masks: onp.ndarray,
Expand Down Expand Up @@ -69,7 +69,7 @@ def gae_advantages(
advantages = advantages[::-1]
return jnp.array(advantages)

@functools.partial(jax.jit, static_argnums=(6))
@functools.partial(jax.jit, static_argnums=6)
def train_step(
optimizer: flax.optim.base.Optimizer,
trajectories: Tuple,
Expand Down Expand Up @@ -207,10 +207,9 @@ def process_experience(
returns = advantages + values[:-1, :]
# After preprocessing, concatenate data from all agents.
trajectories = (states, actions, log_probs, returns, advantages)
trajectory_len = num_agents * actor_steps
trajectories = tuple(map(
lambda x: onp.reshape(
x, (num_agents * actor_steps,) + x.shape[2:]),
trajectories))
lambda x: onp.reshape(x, (trajectory_len,) + x.shape[2:]), trajectories))
return trajectories

def train(
Expand Down
10 changes: 5 additions & 5 deletions examples/ppo/ppo_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ def test_creation(self):
game = self.choose_random_game()
env = env_utils.create_env(game, clip_rewards=True)
obs = env.reset()
self.assertTrue(obs.shape == frame_shape)
self.assertEqual(obs.shape, frame_shape)

def test_step(self):
frame_shape = (84, 84, 4)
game = self.choose_random_game()
env = env_utils.create_env(game, clip_rewards=False)
env = env_utils.create_env(game, clip_rewards=True)
obs = env.reset()
actions = [1, 2, 3, 0]
for a in actions:
obs, reward, done, info = env.step(a)
self.assertTrue(obs.shape == frame_shape)
self.assertEqual(obs.shape, frame_shape)
self.assertTrue(reward <= 1. and reward >= -1.)
self.assertTrue(isinstance(done, bool))
self.assertTrue(isinstance(info, dict))
Expand All @@ -95,9 +95,9 @@ def test_model(self):
test_batch_size, obs_shape = 10, (84, 84, 4)
random_input = onp.random.random(size=(test_batch_size,) + obs_shape)
log_probs, values = optimizer.target(random_input)
self.assertTrue(values.shape == (test_batch_size, 1))
self.assertEqual(values.shape, (test_batch_size, 1))
sum_probs = onp.sum(onp.exp(log_probs), axis=1)
self.assertTrue(sum_probs.shape == (test_batch_size, ))
self.assertEqual(sum_probs.shape, (test_batch_size, ))
onp_testing.assert_allclose(sum_probs, onp.ones((test_batch_size, )),
atol=1e-6)

Expand Down

0 comments on commit 45937af

Please sign in to comment.