From f1506374f171a267f5bef69c5fb21f29ac43bc75 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Sat, 9 Dec 2023 15:26:24 +0200 Subject: [PATCH 01/31] feat: first port of vault into fbx, wip. --- flashbax/__init__.py | 1 + flashbax/vault/__init__.py | 14 ++ flashbax/vault/vault.py | 304 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 3 +- 4 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 flashbax/vault/__init__.py create mode 100644 flashbax/vault/vault.py diff --git a/flashbax/__init__.py b/flashbax/__init__.py index 135b2c7..983627d 100644 --- a/flashbax/__init__.py +++ b/flashbax/__init__.py @@ -25,3 +25,4 @@ trajectory_buffer, trajectory_queue, ) +from flashbax.vault import Vault diff --git a/flashbax/vault/__init__.py b/flashbax/vault/__init__.py new file mode 100644 index 0000000..6da832a --- /dev/null +++ b/flashbax/vault/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flashbax.vault.vault import Vault diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py new file mode 100644 index 0000000..c26e4a2 --- /dev/null +++ b/flashbax/vault/vault.py @@ -0,0 +1,304 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import os +from datetime import datetime +from typing import Any, Optional, Tuple + +import jax +import jax.numpy as jnp +import tensorstore as ts +from chex import Array +from etils import epath + +from flashbax.buffers.trajectory_buffer import TrajectoryBufferState +from flashbax.utils import get_tree_shape_prefix + +# CURRENT LIMITATIONS / TODO LIST +# - Anakin -> extra minibatch dim... +# - Async reading if necessary +# - Only tested with flat buffers +# - Reloading could be nicer, but doing so is tricky! + +DRIVER = "file://" +METADATA_FILE = "metadata.json" +TIME_AXIS_MAX_LENGTH = int(10e12) # Upper bound on the length of the time axis +VERSION = 0.1 + + +class Vault: + def __init__( + self, + init_fbx_state: TrajectoryBufferState, + vault_name: str, + rel_dir: str = "vaults", + vault_uid: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> None: + + vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S") + self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) + + # We use epath for metadata + metadata_path = epath.Path(os.path.join(self._base_path, METADATA_FILE)) + + # Check if the vault exists, otherwise create the necessary dirs and files + base_path_exists = os.path.exists(self._base_path) + if base_path_exists: + self._metadata = json.loads(metadata_path.read_text()) + # Ensure minor versions match + assert (self._metadata["version"] // 1) == (VERSION // 1) + else: + # Create the necessary dirs for the vault + os.makedirs(self._base_path) + + def get_json_ready(obj: Any) -> Any: + """Ensure that the object is json serializable. Convert to string if not. + + Args: + obj (Any): Object to be considered + + Returns: + Any: json serializable object + """ + if not isinstance(obj, (bool, str, int, float, type(None))): + return str(obj) + else: + return obj + + metadata_json_ready = jax.tree_util.tree_map(get_json_ready, metadata) + self._metadata = { + "version": VERSION, + **(metadata_json_ready or {}), # Allow user to save extra metadata + } + metadata_path.write_text(json.dumps(self._metadata)) + + # Keep a data store for the vault index + self._vault_index_ds = ts.open( + self._get_base_spec("vault_index"), + dtype=jnp.int32, + shape=(1,), + create=not base_path_exists, + ).result() + self.vault_index = int(self._vault_index_ds.read().result()[0]) + + # Each leaf of the fbx_state.experience is a data store + self._all_ds = jax.tree_util.tree_map_with_path( + lambda path, x: self._init_leaf( + name=jax.tree_util.keystr(path), # Use the path as the name + leaf=x, + create_checkpoint=not base_path_exists, + ), + init_fbx_state.experience, + ) + + # Just store one timestep for the structure + self._fbx_sample_experience = jax.tree_map( + lambda x: x[:, 0:1, ...], + init_fbx_state.experience, + ) + self._last_received_fbx_index = 0 + + def _get_base_spec(self, name: str) -> dict: + return { + "driver": "zarr", + "kvstore": { + "driver": "ocdbt", + "base": f"{DRIVER}{self._base_path}", # TODO: does this work on other systems? + "path": name, + }, + } + + def _init_leaf(self, name: str, leaf: Array, create_checkpoint: bool = False) -> ts.TensorStore: + spec = self._get_base_spec(name) + leaf_ds = ts.open( + spec, + dtype=leaf.dtype if create_checkpoint else None, + shape=( + leaf.shape[0], # Batch dim + TIME_AXIS_MAX_LENGTH, # Time dim + *leaf.shape[2:], # Experience dim + ) + if create_checkpoint + else None, + create=create_checkpoint, + ).result() + return leaf_ds + + async def _write_leaf( + self, + source_leaf: Array, + dest_leaf: ts.TensorStore, + source_interval: Tuple[int, int], + dest_start: int, + ) -> None: + dest_interval = ( + dest_start, + dest_start + (source_interval[1] - source_interval[0]), # type: ignore + ) + await dest_leaf[:, slice(*dest_interval), ...].write( + source_leaf[:, slice(*source_interval), ...], + ) + + async def _write_chunk( + self, + fbx_state: TrajectoryBufferState, + source_interval: Tuple[int, int], + dest_start: int, + ) -> None: + # Write to each ds + futures_tree = jax.tree_util.tree_map( + lambda x, ds: self._write_leaf( + source_leaf=x, + dest_leaf=ds, + source_interval=source_interval, + dest_start=dest_start, + ), + fbx_state.experience, # x = experience + self._all_ds, # ds = data stores + ) + futures, _ = jax.tree_util.tree_flatten(futures_tree) + await asyncio.gather(*futures) + + def write( + self, + fbx_state: TrajectoryBufferState, + source_interval: Tuple[int, int] = (0, 0), + dest_start: Optional[int] = None, + ) -> None: + # TODO: more than one current_index if B > 1 + fbx_current_index = int(fbx_state.current_index) + + # By default, we write from `last received` to `current index` [CI] + if source_interval == (0, 0): + source_interval = (self._last_received_fbx_index, fbx_current_index) + + if source_interval[1] == source_interval[0]: + # Nothing to write + return + + elif source_interval[1] > source_interval[0]: + # Vanilla write, no wrap around + dest_start = self.vault_index if dest_start is None else dest_start + asyncio.run( + self._write_chunk( + fbx_state=fbx_state, + source_interval=source_interval, + dest_start=dest_start, + ) + ) + written_length = source_interval[1] - source_interval[0] + + elif source_interval[1] < source_interval[0]: + # Wrap around! + + # Get dest start + dest_start = self.vault_index if dest_start is None else dest_start + # Get seq dim + fbx_max_index = get_tree_shape_prefix(fbx_state.experience, n_axes=2)[1] + + # From last received to max + source_interval_a = (source_interval[0], fbx_max_index) + time_length_a = source_interval_a[1] - source_interval_a[0] + + asyncio.run( + self._write_chunk( + fbx_state=fbx_state, + source_interval=source_interval_a, + dest_start=dest_start, + ) + ) + + # From 0 (wrapped) to CI + source_interval_b = (0, source_interval[1]) + time_length_b = source_interval_b[1] - source_interval_b[0] + + asyncio.run( + self._write_chunk( + fbx_state=fbx_state, + source_interval=source_interval_b, + dest_start=dest_start + time_length_a, + ) + ) + + written_length = time_length_a + time_length_b + + # print( + # f"Incoming fbx index was {fbx_current_index}, \ + # vs last received {self._last_received_fbx_index}" + # ) + # print( + # f"Wrote {source_interval} into {(dest_start, dest_start + written_length)}\ + # (steps = {written_length}) to vault" + # ) + # print( + # f"Vault index is now \ + # {self.vault_index + written_length}" + # ) + + # Update vault index, and write this to the ds too + self.vault_index += written_length + self._vault_index_ds.write(self.vault_index).result() + + # Keep track of the last fbx buffer idx received + self._last_received_fbx_index = fbx_current_index + + def _read_leaf( + self, + read_leaf: ts.TensorStore, + read_interval: Tuple[int, int], + ) -> Array: + return read_leaf[:, slice(*read_interval), ...].read().result() + + def read(self, read_interval: Tuple[int, int] = (0, 0)) -> Array: # TODO typing + if read_interval == (0, 0): + read_interval = (0, self.vault_index) # Read all that has been written + + read_result = jax.tree_util.tree_map( + lambda _, ds: self._read_leaf( + read_leaf=ds, + read_interval=read_interval, + ), + self._fbx_sample_experience, # just for structure + self._all_ds, # data stores + ) + return read_result + + def get_full_buffer(self) -> TrajectoryBufferState: + return TrajectoryBufferState( + experience=self.read(), + current_index=self.vault_index, + is_full=True, + ) + + def get_buffer( + self, size: int, key: Array, starting_index: Optional[int] = None + ) -> TrajectoryBufferState: + assert size <= self.vault_index + if starting_index is None: + starting_index = int( + jax.random.randint( + key=key, + shape=(), + minval=0, + maxval=self.vault_index - size, + ) + ) + return TrajectoryBufferState( + experience=self.read((starting_index, starting_index + size)), + current_index=starting_index + size, + is_full=True, + ) diff --git a/pyproject.toml b/pyproject.toml index 2bd4b6e..004d794 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ dependencies = [ 'jax>=0.4.10', 'jaxlib>=0.4.10', 'numpy>=1.19.5', - 'typing_extensions<4.6.0' + 'typing_extensions<4.6.0', + 'tensorstore>=0.1.51' ] [project.optional-dependencies] From d6ddadfc9339b9db085b11127eb167cc31f260db Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 11 Dec 2023 10:47:45 +0200 Subject: [PATCH 02/31] feat: return write_length. --- flashbax/vault/vault.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index c26e4a2..57cea1d 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -28,7 +28,6 @@ from flashbax.utils import get_tree_shape_prefix # CURRENT LIMITATIONS / TODO LIST -# - Anakin -> extra minibatch dim... # - Async reading if necessary # - Only tested with flat buffers # - Reloading could be nicer, but doing so is tricky! @@ -178,7 +177,7 @@ def write( fbx_state: TrajectoryBufferState, source_interval: Tuple[int, int] = (0, 0), dest_start: Optional[int] = None, - ) -> None: + ) -> int: # TODO: more than one current_index if B > 1 fbx_current_index = int(fbx_state.current_index) @@ -188,7 +187,7 @@ def write( if source_interval[1] == source_interval[0]: # Nothing to write - return + return 0 elif source_interval[1] > source_interval[0]: # Vanilla write, no wrap around @@ -236,19 +235,6 @@ def write( written_length = time_length_a + time_length_b - # print( - # f"Incoming fbx index was {fbx_current_index}, \ - # vs last received {self._last_received_fbx_index}" - # ) - # print( - # f"Wrote {source_interval} into {(dest_start, dest_start + written_length)}\ - # (steps = {written_length}) to vault" - # ) - # print( - # f"Vault index is now \ - # {self.vault_index + written_length}" - # ) - # Update vault index, and write this to the ds too self.vault_index += written_length self._vault_index_ds.write(self.vault_index).result() @@ -256,6 +242,8 @@ def write( # Keep track of the last fbx buffer idx received self._last_received_fbx_index = fbx_current_index + return written_length + def _read_leaf( self, read_leaf: ts.TensorStore, From bb96ec1fb065a58520eebe762f6793c5ec9fae3e Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 11 Dec 2023 14:11:55 +0200 Subject: [PATCH 03/31] feat: save fbx structure alongside vault for better reloading (wip!!!) --- flashbax/vault/vault.py | 94 ++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 57cea1d..56a5700 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -15,6 +15,7 @@ import asyncio import json import os +from ast import literal_eval as make_tuple from datetime import datetime from typing import Any, Optional, Tuple @@ -41,8 +42,8 @@ class Vault: def __init__( self, - init_fbx_state: TrajectoryBufferState, vault_name: str, + init_fbx_state: Optional[TrajectoryBufferState] = None, rel_dir: str = "vaults", vault_uid: Optional[str] = None, metadata: Optional[dict] = None, @@ -58,9 +59,14 @@ def __init__( base_path_exists = os.path.exists(self._base_path) if base_path_exists: self._metadata = json.loads(metadata_path.read_text()) + # Ensure minor versions match assert (self._metadata["version"] // 1) == (VERSION // 1) - else: + + elif init_fbx_state is not None: + # init_fbx_state must be a TrajectoryBufferState + assert isinstance(init_fbx_state, TrajectoryBufferState) + # Create the necessary dirs for the vault os.makedirs(self._base_path) @@ -79,11 +85,18 @@ def get_json_ready(obj: Any) -> Any: return obj metadata_json_ready = jax.tree_util.tree_map(get_json_ready, metadata) + experience_structure = jax.tree_map( + lambda x: [str(x.shape), str(x.dtype)], + init_fbx_state.experience, + ) self._metadata = { "version": VERSION, + "structure": experience_structure, **(metadata_json_ready or {}), # Allow user to save extra metadata } metadata_path.write_text(json.dumps(self._metadata)) + else: + raise ValueError("Vault does not exist and no init_fbx_state provided.") # Keep a data store for the vault index self._vault_index_ds = ts.open( @@ -101,45 +114,46 @@ def get_json_ready(obj: Any) -> Any: leaf=x, create_checkpoint=not base_path_exists, ), - init_fbx_state.experience, + self._metadata['structure'], + is_leaf=lambda x: isinstance(x, list), ) - # Just store one timestep for the structure - self._fbx_sample_experience = jax.tree_map( - lambda x: x[:, 0:1, ...], - init_fbx_state.experience, - ) self._last_received_fbx_index = 0 + def _get_base_spec(self, name: str) -> dict: return { "driver": "zarr", "kvstore": { "driver": "ocdbt", - "base": f"{DRIVER}{self._base_path}", # TODO: does this work on other systems? + "base": f"{DRIVER}{self._base_path}", "path": name, }, } - def _init_leaf(self, name: str, leaf: Array, create_checkpoint: bool = False) -> ts.TensorStore: + def _init_leaf(self, name: str, leaf: list, create_checkpoint: bool = False) -> ts.TensorStore: spec = self._get_base_spec(name) + leaf_shape = make_tuple(leaf[0]) + leaf_dtype = leaf[1] leaf_ds = ts.open( spec, - dtype=leaf.dtype if create_checkpoint else None, + # Only specify dtype and shape if we are creating a checkpoint + dtype=leaf_dtype if create_checkpoint else None, shape=( - leaf.shape[0], # Batch dim + leaf_shape[0], # Batch dim TIME_AXIS_MAX_LENGTH, # Time dim - *leaf.shape[2:], # Experience dim + *leaf_shape[2:], # Experience dim ) if create_checkpoint else None, + # Only create directory if we are creating a checkpoint create=create_checkpoint, - ).result() + ).result() # Synchronous return leaf_ds async def _write_leaf( self, - source_leaf: Array, + source_leaf: jax.Array, dest_leaf: ts.TensorStore, source_interval: Tuple[int, int], dest_start: int, @@ -251,42 +265,36 @@ def _read_leaf( ) -> Array: return read_leaf[:, slice(*read_interval), ...].read().result() - def read(self, read_interval: Tuple[int, int] = (0, 0)) -> Array: # TODO typing - if read_interval == (0, 0): - read_interval = (0, self.vault_index) # Read all that has been written + def read( + self, + timesteps: Optional[int] = None, + percentiles: Optional[Tuple[int, int]] = None, + ) -> TrajectoryBufferState: + """Read from the vault.""" + + if timesteps is None and percentiles is None: + read_interval = (0, self.vault_index) + elif timesteps is not None: + read_interval = (self.vault_index - timesteps, self.vault_index) + elif percentiles is not None: + assert percentiles[0] < percentiles[1], "Percentiles must be in ascending order." + read_interval = ( + int(self.vault_index * (percentiles[0] / 100)), + int(self.vault_index * (percentiles[1] / 100)), + ) read_result = jax.tree_util.tree_map( lambda _, ds: self._read_leaf( read_leaf=ds, read_interval=read_interval, ), - self._fbx_sample_experience, # just for structure + self._metadata['structure'], # just for structure self._all_ds, # data stores - ) - return read_result - - def get_full_buffer(self) -> TrajectoryBufferState: - return TrajectoryBufferState( - experience=self.read(), - current_index=self.vault_index, - is_full=True, + is_leaf=lambda x: isinstance(x, list), ) - def get_buffer( - self, size: int, key: Array, starting_index: Optional[int] = None - ) -> TrajectoryBufferState: - assert size <= self.vault_index - if starting_index is None: - starting_index = int( - jax.random.randint( - key=key, - shape=(), - minval=0, - maxval=self.vault_index - size, - ) - ) return TrajectoryBufferState( - experience=self.read((starting_index, starting_index + size)), - current_index=starting_index + size, - is_full=True, + experience=read_result, + current_index=jnp.array(self.vault_index, dtype=int), + is_full=jnp.array(True, dtype=bool), ) From e9a3e59cd76a21a71d93ea86652895f995abbd22 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 11 Dec 2023 16:17:08 +0200 Subject: [PATCH 04/31] chore: unpin typing_extensions from < 4.6.0, as causing issues elsewhere. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 004d794..b69ac17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ 'jax>=0.4.10', 'jaxlib>=0.4.10', 'numpy>=1.19.5', - 'typing_extensions<4.6.0', + 'typing_extensions>=4.6.0', 'tensorstore>=0.1.51' ] From 1d87ed2d63ffd89912ec86fc44a63c8c71a7ef49 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 11 Dec 2023 16:20:43 +0200 Subject: [PATCH 05/31] chore: unpin typing_extensions from < 4.6.0, as it's causing issues elsewhere, plus some precommit. --- flashbax/vault/vault.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 56a5700..914bd9c 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -1,4 +1,4 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. +# Copyright 2023 InstaDeep Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,17 +21,16 @@ import jax import jax.numpy as jnp -import tensorstore as ts +import tensorstore as ts # type: ignore from chex import Array -from etils import epath +from etils import epath # type: ignore from flashbax.buffers.trajectory_buffer import TrajectoryBufferState from flashbax.utils import get_tree_shape_prefix # CURRENT LIMITATIONS / TODO LIST -# - Async reading if necessary # - Only tested with flat buffers -# - Reloading could be nicer, but doing so is tricky! +# - Reloading must be with dicts, not namedtuples DRIVER = "file://" METADATA_FILE = "metadata.json" @@ -48,7 +47,6 @@ def __init__( vault_uid: Optional[str] = None, metadata: Optional[dict] = None, ) -> None: - vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S") self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) @@ -62,7 +60,7 @@ def __init__( # Ensure minor versions match assert (self._metadata["version"] // 1) == (VERSION // 1) - + elif init_fbx_state is not None: # init_fbx_state must be a TrajectoryBufferState assert isinstance(init_fbx_state, TrajectoryBufferState) @@ -114,13 +112,12 @@ def get_json_ready(obj: Any) -> Any: leaf=x, create_checkpoint=not base_path_exists, ), - self._metadata['structure'], + self._metadata["structure"], is_leaf=lambda x: isinstance(x, list), ) self._last_received_fbx_index = 0 - def _get_base_spec(self, name: str) -> dict: return { "driver": "zarr", @@ -131,7 +128,9 @@ def _get_base_spec(self, name: str) -> dict: }, } - def _init_leaf(self, name: str, leaf: list, create_checkpoint: bool = False) -> ts.TensorStore: + def _init_leaf( + self, name: str, leaf: list, create_checkpoint: bool = False + ) -> ts.TensorStore: spec = self._get_base_spec(name) leaf_shape = make_tuple(leaf[0]) leaf_dtype = leaf[1] @@ -277,7 +276,9 @@ def read( elif timesteps is not None: read_interval = (self.vault_index - timesteps, self.vault_index) elif percentiles is not None: - assert percentiles[0] < percentiles[1], "Percentiles must be in ascending order." + assert ( + percentiles[0] < percentiles[1] + ), "Percentiles must be in ascending order." read_interval = ( int(self.vault_index * (percentiles[0] / 100)), int(self.vault_index * (percentiles[1] / 100)), @@ -288,7 +289,7 @@ def read( read_leaf=ds, read_interval=read_interval, ), - self._metadata['structure'], # just for structure + self._metadata["structure"], # just for structure self._all_ds, # data stores is_leaf=lambda x: isinstance(x, list), ) From 64f9d21061aa6d1c29b94483f0184feeab3db97d Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 11 Dec 2023 16:20:43 +0200 Subject: [PATCH 06/31] chore: precommit --- flashbax/vault/vault.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 56a5700..914bd9c 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -1,4 +1,4 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. +# Copyright 2023 InstaDeep Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,17 +21,16 @@ import jax import jax.numpy as jnp -import tensorstore as ts +import tensorstore as ts # type: ignore from chex import Array -from etils import epath +from etils import epath # type: ignore from flashbax.buffers.trajectory_buffer import TrajectoryBufferState from flashbax.utils import get_tree_shape_prefix # CURRENT LIMITATIONS / TODO LIST -# - Async reading if necessary # - Only tested with flat buffers -# - Reloading could be nicer, but doing so is tricky! +# - Reloading must be with dicts, not namedtuples DRIVER = "file://" METADATA_FILE = "metadata.json" @@ -48,7 +47,6 @@ def __init__( vault_uid: Optional[str] = None, metadata: Optional[dict] = None, ) -> None: - vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S") self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) @@ -62,7 +60,7 @@ def __init__( # Ensure minor versions match assert (self._metadata["version"] // 1) == (VERSION // 1) - + elif init_fbx_state is not None: # init_fbx_state must be a TrajectoryBufferState assert isinstance(init_fbx_state, TrajectoryBufferState) @@ -114,13 +112,12 @@ def get_json_ready(obj: Any) -> Any: leaf=x, create_checkpoint=not base_path_exists, ), - self._metadata['structure'], + self._metadata["structure"], is_leaf=lambda x: isinstance(x, list), ) self._last_received_fbx_index = 0 - def _get_base_spec(self, name: str) -> dict: return { "driver": "zarr", @@ -131,7 +128,9 @@ def _get_base_spec(self, name: str) -> dict: }, } - def _init_leaf(self, name: str, leaf: list, create_checkpoint: bool = False) -> ts.TensorStore: + def _init_leaf( + self, name: str, leaf: list, create_checkpoint: bool = False + ) -> ts.TensorStore: spec = self._get_base_spec(name) leaf_shape = make_tuple(leaf[0]) leaf_dtype = leaf[1] @@ -277,7 +276,9 @@ def read( elif timesteps is not None: read_interval = (self.vault_index - timesteps, self.vault_index) elif percentiles is not None: - assert percentiles[0] < percentiles[1], "Percentiles must be in ascending order." + assert ( + percentiles[0] < percentiles[1] + ), "Percentiles must be in ascending order." read_interval = ( int(self.vault_index * (percentiles[0] / 100)), int(self.vault_index * (percentiles[1] / 100)), @@ -288,7 +289,7 @@ def read( read_leaf=ds, read_interval=read_interval, ), - self._metadata['structure'], # just for structure + self._metadata["structure"], # just for structure self._all_ds, # data stores is_leaf=lambda x: isinstance(x, list), ) From 415dbd31eb99a5e34cef0f192103c515d97ee4f4 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 11 Dec 2023 16:25:30 +0200 Subject: [PATCH 07/31] chore: remove bottom level init for Vault. Should be imported as . --- flashbax/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flashbax/__init__.py b/flashbax/__init__.py index 983627d..135b2c7 100644 --- a/flashbax/__init__.py +++ b/flashbax/__init__.py @@ -25,4 +25,3 @@ trajectory_buffer, trajectory_queue, ) -from flashbax.vault import Vault From 6816107d8194c049de822a822a5bfe772ca75f42 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Tue, 16 Jan 2024 08:55:39 +0200 Subject: [PATCH 08/31] feat: big update and refactor in order to checkpoint namedtuples. --- flashbax/vault/vault.py | 108 +++++++++++++++++++++++++++------------- 1 file changed, 74 insertions(+), 34 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 914bd9c..58db823 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -24,13 +24,13 @@ import tensorstore as ts # type: ignore from chex import Array from etils import epath # type: ignore +from orbax.checkpoint.utils import deserialize_tree, serialize_tree -from flashbax.buffers.trajectory_buffer import TrajectoryBufferState +from flashbax.buffers.trajectory_buffer import Experience, TrajectoryBufferState from flashbax.utils import get_tree_shape_prefix # CURRENT LIMITATIONS / TODO LIST # - Only tested with flat buffers -# - Reloading must be with dicts, not namedtuples DRIVER = "file://" METADATA_FILE = "metadata.json" @@ -38,36 +38,48 @@ VERSION = 0.1 +def _path_to_ds_name(path: str) -> str: + path_str = "" + for p in path: + if isinstance(p, jax.tree_util.DictKey): + path_str += str(p.key) + elif isinstance(p, jax.tree_util.GetAttrKey): + path_str += p.name + path_str += "." + return path_str + + class Vault: def __init__( self, vault_name: str, - init_fbx_state: Optional[TrajectoryBufferState] = None, + experience_structure: Optional[Experience] = None, rel_dir: str = "vaults", vault_uid: Optional[str] = None, metadata: Optional[dict] = None, ) -> None: + + ## --- + # Get the base path for the vault and the metadata path vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S") self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) - - # We use epath for metadata metadata_path = epath.Path(os.path.join(self._base_path, METADATA_FILE)) + # TODO: logging at each step # Check if the vault exists, otherwise create the necessary dirs and files base_path_exists = os.path.exists(self._base_path) if base_path_exists: + # Vault exists, so we load the metadata to access the structure etc. self._metadata = json.loads(metadata_path.read_text()) # Ensure minor versions match assert (self._metadata["version"] // 1) == (VERSION // 1) - elif init_fbx_state is not None: - # init_fbx_state must be a TrajectoryBufferState - assert isinstance(init_fbx_state, TrajectoryBufferState) - + elif experience_structure is not None: # Create the necessary dirs for the vault os.makedirs(self._base_path) + # TODO with serialize_tree? def get_json_ready(obj: Any) -> Any: """Ensure that the object is json serializable. Convert to string if not. @@ -83,41 +95,69 @@ def get_json_ready(obj: Any) -> Any: return obj metadata_json_ready = jax.tree_util.tree_map(get_json_ready, metadata) - experience_structure = jax.tree_map( - lambda x: [str(x.shape), str(x.dtype)], - init_fbx_state.experience, + + # We save the structure of the buffer state + # e.g. [(128, 100, 4), jnp.int32] + # We will use this structure to map over the data stores later + serialised_experience_structure = jax.tree_map( + lambda x: [str(x.shape), x.dtype.name], + serialize_tree( + # Get shape and dtype of each leaf, without serialising the structure itself + jax.eval_shape( + lambda: experience_structure, + ), + ) ) + + # Construct metadata self._metadata = { "version": VERSION, - "structure": experience_structure, + "structure": serialised_experience_structure, **(metadata_json_ready or {}), # Allow user to save extra metadata } + # Dump metadata to file metadata_path.write_text(json.dumps(self._metadata)) else: raise ValueError("Vault does not exist and no init_fbx_state provided.") - # Keep a data store for the vault index - self._vault_index_ds = ts.open( - self._get_base_spec("vault_index"), - dtype=jnp.int32, - shape=(1,), - create=not base_path_exists, - ).result() - self.vault_index = int(self._vault_index_ds.read().result()[0]) + ## --- + # We must now build the tree structure from the metadata, whether created here or loaded from file + if experience_structure is None: + # If an example state is not provided, we simply load from the metadata + # and the result will be a dictionary. + self._tree_structure = self._metadata["structure"] + else: + # If an example state is provided, we try deserialise into that structure + self._tree_structure = deserialize_tree( + self._metadata["structure"], + target=experience_structure, + ) - # Each leaf of the fbx_state.experience is a data store + # Each leaf of the fbx_state.experience maps to a data store self._all_ds = jax.tree_util.tree_map_with_path( lambda path, x: self._init_leaf( - name=jax.tree_util.keystr(path), # Use the path as the name + name=_path_to_ds_name(path), leaf=x, - create_checkpoint=not base_path_exists, + create_ds=not base_path_exists, ), - self._metadata["structure"], - is_leaf=lambda x: isinstance(x, list), + self._tree_structure, + is_leaf=lambda x: isinstance(x, list), # The list [shape, dtype] is a leaf ) + # We keep track of the last fbx buffer idx received self._last_received_fbx_index = 0 + # We store and load the vault index from a separate datastore + self._vault_index_ds = ts.open( + self._get_base_spec("vault_index"), + dtype=jnp.int32, + shape=(1,), + create=not base_path_exists, + ).result() + # Just read synchronously as it's one number + self.vault_index = int(self._vault_index_ds.read().result()[0]) + + def _get_base_spec(self, name: str) -> dict: return { "driver": "zarr", @@ -129,7 +169,7 @@ def _get_base_spec(self, name: str) -> dict: } def _init_leaf( - self, name: str, leaf: list, create_checkpoint: bool = False + self, name: str, leaf: list, create_ds: bool = False ) -> ts.TensorStore: spec = self._get_base_spec(name) leaf_shape = make_tuple(leaf[0]) @@ -137,16 +177,16 @@ def _init_leaf( leaf_ds = ts.open( spec, # Only specify dtype and shape if we are creating a checkpoint - dtype=leaf_dtype if create_checkpoint else None, + dtype=leaf_dtype if create_ds else None, shape=( leaf_shape[0], # Batch dim TIME_AXIS_MAX_LENGTH, # Time dim *leaf_shape[2:], # Experience dim ) - if create_checkpoint - else None, - # Only create directory if we are creating a checkpoint - create=create_checkpoint, + if create_ds + else None, # Don't impose shape if we are loading a vault + # Only create datastore if we are creating the vault + create=create_ds, ).result() # Synchronous return leaf_ds @@ -262,7 +302,7 @@ def _read_leaf( read_leaf: ts.TensorStore, read_interval: Tuple[int, int], ) -> Array: - return read_leaf[:, slice(*read_interval), ...].read().result() + return jnp.asarray(read_leaf[:, slice(*read_interval), ...].read().result()) def read( self, @@ -289,7 +329,7 @@ def read( read_leaf=ds, read_interval=read_interval, ), - self._metadata["structure"], # just for structure + self._tree_structure, self._all_ds, # data stores is_leaf=lambda x: isinstance(x, list), ) From 99fc9c955db74bdd6d0347bad079175601970efc Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Tue, 16 Jan 2024 11:08:44 +0200 Subject: [PATCH 09/31] docs: big comment update. --- flashbax/vault/vault.py | 207 +++++++++++++++++++++++++++++----------- 1 file changed, 153 insertions(+), 54 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 58db823..f1897e9 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -17,13 +17,14 @@ import os from ast import literal_eval as make_tuple from datetime import datetime -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, Union import jax import jax.numpy as jnp import tensorstore as ts # type: ignore from chex import Array from etils import epath # type: ignore +from jax.tree_util import DictKey, GetAttrKey from orbax.checkpoint.utils import deserialize_tree, serialize_tree from flashbax.buffers.trajectory_buffer import Experience, TrajectoryBufferState @@ -38,12 +39,24 @@ VERSION = 0.1 -def _path_to_ds_name(path: str) -> str: +def _path_to_ds_name(path: Tuple[Union[DictKey, GetAttrKey], ...]) -> str: + """Utility function to convert a path (as defined by jax.tree_util.tree_map_with_path + to a datastore name. The alternative is to use jax.tree_util.keystr, but this has + different behaviour for dictionaries (DictKey) vs. namedtuples (GetAttrKey), which means + we could not save a vault based on a namedtuple structure but later load it as a dict. + Instead, this maps both to a consistent string representation. + + Args: + path: tuple of DictKeys or GetAttrKeys + + Returns: + str: standardised string representation of the path + """ path_str = "" for p in path: - if isinstance(p, jax.tree_util.DictKey): + if isinstance(p, DictKey): path_str += str(p.key) - elif isinstance(p, jax.tree_util.GetAttrKey): + elif isinstance(p, GetAttrKey): path_str += p.name path_str += "." return path_str @@ -58,8 +71,27 @@ def __init__( vault_uid: Optional[str] = None, metadata: Optional[dict] = None, ) -> None: - - ## --- + """Flashbax utility for storing buffers to disk efficiently. + + Args: + vault_name (str): the upper-level name of this vault. + Resulting path is . + experience_structure (Optional[Experience], optional): + Structure of the experience data, usually given as `buffer_state.experience`. + Defaults to None, which can only be done if reading an existing vault. + rel_dir (str, optional): + Base directory of all vaults. Defaults to "vaults". + vault_uid (Optional[str], optional): Unique identifier for this vault. + Defaults to None, which will use the current timestamp. + metadata (Optional[dict], optional): + Any additional metadata to save. Defaults to None. + + Raises: + ValueError: if the targeted vault does not exist, and no experience_structure is provided. + + Returns: + Vault: a vault object. + """ # Get the base path for the vault and the metadata path vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S") self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) @@ -79,22 +111,11 @@ def __init__( # Create the necessary dirs for the vault os.makedirs(self._base_path) - # TODO with serialize_tree? - def get_json_ready(obj: Any) -> Any: - """Ensure that the object is json serializable. Convert to string if not. - - Args: - obj (Any): Object to be considered - - Returns: - Any: json serializable object - """ - if not isinstance(obj, (bool, str, int, float, type(None))): - return str(obj) - else: - return obj - - metadata_json_ready = jax.tree_util.tree_map(get_json_ready, metadata) + # Ensure provided metadata is json serialisable + metadata_json_ready = jax.tree_util.tree_map( + lambda obj: str(obj) if not isinstance(obj, (bool, str, int, float, type(None))) else obj, + metadata, + ) # We save the structure of the buffer state # e.g. [(128, 100, 4), jnp.int32] @@ -118,22 +139,25 @@ def get_json_ready(obj: Any) -> Any: # Dump metadata to file metadata_path.write_text(json.dumps(self._metadata)) else: - raise ValueError("Vault does not exist and no init_fbx_state provided.") + # If the vault does not exist already, and no experience_structure is provided to create + # a new vault, we cannot proceed. + raise ValueError("Vault does not exist and no experience_structure was provided.") - ## --- - # We must now build the tree structure from the metadata, whether created here or loaded from file + # We must now build the tree structure from the metadata, whether the metadata was created + # here or loaded from file if experience_structure is None: - # If an example state is not provided, we simply load from the metadata - # and the result will be a dictionary. + # Since the experience structure is not provided, we simply use the metadata as is. + # The result will always be a dictionary. self._tree_structure = self._metadata["structure"] else: - # If an example state is provided, we try deserialise into that structure + # If experience structure is provided, we try deserialise into that structure self._tree_structure = deserialize_tree( self._metadata["structure"], target=experience_structure, ) - # Each leaf of the fbx_state.experience maps to a data store + # Each leaf of the fbx_state.experience maps to a data store, so we tree map over the + # tree structure to create each of the data stores. self._all_ds = jax.tree_util.tree_map_with_path( lambda path, x: self._init_leaf( name=_path_to_ds_name(path), @@ -141,7 +165,8 @@ def get_json_ready(obj: Any) -> Any: create_ds=not base_path_exists, ), self._tree_structure, - is_leaf=lambda x: isinstance(x, list), # The list [shape, dtype] is a leaf + # Tree structure uses a list [shape, dtype] as a leaf + is_leaf=lambda x: isinstance(x, list), ) # We keep track of the last fbx buffer idx received @@ -159,6 +184,14 @@ def get_json_ready(obj: Any) -> Any: def _get_base_spec(self, name: str) -> dict: + """Simple common specs for all datastores. + + Args: + name (str): name of the datastore + + Returns: + dict: config for the datastore + """ return { "driver": "zarr", "kvstore": { @@ -171,23 +204,37 @@ def _get_base_spec(self, name: str) -> dict: def _init_leaf( self, name: str, leaf: list, create_ds: bool = False ) -> ts.TensorStore: + """Initialise a datastore for a leaf of the experience tree. + + Args: + name (str): datastore name + leaf (list): leaf of the form ["shape", "dtype"] + create_ds (bool, optional): _description_. Defaults to False. + + Returns: + ts.TensorStore: the datastore object + """ spec = self._get_base_spec(name) - leaf_shape = make_tuple(leaf[0]) - leaf_dtype = leaf[1] + # Convert shape and dtype from str to tuple and dtype + leaf_shape = make_tuple(leaf[0]) # Must convert to a real tuple + leaf_dtype = leaf[1] # Can leave this as a str + leaf_ds = ts.open( spec, - # Only specify dtype and shape if we are creating a checkpoint + # Only specify dtype and shape if we are creating a vault + # (i.e. don't impose dtype and shape if we are _loading_ a vault) dtype=leaf_dtype if create_ds else None, shape=( leaf_shape[0], # Batch dim - TIME_AXIS_MAX_LENGTH, # Time dim - *leaf_shape[2:], # Experience dim + TIME_AXIS_MAX_LENGTH, # Time dim, which we extend + *leaf_shape[2:], # Experience dim(s) ) if create_ds - else None, # Don't impose shape if we are loading a vault - # Only create datastore if we are creating the vault + else None, + # Only create datastore if we are creating the vault: create=create_ds, - ).result() # Synchronous + ).result() # Do this synchronously + return leaf_ds async def _write_leaf( @@ -197,10 +244,19 @@ async def _write_leaf( source_interval: Tuple[int, int], dest_start: int, ) -> None: + """Asychronously write a chunk of data to a leaf's datastore. + + Args: + source_leaf (jax.Array): the input fbx_state.experience array + dest_leaf (ts.TensorStore): the destination datastore + source_interval (Tuple[int, int]): read interval from the source leaf + dest_start (int): write start index in the destination leaf + """ dest_interval = ( dest_start, - dest_start + (source_interval[1] - source_interval[0]), # type: ignore + dest_start + (source_interval[1] - source_interval[0]), ) + # Write to the datastore along the time axis await dest_leaf[:, slice(*dest_interval), ...].write( source_leaf[:, slice(*source_interval), ...], ) @@ -211,7 +267,14 @@ async def _write_chunk( source_interval: Tuple[int, int], dest_start: int, ) -> None: - # Write to each ds + """Asynchronous method for writing to all the datastores. + + Args: + fbx_state (TrajectoryBufferState): input buffer state + source_interval (Tuple[int, int]): read interval from the buffer state + dest_start (int): write start index in the vault + """ + # Collect futures for writing to each datastore futures_tree = jax.tree_util.tree_map( lambda x, ds: self._write_leaf( source_leaf=x, @@ -222,6 +285,7 @@ async def _write_chunk( fbx_state.experience, # x = experience self._all_ds, # ds = data stores ) + # Write to all datastores asynchronously futures, _ = jax.tree_util.tree_flatten(futures_tree) await asyncio.gather(*futures) @@ -231,20 +295,34 @@ def write( source_interval: Tuple[int, int] = (0, 0), dest_start: Optional[int] = None, ) -> int: - # TODO: more than one current_index if B > 1 + """Write any new data from the fbx buffer state to the vault. + + Args: + fbx_state (TrajectoryBufferState): input buffer state + source_interval (Tuple[int, int], optional): from where to read in the buffer. + Defaults to (0, 0), which reads from the last received index up to the + current buffer state's index. + dest_start (Optional[int], optional): where to write in the vault. + Defaults to None, which writes from the current vault index. + + Returns: + int: how many elements along the time-axis were written to the vault + """ fbx_current_index = int(fbx_state.current_index) - # By default, we write from `last received` to `current index` [CI] + # By default, we read from `last received` to `current index` if source_interval == (0, 0): source_interval = (self._last_received_fbx_index, fbx_current_index) + # By default, we continue writing from the current vault index + dest_start = self.vault_index if dest_start is None else dest_start + if source_interval[1] == source_interval[0]: # Nothing to write return 0 elif source_interval[1] > source_interval[0]: - # Vanilla write, no wrap around - dest_start = self.vault_index if dest_start is None else dest_start + # Vanilla write, no wrap around in the buffer state asyncio.run( self._write_chunk( fbx_state=fbx_state, @@ -255,14 +333,12 @@ def write( written_length = source_interval[1] - source_interval[0] elif source_interval[1] < source_interval[0]: - # Wrap around! + # Wrap around in the buffer state! - # Get dest start - dest_start = self.vault_index if dest_start is None else dest_start - # Get seq dim + # Get seq dim (i.e. the length of the time axis in the fbx buffer state) fbx_max_index = get_tree_shape_prefix(fbx_state.experience, n_axes=2)[1] - # From last received to max + # Read from last received fbx index to max index source_interval_a = (source_interval[0], fbx_max_index) time_length_a = source_interval_a[1] - source_interval_a[0] @@ -274,7 +350,7 @@ def write( ) ) - # From 0 (wrapped) to CI + # Read from the start of the fbx buffer state to the current fbx index source_interval_b = (0, source_interval[1]) time_length_b = source_interval_b[1] - source_interval_b[0] @@ -288,7 +364,7 @@ def write( written_length = time_length_a + time_length_b - # Update vault index, and write this to the ds too + # Update vault index, and write this to its datastore too self.vault_index += written_length self._vault_index_ds.write(self.vault_index).result() @@ -302,6 +378,15 @@ def _read_leaf( read_leaf: ts.TensorStore, read_interval: Tuple[int, int], ) -> Array: + """Read from a leaf of the experience tree. + + Args: + read_leaf (ts.TensorStore): the datastore from which to read + read_interval (Tuple[int, int]): the interval on the time-axis to read + + Returns: + Array: the read data, as a jax array + """ return jnp.asarray(read_leaf[:, slice(*read_interval), ...].read().result()) def read( @@ -309,12 +394,23 @@ def read( timesteps: Optional[int] = None, percentiles: Optional[Tuple[int, int]] = None, ) -> TrajectoryBufferState: - """Read from the vault.""" + """Read synchronously from the vault. + + Args: + timesteps (Optional[int], optional): _description_. Defaults to None. + percentiles (Optional[Tuple[int, int]], optional): _description_. Defaults to None. + Returns: + TrajectoryBufferState: the read data as a fbx buffer state + """ + + # By default we read the entire vault if timesteps is None and percentiles is None: read_interval = (0, self.vault_index) + # If time steps are provided, we read the last `timesteps` count of elements elif timesteps is not None: read_interval = (self.vault_index - timesteps, self.vault_index) + # If percentiles are provided, we read the corresponding interval elif percentiles is not None: assert ( percentiles[0] < percentiles[1] @@ -324,16 +420,19 @@ def read( int(self.vault_index * (percentiles[1] / 100)), ) + # read_result = jax.tree_util.tree_map( lambda _, ds: self._read_leaf( read_leaf=ds, read_interval=read_interval, ), - self._tree_structure, - self._all_ds, # data stores + self._tree_structure, # Just used for structure + self._all_ds, # The vault data stores + # Interpret ["shape", "dtype"] from tree_structure as leaves: is_leaf=lambda x: isinstance(x, list), ) + # Return the read result as a fbx buffer state return TrajectoryBufferState( experience=read_result, current_index=jnp.array(self.vault_index, dtype=int), From 4a7d6ba1e117e40f134830ad7e5935f4429187bb Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Tue, 16 Jan 2024 11:24:01 +0200 Subject: [PATCH 10/31] chore: precommit. --- flashbax/vault/vault.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index f1897e9..feb5b44 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -17,7 +17,7 @@ import os from ast import literal_eval as make_tuple from datetime import datetime -from typing import Any, Optional, Tuple, Union +from typing import Optional, Tuple, Union import jax import jax.numpy as jnp @@ -25,7 +25,7 @@ from chex import Array from etils import epath # type: ignore from jax.tree_util import DictKey, GetAttrKey -from orbax.checkpoint.utils import deserialize_tree, serialize_tree +from orbax.checkpoint.utils import deserialize_tree, serialize_tree # type: ignore from flashbax.buffers.trajectory_buffer import Experience, TrajectoryBufferState from flashbax.utils import get_tree_shape_prefix @@ -87,7 +87,8 @@ def __init__( Any additional metadata to save. Defaults to None. Raises: - ValueError: if the targeted vault does not exist, and no experience_structure is provided. + ValueError: + If the targeted vault does not exist, and no experience_structure is provided. Returns: Vault: a vault object. @@ -113,7 +114,9 @@ def __init__( # Ensure provided metadata is json serialisable metadata_json_ready = jax.tree_util.tree_map( - lambda obj: str(obj) if not isinstance(obj, (bool, str, int, float, type(None))) else obj, + lambda obj: str(obj) + if not isinstance(obj, (bool, str, int, float, type(None))) + else obj, metadata, ) @@ -124,10 +127,10 @@ def __init__( lambda x: [str(x.shape), x.dtype.name], serialize_tree( # Get shape and dtype of each leaf, without serialising the structure itself - jax.eval_shape( + jax.eval_shape( lambda: experience_structure, ), - ) + ), ) # Construct metadata @@ -141,9 +144,11 @@ def __init__( else: # If the vault does not exist already, and no experience_structure is provided to create # a new vault, we cannot proceed. - raise ValueError("Vault does not exist and no experience_structure was provided.") + raise ValueError( + "Vault does not exist and no experience_structure was provided." + ) - # We must now build the tree structure from the metadata, whether the metadata was created + # We must now build the tree structure from the metadata, whether the metadata was created # here or loaded from file if experience_structure is None: # Since the experience structure is not provided, we simply use the metadata as is. @@ -182,7 +187,6 @@ def __init__( # Just read synchronously as it's one number self.vault_index = int(self._vault_index_ds.read().result()[0]) - def _get_base_spec(self, name: str) -> dict: """Simple common specs for all datastores. From ce0724b97d282c7f986f98baed49a339f7959f29 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Tue, 16 Jan 2024 11:27:42 +0200 Subject: [PATCH 11/31] chore: minor docs fix. --- flashbax/vault/vault.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index feb5b44..838e91c 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -40,11 +40,11 @@ def _path_to_ds_name(path: Tuple[Union[DictKey, GetAttrKey], ...]) -> str: - """Utility function to convert a path (as defined by jax.tree_util.tree_map_with_path - to a datastore name. The alternative is to use jax.tree_util.keystr, but this has + """Utility function to convert a path (yielded by jax.tree_util.tree_map_with_path) + to a datastore name. The alternative is to use jax.tree_util.keystr(...), but this has different behaviour for dictionaries (DictKey) vs. namedtuples (GetAttrKey), which means we could not save a vault based on a namedtuple structure but later load it as a dict. - Instead, this maps both to a consistent string representation. + Instead, this function maps both to a consistent string representation. Args: path: tuple of DictKeys or GetAttrKeys From 9c16bee4eee28fcf5da309de7e9fa25e09081ac2 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Tue, 16 Jan 2024 11:40:40 +0200 Subject: [PATCH 12/31] chore: bump version to first major (hopefully stable) release. --- flashbax/vault/vault.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 838e91c..b11b44d 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -30,13 +30,11 @@ from flashbax.buffers.trajectory_buffer import Experience, TrajectoryBufferState from flashbax.utils import get_tree_shape_prefix -# CURRENT LIMITATIONS / TODO LIST -# - Only tested with flat buffers - +# Constants DRIVER = "file://" METADATA_FILE = "metadata.json" TIME_AXIS_MAX_LENGTH = int(10e12) # Upper bound on the length of the time axis -VERSION = 0.1 +VERSION = 1.0 def _path_to_ds_name(path: Tuple[Union[DictKey, GetAttrKey], ...]) -> str: From aef9e22cb52b1209d6025bb0dcc65b99d04b9501 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jan 2024 12:05:14 +0200 Subject: [PATCH 13/31] feat: first few tests for vault. --- flashbax/vault/vault_test.py | 167 +++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 flashbax/vault/vault_test.py diff --git a/flashbax/vault/vault_test.py b/flashbax/vault/vault_test.py new file mode 100644 index 0000000..a7da45c --- /dev/null +++ b/flashbax/vault/vault_test.py @@ -0,0 +1,167 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tempfile import TemporaryDirectory +from typing import NamedTuple + +import chex +import jax +import jax.numpy as jnp +import pytest +from chex import Array + +import flashbax as fbx +from flashbax.vault import Vault + + +class CustomObservation(NamedTuple): + x: Array + y: Array + + +class FbxTransition(NamedTuple): + obs: CustomObservation + act: Array + + +@pytest.fixture() +def max_length() -> int: + return 256 + + +@pytest.fixture() +def fake_transition() -> FbxTransition: + return FbxTransition( + obs=CustomObservation( + x=jnp.ones(shape=(1, 2, 3), dtype=jnp.float32), + y=jnp.ones(shape=(4, 5, 6), dtype=jnp.float32), + ), + act=jnp.ones(shape=(7, 8), dtype=jnp.float32), + ) + + +def test_write_to_vault( + fake_transition: FbxTransition, + max_length: int, +): + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state + + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir="tmp", + ) + + # Add to the vault up to the fbx buffer being full + for i in range(0, max_length): + assert v.vault_index == i + buffer_state = buffer_add( + buffer_state, + fake_transition, + ) + v.write(buffer_state) + + +def test_read_from_vault( + fake_transition: FbxTransition, + max_length: int, +): + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state + + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir="tmp", + ) + + for _ in range(0, max_length): + buffer_state = buffer_add( + buffer_state, + fake_transition, + ) + v.write(buffer_state) + + # Load the state from the vault + buffer_state_reloaded = v.read() + # Experience of the two should match + chex.assert_trees_all_equal( + buffer_state.experience, + buffer_state_reloaded.experience, + ) + + +def test_reload_vault( + fake_transition: FbxTransition, + max_length: int, +): + # Extend the vault more than the buffer size + n_timesteps = max_length * 5 + + with TemporaryDirectory() as temp_dir_path: + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state + + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir=temp_dir_path, + vault_uid="test_vault_uid", + ) + + # Add to the vault + for _ in range(0, n_timesteps): + buffer_state = buffer_add( + buffer_state, + fake_transition, + ) + v.write(buffer_state) + + # Ensure we can't access the vault + del v + + # Reload the vault + v_reload = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir=temp_dir_path, + vault_uid="test_vault_uid", + ) + buffer_state_reloaded = v_reload.read() + + # We want to check that all the timesteps are there + assert buffer_state_reloaded.experience.obs.x.shape[1] == n_timesteps From 249b296dada649e53eb10e25c5bdfefc51e07b20 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jan 2024 14:20:15 +0200 Subject: [PATCH 14/31] fix: use temp dirs for all of the tests. --- flashbax/vault/vault_test.py | 102 ++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 50 deletions(-) diff --git a/flashbax/vault/vault_test.py b/flashbax/vault/vault_test.py index a7da45c..192ce05 100644 --- a/flashbax/vault/vault_test.py +++ b/flashbax/vault/vault_test.py @@ -56,66 +56,68 @@ def test_write_to_vault( fake_transition: FbxTransition, max_length: int, ): - # Get the buffer pure functions - buffer = fbx.make_flat_buffer( - max_length=max_length, - min_length=1, - sample_batch_size=1, - ) - buffer_add = jax.jit(buffer.add, donate_argnums=0) - buffer_state = buffer.init(fake_transition) # Initialise the state - - # Initialise the vault - v = Vault( - vault_name="test_vault", - experience_structure=buffer_state.experience, - rel_dir="tmp", - ) + with TemporaryDirectory() as temp_dir_path: + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state - # Add to the vault up to the fbx buffer being full - for i in range(0, max_length): - assert v.vault_index == i - buffer_state = buffer_add( - buffer_state, - fake_transition, + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir=temp_dir_path, ) - v.write(buffer_state) + + # Add to the vault up to the fbx buffer being full + for i in range(0, max_length): + assert v.vault_index == i + buffer_state = buffer_add( + buffer_state, + fake_transition, + ) + v.write(buffer_state) def test_read_from_vault( fake_transition: FbxTransition, max_length: int, ): - # Get the buffer pure functions - buffer = fbx.make_flat_buffer( - max_length=max_length, - min_length=1, - sample_batch_size=1, - ) - buffer_add = jax.jit(buffer.add, donate_argnums=0) - buffer_state = buffer.init(fake_transition) # Initialise the state - - # Initialise the vault - v = Vault( - vault_name="test_vault", - experience_structure=buffer_state.experience, - rel_dir="tmp", - ) + with TemporaryDirectory() as temp_dir_path: + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state - for _ in range(0, max_length): - buffer_state = buffer_add( - buffer_state, - fake_transition, + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir=temp_dir_path, + ) + + for _ in range(0, max_length): + buffer_state = buffer_add( + buffer_state, + fake_transition, + ) + v.write(buffer_state) + + # Load the state from the vault + buffer_state_reloaded = v.read() + # Experience of the two should match + chex.assert_trees_all_equal( + buffer_state.experience, + buffer_state_reloaded.experience, ) - v.write(buffer_state) - - # Load the state from the vault - buffer_state_reloaded = v.read() - # Experience of the two should match - chex.assert_trees_all_equal( - buffer_state.experience, - buffer_state_reloaded.experience, - ) def test_reload_vault( From e01752b9eae496e632e07ae36156cba4f269e42a Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jan 2024 16:29:47 +0200 Subject: [PATCH 15/31] feat: improved test for reloading vault. --- flashbax/vault/vault_test.py | 58 +++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/flashbax/vault/vault_test.py b/flashbax/vault/vault_test.py index 192ce05..6a2bad7 100644 --- a/flashbax/vault/vault_test.py +++ b/flashbax/vault/vault_test.py @@ -13,6 +13,7 @@ # limitations under the License. +from functools import partial from tempfile import TemporaryDirectory from typing import NamedTuple @@ -120,12 +121,12 @@ def test_read_from_vault( ) -def test_reload_vault( +def test_extend_vault( fake_transition: FbxTransition, max_length: int, ): # Extend the vault more than the buffer size - n_timesteps = max_length * 5 + n_timesteps = max_length * 10 with TemporaryDirectory() as temp_dir_path: # Get the buffer pure functions @@ -142,10 +143,10 @@ def test_reload_vault( vault_name="test_vault", experience_structure=buffer_state.experience, rel_dir=temp_dir_path, - vault_uid="test_vault_uid", ) - # Add to the vault + # Add to the vault, wrapping around the circular buffer, + # but writing to the vault each time we add for _ in range(0, n_timesteps): buffer_state = buffer_add( buffer_state, @@ -153,6 +154,46 @@ def test_reload_vault( ) v.write(buffer_state) + # Read in the full vault state --> longer than the buffer + long_buffer_state = v.read() + + # We want to check that all the timesteps are there + assert long_buffer_state.experience.obs.x.shape[1] == n_timesteps + + +def test_reload_vault( + fake_transition: FbxTransition, + max_length: int, +): + with TemporaryDirectory() as temp_dir_path: + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state + + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir=temp_dir_path, + vault_uid="test_vault_uid", + ) + + def multiplier(x: Array, i: int): + return x * i + + # Add to the vault + for i in range(0, max_length): + buffer_state = buffer_add( + buffer_state, + jax.tree_map(partial(multiplier, i=i), fake_transition), + ) + v.write(buffer_state) + # Ensure we can't access the vault del v @@ -161,9 +202,12 @@ def test_reload_vault( vault_name="test_vault", experience_structure=buffer_state.experience, rel_dir=temp_dir_path, - vault_uid="test_vault_uid", + vault_uid="test_vault_uid", # Need to pass the same UID ) buffer_state_reloaded = v_reload.read() - # We want to check that all the timesteps are there - assert buffer_state_reloaded.experience.obs.x.shape[1] == n_timesteps + # We want to check that all the data is correct + chex.assert_trees_all_equal( + buffer_state.experience, + buffer_state_reloaded.experience, + ) From a878bcc4271ee40458c346f1ffd5931844450135 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jan 2024 16:52:34 +0200 Subject: [PATCH 16/31] docs: vault explainer in readme, along with demonstrative notebook. --- README.md | 6 + examples/vault_demonstration.ipynb | 311 +++++++++++++++++++++++++++++ 2 files changed, 317 insertions(+) create mode 100644 examples/vault_demonstration.ipynb diff --git a/README.md b/README.md index bd54e09..e57642a 100644 --- a/README.md +++ b/README.md @@ -242,6 +242,12 @@ Previous benchmarks added only a single timestep at a time, we now evaluate addi Ultimately, we see improved or comparable performance to benchmarked buffers whilst providing buffers that are fully JAX-compatible in addition to other features such as batched adding as well as being able to add sequences of varying length. We do note that due to JAX having different XLA backends for CPU, GPU, and TPU, the performance of the buffers can vary depending on the device and the specific operation being called. +## Vault 💾 +Vault is an efficient mechanism for saving flashbax buffers to persistent data storage, e.g. for use in offline reinforcement learning. Consider a Flashbax buffer which has experience data of dimensionality $(B, T, *E)$, where $B$ is a batch dimension (for the sake of recording independent trajectories synchronously), $T$ is a temporal/sequential dimension, and $*E$ indicates the one or more dimensions of the experience data itself. Since large quantities of data may be generated for a given environment, Vault extends the $T$ dimension to a virtually unconstrained degree by reading and writing slices of buffers along this temporal axis. In doing so, gigantic buffer stores can reside on disk, from which sub-buffers can be loaded into RAM/VRAM for efficient offline training. Vault has been tested with the item, flat, and trajectory buffers. + +For more information, see the demonstration notebook: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/vault_demonstration.ipynb) + + ## Contributing 🤝 Contributions are welcome! See our issue tracker for diff --git a/examples/vault_demonstration.ipynb b/examples/vault_demonstration.ipynb new file mode 100644 index 0000000..0b23935 --- /dev/null +++ b/examples/vault_demonstration.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Vault demonstration" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "from typing import NamedTuple\n", + "import jax.numpy as jnp\n", + "from flashbax.vault import Vault\n", + "import flashbax as fbx\n", + "from chex import Array" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We create a simple timestep structure, with a corresponding flat buffer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class FbxTransition(NamedTuple):\n", + " obs: Array\n", + "\n", + "tx = FbxTransition(obs=jnp.zeros(1))\n", + "\n", + "buffer = fbx.make_flat_buffer(\n", + " max_length=5,\n", + " min_length=1,\n", + " sample_batch_size=1,\n", + ")\n", + "buffer_state = buffer.init(tx)\n", + "buffer_add = jax.jit(buffer.add, donate_argnums=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The shape of this buffer is $(B = 1, T = 5, E = 1)$, meaning the buffer can hold 5 timesteps." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 5, 1)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "buffer_state.experience.obs.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We create the vault, based on the buffer's experience structure." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "v = Vault(\n", + " vault_name=\"demo\",\n", + " experience_structure=buffer_state.experience,\n", + " rel_dir=\"/tmp\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now add 10 timesteps to the buffer, and write that buffer to the vault. We inspect the buffer and vault state after each timestep." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "------------------\n", + "Buffer state:\n", + "[[[0.]\n", + " [0.]\n", + " [0.]\n", + " [0.]\n", + " [0.]]]\n", + "\n", + "Vault state:\n", + "[]\n", + "------------------\n", + "------------------\n", + "Buffer state:\n", + "[[[1.]\n", + " [0.]\n", + " [0.]\n", + " [0.]\n", + " [0.]]]\n", + "\n", + "Vault state:\n", + "[[[1.]]]\n", + "------------------\n", + "------------------\n", + "Buffer state:\n", + "[[[1.]\n", + " [2.]\n", + " [0.]\n", + " [0.]\n", + " [0.]]]\n", + "\n", + "Vault state:\n", + "[[[1.]\n", + " [2.]]]\n", + "------------------\n", + "------------------\n", + "Buffer state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [0.]\n", + " [0.]]]\n", + "\n", + "Vault state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]]]\n", + "------------------\n", + "------------------\n", + "Buffer state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [0.]]]\n", + "\n", + "Vault state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [4.]]]\n", + "------------------\n", + "------------------\n", + "Buffer state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [5.]]]\n", + "\n", + "Vault state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [5.]]]\n", + "------------------\n", + "------------------\n", + "Buffer state:\n", + "[[[6.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [5.]]]\n", + "\n", + "Vault state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [5.]\n", + " [6.]]]\n", + "------------------\n", + "------------------\n", + "Buffer state:\n", + "[[[6.]\n", + " [7.]\n", + " [3.]\n", + " [4.]\n", + " [5.]]]\n", + "\n", + "Vault state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [5.]\n", + " [6.]\n", + " [7.]]]\n", + "------------------\n", + "------------------\n", + "Buffer state:\n", + "[[[6.]\n", + " [7.]\n", + " [8.]\n", + " [4.]\n", + " [5.]]]\n", + "\n", + "Vault state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [5.]\n", + " [6.]\n", + " [7.]\n", + " [8.]]]\n", + "------------------\n" + ] + } + ], + "source": [ + "for i in range(1, 10):\n", + " print('------------------')\n", + " print(\"Buffer state:\")\n", + " print(buffer_state.experience.obs)\n", + " print()\n", + "\n", + " v.write(buffer_state)\n", + "\n", + " print(\"Vault state:\")\n", + " print(v.read().experience.obs)\n", + " print('------------------')\n", + "\n", + " buffer_state = buffer_add(\n", + " buffer_state,\n", + " FbxTransition(obs=i * jnp.ones(1))\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that when the buffer (implemented as a ring buffer) wraps around, the vault continues storing the data:\n", + "```\n", + "Buffer state:\n", + "[[[6.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [5.]]]\n", + "\n", + "\n", + "Vault state:\n", + "[[[1.]\n", + " [2.]\n", + " [3.]\n", + " [4.]\n", + " [5.]\n", + " [6.]]]\n", + "```\n", + "\n", + "Note: the vault must be given the buffer state at least every `max_steps` number of timesteps (i.e. before stale data is overwritten in the ring buffer)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flashbax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 37b940e17a6cb2180a77a734546d96c45618d3f1 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jan 2024 16:57:47 +0200 Subject: [PATCH 17/31] fix: add fbx install to vault example notebook. --- examples/vault_demonstration.ipynb | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/examples/vault_demonstration.ipynb b/examples/vault_demonstration.ipynb index 0b23935..7ca3836 100644 --- a/examples/vault_demonstration.ipynb +++ b/examples/vault_demonstration.ipynb @@ -7,6 +7,21 @@ "# Vault demonstration" ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "try:\n", + " import flashbax as fbx\n", + "except ModuleNotFoundError:\n", + " print('installing flashbax')\n", + " %pip install -q flashbax\n", + " import flashbax as fbx" + ] + }, { "cell_type": "code", "execution_count": 1, From cfc30db232f245cf8d8021284969ec1aa3984dad Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jan 2024 17:38:16 +0200 Subject: [PATCH 18/31] feat: print messages after loading or creating vault. --- flashbax/vault/vault.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index b11b44d..989ae67 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -96,10 +96,10 @@ def __init__( self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) metadata_path = epath.Path(os.path.join(self._base_path, METADATA_FILE)) - # TODO: logging at each step # Check if the vault exists, otherwise create the necessary dirs and files base_path_exists = os.path.exists(self._base_path) if base_path_exists: + print(f"Loading vault found at {self._base_path}") # Vault exists, so we load the metadata to access the structure etc. self._metadata = json.loads(metadata_path.read_text()) @@ -107,6 +107,8 @@ def __init__( assert (self._metadata["version"] // 1) == (VERSION // 1) elif experience_structure is not None: + print(f"New vault created at {self._base_path}") + # Create the necessary dirs for the vault os.makedirs(self._base_path) From d8abdc9a0dc3f551d89a7d50545d090c774166d8 Mon Sep 17 00:00:00 2001 From: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> Date: Wed, 17 Jan 2024 17:49:38 +0200 Subject: [PATCH 19/31] chore: minor --- flashbax/vault/vault.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 989ae67..c111343 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -126,7 +126,7 @@ def __init__( serialised_experience_structure = jax.tree_map( lambda x: [str(x.shape), x.dtype.name], serialize_tree( - # Get shape and dtype of each leaf, without serialising the structure itself + # Get shape and dtype of each leaf, without serialising the data itself jax.eval_shape( lambda: experience_structure, ), From f85c70d84cdd29ee23b13c7c5e6f5a860f05f04c Mon Sep 17 00:00:00 2001 From: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> Date: Wed, 17 Jan 2024 17:53:24 +0200 Subject: [PATCH 20/31] chore: minor docs fix --- flashbax/vault/vault.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index c111343..4f9ecef 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -213,7 +213,7 @@ def _init_leaf( Args: name (str): datastore name leaf (list): leaf of the form ["shape", "dtype"] - create_ds (bool, optional): _description_. Defaults to False. + create_ds (bool, optional): whether to create the datastore. Defaults to False. Returns: ts.TensorStore: the datastore object From e0a6fdc6260d06c5ee8cb042d410a5bc8c9f44d9 Mon Sep 17 00:00:00 2001 From: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> Date: Wed, 17 Jan 2024 18:03:33 +0200 Subject: [PATCH 21/31] chore: minor docs update --- flashbax/vault/vault.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 4f9ecef..bcaace2 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -401,8 +401,12 @@ def read( """Read synchronously from the vault. Args: - timesteps (Optional[int], optional): _description_. Defaults to None. - percentiles (Optional[Tuple[int, int]], optional): _description_. Defaults to None. + timesteps (Optional[int], optional): + If provided, we read the last `timesteps` count of elements. + Defaults to None. + percentiles (Optional[Tuple[int, int]], optional): + If provided (and timesteps is None) we read the corresponding interval. + Defaults to None. Returns: TrajectoryBufferState: the read data as a fbx buffer state From 804335ab16b0958b3373d3310eb0d6d5effd1de1 Mon Sep 17 00:00:00 2001 From: Callum Tilbury <37700709+callumtilbury@users.noreply.github.com> Date: Wed, 24 Jan 2024 12:52:03 +0200 Subject: [PATCH 22/31] chore: nits from code review Co-authored-by: Sasha --- flashbax/vault/vault.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index bcaace2..a46d2a2 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -121,15 +121,13 @@ def __init__( ) # We save the structure of the buffer state - # e.g. [(128, 100, 4), jnp.int32] + # e.g. [(128, 100, 4), jnp.int32] # We will use this structure to map over the data stores later serialised_experience_structure = jax.tree_map( lambda x: [str(x.shape), x.dtype.name], serialize_tree( # Get shape and dtype of each leaf, without serialising the data itself - jax.eval_shape( - lambda: experience_structure, - ), + jax.eval_shape(lambda: experience_structure), ), ) @@ -149,10 +147,10 @@ def __init__( ) # We must now build the tree structure from the metadata, whether the metadata was created - # here or loaded from file + # here or loaded from file if experience_structure is None: # Since the experience structure is not provided, we simply use the metadata as is. - # The result will always be a dictionary. + # The result will always be a dictionary. self._tree_structure = self._metadata["structure"] else: # If experience structure is provided, we try deserialise into that structure @@ -162,7 +160,7 @@ def __init__( ) # Each leaf of the fbx_state.experience maps to a data store, so we tree map over the - # tree structure to create each of the data stores. + # tree structure to create each of the data stores. self._all_ds = jax.tree_util.tree_map_with_path( lambda path, x: self._init_leaf( name=_path_to_ds_name(path), From a8c9fbb7b94d556359067b3c900d1f74c014a7be Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 24 Jan 2024 16:51:30 +0200 Subject: [PATCH 23/31] feat: save structure metadata of shape and dtype in separate trees. --- flashbax/vault/vault.py | 78 +++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index a46d2a2..d2dbde8 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -120,21 +120,24 @@ def __init__( metadata, ) - # We save the structure of the buffer state - # e.g. [(128, 100, 4), jnp.int32] - # We will use this structure to map over the data stores later - serialised_experience_structure = jax.tree_map( - lambda x: [str(x.shape), x.dtype.name], - serialize_tree( - # Get shape and dtype of each leaf, without serialising the data itself - jax.eval_shape(lambda: experience_structure), - ), + # We save the structure of the buffer state, storing the shape and dtype of + # each leaf. We will use this structure to map over the data stores later. + # (Note: we use `jax.eval_shape` to get shape and dtype of each leaf, without + # unnecessarily serialising the buffer data itself) + serialised_experience_structure_shape = jax.tree_map( + lambda x: str(x.shape), + serialize_tree(jax.eval_shape(lambda: experience_structure)), + ) + serialised_experience_structure_dtype = jax.tree_map( + lambda x: x.dtype.name, + serialize_tree(jax.eval_shape(lambda: experience_structure)), ) # Construct metadata self._metadata = { "version": VERSION, - "structure": serialised_experience_structure, + "structure_shape": serialised_experience_structure_shape, + "structure_dtype": serialised_experience_structure_dtype, **(metadata_json_ready or {}), # Allow user to save extra metadata } # Dump metadata to file @@ -151,25 +154,32 @@ def __init__( if experience_structure is None: # Since the experience structure is not provided, we simply use the metadata as is. # The result will always be a dictionary. - self._tree_structure = self._metadata["structure"] + self._tree_structure_shape = self._metadata["structure_shape"] + self._tree_structure_dtype = self._metadata["structure_dtype"] else: # If experience structure is provided, we try deserialise into that structure - self._tree_structure = deserialize_tree( - self._metadata["structure"], + self._tree_structure_shape = deserialize_tree( + self._metadata["structure_shape"], + target=experience_structure, + ) + self._tree_structure_dtype = deserialize_tree( + self._metadata["structure_dtype"], target=experience_structure, ) # Each leaf of the fbx_state.experience maps to a data store, so we tree map over the # tree structure to create each of the data stores. self._all_ds = jax.tree_util.tree_map_with_path( - lambda path, x: self._init_leaf( + lambda path, shape, dtype: self._init_leaf( name=_path_to_ds_name(path), - leaf=x, + shape=make_tuple( + shape + ), # Must convert to a real tuple from the saved str + dtype=dtype, create_ds=not base_path_exists, ), - self._tree_structure, - # Tree structure uses a list [shape, dtype] as a leaf - is_leaf=lambda x: isinstance(x, list), + self._tree_structure_shape, + self._tree_structure_dtype, ) # We keep track of the last fbx buffer idx received @@ -204,35 +214,36 @@ def _get_base_spec(self, name: str) -> dict: } def _init_leaf( - self, name: str, leaf: list, create_ds: bool = False + self, name: str, shape: Tuple[int, ...], dtype: str, create_ds: bool = False ) -> ts.TensorStore: """Initialise a datastore for a leaf of the experience tree. Args: name (str): datastore name - leaf (list): leaf of the form ["shape", "dtype"] + shape (Tuple[int, ...]): shape of the data for this leaf + dtype (str): dtype of the data for this leaf create_ds (bool, optional): whether to create the datastore. Defaults to False. Returns: ts.TensorStore: the datastore object """ spec = self._get_base_spec(name) - # Convert shape and dtype from str to tuple and dtype - leaf_shape = make_tuple(leaf[0]) # Must convert to a real tuple - leaf_dtype = leaf[1] # Can leave this as a str - leaf_ds = ts.open( - spec, + leaf_shape, leaf_dtype = None, None + if create_ds: # Only specify dtype and shape if we are creating a vault # (i.e. don't impose dtype and shape if we are _loading_ a vault) - dtype=leaf_dtype if create_ds else None, - shape=( - leaf_shape[0], # Batch dim + leaf_shape = ( + shape[0], # Batch dim TIME_AXIS_MAX_LENGTH, # Time dim, which we extend - *leaf_shape[2:], # Experience dim(s) + *shape[2:], # Experience dim(s) ) - if create_ds - else None, + leaf_dtype = dtype + + leaf_ds = ts.open( + spec, + shape=leaf_shape, + dtype=leaf_dtype, # Only create datastore if we are creating the vault: create=create_ds, ).result() # Do this synchronously @@ -426,16 +437,13 @@ def read( int(self.vault_index * (percentiles[1] / 100)), ) - # read_result = jax.tree_util.tree_map( lambda _, ds: self._read_leaf( read_leaf=ds, read_interval=read_interval, ), - self._tree_structure, # Just used for structure + self._tree_structure_shape, # Just used to return a valid tree structure self._all_ds, # The vault data stores - # Interpret ["shape", "dtype"] from tree_structure as leaves: - is_leaf=lambda x: isinstance(x, list), ) # Return the read result as a fbx buffer state From dc4a88ea04f322375dfcb67028aff5b83e0aa65f Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 5 Feb 2024 14:05:10 +0200 Subject: [PATCH 24/31] minor: move print lower down, in case of code block failure. --- flashbax/vault/vault.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index d2dbde8..6578978 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -107,8 +107,6 @@ def __init__( assert (self._metadata["version"] // 1) == (VERSION // 1) elif experience_structure is not None: - print(f"New vault created at {self._base_path}") - # Create the necessary dirs for the vault os.makedirs(self._base_path) @@ -142,6 +140,8 @@ def __init__( } # Dump metadata to file metadata_path.write_text(json.dumps(self._metadata)) + + print(f"New vault created at {self._base_path}") else: # If the vault does not exist already, and no experience_structure is provided to create # a new vault, we cannot proceed. From ff4fac3962ce73be0b80bc1a0d9c36964c9cc914 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 5 Feb 2024 14:09:38 +0200 Subject: [PATCH 25/31] feat: use different obs dim in example for clarity's sake. --- examples/vault_demonstration.ipynb | 238 ++++++++++++++++------------- 1 file changed, 133 insertions(+), 105 deletions(-) diff --git a/examples/vault_demonstration.ipynb b/examples/vault_demonstration.ipynb index 7ca3836..b25aeca 100644 --- a/examples/vault_demonstration.ipynb +++ b/examples/vault_demonstration.ipynb @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -45,14 +45,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/claude/flashbax/flashbax/buffers/trajectory_buffer.py:473: UserWarning: Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = 5`.This allows one to control exactly how many timesteps are stored in the buffer.Note that this overrides the `max_length_time_axis` argument.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "class FbxTransition(NamedTuple):\n", " obs: Array\n", "\n", - "tx = FbxTransition(obs=jnp.zeros(1))\n", + "tx = FbxTransition(obs=jnp.zeros(shape=(2,)))\n", "\n", "buffer = fbx.make_flat_buffer(\n", " max_length=5,\n", @@ -67,21 +83,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The shape of this buffer is $(B = 1, T = 5, E = 1)$, meaning the buffer can hold 5 timesteps." + "The shape of this buffer is $(B = 1, T = 5, E = 2)$, meaning the buffer can hold 5 timesteps, where each observation is of shape $(2,)$." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(1, 5, 1)" + "(1, 5, 2)" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -99,9 +115,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New vault created at /tmp/demo/20240205140817\n" + ] + } + ], "source": [ "v = Vault(\n", " vault_name=\"demo\",\n", @@ -119,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -128,130 +152,130 @@ "text": [ "------------------\n", "Buffer state:\n", - "[[[0.]\n", - " [0.]\n", - " [0.]\n", - " [0.]\n", - " [0.]]]\n", + "[[[0. 0.]\n", + " [0. 0.]\n", + " [0. 0.]\n", + " [0. 0.]\n", + " [0. 0.]]]\n", "\n", "Vault state:\n", "[]\n", "------------------\n", "------------------\n", "Buffer state:\n", - "[[[1.]\n", - " [0.]\n", - " [0.]\n", - " [0.]\n", - " [0.]]]\n", + "[[[1. 1.]\n", + " [0. 0.]\n", + " [0. 0.]\n", + " [0. 0.]\n", + " [0. 0.]]]\n", "\n", "Vault state:\n", - "[[[1.]]]\n", + "[[[1. 1.]]]\n", "------------------\n", "------------------\n", "Buffer state:\n", - "[[[1.]\n", - " [2.]\n", - " [0.]\n", - " [0.]\n", - " [0.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [0. 0.]\n", + " [0. 0.]\n", + " [0. 0.]]]\n", "\n", "Vault state:\n", - "[[[1.]\n", - " [2.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]]]\n", "------------------\n", "------------------\n", "Buffer state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [0.]\n", - " [0.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [0. 0.]\n", + " [0. 0.]]]\n", "\n", "Vault state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]]]\n", "------------------\n", "------------------\n", "Buffer state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [0.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [0. 0.]]]\n", "\n", "Vault state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [4.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]]]\n", "------------------\n", "------------------\n", "Buffer state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [5.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]]]\n", "\n", "Vault state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [5.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]]]\n", "------------------\n", "------------------\n", "Buffer state:\n", - "[[[6.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [5.]]]\n", + "[[[6. 6.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]]]\n", "\n", "Vault state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [5.]\n", - " [6.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]\n", + " [6. 6.]]]\n", "------------------\n", "------------------\n", "Buffer state:\n", - "[[[6.]\n", - " [7.]\n", - " [3.]\n", - " [4.]\n", - " [5.]]]\n", + "[[[6. 6.]\n", + " [7. 7.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]]]\n", "\n", "Vault state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [5.]\n", - " [6.]\n", - " [7.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]\n", + " [6. 6.]\n", + " [7. 7.]]]\n", "------------------\n", "------------------\n", "Buffer state:\n", - "[[[6.]\n", - " [7.]\n", - " [8.]\n", - " [4.]\n", - " [5.]]]\n", + "[[[6. 6.]\n", + " [7. 7.]\n", + " [8. 8.]\n", + " [4. 4.]\n", + " [5. 5.]]]\n", "\n", "Vault state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [5.]\n", - " [6.]\n", - " [7.]\n", - " [8.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]\n", + " [6. 6.]\n", + " [7. 7.]\n", + " [8. 8.]]]\n", "------------------\n" ] } @@ -282,24 +306,28 @@ "Notice that when the buffer (implemented as a ring buffer) wraps around, the vault continues storing the data:\n", "```\n", "Buffer state:\n", - "[[[6.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [5.]]]\n", - "\n", + "[[[6. 6.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]]]\n", "\n", "Vault state:\n", - "[[[1.]\n", - " [2.]\n", - " [3.]\n", - " [4.]\n", - " [5.]\n", - " [6.]]]\n", + "[[[1. 1.]\n", + " [2. 2.]\n", + " [3. 3.]\n", + " [4. 4.]\n", + " [5. 5.]\n", + " [6. 6.]]]\n", "```\n", "\n", "Note: the vault must be given the buffer state at least every `max_steps` number of timesteps (i.e. before stale data is overwritten in the ring buffer)." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { @@ -318,7 +346,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.9.18" } }, "nbformat": 4, From df947465805505b01fb234f143e625b05a4531a5 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 5 Feb 2024 14:11:12 +0200 Subject: [PATCH 26/31] chore: minor var rename. --- flashbax/vault/vault.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 6578978..ebfbdb2 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -169,7 +169,7 @@ def __init__( # Each leaf of the fbx_state.experience maps to a data store, so we tree map over the # tree structure to create each of the data stores. - self._all_ds = jax.tree_util.tree_map_with_path( + self._all_datastores = jax.tree_util.tree_map_with_path( lambda path, shape, dtype: self._init_leaf( name=_path_to_ds_name(path), shape=make_tuple( @@ -296,7 +296,7 @@ async def _write_chunk( dest_start=dest_start, ), fbx_state.experience, # x = experience - self._all_ds, # ds = data stores + self._all_datastores, # ds = data stores ) # Write to all datastores asynchronously futures, _ = jax.tree_util.tree_flatten(futures_tree) @@ -443,7 +443,7 @@ def read( read_interval=read_interval, ), self._tree_structure_shape, # Just used to return a valid tree structure - self._all_ds, # The vault data stores + self._all_datastores, # The vault data stores ) # Return the read result as a fbx buffer state From 0629f29e501b23f28817566f5379ce6ca7f903b3 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 5 Feb 2024 16:29:38 +0200 Subject: [PATCH 27/31] minor: only print after code block success. --- flashbax/vault/vault.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index ebfbdb2..c1e8839 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -99,13 +99,14 @@ def __init__( # Check if the vault exists, otherwise create the necessary dirs and files base_path_exists = os.path.exists(self._base_path) if base_path_exists: - print(f"Loading vault found at {self._base_path}") # Vault exists, so we load the metadata to access the structure etc. self._metadata = json.loads(metadata_path.read_text()) # Ensure minor versions match assert (self._metadata["version"] // 1) == (VERSION // 1) + print(f"Loading vault found at {self._base_path}") + elif experience_structure is not None: # Create the necessary dirs for the vault os.makedirs(self._base_path) From ded832c5ee06b500fdb5931fd2c5eb7e88589123 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 5 Feb 2024 16:35:31 +0200 Subject: [PATCH 28/31] docs: add important consideration of vault ring buffer challenge. --- README.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e57642a..72501a0 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,12 @@ from CleanRLs DQN JAX example. - 🦎 [Jumanji](https://github.com/instadeepai/jumanji/) - utilise Jumanji's JAX based environments like Snake for our fully jitted examples. +## Vault 💾 +Vault is an efficient mechanism for saving Flashbax buffers to persistent data storage, e.g. for use in offline reinforcement learning. Consider a Flashbax buffer which has experience data of dimensionality $(B, T, *E)$, where $B$ is a batch dimension (for the sake of recording independent trajectories synchronously), $T$ is a temporal/sequential dimension, and $*E$ indicates the one or more dimensions of the experience data itself. Since large quantities of data may be generated for a given environment, Vault extends the $T$ dimension to a virtually unconstrained degree by reading and writing slices of buffers along this temporal axis. In doing so, gigantic buffer stores can reside on disk, from which sub-buffers can be loaded into RAM/VRAM for efficient offline training. Vault has been tested with the item, flat, and trajectory buffers. + +For more information, see the demonstrative notebook: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/vault_demonstration.ipynb) + + ## Important Considerations ⚠️ When working with Flashbax buffers, it's crucial to be mindful of certain considerations to ensure the proper functionality of your RL agent. @@ -190,6 +196,9 @@ It is important to include `donate_argnums` when calling `jax.jit` to enable JAX In summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers. +### Storing Data with Vault +As mentioned [above](./README.md#vault-💾), Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently, before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\tau$, then we must save to the vault every $\tau - 1$ steps, lest we overwrite (and lose) unsaved data. + ## Benchmarks 📈 Here we provide a series of initial benchmarks outlining the performance of the various Flashbax buffers compared against commonly used open-source buffers. In these benchmarks we (unless explicitly stated otherwise) use the following configuration: @@ -242,11 +251,6 @@ Previous benchmarks added only a single timestep at a time, we now evaluate addi Ultimately, we see improved or comparable performance to benchmarked buffers whilst providing buffers that are fully JAX-compatible in addition to other features such as batched adding as well as being able to add sequences of varying length. We do note that due to JAX having different XLA backends for CPU, GPU, and TPU, the performance of the buffers can vary depending on the device and the specific operation being called. -## Vault 💾 -Vault is an efficient mechanism for saving flashbax buffers to persistent data storage, e.g. for use in offline reinforcement learning. Consider a Flashbax buffer which has experience data of dimensionality $(B, T, *E)$, where $B$ is a batch dimension (for the sake of recording independent trajectories synchronously), $T$ is a temporal/sequential dimension, and $*E$ indicates the one or more dimensions of the experience data itself. Since large quantities of data may be generated for a given environment, Vault extends the $T$ dimension to a virtually unconstrained degree by reading and writing slices of buffers along this temporal axis. In doing so, gigantic buffer stores can reside on disk, from which sub-buffers can be loaded into RAM/VRAM for efficient offline training. Vault has been tested with the item, flat, and trajectory buffers. - -For more information, see the demonstration notebook: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/vault_demonstration.ipynb) - ## Contributing 🤝 From f031f55327ea7f1443637a7f03b42e7c971fc97d Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 5 Feb 2024 16:37:27 +0200 Subject: [PATCH 29/31] chore: minor text --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 72501a0..68ec26b 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ It is important to include `donate_argnums` when calling `jax.jit` to enable JAX In summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers. ### Storing Data with Vault -As mentioned [above](./README.md#vault-💾), Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently, before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\tau$, then we must save to the vault every $\tau - 1$ steps, lest we overwrite (and lose) unsaved data. +As mentioned [above](./README.md#vault-💾), Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\tau$, then we must save to the vault every $\tau - 1$ steps, lest we overwrite (and lose) unsaved data. ## Benchmarks 📈 From 222cb6dc52d263d22d28ff63964e12b3f7c392a9 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 5 Feb 2024 17:41:06 +0200 Subject: [PATCH 30/31] chore: fix location of summary sentence --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 68ec26b..5f0c803 100644 --- a/README.md +++ b/README.md @@ -194,11 +194,12 @@ train_state, buffer_state = jax.jit(train, donate_argnums=(1,))( It is important to include `donate_argnums` when calling `jax.jit` to enable JAX to perform an in-place update of the replay buffer state. Omitting `donate_argnums` would force JAX to create a copy of the state for any modifications to the replay buffer state, potentially negating all performance benefits. More information about buffer donation in JAX can be found in the [documentation](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation). -In summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers. ### Storing Data with Vault As mentioned [above](./README.md#vault-💾), Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\tau$, then we must save to the vault every $\tau - 1$ steps, lest we overwrite (and lose) unsaved data. +In summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers. + ## Benchmarks 📈 Here we provide a series of initial benchmarks outlining the performance of the various Flashbax buffers compared against commonly used open-source buffers. In these benchmarks we (unless explicitly stated otherwise) use the following configuration: From 8709c6588b379b000f00336a634127ba70dc0f29 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 5 Feb 2024 17:42:14 +0200 Subject: [PATCH 31/31] feat: add timesteps overwrite warning when creating vault --- flashbax/vault/vault.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index c1e8839..08e0353 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -143,6 +143,13 @@ def __init__( metadata_path.write_text(json.dumps(self._metadata)) print(f"New vault created at {self._base_path}") + + _fbx_shape = jax.tree_util.tree_leaves(experience_structure)[0].shape + print( + f"Since the provided buffer state has a temporal dimension of {_fbx_shape[1]}, " + f"you must write to the vault at least every {_fbx_shape[1] - 1} " + "timesteps to avoid data loss." + ) else: # If the vault does not exist already, and no experience_structure is provided to create # a new vault, we cannot proceed.