From e49b159d0f54d403f0030ffa5bcb6381bdc56a5f Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 8 Dec 2023 12:15:26 +0000 Subject: [PATCH 01/10] feat: add item buffer wrapper for trajectory buffer --- flashbax/buffers/item_buffer.py | 149 ++++++++++++++++ flashbax/buffers/item_buffer_test.py | 248 +++++++++++++++++++++++++++ 2 files changed, 397 insertions(+) create mode 100644 flashbax/buffers/item_buffer.py create mode 100644 flashbax/buffers/item_buffer_test.py diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py new file mode 100644 index 0000000..3afc48b --- /dev/null +++ b/flashbax/buffers/item_buffer.py @@ -0,0 +1,149 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from chex import PRNGKey + +from flashbax import utils +from flashbax.buffers.trajectory_buffer import ( + Experience, + TrajectoryBuffer, + TrajectoryBufferSample, + TrajectoryBufferState, + make_trajectory_buffer, +) +from flashbax.utils import add_dim_to_args + + +def validate_sample_batch_size(sample_batch_size: int, max_length: int): + if sample_batch_size > max_length: + raise ValueError("sample_batch_size must be less than or equal to max_length") + + +def validate_min_length(min_length: int, max_length: int): + if min_length > max_length: + raise ValueError("min_length used is too large for the buffer size.") + + +def validate_item_buffer_args( + max_length: int, + min_length: int, + sample_batch_size: int, +): + """Validates the arguments for the flat buffer.""" + + validate_sample_batch_size(sample_batch_size, max_length) + validate_min_length(min_length, max_length) + + +def create_item_buffer( + max_length: int, + min_length: int, + sample_batch_size: int, + add_sequences: bool, + add_batches: bool, +) -> TrajectoryBuffer: + """Creates a trajectory buffer that acts as a independent item buffer. + + Args: + max_length (int): The maximum length of the buffer. + min_length (int): The minimum length of the buffer. + sample_batch_size (int): The batch size of the samples. + add_sequences (Optional[bool], optional): Whether data is being added in sequences + to the buffer. If False, single transitions are being added each time add + is called. Defaults to False. + add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer. + If False, single transitions or single sequences are being added each time add + is called. Defaults to False. + + Returns: + The buffer.""" + + validate_item_buffer_args( + max_length=max_length, + min_length=min_length, + sample_batch_size=sample_batch_size, + ) + + buffer = make_trajectory_buffer( + max_length_time_axis=max_length, + min_length_time_axis=min_length, + add_batch_size=1, + sample_batch_size=sample_batch_size, + sample_sequence_length=1, + period=1, + ) + + def add_fn( + state: TrajectoryBufferState, batch: Experience + ) -> TrajectoryBufferState[Experience]: + """Flattens a batch to add items along single time axis.""" + batch_size, seq_len = utils.get_tree_shape_prefix(batch, n_axes=2) + flattened_batch = jax.tree_map( + lambda x: x.reshape((1, batch_size * seq_len, *x.shape[2:])), batch + ) + return buffer.add(state, flattened_batch) + + if not add_batches: + add_fn = add_dim_to_args( + add_fn, axis=0, starting_arg_index=1, ending_arg_index=2 + ) + + if not add_sequences: + axis = 1 - int(not add_batches) # 1 if add_batches else 0 + add_fn = add_dim_to_args( + add_fn, axis=axis, starting_arg_index=1, ending_arg_index=2 + ) + + def sample_fn( + state: TrajectoryBufferState, rng_key: PRNGKey + ) -> TrajectoryBufferSample[Experience]: + """Samples a batch of transitions from the buffer.""" + sampled_batch = buffer.sample(state, rng_key).experience + sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch) + return TrajectoryBufferSample(experience=sampled_batch) + + return buffer.replace(add=add_fn, sample=sample_fn) # type: ignore + + +def make_item_buffer( + max_length: int, + min_length: int, + sample_batch_size: int, + add_sequences: bool = False, + add_batches: bool = False, +) -> TrajectoryBuffer: + """Makes a trajectory buffer act as a independent item buffer. + + Args: + max_length (int): The maximum length of the buffer. + min_length (int): The minimum length of the buffer. + sample_batch_size (int): The batch size of the samples. + add_sequences (Optional[bool], optional): Whether data is being added in sequences + to the buffer. If False, single transitions are being added each time add + is called. Defaults to False. + add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer. + If False, single transitions or single sequences are being added each time add + is called. Defaults to False. + + Returns: + The buffer.""" + + return create_item_buffer( + max_length=max_length, + min_length=min_length, + sample_batch_size=sample_batch_size, + add_sequences=add_sequences, + add_batches=add_batches, + ) diff --git a/flashbax/buffers/item_buffer_test.py b/flashbax/buffers/item_buffer_test.py new file mode 100644 index 0000000..a9633bf --- /dev/null +++ b/flashbax/buffers/item_buffer_test.py @@ -0,0 +1,248 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from copy import deepcopy + +import chex +import jax +import jax.numpy as jnp +import pytest + +from flashbax.buffers import item_buffer +from flashbax.buffers.conftest import get_fake_batch +from flashbax.conftest import _DEVICE_COUNT_MOCK + + +def test_add_and_can_sample( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, + add_batch_size: int, +) -> None: + """Check the `add` function by filling the buffer all + the way to the max_length and checking that it produces the expected behaviour . + """ + fake_batch = get_fake_batch(fake_transition, add_batch_size) + + buffer = item_buffer.make_item_buffer(max_length, min_length, 4, False, True) + state = buffer.init(fake_transition) + + init_state = deepcopy(state) # Save for later checks. + + n_batches_to_fill = max_length // add_batch_size + n_batches_to_sample = min_length // add_batch_size + + for i in range(n_batches_to_fill): + assert not state.is_full + state = buffer.add(state, fake_batch) + assert state.current_index == (((i + 1) * add_batch_size) % max_length) + + # Check that the `can_sample` function behavior is correct. + is_ready_to_sample = buffer.can_sample(state) + if i < (n_batches_to_sample): + assert not is_ready_to_sample + else: + assert is_ready_to_sample + + assert state.is_full + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(state.experience, init_state.experience) + + +def test_sample( + fake_transition: chex.ArrayTree, + max_length: int, + rng_key: chex.PRNGKey, +) -> None: + """Test the random sampling from the buffer.""" + + min_length = 40 + sample_batch_size = 20 + rng_key1, rng_key2 = jax.random.split(rng_key) + + # Fill buffer to the point that we can sample + fake_batch = get_fake_batch(fake_transition, sample_batch_size) + + buffer = item_buffer.make_item_buffer( + max_length, min_length, sample_batch_size, False, True + ) + state = buffer.init(fake_transition) + + # Add two batches of items. + state = buffer.add(state, fake_batch) + assert not buffer.can_sample(state) + state = buffer.add(state, fake_batch) + assert buffer.can_sample(state) + + # Sample from the buffer with different keys and check it gives us different batches. + batch1 = buffer.sample(state, rng_key1).experience + batch2 = buffer.sample(state, rng_key2).experience + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(batch1, batch2) + + # Check dtypes are correct. + chex.assert_trees_all_equal_dtypes( + fake_transition, + batch1, + batch2, + ) + + +def test_add_batch_size_none( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, +): + # create a fake batch and ensure there is no batch dimension + fake_batch = jax.tree_map( + lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1) + ) + + buffer = item_buffer.make_item_buffer(max_length, min_length, 4, False, False) + state = buffer.init(fake_transition) + + init_state = deepcopy(state) # Save for later checks. + + n_batches_to_fill = max_length + n_batches_to_sample = min_length + + for i in range(n_batches_to_fill): + assert not state.is_full + state = buffer.add(state, fake_batch) + assert state.current_index == ((i + 1) % (max_length)) + + # Check that the `can_sample` function behavior is correct. + is_ready_to_sample = buffer.can_sample(state) + if i < (n_batches_to_sample - 1): + assert not is_ready_to_sample + else: + assert is_ready_to_sample + + assert state.is_full + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(state.experience, init_state.experience) + + +def test_add_sequences( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, +): + add_sequence_size = 5 + # create a fake sequence and ensure there is no batch dimension + fake_batch = jax.tree_map( + lambda x: x.repeat(add_sequence_size, axis=0), + get_fake_batch(fake_transition, 1), + ) + assert fake_batch["obs"].shape[0] == add_sequence_size + + buffer = item_buffer.make_item_buffer( + max_length, min_length, 4, add_sequences=True, add_batches=False + ) + state = buffer.init(fake_transition) + + init_state = deepcopy(state) # Save for later checks. + + n_sequences_to_fill = (max_length // add_sequence_size) + 1 + + for i in range(n_sequences_to_fill): + assert not state.is_full + state = buffer.add(state, fake_batch) + assert state.current_index == (((i + 1) * add_sequence_size) % (max_length)) + + assert state.is_full + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(state.experience, init_state.experience) + + +def test_add_sequences_and_batches( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, + add_batch_size: int, +): + add_sequence_size = 5 + # create a fake batch and sequence + fake_batch = jax.tree_map( + lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1), + get_fake_batch(fake_transition, add_batch_size), + ) + assert fake_batch["obs"].shape[:2] == (add_batch_size, add_sequence_size) + + buffer = item_buffer.make_item_buffer( + max_length, min_length, 4, add_sequences=True, add_batches=True + ) + state = buffer.init(fake_transition) + + init_state = deepcopy(state) # Save for later checks. + + n_sequences_to_fill = (max_length // (add_batch_size * add_sequence_size)) + 1 + + for i in range(n_sequences_to_fill): + assert not state.is_full + state = buffer.add(state, fake_batch) + assert state.current_index == ( + ((i + 1) * add_sequence_size * add_batch_size) % max_length + ) + + assert state.is_full + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(state.experience, init_state.experience) + + +def test_item_replay_buffer_does_not_smoke( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, + rng_key: chex.PRNGKey, + sample_batch_size: int, +): + """Create the itemBuffer NamedTuple, and check that it is pmap-able and does not smoke.""" + + add_batch_size = int(min_length + 5) + buffer = item_buffer.make_item_buffer( + max_length, min_length, sample_batch_size, False, True + ) + + # Initialise the buffer's state. + fake_transition_per_device = jax.tree_map( + lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition + ) + state = jax.pmap(buffer.init)(fake_transition_per_device) + + # Now fill the buffer above its minimum length. + + fake_batch = jax.pmap(get_fake_batch, static_broadcasted_argnums=1)( + fake_transition_per_device, add_batch_size + ) + # Add two items thereby giving a single transition. + state = jax.pmap(buffer.add)(state, fake_batch) + state = jax.pmap(buffer.add)(state, fake_batch) + assert buffer.can_sample(state).all() + + # Sample from the buffer. + rng_key_per_device = jax.random.split(rng_key, _DEVICE_COUNT_MOCK) + batch = jax.pmap(buffer.sample)(state, rng_key_per_device) + chex.assert_tree_shape_prefix(batch, (_DEVICE_COUNT_MOCK, sample_batch_size)) From 6b58e1b08cc49d9d18d41ed74cf1a3693f351b51 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 8 Dec 2023 12:17:12 +0000 Subject: [PATCH 02/10] chore: edit docstring --- flashbax/buffers/item_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py index 3afc48b..091e316 100644 --- a/flashbax/buffers/item_buffer.py +++ b/flashbax/buffers/item_buffer.py @@ -41,7 +41,7 @@ def validate_item_buffer_args( min_length: int, sample_batch_size: int, ): - """Validates the arguments for the flat buffer.""" + """Validates the arguments for the item buffer.""" validate_sample_batch_size(sample_batch_size, max_length) validate_min_length(min_length, max_length) From bb8d382c5228e15696f1ce31503272b55342282b Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Sat, 9 Dec 2023 14:47:11 +0200 Subject: [PATCH 03/10] chore: added item_buffer imports to init. --- flashbax/__init__.py | 2 ++ flashbax/buffers/__init__.py | 1 + 2 files changed, 3 insertions(+) diff --git a/flashbax/__init__.py b/flashbax/__init__.py index 135b2c7..f685c37 100644 --- a/flashbax/__init__.py +++ b/flashbax/__init__.py @@ -15,7 +15,9 @@ from flashbax.buffers import ( flat_buffer, + item_buffer, make_flat_buffer, + make_item_buffer, make_prioritised_flat_buffer, make_prioritised_trajectory_buffer, make_trajectory_buffer, diff --git a/flashbax/buffers/__init__.py b/flashbax/buffers/__init__.py index 15f3f71..a458d89 100644 --- a/flashbax/buffers/__init__.py +++ b/flashbax/buffers/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from flashbax.buffers.flat_buffer import make_flat_buffer +from flashbax.buffers.item_buffer import make_item_buffer from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer from flashbax.buffers.prioritised_trajectory_buffer import ( make_prioritised_trajectory_buffer, From 01743814dcec455565b5fa707a9b5551570fe6cc Mon Sep 17 00:00:00 2001 From: Edan Toledo <42650996+EdanToledo@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:15:24 +0000 Subject: [PATCH 04/10] Update flashbax/buffers/item_buffer.py Co-authored-by: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> --- flashbax/buffers/item_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py index 091e316..35f2305 100644 --- a/flashbax/buffers/item_buffer.py +++ b/flashbax/buffers/item_buffer.py @@ -131,7 +131,7 @@ def make_item_buffer( min_length (int): The minimum length of the buffer. sample_batch_size (int): The batch size of the samples. add_sequences (Optional[bool], optional): Whether data is being added in sequences - to the buffer. If False, single transitions are being added each time add + to the buffer. If False, single items are being added each time add is called. Defaults to False. add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer. If False, single transitions or single sequences are being added each time add From fc77281d1e74188a32e223ac68037040dc163c04 Mon Sep 17 00:00:00 2001 From: Edan Toledo <42650996+EdanToledo@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:15:38 +0000 Subject: [PATCH 05/10] Update flashbax/buffers/item_buffer.py Co-authored-by: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> --- flashbax/buffers/item_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py index 35f2305..1300ca1 100644 --- a/flashbax/buffers/item_buffer.py +++ b/flashbax/buffers/item_buffer.py @@ -109,7 +109,7 @@ def add_fn( def sample_fn( state: TrajectoryBufferState, rng_key: PRNGKey ) -> TrajectoryBufferSample[Experience]: - """Samples a batch of transitions from the buffer.""" + """Samples a batch of items from the buffer.""" sampled_batch = buffer.sample(state, rng_key).experience sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch) return TrajectoryBufferSample(experience=sampled_batch) From 31f2ec6552493dfd55313866393a8be0c3f96280 Mon Sep 17 00:00:00 2001 From: Edan Toledo <42650996+EdanToledo@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:15:51 +0000 Subject: [PATCH 06/10] Update flashbax/buffers/item_buffer.py Co-authored-by: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> --- flashbax/buffers/item_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py index 1300ca1..e33614d 100644 --- a/flashbax/buffers/item_buffer.py +++ b/flashbax/buffers/item_buffer.py @@ -54,7 +54,7 @@ def create_item_buffer( add_sequences: bool, add_batches: bool, ) -> TrajectoryBuffer: - """Creates a trajectory buffer that acts as a independent item buffer. + """Creates a trajectory buffer that acts as an independent item buffer. Args: max_length (int): The maximum length of the buffer. From 3b6fd7b19c2ce55e0123aff2528f91b7b7ba4f92 Mon Sep 17 00:00:00 2001 From: Edan Toledo <42650996+EdanToledo@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:15:59 +0000 Subject: [PATCH 07/10] Update flashbax/buffers/item_buffer.py Co-authored-by: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> --- flashbax/buffers/item_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py index e33614d..de30a15 100644 --- a/flashbax/buffers/item_buffer.py +++ b/flashbax/buffers/item_buffer.py @@ -61,7 +61,7 @@ def create_item_buffer( min_length (int): The minimum length of the buffer. sample_batch_size (int): The batch size of the samples. add_sequences (Optional[bool], optional): Whether data is being added in sequences - to the buffer. If False, single transitions are being added each time add + to the buffer. If False, single items are being added each time add is called. Defaults to False. add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer. If False, single transitions or single sequences are being added each time add From 4d1f0557a02f38fa955ac209c730a30f1b4cdc6f Mon Sep 17 00:00:00 2001 From: Edan Toledo <42650996+EdanToledo@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:16:15 +0000 Subject: [PATCH 08/10] Update flashbax/buffers/item_buffer.py Co-authored-by: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> --- flashbax/buffers/item_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py index de30a15..95cd407 100644 --- a/flashbax/buffers/item_buffer.py +++ b/flashbax/buffers/item_buffer.py @@ -33,7 +33,7 @@ def validate_sample_batch_size(sample_batch_size: int, max_length: int): def validate_min_length(min_length: int, max_length: int): if min_length > max_length: - raise ValueError("min_length used is too large for the buffer size.") + raise ValueError("min_length used cannot be larger than max_length.") def validate_item_buffer_args( From 97fe823c2f4e188139a24f4114fffa8341a16b10 Mon Sep 17 00:00:00 2001 From: Edan Toledo <42650996+EdanToledo@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:16:29 +0000 Subject: [PATCH 09/10] Update flashbax/buffers/item_buffer.py Co-authored-by: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> --- flashbax/buffers/item_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py index 95cd407..446b6cc 100644 --- a/flashbax/buffers/item_buffer.py +++ b/flashbax/buffers/item_buffer.py @@ -64,7 +64,7 @@ def create_item_buffer( to the buffer. If False, single items are being added each time add is called. Defaults to False. add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer. - If False, single transitions or single sequences are being added each time add + If False, single items (or single sequences of items) are being added each time add is called. Defaults to False. Returns: From 195a6dfb8f6d4b05594a6b2cdd49c35a508fa10f Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Thu, 14 Dec 2023 13:47:29 +0000 Subject: [PATCH 10/10] feat: add to readme about item buffer --- README.md | 10 +++++++--- examples/quickstart_flat_buffer.ipynb | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0dd338a..e8c21ae 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ Flashbax provides an implementation of various different types of buffers, such ๐Ÿ—„๏ธ **Flat Buffer**: The Flat Buffer, akin to the transition buffer used in algorithms like DQN, is a core component. It employs a sequence of 2 (i.e. $s_t$, $s_{t+1}$), with a period of 1 for comprehensive transition pair consideration. +๐Ÿงบ **Item Buffer**: The Item Buffer is a simple buffer that stores individual items. It is useful for storing data that is independent of each other, such as (observation, action, reward, discount, next_observation) tuples, or entire episodes. + ๐Ÿ›ค๏ธ **Trajectory Buffer**: The Trajectory Buffer facilitates the sampling of multi-step trajectories, catering to algorithms utilising recurrent networks like R2D2 (Kapturowski et al., [2018](https://www.deepmind.com/publications/recurrent-experience-replay-in-distributed-reinforcement-learning)). ๐Ÿ… **Prioritised Buffers**: Both Flat and Trajectory Buffers can be prioritised, enabling sampling based on user-defined priorities. The prioritisation mechanism aligns with the principles outlined in the PER paper (Schaul et al, [2016](https://arxiv.org/abs/1511.05952)). @@ -76,6 +78,8 @@ buffer = fbx.make_flat_buffer(...) # OR buffer = fbx.make_prioritised_flat_buffer(...) # OR +buffer = fbx.make_item_buffer(...) +# OR buffer = fbx.make_trajectory_queue(...) # Initialise @@ -155,10 +159,10 @@ Flashbax uses a trajectory buffer as the foundation for all buffer types. This m When adding batches of data, the buffer is created in a block-like structure. This means that the effective buffer size is dependent on the size of the batch dimension. The trajectory buffer allows a user to specify the add batch dimension and the max length of the time axis. This will create a block structure of (batch, time) allowing the maximum number of timesteps that can be in storage to be batch*time. For ease of use, we provide the max size argument that allows a user to set their total desired number of timesteps and we calculate the max length of the time axis dependent on the add batch dimension that is provided. Due to this, it is important to note that when using the max size argument, the max length of the time axis will be equal to max size // add batch size which will round down thereby reducing the effective buffer size. This means one might think they are increasing the buffer size by a certain amount but in actuality there is no increase. Therefore, to avoid this, we recommend one of two things: Use the max length time axis argument explicitly or increase the max size argument in multiples of the add batch size. ### Handling Episode Truncation -Another critical aspect is episode truncation. When truncating episodes and adding data to the buffer, it's vital to ensure that you set a done flag or a 'discount' value appropriately. Neglecting to do so can introduce challenges into your algorithm's implementation and behavior. As stated previously, it is expected that the algorithm handles these cases appropriately. +Another critical aspect is episode truncation. When truncating episodes and adding data to the buffer, it's vital to ensure that you set a done flag or a 'discount' value appropriately. Neglecting to do so can introduce challenges into your algorithm's implementation and behavior. As stated previously, it is expected that the algorithm handles these cases appropriately. It can be difficult handling truncation when using the flat buffer or trajectory buffer as the algorithm must handle the case of the final timestep in an episode being followed by the first timestep in the next episode. Sacrificing memory efficiency for ease of use, the item buffer can be used to store transitions or entire trajectories independently. This means that the algorithm does not need to handle the case of the final timestep in an episode being followed by the first timestep in the next episode as only the data that is explicitly inserted can be sampled. ### Independent Data Usage -For situations where you intend to utilise buffers with data that lack sequential information, you can leverage a trajectory buffer with specific configurations. By setting a sequence dimension of 1 and a period of 1, each item will be treated as independent. However, when working with independent transition items like (observation, action, reward, discount, next_observation), be mindful that this approach will result in duplicate observations within the buffer, leading to unnecessary memory consumption. It is important to note that the implementation of the flat buffer will be slower than utilising a trajectory buffer in this way due to the inherent speed issues that arise with data indexing on hardware accelerators; however, this trade-off is done to enhance memory efficiency. If speed is largely preferred over memory efficiency then use the trajectory buffer with sequence 1 and period 1 storing full transition data items. +For situations where you intend to utilise buffers with data that lack sequential information, you can leverage the item buffer which is a wrapped trajectory buffer with specific configurations. By setting a sequence dimension of 1 and a period of 1, each item will be treated as independent. However, when working with independent transition items like (observation, action, reward, discount, next_observation), be mindful that this approach will result in duplicate observations within the buffer, leading to unnecessary memory consumption. It is important to note that the implementation of the flat buffer will be slower than utilising the item buffer in this way due to the inherent speed issues that arise with data indexing on hardware accelerators; however, this trade-off is done to enhance memory efficiency. If speed is largely preferred over memory efficiency then use the trajectory buffer with sequence 1 and period 1 storing full transition data items. ### In-place Updating of Buffer State Since buffers are generally large and occupy a significant portion of device memory, it is beneficial to perform in-place updates. To do this, it is important to specify to the top-level compiled function that you would like to perform this in-place update operation. This is indicated as follows: @@ -200,7 +204,7 @@ Here we provide a series of initial benchmarks outlining the performance of the | Sample Sequence Length | 1 | | Sample Sequence Period | 1 | -The reason we use a sample sequence length and period of 1 is to directly compare to the other buffers. This essentially means that the trajectory buffers are being used as memory inefficent transition buffers. It is important to note that the Flat Buffer implementations use a sample sequence length of 2. Additionally, one must bear in mind that not all other buffer implementations can efficiently make use of GPUs/TPUs thus they simply run on the CPU and perform device conversions. Lastly, we explicitly make use of python loops to fairly compare to the other buffers. Speeds can be largely improved using scan operations (depending on observation size). +The reason we use a sample sequence length and period of 1 is to directly compare to the other buffers, this means the speeds for the trajectory buffer are comparable to the speeds of the item buffer as the item buffer is simply a wrapped trajectory buffer with this configuration. This essentially means that the trajectory buffers are being used as memory inefficent transition buffers. It is important to note that the Flat Buffer implementations use a sample sequence length of 2. Additionally, one must bear in mind that not all other buffer implementations can efficiently make use of GPUs/TPUs thus they simply run on the CPU and perform device conversions. Lastly, we explicitly make use of python loops to fairly compare to the other buffers. Speeds can be largely improved using scan operations (depending on observation size). ### CPU Speeds diff --git a/examples/quickstart_flat_buffer.ipynb b/examples/quickstart_flat_buffer.ipynb index 902f344..b47a768 100644 --- a/examples/quickstart_flat_buffer.ipynb +++ b/examples/quickstart_flat_buffer.ipynb @@ -7,7 +7,7 @@ "source": [ "# Quickstart: Using the Flat Buffer with Flashbax\n", "\n", - "This guide demonstrates how to use the Flat Buffer, the simplest of the Flashbax buffers, for experience replay in reinforcement learning tasks. The Flat Buffer operates by saving all experience data in a first-in-first-out (FIFO) queue and returns batches of uniformly sampled experience from it. This is akin to the buffer used in the [original DQN paper](https://arxiv.org/abs/1312.5602). " + "This guide demonstrates how to use the Flat Buffer, for experience replay in reinforcement learning tasks. The Flat Buffer operates by saving all experience data in a first-in-first-out (FIFO) queue and returns batches of uniformly sampled experience from it. This is akin to the buffer used in the [original DQN paper](https://arxiv.org/abs/1312.5602). " ] }, {