Skip to content

Commit

Permalink
feat: big update and refactor in order to checkpoint namedtuples.
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury committed Jan 16, 2024
1 parent 7ca7c32 commit 6816107
Showing 1 changed file with 74 additions and 34 deletions.
108 changes: 74 additions & 34 deletions flashbax/vault/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,50 +24,62 @@
import tensorstore as ts # type: ignore
from chex import Array
from etils import epath # type: ignore
from orbax.checkpoint.utils import deserialize_tree, serialize_tree

from flashbax.buffers.trajectory_buffer import TrajectoryBufferState
from flashbax.buffers.trajectory_buffer import Experience, TrajectoryBufferState
from flashbax.utils import get_tree_shape_prefix

# CURRENT LIMITATIONS / TODO LIST
# - Only tested with flat buffers
# - Reloading must be with dicts, not namedtuples

DRIVER = "file://"
METADATA_FILE = "metadata.json"
TIME_AXIS_MAX_LENGTH = int(10e12) # Upper bound on the length of the time axis
VERSION = 0.1


def _path_to_ds_name(path: str) -> str:
path_str = ""
for p in path:
if isinstance(p, jax.tree_util.DictKey):
path_str += str(p.key)
elif isinstance(p, jax.tree_util.GetAttrKey):
path_str += p.name
path_str += "."
return path_str


class Vault:
def __init__(
self,
vault_name: str,
init_fbx_state: Optional[TrajectoryBufferState] = None,
experience_structure: Optional[Experience] = None,
rel_dir: str = "vaults",
vault_uid: Optional[str] = None,
metadata: Optional[dict] = None,
) -> None:

## ---
# Get the base path for the vault and the metadata path
vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S")
self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str)

# We use epath for metadata
metadata_path = epath.Path(os.path.join(self._base_path, METADATA_FILE))

# TODO: logging at each step
# Check if the vault exists, otherwise create the necessary dirs and files
base_path_exists = os.path.exists(self._base_path)
if base_path_exists:
# Vault exists, so we load the metadata to access the structure etc.
self._metadata = json.loads(metadata_path.read_text())

# Ensure minor versions match
assert (self._metadata["version"] // 1) == (VERSION // 1)

elif init_fbx_state is not None:
# init_fbx_state must be a TrajectoryBufferState
assert isinstance(init_fbx_state, TrajectoryBufferState)

elif experience_structure is not None:
# Create the necessary dirs for the vault
os.makedirs(self._base_path)

# TODO with serialize_tree?
def get_json_ready(obj: Any) -> Any:
"""Ensure that the object is json serializable. Convert to string if not.
Expand All @@ -83,41 +95,69 @@ def get_json_ready(obj: Any) -> Any:
return obj

metadata_json_ready = jax.tree_util.tree_map(get_json_ready, metadata)
experience_structure = jax.tree_map(
lambda x: [str(x.shape), str(x.dtype)],
init_fbx_state.experience,

# We save the structure of the buffer state
# e.g. [(128, 100, 4), jnp.int32]
# We will use this structure to map over the data stores later
serialised_experience_structure = jax.tree_map(
lambda x: [str(x.shape), x.dtype.name],
serialize_tree(
# Get shape and dtype of each leaf, without serialising the structure itself
jax.eval_shape(
lambda: experience_structure,
),
)
)

# Construct metadata
self._metadata = {
"version": VERSION,
"structure": experience_structure,
"structure": serialised_experience_structure,
**(metadata_json_ready or {}), # Allow user to save extra metadata
}
# Dump metadata to file
metadata_path.write_text(json.dumps(self._metadata))
else:
raise ValueError("Vault does not exist and no init_fbx_state provided.")

# Keep a data store for the vault index
self._vault_index_ds = ts.open(
self._get_base_spec("vault_index"),
dtype=jnp.int32,
shape=(1,),
create=not base_path_exists,
).result()
self.vault_index = int(self._vault_index_ds.read().result()[0])
## ---
# We must now build the tree structure from the metadata, whether created here or loaded from file
if experience_structure is None:
# If an example state is not provided, we simply load from the metadata
# and the result will be a dictionary.
self._tree_structure = self._metadata["structure"]
else:
# If an example state is provided, we try deserialise into that structure
self._tree_structure = deserialize_tree(
self._metadata["structure"],
target=experience_structure,
)

# Each leaf of the fbx_state.experience is a data store
# Each leaf of the fbx_state.experience maps to a data store
self._all_ds = jax.tree_util.tree_map_with_path(
lambda path, x: self._init_leaf(
name=jax.tree_util.keystr(path), # Use the path as the name
name=_path_to_ds_name(path),
leaf=x,
create_checkpoint=not base_path_exists,
create_ds=not base_path_exists,
),
self._metadata["structure"],
is_leaf=lambda x: isinstance(x, list),
self._tree_structure,
is_leaf=lambda x: isinstance(x, list), # The list [shape, dtype] is a leaf
)

# We keep track of the last fbx buffer idx received
self._last_received_fbx_index = 0

# We store and load the vault index from a separate datastore
self._vault_index_ds = ts.open(
self._get_base_spec("vault_index"),
dtype=jnp.int32,
shape=(1,),
create=not base_path_exists,
).result()
# Just read synchronously as it's one number
self.vault_index = int(self._vault_index_ds.read().result()[0])


def _get_base_spec(self, name: str) -> dict:
return {
"driver": "zarr",
Expand All @@ -129,24 +169,24 @@ def _get_base_spec(self, name: str) -> dict:
}

def _init_leaf(
self, name: str, leaf: list, create_checkpoint: bool = False
self, name: str, leaf: list, create_ds: bool = False
) -> ts.TensorStore:
spec = self._get_base_spec(name)
leaf_shape = make_tuple(leaf[0])
leaf_dtype = leaf[1]
leaf_ds = ts.open(
spec,
# Only specify dtype and shape if we are creating a checkpoint
dtype=leaf_dtype if create_checkpoint else None,
dtype=leaf_dtype if create_ds else None,
shape=(
leaf_shape[0], # Batch dim
TIME_AXIS_MAX_LENGTH, # Time dim
*leaf_shape[2:], # Experience dim
)
if create_checkpoint
else None,
# Only create directory if we are creating a checkpoint
create=create_checkpoint,
if create_ds
else None, # Don't impose shape if we are loading a vault
# Only create datastore if we are creating the vault
create=create_ds,
).result() # Synchronous
return leaf_ds

Expand Down Expand Up @@ -262,7 +302,7 @@ def _read_leaf(
read_leaf: ts.TensorStore,
read_interval: Tuple[int, int],
) -> Array:
return read_leaf[:, slice(*read_interval), ...].read().result()
return jnp.asarray(read_leaf[:, slice(*read_interval), ...].read().result())

def read(
self,
Expand All @@ -289,7 +329,7 @@ def read(
read_leaf=ds,
read_interval=read_interval,
),
self._metadata["structure"], # just for structure
self._tree_structure,
self._all_ds, # data stores
is_leaf=lambda x: isinstance(x, list),
)
Expand Down

0 comments on commit 6816107

Please sign in to comment.