-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from instadeepai/feat/vault
feat: vault
- Loading branch information
Showing
6 changed files
with
1,056 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,354 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Vault demonstration" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%capture\n", | ||
"try:\n", | ||
" import flashbax as fbx\n", | ||
"except ModuleNotFoundError:\n", | ||
" print('installing flashbax')\n", | ||
" %pip install -q flashbax\n", | ||
" import flashbax as fbx" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax\n", | ||
"from typing import NamedTuple\n", | ||
"import jax.numpy as jnp\n", | ||
"from flashbax.vault import Vault\n", | ||
"import flashbax as fbx\n", | ||
"from chex import Array" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We create a simple timestep structure, with a corresponding flat buffer." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/claude/flashbax/flashbax/buffers/trajectory_buffer.py:473: UserWarning: Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = 5`.This allows one to control exactly how many timesteps are stored in the buffer.Note that this overrides the `max_length_time_axis` argument.\n", | ||
" warnings.warn(\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"class FbxTransition(NamedTuple):\n", | ||
" obs: Array\n", | ||
"\n", | ||
"tx = FbxTransition(obs=jnp.zeros(shape=(2,)))\n", | ||
"\n", | ||
"buffer = fbx.make_flat_buffer(\n", | ||
" max_length=5,\n", | ||
" min_length=1,\n", | ||
" sample_batch_size=1,\n", | ||
")\n", | ||
"buffer_state = buffer.init(tx)\n", | ||
"buffer_add = jax.jit(buffer.add, donate_argnums=0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The shape of this buffer is $(B = 1, T = 5, E = 2)$, meaning the buffer can hold 5 timesteps, where each observation is of shape $(2,)$." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(1, 5, 2)" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"buffer_state.experience.obs.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We create the vault, based on the buffer's experience structure." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"New vault created at /tmp/demo/20240205140817\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"v = Vault(\n", | ||
" vault_name=\"demo\",\n", | ||
" experience_structure=buffer_state.experience,\n", | ||
" rel_dir=\"/tmp\"\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We now add 10 timesteps to the buffer, and write that buffer to the vault. We inspect the buffer and vault state after each timestep." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[0. 0.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[]\n", | ||
"------------------\n", | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[1. 1.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]]]\n", | ||
"------------------\n", | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]]]\n", | ||
"------------------\n", | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [0. 0.]\n", | ||
" [0. 0.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]]]\n", | ||
"------------------\n", | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [0. 0.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]]]\n", | ||
"------------------\n", | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]]]\n", | ||
"------------------\n", | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[6. 6.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]\n", | ||
" [6. 6.]]]\n", | ||
"------------------\n", | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[6. 6.]\n", | ||
" [7. 7.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]\n", | ||
" [6. 6.]\n", | ||
" [7. 7.]]]\n", | ||
"------------------\n", | ||
"------------------\n", | ||
"Buffer state:\n", | ||
"[[[6. 6.]\n", | ||
" [7. 7.]\n", | ||
" [8. 8.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]\n", | ||
" [6. 6.]\n", | ||
" [7. 7.]\n", | ||
" [8. 8.]]]\n", | ||
"------------------\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for i in range(1, 10):\n", | ||
" print('------------------')\n", | ||
" print(\"Buffer state:\")\n", | ||
" print(buffer_state.experience.obs)\n", | ||
" print()\n", | ||
"\n", | ||
" v.write(buffer_state)\n", | ||
"\n", | ||
" print(\"Vault state:\")\n", | ||
" print(v.read().experience.obs)\n", | ||
" print('------------------')\n", | ||
"\n", | ||
" buffer_state = buffer_add(\n", | ||
" buffer_state,\n", | ||
" FbxTransition(obs=i * jnp.ones(1))\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Notice that when the buffer (implemented as a ring buffer) wraps around, the vault continues storing the data:\n", | ||
"```\n", | ||
"Buffer state:\n", | ||
"[[[6. 6.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]]]\n", | ||
"\n", | ||
"Vault state:\n", | ||
"[[[1. 1.]\n", | ||
" [2. 2.]\n", | ||
" [3. 3.]\n", | ||
" [4. 4.]\n", | ||
" [5. 5.]\n", | ||
" [6. 6.]]]\n", | ||
"```\n", | ||
"\n", | ||
"Note: the vault must be given the buffer state at least every `max_steps` number of timesteps (i.e. before stale data is overwritten in the ring buffer)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "flashbax", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.