Skip to content

Commit

Permalink
Merge pull request #4 from instadeepai/feat/item_buffer
Browse files Browse the repository at this point in the history
Feat/item buffer
  • Loading branch information
callumtilbury authored Dec 14, 2023
2 parents b413ecb + 195a6df commit 1d45803
Show file tree
Hide file tree
Showing 6 changed files with 408 additions and 4 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Flashbax provides an implementation of various different types of buffers, such

🗄️ **Flat Buffer**: The Flat Buffer, akin to the transition buffer used in algorithms like DQN, is a core component. It employs a sequence of 2 (i.e. $s_t$, $s_{t+1}$), with a period of 1 for comprehensive transition pair consideration.

🧺 **Item Buffer**: The Item Buffer is a simple buffer that stores individual items. It is useful for storing data that is independent of each other, such as (observation, action, reward, discount, next_observation) tuples, or entire episodes.

🛤️ **Trajectory Buffer**: The Trajectory Buffer facilitates the sampling of multi-step trajectories, catering to algorithms utilising recurrent networks like R2D2 (Kapturowski et al., [2018](https://www.deepmind.com/publications/recurrent-experience-replay-in-distributed-reinforcement-learning)).

🏅 **Prioritised Buffers**: Both Flat and Trajectory Buffers can be prioritised, enabling sampling based on user-defined priorities. The prioritisation mechanism aligns with the principles outlined in the PER paper (Schaul et al, [2016](https://arxiv.org/abs/1511.05952)).
Expand Down Expand Up @@ -76,6 +78,8 @@ buffer = fbx.make_flat_buffer(...)
# OR
buffer = fbx.make_prioritised_flat_buffer(...)
# OR
buffer = fbx.make_item_buffer(...)
# OR
buffer = fbx.make_trajectory_queue(...)

# Initialise
Expand Down Expand Up @@ -155,10 +159,10 @@ Flashbax uses a trajectory buffer as the foundation for all buffer types. This m
When adding batches of data, the buffer is created in a block-like structure. This means that the effective buffer size is dependent on the size of the batch dimension. The trajectory buffer allows a user to specify the add batch dimension and the max length of the time axis. This will create a block structure of (batch, time) allowing the maximum number of timesteps that can be in storage to be batch*time. For ease of use, we provide the max size argument that allows a user to set their total desired number of timesteps and we calculate the max length of the time axis dependent on the add batch dimension that is provided. Due to this, it is important to note that when using the max size argument, the max length of the time axis will be equal to max size // add batch size which will round down thereby reducing the effective buffer size. This means one might think they are increasing the buffer size by a certain amount but in actuality there is no increase. Therefore, to avoid this, we recommend one of two things: Use the max length time axis argument explicitly or increase the max size argument in multiples of the add batch size.

### Handling Episode Truncation
Another critical aspect is episode truncation. When truncating episodes and adding data to the buffer, it's vital to ensure that you set a done flag or a 'discount' value appropriately. Neglecting to do so can introduce challenges into your algorithm's implementation and behavior. As stated previously, it is expected that the algorithm handles these cases appropriately.
Another critical aspect is episode truncation. When truncating episodes and adding data to the buffer, it's vital to ensure that you set a done flag or a 'discount' value appropriately. Neglecting to do so can introduce challenges into your algorithm's implementation and behavior. As stated previously, it is expected that the algorithm handles these cases appropriately. It can be difficult handling truncation when using the flat buffer or trajectory buffer as the algorithm must handle the case of the final timestep in an episode being followed by the first timestep in the next episode. Sacrificing memory efficiency for ease of use, the item buffer can be used to store transitions or entire trajectories independently. This means that the algorithm does not need to handle the case of the final timestep in an episode being followed by the first timestep in the next episode as only the data that is explicitly inserted can be sampled.

### Independent Data Usage
For situations where you intend to utilise buffers with data that lack sequential information, you can leverage a trajectory buffer with specific configurations. By setting a sequence dimension of 1 and a period of 1, each item will be treated as independent. However, when working with independent transition items like (observation, action, reward, discount, next_observation), be mindful that this approach will result in duplicate observations within the buffer, leading to unnecessary memory consumption. It is important to note that the implementation of the flat buffer will be slower than utilising a trajectory buffer in this way due to the inherent speed issues that arise with data indexing on hardware accelerators; however, this trade-off is done to enhance memory efficiency. If speed is largely preferred over memory efficiency then use the trajectory buffer with sequence 1 and period 1 storing full transition data items.
For situations where you intend to utilise buffers with data that lack sequential information, you can leverage the item buffer which is a wrapped trajectory buffer with specific configurations. By setting a sequence dimension of 1 and a period of 1, each item will be treated as independent. However, when working with independent transition items like (observation, action, reward, discount, next_observation), be mindful that this approach will result in duplicate observations within the buffer, leading to unnecessary memory consumption. It is important to note that the implementation of the flat buffer will be slower than utilising the item buffer in this way due to the inherent speed issues that arise with data indexing on hardware accelerators; however, this trade-off is done to enhance memory efficiency. If speed is largely preferred over memory efficiency then use the trajectory buffer with sequence 1 and period 1 storing full transition data items.

### In-place Updating of Buffer State
Since buffers are generally large and occupy a significant portion of device memory, it is beneficial to perform in-place updates. To do this, it is important to specify to the top-level compiled function that you would like to perform this in-place update operation. This is indicated as follows:
Expand Down Expand Up @@ -200,7 +204,7 @@ Here we provide a series of initial benchmarks outlining the performance of the
| Sample Sequence Length | 1 |
| Sample Sequence Period | 1 |

The reason we use a sample sequence length and period of 1 is to directly compare to the other buffers. This essentially means that the trajectory buffers are being used as memory inefficent transition buffers. It is important to note that the Flat Buffer implementations use a sample sequence length of 2. Additionally, one must bear in mind that not all other buffer implementations can efficiently make use of GPUs/TPUs thus they simply run on the CPU and perform device conversions. Lastly, we explicitly make use of python loops to fairly compare to the other buffers. Speeds can be largely improved using scan operations (depending on observation size).
The reason we use a sample sequence length and period of 1 is to directly compare to the other buffers, this means the speeds for the trajectory buffer are comparable to the speeds of the item buffer as the item buffer is simply a wrapped trajectory buffer with this configuration. This essentially means that the trajectory buffers are being used as memory inefficent transition buffers. It is important to note that the Flat Buffer implementations use a sample sequence length of 2. Additionally, one must bear in mind that not all other buffer implementations can efficiently make use of GPUs/TPUs thus they simply run on the CPU and perform device conversions. Lastly, we explicitly make use of python loops to fairly compare to the other buffers. Speeds can be largely improved using scan operations (depending on observation size).

### CPU Speeds

Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart_flat_buffer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"# Quickstart: Using the Flat Buffer with Flashbax\n",
"\n",
"This guide demonstrates how to use the Flat Buffer, the simplest of the Flashbax buffers, for experience replay in reinforcement learning tasks. The Flat Buffer operates by saving all experience data in a first-in-first-out (FIFO) queue and returns batches of uniformly sampled experience from it. This is akin to the buffer used in the [original DQN paper](https://arxiv.org/abs/1312.5602). "
"This guide demonstrates how to use the Flat Buffer, for experience replay in reinforcement learning tasks. The Flat Buffer operates by saving all experience data in a first-in-first-out (FIFO) queue and returns batches of uniformly sampled experience from it. This is akin to the buffer used in the [original DQN paper](https://arxiv.org/abs/1312.5602). "
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions flashbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from flashbax.buffers import (
flat_buffer,
item_buffer,
make_flat_buffer,
make_item_buffer,
make_prioritised_flat_buffer,
make_prioritised_trajectory_buffer,
make_trajectory_buffer,
Expand Down
1 change: 1 addition & 0 deletions flashbax/buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flashbax.buffers.flat_buffer import make_flat_buffer
from flashbax.buffers.item_buffer import make_item_buffer
from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer
from flashbax.buffers.prioritised_trajectory_buffer import (
make_prioritised_trajectory_buffer,
Expand Down
149 changes: 149 additions & 0 deletions flashbax/buffers/item_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from chex import PRNGKey

from flashbax import utils
from flashbax.buffers.trajectory_buffer import (
Experience,
TrajectoryBuffer,
TrajectoryBufferSample,
TrajectoryBufferState,
make_trajectory_buffer,
)
from flashbax.utils import add_dim_to_args


def validate_sample_batch_size(sample_batch_size: int, max_length: int):
if sample_batch_size > max_length:
raise ValueError("sample_batch_size must be less than or equal to max_length")


def validate_min_length(min_length: int, max_length: int):
if min_length > max_length:
raise ValueError("min_length used cannot be larger than max_length.")


def validate_item_buffer_args(
max_length: int,
min_length: int,
sample_batch_size: int,
):
"""Validates the arguments for the item buffer."""

validate_sample_batch_size(sample_batch_size, max_length)
validate_min_length(min_length, max_length)


def create_item_buffer(
max_length: int,
min_length: int,
sample_batch_size: int,
add_sequences: bool,
add_batches: bool,
) -> TrajectoryBuffer:
"""Creates a trajectory buffer that acts as an independent item buffer.
Args:
max_length (int): The maximum length of the buffer.
min_length (int): The minimum length of the buffer.
sample_batch_size (int): The batch size of the samples.
add_sequences (Optional[bool], optional): Whether data is being added in sequences
to the buffer. If False, single items are being added each time add
is called. Defaults to False.
add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer.
If False, single items (or single sequences of items) are being added each time add
is called. Defaults to False.
Returns:
The buffer."""

validate_item_buffer_args(
max_length=max_length,
min_length=min_length,
sample_batch_size=sample_batch_size,
)

buffer = make_trajectory_buffer(
max_length_time_axis=max_length,
min_length_time_axis=min_length,
add_batch_size=1,
sample_batch_size=sample_batch_size,
sample_sequence_length=1,
period=1,
)

def add_fn(
state: TrajectoryBufferState, batch: Experience
) -> TrajectoryBufferState[Experience]:
"""Flattens a batch to add items along single time axis."""
batch_size, seq_len = utils.get_tree_shape_prefix(batch, n_axes=2)
flattened_batch = jax.tree_map(
lambda x: x.reshape((1, batch_size * seq_len, *x.shape[2:])), batch
)
return buffer.add(state, flattened_batch)

if not add_batches:
add_fn = add_dim_to_args(
add_fn, axis=0, starting_arg_index=1, ending_arg_index=2
)

if not add_sequences:
axis = 1 - int(not add_batches) # 1 if add_batches else 0
add_fn = add_dim_to_args(
add_fn, axis=axis, starting_arg_index=1, ending_arg_index=2
)

def sample_fn(
state: TrajectoryBufferState, rng_key: PRNGKey
) -> TrajectoryBufferSample[Experience]:
"""Samples a batch of items from the buffer."""
sampled_batch = buffer.sample(state, rng_key).experience
sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch)
return TrajectoryBufferSample(experience=sampled_batch)

return buffer.replace(add=add_fn, sample=sample_fn) # type: ignore


def make_item_buffer(
max_length: int,
min_length: int,
sample_batch_size: int,
add_sequences: bool = False,
add_batches: bool = False,
) -> TrajectoryBuffer:
"""Makes a trajectory buffer act as a independent item buffer.
Args:
max_length (int): The maximum length of the buffer.
min_length (int): The minimum length of the buffer.
sample_batch_size (int): The batch size of the samples.
add_sequences (Optional[bool], optional): Whether data is being added in sequences
to the buffer. If False, single items are being added each time add
is called. Defaults to False.
add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer.
If False, single transitions or single sequences are being added each time add
is called. Defaults to False.
Returns:
The buffer."""

return create_item_buffer(
max_length=max_length,
min_length=min_length,
sample_batch_size=sample_batch_size,
add_sequences=add_sequences,
add_batches=add_batches,
)
Loading

0 comments on commit 1d45803

Please sign in to comment.