Skip to content

Commit

Permalink
Merge pull request #5 from instadeepai/feat/vault
Browse files Browse the repository at this point in the history
feat: vault
  • Loading branch information
callumtilbury authored Feb 5, 2024
2 parents 1d45803 + 8709c65 commit 21ce0b0
Show file tree
Hide file tree
Showing 6 changed files with 1,056 additions and 1 deletion.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ from CleanRLs DQN JAX example.
- 🦎 [Jumanji](https://github.com/instadeepai/jumanji/) - utilise Jumanji's JAX based environments
like Snake for our fully jitted examples.

## Vault 💾
Vault is an efficient mechanism for saving Flashbax buffers to persistent data storage, e.g. for use in offline reinforcement learning. Consider a Flashbax buffer which has experience data of dimensionality $(B, T, *E)$, where $B$ is a batch dimension (for the sake of recording independent trajectories synchronously), $T$ is a temporal/sequential dimension, and $*E$ indicates the one or more dimensions of the experience data itself. Since large quantities of data may be generated for a given environment, Vault extends the $T$ dimension to a virtually unconstrained degree by reading and writing slices of buffers along this temporal axis. In doing so, gigantic buffer stores can reside on disk, from which sub-buffers can be loaded into RAM/VRAM for efficient offline training. Vault has been tested with the item, flat, and trajectory buffers.

For more information, see the demonstrative notebook: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/vault_demonstration.ipynb)


## Important Considerations ⚠️

When working with Flashbax buffers, it's crucial to be mindful of certain considerations to ensure the proper functionality of your RL agent.
Expand Down Expand Up @@ -188,6 +194,10 @@ train_state, buffer_state = jax.jit(train, donate_argnums=(1,))(

It is important to include `donate_argnums` when calling `jax.jit` to enable JAX to perform an in-place update of the replay buffer state. Omitting `donate_argnums` would force JAX to create a copy of the state for any modifications to the replay buffer state, potentially negating all performance benefits. More information about buffer donation in JAX can be found in the [documentation](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).


### Storing Data with Vault
As mentioned [above](./README.md#vault-💾), Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\tau$, then we must save to the vault every $\tau - 1$ steps, lest we overwrite (and lose) unsaved data.

In summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers.

## Benchmarks 📈
Expand Down Expand Up @@ -242,6 +252,7 @@ Previous benchmarks added only a single timestep at a time, we now evaluate addi

Ultimately, we see improved or comparable performance to benchmarked buffers whilst providing buffers that are fully JAX-compatible in addition to other features such as batched adding as well as being able to add sequences of varying length. We do note that due to JAX having different XLA backends for CPU, GPU, and TPU, the performance of the buffers can vary depending on the device and the specific operation being called.


## Contributing 🤝

Contributions are welcome! See our issue tracker for
Expand Down
354 changes: 354 additions & 0 deletions examples/vault_demonstration.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vault demonstration"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"try:\n",
" import flashbax as fbx\n",
"except ModuleNotFoundError:\n",
" print('installing flashbax')\n",
" %pip install -q flashbax\n",
" import flashbax as fbx"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"from typing import NamedTuple\n",
"import jax.numpy as jnp\n",
"from flashbax.vault import Vault\n",
"import flashbax as fbx\n",
"from chex import Array"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create a simple timestep structure, with a corresponding flat buffer."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/claude/flashbax/flashbax/buffers/trajectory_buffer.py:473: UserWarning: Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = 5`.This allows one to control exactly how many timesteps are stored in the buffer.Note that this overrides the `max_length_time_axis` argument.\n",
" warnings.warn(\n"
]
}
],
"source": [
"class FbxTransition(NamedTuple):\n",
" obs: Array\n",
"\n",
"tx = FbxTransition(obs=jnp.zeros(shape=(2,)))\n",
"\n",
"buffer = fbx.make_flat_buffer(\n",
" max_length=5,\n",
" min_length=1,\n",
" sample_batch_size=1,\n",
")\n",
"buffer_state = buffer.init(tx)\n",
"buffer_add = jax.jit(buffer.add, donate_argnums=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The shape of this buffer is $(B = 1, T = 5, E = 2)$, meaning the buffer can hold 5 timesteps, where each observation is of shape $(2,)$."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 5, 2)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"buffer_state.experience.obs.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create the vault, based on the buffer's experience structure."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New vault created at /tmp/demo/20240205140817\n"
]
}
],
"source": [
"v = Vault(\n",
" vault_name=\"demo\",\n",
" experience_structure=buffer_state.experience,\n",
" rel_dir=\"/tmp\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now add 10 timesteps to the buffer, and write that buffer to the vault. We inspect the buffer and vault state after each timestep."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"------------------\n",
"Buffer state:\n",
"[[[0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [0. 0.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[6. 6.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]\n",
" [6. 6.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[6. 6.]\n",
" [7. 7.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]\n",
" [6. 6.]\n",
" [7. 7.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[6. 6.]\n",
" [7. 7.]\n",
" [8. 8.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]\n",
" [6. 6.]\n",
" [7. 7.]\n",
" [8. 8.]]]\n",
"------------------\n"
]
}
],
"source": [
"for i in range(1, 10):\n",
" print('------------------')\n",
" print(\"Buffer state:\")\n",
" print(buffer_state.experience.obs)\n",
" print()\n",
"\n",
" v.write(buffer_state)\n",
"\n",
" print(\"Vault state:\")\n",
" print(v.read().experience.obs)\n",
" print('------------------')\n",
"\n",
" buffer_state = buffer_add(\n",
" buffer_state,\n",
" FbxTransition(obs=i * jnp.ones(1))\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that when the buffer (implemented as a ring buffer) wraps around, the vault continues storing the data:\n",
"```\n",
"Buffer state:\n",
"[[[6. 6.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]\n",
" [6. 6.]]]\n",
"```\n",
"\n",
"Note: the vault must be given the buffer state at least every `max_steps` number of timesteps (i.e. before stale data is overwritten in the ring buffer)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "flashbax",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 21ce0b0

Please sign in to comment.