Skip to content

Commit

Permalink
feat: init n_step buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Feb 21, 2024
1 parent 53df15b commit 1b9d0b0
Show file tree
Hide file tree
Showing 19 changed files with 327 additions and 75 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ buffer = fbx.make_trajectory_buffer(...)
# OR
buffer = fbx.make_prioritised_trajectory_buffer(...)
# OR
buffer = fbx.make_flat_buffer(...)
buffer = fbx.make_n_step_buffer(...)
# OR
buffer = fbx.make_prioritised_flat_buffer(...)
buffer = fbx.make_prioritised_n_step_buffer(...)
# OR
buffer = fbx.make_item_buffer(...)
# OR
Expand All @@ -99,9 +99,9 @@ import jax
import jax.numpy as jnp
import flashbax as fbx

# Instantiate the flat buffer NamedTuple using `make_flat_buffer` using a simple configuration.
# Instantiate the flat buffer NamedTuple using `make_n_step_buffer` using a simple configuration.
# The returned `buffer` is simply a container for the pure functions needed for using a flat buffer.
buffer = fbx.make_flat_buffer(max_length=32, min_length=2, sample_batch_size=1)
buffer = fbx.make_n_step_buffer(max_length=32, min_length=2, sample_batch_size=1)

# Initialise the buffer's state.
fake_timestep = {"obs": jnp.array([0, 0]), "reward": jnp.array(0.0)}
Expand Down Expand Up @@ -133,9 +133,9 @@ We provide the following Colab examples for a more advanced tutorial on how to u

| Colab Notebook | Description |
|----------------|-------------|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_flat_buffer.ipynb) | Flat Buffer Quickstart|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_n_step_buffer.ipynb) | Flat Buffer Quickstart|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_trajectory_buffer.ipynb) | Trajectory Buffer Quickstart|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_prioritised_flat_buffer.ipynb) | Prioritised Flat Buffer Quickstart|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/quickstart_prioritised_n_step_buffer.ipynb) | Prioritised Flat Buffer Quickstart|
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_dqn_example.ipynb) | Anakin DQN |
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_prioritised_dqn_example.ipynb) | Anakin Prioritised DQN |
| [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/anakin_ppo_example.ipynb) | Anakin PPO |
Expand Down
4 changes: 2 additions & 2 deletions docs/api/flat_buffer.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
::: flashbax.buffers.flat_buffer
::: flashbax.buffers.n_step_buffer
options:
members:
- make_flat_buffer
- make_n_step_buffer
4 changes: 2 additions & 2 deletions docs/api/prioritised_flat_buffer.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
::: flashbax.buffers.prioritised_flat_buffer
::: flashbax.buffers.prioritised_n_step_buffer
options:
members:
- make_prioritised_flat_buffer
- make_prioritised_n_step_buffer
2 changes: 1 addition & 1 deletion examples/anakin_dqn_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@
" dummy_obs = jnp.expand_dims(obs, 0) # dummy for net init.\n",
" params = network.init(rng_p, dummy_obs) # initialise params.\n",
" opt_state = optim.init(params) # initialise optimiser stats.\n",
" buffer_fn = fbx.make_flat_buffer(\n",
" buffer_fn = fbx.make_n_step_buffer(\n",
" max_length=buffer_size,\n",
" min_length=batch_size,\n",
" sample_batch_size=batch_size,\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/anakin_prioritised_dqn_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@
" dummy_obs = jnp.expand_dims(obs, 0) # dummy for net init.\n",
" params = network.init(rng_p, dummy_obs) # initialise params.\n",
" opt_state = optim.init(params) # initialise optimiser stats.\n",
" buffer_fn = fbx.make_prioritised_flat_buffer(\n",
" buffer_fn = fbx.make_prioritised_n_step_buffer(\n",
" max_length=buffer_size,\n",
" min_length=batch_size,\n",
" sample_batch_size=batch_size,\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/gym_dqn_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@
" # We specify whether we will be adding batches of transitions or individual transitions. \n",
" # If we are using a vectorised environment (n_env > 1), we will add batches of transitions and \n",
" # specify the add batch size as the number of environments\n",
" buffer = fbx.make_flat_buffer(\n",
" buffer = fbx.make_n_step_buffer(\n",
" max_length=buffer_size,\n",
" min_length=sample_batch_size,\n",
" sample_batch_size=sample_batch_size,\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/quickstart_flat_buffer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
" # If adding data in batches, what is the batch size that is being added each time?\n",
"\n",
"# Instantiate the flat buffer, which is a Dataclass of pure functions.\n",
"buffer = fbx.make_flat_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batch_size)"
"buffer = fbx.make_n_step_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batch_size)"
]
},
{
Expand Down Expand Up @@ -292,7 +292,7 @@
"add_batch_size = 8\n",
"\n",
"# Re-instantiate the buffer\n",
"buffer = fbx.make_flat_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batch_size)\n",
"buffer = fbx.make_n_step_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batch_size)\n",
"\n",
"# Initialize the buffer's state with a \"device\" dimension\n",
"fake_timestep_per_device = jax.tree_map(\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/quickstart_prioritised_flat_buffer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@
"- **sample**: sample a batch from the buffer's state with probability proportional to the samples priority\n",
"- **set_priorities**: update the priorities of specific experience within the buffer state\n",
"\n",
"below we will go through how each of these can be used. In the below code we use these functions without `jax.pmap`, however they can be easily adapted for this. To see how this can be done we refer to the `examples/quickstart_flat_buffer` notebook and the `test_prioritised_buffer_does_not_smoke` function notebook in `flashbax.buffers.prioritised_buffer_test.py`. "
"below we will go through how each of these can be used. In the below code we use these functions without `jax.pmap`, however they can be easily adapted for this. To see how this can be done we refer to the `examples/quickstart_n_step_buffer` notebook and the `test_prioritised_buffer_does_not_smoke` function notebook in `flashbax.buffers.prioritised_buffer_test.py`. "
]
},
{
"cell_type": "markdown",
"id": "3d5a48a7",
"metadata": {},
"source": [
"Firstly, we provide the function `make_prioritised_flat_buffer` which returns an instance of the `PrioritisedTrajectoryBuffer` with wrapped sample and add functionality. This is a `Dataclass` containing the aforementioned `init`, `add`, `can_sample`, `sample` and `set_prioritised` pure functions."
"Firstly, we provide the function `make_prioritised_n_step_buffer` which returns an instance of the `PrioritisedTrajectoryBuffer` with wrapped sample and add functionality. This is a `Dataclass` containing the aforementioned `init`, `add`, `can_sample`, `sample` and `set_prioritised` pure functions."
]
},
{
Expand Down Expand Up @@ -115,7 +115,7 @@
" # If adding data in batches, what is the batch size that is being added each time?\n",
"\n",
"# Instantiate the prioritised buffer, which is a NamedTuple of pure functions.\n",
"buffer = fbx.make_prioritised_flat_buffer(\n",
"buffer = fbx.make_prioritised_n_step_buffer(\n",
" max_length, min_length, sample_batch_size, add_sequences, add_batch_size, priority_exponent\n",
")"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart_trajectory_buffer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"3. `can_sample`: Check if the buffer is ready to be sampled.\n",
"4. `sample`: Sample a batch from the buffer.\n",
"\n",
"Below we will go through how each of these can be used. We note the buffer is compatible with `jax.pmap` - we show how to use flashbax buffers with `jax.pmap` in this `examples.quickstart_flat_buffer.py` tutorial. "
"Below we will go through how each of these can be used. We note the buffer is compatible with `jax.pmap` - we show how to use flashbax buffers with `jax.pmap` in this `examples.quickstart_n_step_buffer.py` tutorial. "
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/vault_demonstration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"\n",
"tx = FbxTransition(obs=jnp.zeros(shape=(2,)))\n",
"\n",
"buffer = fbx.make_flat_buffer(\n",
"buffer = fbx.make_n_step_buffer(\n",
" max_length=5,\n",
" min_length=1,\n",
" sample_batch_size=1,\n",
Expand Down
8 changes: 3 additions & 5 deletions flashbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
# limitations under the License.


from flashbax.buffers import (
flat_buffer,
from flashbax.buffers import ( # make_prioritised_n_step_buffer,; prioritised_n_step_buffer,
item_buffer,
make_flat_buffer,
make_item_buffer,
make_prioritised_flat_buffer,
make_n_step_buffer,
make_prioritised_trajectory_buffer,
make_trajectory_buffer,
make_trajectory_queue,
prioritised_flat_buffer,
n_step_buffer,
prioritised_trajectory_buffer,
trajectory_buffer,
trajectory_queue,
Expand Down
5 changes: 3 additions & 2 deletions flashbax/buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flashbax.buffers.flat_buffer import make_flat_buffer
from flashbax.buffers.item_buffer import make_item_buffer
from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer
from flashbax.buffers.n_step_buffer import make_n_step_buffer

# from flashbax.buffers.prioritised_n_step_buffer import make_prioritised_n_step_buffer
from flashbax.buffers.prioritised_trajectory_buffer import (
make_prioritised_trajectory_buffer,
)
Expand Down
2 changes: 1 addition & 1 deletion flashbax/buffers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def fake_transition() -> chex.ArrayTree:
return {
"obs": jnp.array((5, 4), dtype=jnp.float32),
"action": jnp.ones((2,), dtype=jnp.int32),
"reward": jnp.zeros((), dtype=jnp.float16),
"reward": jnp.zeros((), dtype=jnp.float32),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Generic, Optional
from typing import TYPE_CHECKING, Callable, Dict, Generic, Optional

import chex
from chex import PRNGKey
from typing_extensions import NamedTuple

Expand Down Expand Up @@ -63,32 +64,47 @@ def validate_max_length_add_batch_size(max_length: int, add_batch_size: int):
)


def validate_flat_buffer_args(
def validate_n_step(n_step: int, max_length: int):
if n_step >= max_length:
raise ValueError(
f"""n_step must be less than max_length. It is currently
{n_step} >= {max_length}"""
)


def validate_n_step_buffer_args(
max_length: int,
min_length: int,
sample_batch_size: int,
add_batch_size: int,
n_step: int,
):
"""Validates the arguments for the flat buffer."""
"""Validates the arguments for the n-step buffer."""

validate_sample_batch_size(sample_batch_size, max_length)
validate_min_length(min_length, add_batch_size, max_length)
validate_max_length_add_batch_size(max_length, add_batch_size)
validate_n_step(n_step, max_length)


def create_flat_buffer(
def create_n_step_buffer(
max_length: int,
min_length: int,
sample_batch_size: int,
add_sequences: bool,
add_batch_size: Optional[int],
n_step: int,
n_step_functional_map: Optional[
Dict[str, Callable[[chex.Array], chex.Array]]
] = None,
) -> TrajectoryBuffer:
"""Creates a trajectory buffer that acts as a flat buffer.
"""Creates a trajectory buffer that acts as a n-step buffer.
Args:
max_length (int): The maximum length of the buffer.
min_length (int): The minimum length of the buffer.
sample_batch_size (int): The batch size of the samples.
n_step (int): The number of steps to use for the n-step buffer.
add_sequences (Optional[bool], optional): Whether data is being added in sequences
to the buffer. If False, single transitions are being added each time add
is called. Defaults to False.
Expand All @@ -106,11 +122,12 @@ def create_flat_buffer(
else:
add_batches = True

validate_flat_buffer_args(
validate_n_step_buffer_args(
max_length=max_length,
min_length=min_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
n_step=n_step,
)

with warnings.catch_warnings():
Expand All @@ -127,7 +144,7 @@ def create_flat_buffer(
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
sample_sequence_length=n_step + 1,
period=1,
max_size=max_length,
)
Expand All @@ -147,41 +164,73 @@ def create_flat_buffer(

def sample_fn(state: TrajectoryBufferState, rng_key: PRNGKey) -> TransitionSample:
"""Samples a batch of transitions from the buffer."""
sampled_batch = buffer.sample(state, rng_key).experience
first = jax.tree_util.tree_map(lambda x: x[:, 0], sampled_batch)
second = jax.tree_util.tree_map(lambda x: x[:, 1], sampled_batch)
sampled_n_step_sequence_item = buffer.sample(state, rng_key).experience
if n_step_functional_map is not None:
item_type_dict_mapper = n_step_functional_map.pop("dict_mapper", dict)
experience_type = type(sampled_n_step_sequence_item)
sampled_n_step_sequence_dict = item_type_dict_mapper(
sampled_n_step_sequence_item
)
for key, fun in n_step_functional_map.items():
sampled_n_step_sequence_dict[key] = jax.tree_util.tree_map(
fun, sampled_n_step_sequence_dict[key]
)
sampled_n_step_sequence_item = experience_type(
**sampled_n_step_sequence_dict
)
first = jax.tree_util.tree_map(lambda x: x[:, 0], sampled_n_step_sequence_item)
second = jax.tree_util.tree_map(
lambda x: x[:, -1], sampled_n_step_sequence_item
)
return TransitionSample(experience=ExperiencePair(first=first, second=second))

return buffer.replace(add=add_fn, sample=sample_fn) # type: ignore


def make_flat_buffer(
def make_n_step_buffer(
max_length: int,
min_length: int,
sample_batch_size: int,
n_step: int = 1,
add_sequences: bool = False,
add_batch_size: Optional[int] = None,
n_step_functional_map: Optional[
Dict[str, Callable[[chex.Array], chex.Array]]
] = None,
) -> TrajectoryBuffer:
"""Makes a trajectory buffer act as a flat buffer.
"""Makes a trajectory buffer act as a n-step buffer.
Args:
max_length (int): The maximum length of the buffer.
min_length (int): The minimum length of the buffer.
sample_batch_size (int): The batch size of the samples.
n_step (int): The number of steps to use for the n-step buffer.
add_sequences (Optional[bool], optional): Whether data is being added in sequences
to the buffer. If False, single transitions are being added each time add
is called. Defaults to False.
add_batch_size (Optional[int], optional): If adding data in batches, what is the
batch size that is being added each time. If None, single transitions or single
sequences are being added each time add is called. Defaults to None.
n_step_functional_map (Optional[Dict[str, Callable[[chex.Array], chex.Array]]], optional):
A dictionary of functions to apply to the n-step transitions. The keys are the names of
the data attributes and the values are the functions to apply. Each function takes in a
sequence of n+1 data items and returns a sequence of n+1 data items. However, only the
first and last value of the sequence are used and placed into first and second
respectively. If the dictionary has a key 'dict_mapper' then the value is a function
that takes in the batch data and returns a dictionary. This function is used to map the
data type that to a dictionary. For dictionaries and chex.dataclasses this does not need
to be set but for flax structs and named tuples this needs to be set accordingly. Only
data attributes that are in the dictionary are modified.
Returns:
The buffer."""

return create_flat_buffer(
return create_n_step_buffer(
max_length=max_length,
min_length=min_length,
sample_batch_size=sample_batch_size,
n_step=n_step,
add_sequences=add_sequences,
add_batch_size=add_batch_size,
n_step_functional_map=n_step_functional_map,
)
Loading

0 comments on commit 1b9d0b0

Please sign in to comment.