Skip to content

Commit

Permalink
Update flashbax/buffers/item_buffer.py
Browse files Browse the repository at this point in the history
Co-authored-by: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com>
  • Loading branch information
EdanToledo and callumtilbury authored Dec 14, 2023
1 parent 0174381 commit fc77281
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flashbax/buffers/item_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def add_fn(
def sample_fn(
state: TrajectoryBufferState, rng_key: PRNGKey
) -> TrajectoryBufferSample[Experience]:
"""Samples a batch of transitions from the buffer."""
"""Samples a batch of items from the buffer."""
sampled_batch = buffer.sample(state, rng_key).experience
sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch)
return TrajectoryBufferSample(experience=sampled_batch)
Expand Down

0 comments on commit fc77281

Please sign in to comment.