From 6816107d8194c049de822a822a5bfe772ca75f42 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Tue, 16 Jan 2024 08:55:39 +0200 Subject: [PATCH] feat: big update and refactor in order to checkpoint namedtuples. --- flashbax/vault/vault.py | 108 +++++++++++++++++++++++++++------------- 1 file changed, 74 insertions(+), 34 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 914bd9c..58db823 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -24,13 +24,13 @@ import tensorstore as ts # type: ignore from chex import Array from etils import epath # type: ignore +from orbax.checkpoint.utils import deserialize_tree, serialize_tree -from flashbax.buffers.trajectory_buffer import TrajectoryBufferState +from flashbax.buffers.trajectory_buffer import Experience, TrajectoryBufferState from flashbax.utils import get_tree_shape_prefix # CURRENT LIMITATIONS / TODO LIST # - Only tested with flat buffers -# - Reloading must be with dicts, not namedtuples DRIVER = "file://" METADATA_FILE = "metadata.json" @@ -38,36 +38,48 @@ VERSION = 0.1 +def _path_to_ds_name(path: str) -> str: + path_str = "" + for p in path: + if isinstance(p, jax.tree_util.DictKey): + path_str += str(p.key) + elif isinstance(p, jax.tree_util.GetAttrKey): + path_str += p.name + path_str += "." + return path_str + + class Vault: def __init__( self, vault_name: str, - init_fbx_state: Optional[TrajectoryBufferState] = None, + experience_structure: Optional[Experience] = None, rel_dir: str = "vaults", vault_uid: Optional[str] = None, metadata: Optional[dict] = None, ) -> None: + + ## --- + # Get the base path for the vault and the metadata path vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S") self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) - - # We use epath for metadata metadata_path = epath.Path(os.path.join(self._base_path, METADATA_FILE)) + # TODO: logging at each step # Check if the vault exists, otherwise create the necessary dirs and files base_path_exists = os.path.exists(self._base_path) if base_path_exists: + # Vault exists, so we load the metadata to access the structure etc. self._metadata = json.loads(metadata_path.read_text()) # Ensure minor versions match assert (self._metadata["version"] // 1) == (VERSION // 1) - elif init_fbx_state is not None: - # init_fbx_state must be a TrajectoryBufferState - assert isinstance(init_fbx_state, TrajectoryBufferState) - + elif experience_structure is not None: # Create the necessary dirs for the vault os.makedirs(self._base_path) + # TODO with serialize_tree? def get_json_ready(obj: Any) -> Any: """Ensure that the object is json serializable. Convert to string if not. @@ -83,41 +95,69 @@ def get_json_ready(obj: Any) -> Any: return obj metadata_json_ready = jax.tree_util.tree_map(get_json_ready, metadata) - experience_structure = jax.tree_map( - lambda x: [str(x.shape), str(x.dtype)], - init_fbx_state.experience, + + # We save the structure of the buffer state + # e.g. [(128, 100, 4), jnp.int32] + # We will use this structure to map over the data stores later + serialised_experience_structure = jax.tree_map( + lambda x: [str(x.shape), x.dtype.name], + serialize_tree( + # Get shape and dtype of each leaf, without serialising the structure itself + jax.eval_shape( + lambda: experience_structure, + ), + ) ) + + # Construct metadata self._metadata = { "version": VERSION, - "structure": experience_structure, + "structure": serialised_experience_structure, **(metadata_json_ready or {}), # Allow user to save extra metadata } + # Dump metadata to file metadata_path.write_text(json.dumps(self._metadata)) else: raise ValueError("Vault does not exist and no init_fbx_state provided.") - # Keep a data store for the vault index - self._vault_index_ds = ts.open( - self._get_base_spec("vault_index"), - dtype=jnp.int32, - shape=(1,), - create=not base_path_exists, - ).result() - self.vault_index = int(self._vault_index_ds.read().result()[0]) + ## --- + # We must now build the tree structure from the metadata, whether created here or loaded from file + if experience_structure is None: + # If an example state is not provided, we simply load from the metadata + # and the result will be a dictionary. + self._tree_structure = self._metadata["structure"] + else: + # If an example state is provided, we try deserialise into that structure + self._tree_structure = deserialize_tree( + self._metadata["structure"], + target=experience_structure, + ) - # Each leaf of the fbx_state.experience is a data store + # Each leaf of the fbx_state.experience maps to a data store self._all_ds = jax.tree_util.tree_map_with_path( lambda path, x: self._init_leaf( - name=jax.tree_util.keystr(path), # Use the path as the name + name=_path_to_ds_name(path), leaf=x, - create_checkpoint=not base_path_exists, + create_ds=not base_path_exists, ), - self._metadata["structure"], - is_leaf=lambda x: isinstance(x, list), + self._tree_structure, + is_leaf=lambda x: isinstance(x, list), # The list [shape, dtype] is a leaf ) + # We keep track of the last fbx buffer idx received self._last_received_fbx_index = 0 + # We store and load the vault index from a separate datastore + self._vault_index_ds = ts.open( + self._get_base_spec("vault_index"), + dtype=jnp.int32, + shape=(1,), + create=not base_path_exists, + ).result() + # Just read synchronously as it's one number + self.vault_index = int(self._vault_index_ds.read().result()[0]) + + def _get_base_spec(self, name: str) -> dict: return { "driver": "zarr", @@ -129,7 +169,7 @@ def _get_base_spec(self, name: str) -> dict: } def _init_leaf( - self, name: str, leaf: list, create_checkpoint: bool = False + self, name: str, leaf: list, create_ds: bool = False ) -> ts.TensorStore: spec = self._get_base_spec(name) leaf_shape = make_tuple(leaf[0]) @@ -137,16 +177,16 @@ def _init_leaf( leaf_ds = ts.open( spec, # Only specify dtype and shape if we are creating a checkpoint - dtype=leaf_dtype if create_checkpoint else None, + dtype=leaf_dtype if create_ds else None, shape=( leaf_shape[0], # Batch dim TIME_AXIS_MAX_LENGTH, # Time dim *leaf_shape[2:], # Experience dim ) - if create_checkpoint - else None, - # Only create directory if we are creating a checkpoint - create=create_checkpoint, + if create_ds + else None, # Don't impose shape if we are loading a vault + # Only create datastore if we are creating the vault + create=create_ds, ).result() # Synchronous return leaf_ds @@ -262,7 +302,7 @@ def _read_leaf( read_leaf: ts.TensorStore, read_interval: Tuple[int, int], ) -> Array: - return read_leaf[:, slice(*read_interval), ...].read().result() + return jnp.asarray(read_leaf[:, slice(*read_interval), ...].read().result()) def read( self, @@ -289,7 +329,7 @@ def read( read_leaf=ds, read_interval=read_interval, ), - self._metadata["structure"], # just for structure + self._tree_structure, self._all_ds, # data stores is_leaf=lambda x: isinstance(x, list), )