Skip to content

Commit

Permalink
Merge branch 'main' into feat/mixed_experience_replay
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo authored Sep 20, 2024
2 parents 3fddc3b + 3cadd21 commit 57f77ca
Show file tree
Hide file tree
Showing 23 changed files with 76 additions and 72 deletions.
14 changes: 7 additions & 7 deletions examples/anakin_dqn_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@
"):\n",
" \"\"\"Broadcasts parameters to device shape.\"\"\"\n",
" broadcast = lambda x: jnp.broadcast_to(x, (cores_count, num_envs) + x.shape)\n",
" params = jax.tree_map(broadcast, params) # broadcast to cores and batch.\n",
" opt_state = jax.tree_map(broadcast, opt_state) # broadcast to cores and batch\n",
" buffer_state = jax.tree_map(broadcast, buffer_state) # broadcast to cores and batch\n",
" params = jax.tree.map(broadcast, params) # broadcast to cores and batch.\n",
" opt_state = jax.tree.map(broadcast, opt_state) # broadcast to cores and batch\n",
" buffer_state = jax.tree.map(broadcast, buffer_state) # broadcast to cores and batch\n",
" params_state = Params(\n",
" online=params,\n",
" target=params,\n",
Expand Down Expand Up @@ -494,7 +494,7 @@
"def eval(params, rng):\n",
" \"\"\"Evaluates multiple episodes.\"\"\"\n",
" rngs = random.split(rng, NUM_EVAL_EPISODES)\n",
" params = jax.tree_map(lambda x: x[0][0], params)\n",
" params = jax.tree.map(lambda x: x[0][0], params)\n",
" _, tot_r = jax.lax.scan(eval_one_episode, params, rngs)\n",
" return tot_r.mean()"
]
Expand Down Expand Up @@ -535,8 +535,8 @@
"rng, *env_rngs = jax.random.split(rng, cores_count * NUM_ENVS + 1)\n",
"env_states, env_timesteps = jax.vmap(env.reset)(jnp.stack(env_rngs)) # init envs.\n",
"reshape = lambda x: x.reshape((cores_count, NUM_ENVS) + x.shape[1:])\n",
"env_states = jax.tree_map(reshape, env_states) # add dimension to pmap over.\n",
"env_timesteps = jax.tree_map(reshape, env_timesteps) # add dimension to pmap over.\n",
"env_states = jax.tree.map(reshape, env_states) # add dimension to pmap over.\n",
"env_timesteps = jax.tree.map(reshape, env_timesteps) # add dimension to pmap over.\n",
"params_state, opt_state, buffer_state, step_rngs, rng = broadcast_to_device_shape(\n",
" cores_count, NUM_ENVS, params, opt_state, buffer_state, rng\n",
")\n",
Expand All @@ -562,7 +562,7 @@
" # Train\n",
" start = timeit.default_timer()\n",
" params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps = learn(params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps)\n",
" params_state = jax.tree_map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n",
" params_state = jax.tree.map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n",
" total_time += timeit.default_timer() - start\n",
" # Eval\n",
" rng, eval_rng = jax.random.split(rng, num=2)\n",
Expand Down
18 changes: 9 additions & 9 deletions examples/anakin_ppo_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
" step_rngs = random.split(outer_rng, rollout_len)\n",
" (env_state, env_timestep, params), rollout = lax.scan(step_fn, (env_state, env_timestep, params), step_rngs) \n",
"\n",
" rollout = jax.tree_map(lambda x: jnp.expand_dims(x,0), rollout)\n",
" rollout = jax.tree.map(lambda x: jnp.expand_dims(x,0), rollout)\n",
" \n",
" return rollout, env_state, env_timestep\n",
" \n",
Expand Down Expand Up @@ -200,7 +200,7 @@
" buffer_state = buffer_fn.add(buffer_state, data_rollout) \n",
" buffer_state, batch = buffer_fn.sample(buffer_state) \n",
" # We get rid of the batch dimension here\n",
" batch = jax.tree_map(lambda x: jnp.squeeze(x, 0), batch.experience) \n",
" batch = jax.tree.map(lambda x: jnp.squeeze(x, 0), batch.experience) \n",
" \n",
" def epoch_update(carry, _):\n",
" \"\"\"Updates the parameters of the agent.\"\"\"\n",
Expand Down Expand Up @@ -289,9 +289,9 @@
"def broadcast_to_device_shape(cores_count, num_envs, params, opt_state, buffer_state, rng):\n",
" \"\"\"Broadcasts the parameters to the shape of the device.\"\"\"\n",
" broadcast = lambda x: jnp.broadcast_to(x, (cores_count, num_envs) + x.shape)\n",
" params = jax.tree_map(broadcast, params) # broadcast to cores and batch.\n",
" opt_state = jax.tree_map(broadcast, opt_state) # broadcast to cores and batch\n",
" buffer_state = jax.tree_map(broadcast, buffer_state) # broadcast to cores and batch\n",
" params = jax.tree.map(broadcast, params) # broadcast to cores and batch.\n",
" opt_state = jax.tree.map(broadcast, opt_state) # broadcast to cores and batch\n",
" buffer_state = jax.tree.map(broadcast, buffer_state) # broadcast to cores and batch\n",
" params_state = Params(online=params, update_count=jnp.zeros(shape=(cores_count, num_envs)))\n",
" rng, step_rngs = get_rng_keys(cores_count, num_envs, rng)\n",
" return params_state, opt_state, buffer_state, step_rngs, rng"
Expand Down Expand Up @@ -393,7 +393,7 @@
" \"\"\"Evaluates the agent on multiple episodes.\"\"\"\n",
"\n",
" rngs = random.split(rng, NUM_EVAL_EPISODES)\n",
" params = jax.tree_map(lambda x: x[0][0], params)\n",
" params = jax.tree.map(lambda x: x[0][0], params)\n",
" _, tot_r = jax.lax.scan(eval_one_episode, params, rngs)\n",
" return tot_r.mean()"
]
Expand Down Expand Up @@ -431,8 +431,8 @@
"rng, *env_rngs = jax.random.split(rng, cores_count * NUM_ENVS + 1)\n",
"env_states, env_timesteps = jax.vmap(env.reset)(jnp.stack(env_rngs)) # init envs.\n",
"reshape = lambda x: x.reshape((cores_count, NUM_ENVS) + x.shape[1:])\n",
"env_states = jax.tree_map(reshape, env_states) # add dimension to pmap over.\n",
"env_timesteps = jax.tree_map(reshape, env_timesteps) # add dimension to pmap over.\n",
"env_states = jax.tree.map(reshape, env_states) # add dimension to pmap over.\n",
"env_timesteps = jax.tree.map(reshape, env_timesteps) # add dimension to pmap over.\n",
"params_state, opt_state, buffer_state, step_rngs, rng = broadcast_to_device_shape(cores_count, NUM_ENVS, params, opt_state, buffer_state, rng)\n",
"\n",
"\n",
Expand All @@ -445,7 +445,7 @@
" # Train\n",
" start = timeit.default_timer()\n",
" params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps = learn(params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps)\n",
" params_state = jax.tree_map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n",
" params_state = jax.tree.map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n",
" total_time += timeit.default_timer() - start\n",
" # Eval\n",
" rng, eval_rng = jax.random.split(rng, num=2)\n",
Expand Down
14 changes: 7 additions & 7 deletions examples/anakin_prioritised_dqn_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@
"):\n",
" \"\"\"Broadcasts parameters to device shape.\"\"\"\n",
" broadcast = lambda x: jnp.broadcast_to(x, (cores_count, num_envs) + x.shape)\n",
" params = jax.tree_map(broadcast, params) # broadcast to cores and batch.\n",
" opt_state = jax.tree_map(broadcast, opt_state) # broadcast to cores and batch\n",
" buffer_state = jax.tree_map(broadcast, buffer_state) # broadcast to cores and batch\n",
" params = jax.tree.map(broadcast, params) # broadcast to cores and batch.\n",
" opt_state = jax.tree.map(broadcast, opt_state) # broadcast to cores and batch\n",
" buffer_state = jax.tree.map(broadcast, buffer_state) # broadcast to cores and batch\n",
" params_state = Params(\n",
" online=params,\n",
" target=params,\n",
Expand Down Expand Up @@ -521,7 +521,7 @@
"def eval(params, rng):\n",
" \"\"\"Evaluates multiple episodes.\"\"\"\n",
" rngs = random.split(rng, NUM_EVAL_EPISODES)\n",
" params = jax.tree_map(lambda x: x[0][0], params)\n",
" params = jax.tree.map(lambda x: x[0][0], params)\n",
" _, tot_r = jax.lax.scan(eval_one_episode, params, rngs)\n",
" return tot_r.mean()"
]
Expand Down Expand Up @@ -580,8 +580,8 @@
"rng, *env_rngs = jax.random.split(rng, cores_count * NUM_ENVS + 1)\n",
"env_states, env_timesteps = jax.vmap(env.reset)(jnp.stack(env_rngs)) # init envs.\n",
"reshape = lambda x: x.reshape((cores_count, NUM_ENVS) + x.shape[1:])\n",
"env_states = jax.tree_map(reshape, env_states) # add dimension to pmap over.\n",
"env_timesteps = jax.tree_map(reshape, env_timesteps) # add dimension to pmap over.\n",
"env_states = jax.tree.map(reshape, env_states) # add dimension to pmap over.\n",
"env_timesteps = jax.tree.map(reshape, env_timesteps) # add dimension to pmap over.\n",
"params_state, opt_state, buffer_state, step_rngs, rng = broadcast_to_device_shape(\n",
" cores_count, NUM_ENVS, params, opt_state, buffer_state, rng\n",
")\n",
Expand All @@ -608,7 +608,7 @@
" # Train\n",
" start = timeit.default_timer()\n",
" params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps = learn(params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps)\n",
" params_state = jax.tree_map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n",
" params_state = jax.tree.map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n",
" total_time += timeit.default_timer() - start\n",
" # Eval\n",
" rng, eval_rng = jax.random.split(rng, num=2)\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/quickstart_flat_buffer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
"metadata": {},
"outputs": [],
"source": [
"fake_batch = jax.tree_map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]), fake_timestep)\n",
"fake_batch = jax.tree.map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]), fake_timestep)\n",
"state = buffer.add(state, fake_batch)\n",
"assert not buffer.can_sample(state) # Buffer is not ready to sample\n",
"state = buffer.add(state, fake_batch)\n",
Expand Down Expand Up @@ -287,15 +287,15 @@
"source": [
"# Define a function to create a fake batch of data\n",
"def get_fake_batch(fake_timestep: chex.ArrayTree, batch_size) -> chex.ArrayTree:\n",
" return jax.tree_map(lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_timestep)\n",
" return jax.tree.map(lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_timestep)\n",
"\n",
"add_batch_size = 8\n",
"\n",
"# Re-instantiate the buffer\n",
"buffer = fbx.make_flat_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batch_size)\n",
"\n",
"# Initialize the buffer's state with a \"device\" dimension\n",
"fake_timestep_per_device = jax.tree_map(\n",
"fake_timestep_per_device = jax.tree.map(\n",
" lambda x: jnp.stack([x + i for i in range(DEVICE_COUNT_MOCK)]), fake_timestep\n",
")\n",
"state = jax.pmap(buffer.init)(fake_timestep_per_device)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart_prioritised_flat_buffer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"# The add function expects batches of experience - we create a fake batch by stacking\n",
"# timesteps.\n",
"# New samples to the buffer have their priority set to the maximum priority within the buffer. \n",
"fake_batch = jax.tree_map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]),\n",
"fake_batch = jax.tree.map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]),\n",
" fake_timestep) \n",
"state = buffer.add(state, fake_batch)\n",
"assert buffer.can_sample(state) == False # After one batch the buffer is not yet full.\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart_trajectory_buffer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@
"# The add function expects batches of trajectories.\n",
"# Thus, we create a fake batch of trajectories by broadcasting the `fake_timestep`.\n",
"broadcast_fn = lambda x: jnp.broadcast_to(x, (add_batch_size, add_sequence_length, *x.shape))\n",
"fake_batch_sequence = jax.tree_map(broadcast_fn, fake_timestep)\n",
"fake_batch_sequence = jax.tree.map(broadcast_fn, fake_timestep)\n",
"state = buffer.add(state, fake_batch_sequence)\n",
"assert buffer.can_sample(state) == False # After one batch the buffer is not yet full.\n",
"state = buffer.add(state, fake_batch_sequence)\n",
Expand Down
2 changes: 1 addition & 1 deletion flashbax/buffers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ def fake_transition() -> chex.ArrayTree:

def get_fake_batch(fake_transition: chex.ArrayTree, batch_size) -> chex.ArrayTree:
"""Create a fake batch with differing values for each transition."""
return jax.tree_map(
return jax.tree.map(
lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_transition
)
8 changes: 4 additions & 4 deletions flashbax/buffers/flat_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_add_batch_size_none(
max_length: int,
):
# create a fake batch and ensure there is no batch dimension
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1)
)

Expand Down Expand Up @@ -149,7 +149,7 @@ def test_add_sequences(
):
add_sequence_size = 5
# create a fake sequence and ensure there is no batch dimension
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: x.repeat(add_sequence_size, axis=0),
get_fake_batch(fake_transition, 1),
)
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_add_sequences_and_batches(
):
add_sequence_size = 5
# create a fake batch and sequence
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1),
get_fake_batch(fake_transition, add_batch_size),
)
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_flat_replay_buffer_does_not_smoke(
)

# Initialise the buffer's state.
fake_transition_per_device = jax.tree_map(
fake_transition_per_device = jax.tree.map(
lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition
)
state = jax.pmap(buffer.init)(fake_transition_per_device)
Expand Down
4 changes: 2 additions & 2 deletions flashbax/buffers/item_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def add_fn(
) -> TrajectoryBufferState[Experience]:
"""Flattens a batch to add items along single time axis."""
batch_size, seq_len = utils.get_tree_shape_prefix(batch, n_axes=2)
flattened_batch = jax.tree_map(
flattened_batch = jax.tree.map(
lambda x: x.reshape((1, batch_size * seq_len, *x.shape[2:])), batch
)
return buffer.add(state, flattened_batch)
Expand All @@ -111,7 +111,7 @@ def sample_fn(
) -> TrajectoryBufferSample[Experience]:
"""Samples a batch of items from the buffer."""
sampled_batch = buffer.sample(state, rng_key).experience
sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch)
sampled_batch = jax.tree.map(lambda x: x.squeeze(axis=1), sampled_batch)
return TrajectoryBufferSample(experience=sampled_batch)

return buffer.replace(add=add_fn, sample=sample_fn) # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions flashbax/buffers/item_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_add_batch_size_none(
max_length: int,
):
# create a fake batch and ensure there is no batch dimension
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1)
)

Expand Down Expand Up @@ -148,7 +148,7 @@ def test_add_sequences(
):
add_sequence_size = 5
# create a fake sequence and ensure there is no batch dimension
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: x.repeat(add_sequence_size, axis=0),
get_fake_batch(fake_transition, 1),
)
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_add_sequences_and_batches(
):
add_sequence_size = 5
# create a fake batch and sequence
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1),
get_fake_batch(fake_transition, add_batch_size),
)
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_item_replay_buffer_does_not_smoke(
)

# Initialise the buffer's state.
fake_transition_per_device = jax.tree_map(
fake_transition_per_device = jax.tree.map(
lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition
)
state = jax.pmap(buffer.init)(fake_transition_per_device)
Expand Down
8 changes: 4 additions & 4 deletions flashbax/buffers/prioritised_flat_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_prioritised_flat_buffer_does_not_smoke(
)

# Initialise the buffer's state.
fake_transition_per_device = jax.tree_map(
fake_transition_per_device = jax.tree.map(
lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition
)
state = jax.pmap(buffer.init)(fake_transition_per_device)
Expand Down Expand Up @@ -230,7 +230,7 @@ def test_add_batch_size_none(
priority_exponent: float,
):
# create a fake batch and ensure there is no batch dimension
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1)
)

Expand Down Expand Up @@ -278,7 +278,7 @@ def test_add_sequences(
):
add_sequence_size = 5
# create a fake sequence and ensure there is no batch dimension
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: x.repeat(add_sequence_size, axis=0),
get_fake_batch(fake_transition, 1),
)
Expand Down Expand Up @@ -321,7 +321,7 @@ def test_add_sequences_and_batches(
):
add_sequence_size = 5
# create a fake batch and sequence
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1),
get_fake_batch(fake_transition, add_batch_size),
)
Expand Down
4 changes: 2 additions & 2 deletions flashbax/buffers/prioritised_item_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def add_fn(
) -> PrioritisedTrajectoryBufferState[Experience]:
"""Flattens a batch to add items along single time axis."""
batch_size, seq_len = utils.get_tree_shape_prefix(batch, n_axes=2)
flattened_batch = jax.tree_map(
flattened_batch = jax.tree.map(
lambda x: x.reshape((1, batch_size * seq_len, *x.shape[2:])), batch
)
return buffer.add(state, flattened_batch)
Expand All @@ -107,7 +107,7 @@ def sample_fn(
priorities = sampled_batch.priorities
indices = sampled_batch.indices
sampled_batch = sampled_batch.experience
sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch)
sampled_batch = jax.tree.map(lambda x: x.squeeze(axis=1), sampled_batch)
return PrioritisedTrajectoryBufferSample(
experience=sampled_batch, indices=indices, priorities=priorities
)
Expand Down
8 changes: 4 additions & 4 deletions flashbax/buffers/prioritised_item_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_add_batch_size_none(
max_length: int,
):
# create a fake batch and ensure there is no batch dimension
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1)
)

Expand Down Expand Up @@ -201,7 +201,7 @@ def test_add_sequences(
):
add_sequence_size = 5
# create a fake sequence and ensure there is no batch dimension
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: x.repeat(add_sequence_size, axis=0),
get_fake_batch(fake_transition, 1),
)
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_add_sequences_and_batches(
):
add_sequence_size = 5
# create a fake batch and sequence
fake_batch = jax.tree_map(
fake_batch = jax.tree.map(
lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1),
get_fake_batch(fake_transition, add_batch_size),
)
Expand Down Expand Up @@ -281,7 +281,7 @@ def test_item_replay_buffer_does_not_smoke(
)

# Initialise the buffer's state.
fake_transition_per_device = jax.tree_map(
fake_transition_per_device = jax.tree.map(
lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition
)
state = jax.pmap(buffer.init)(fake_transition_per_device)
Expand Down
Loading

0 comments on commit 57f77ca

Please sign in to comment.