Skip to content

Commit

Permalink
Merge pull request #22 from instadeepai/fix/improved-vault-compressio…
Browse files Browse the repository at this point in the history
…n-api

fix: improve the compression api of vaults
  • Loading branch information
callumtilbury authored Mar 12, 2024
2 parents 0429642 + 88894f8 commit 2b8354d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
34 changes: 17 additions & 17 deletions flashbax/vault/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"id": "gzip",
"level": 5,
}
VERSION = 1.1
VERSION = 1.2


def _path_to_ds_name(path: Tuple[Union[DictKey, GetAttrKey], ...]) -> str:
Expand Down Expand Up @@ -87,8 +87,8 @@ def __init__( # noqa: CCR001
vault_uid (Optional[str], optional): Unique identifier for this vault.
Defaults to None, which will use the current timestamp.
compression (Optional[dict], optional):
Compression settings for the vault. Defaults to None, which will use
the default settings.
Compression settings used when when creating the vault.
Defaults to None, which will use the default compression.
metadata (Optional[dict], optional):
Any additional metadata to save. Defaults to None.
Expand All @@ -115,6 +115,11 @@ def __init__( # noqa: CCR001

print(f"Loading vault found at {self._base_path}")

if compression is not None:
print(
"Requested compression settings will be ignored as the vault already exists."
)

elif experience_structure is not None:
# Create the necessary dirs for the vault
os.makedirs(self._base_path)
Expand Down Expand Up @@ -145,7 +150,6 @@ def __init__( # noqa: CCR001
"version": VERSION,
"structure_shape": serialised_experience_structure_shape,
"structure_dtype": serialised_experience_structure_dtype,
"compression": compression or COMPRESSION_DEFAULT,
**(metadata_json_ready or {}), # Allow user to save extra metadata
}
# Dump metadata to file
Expand Down Expand Up @@ -184,12 +188,8 @@ def __init__( # noqa: CCR001
target=experience_structure,
)

# Load compression settings from metadata
self._compression = (
self._metadata["compression"]
if "compression" in self._metadata
else COMPRESSION_DEFAULT
)
# Keep the compression settings, to be used in init_leaf, in case we're creating the vault
self._compression = compression

# Each leaf of the fbx_state.experience maps to a data store, so we tree map over the
# tree structure to create each of the data stores.
Expand Down Expand Up @@ -235,11 +235,6 @@ def _get_base_spec(self, name: str) -> dict:
"base": f"{DRIVER}{self._base_path}",
"path": name,
},
"metadata": {
"compressor": {
**self._compression,
}
},
}

def _init_leaf(
Expand All @@ -260,14 +255,19 @@ def _init_leaf(

leaf_shape, leaf_dtype = None, None
if create_ds:
# Only specify dtype and shape if we are creating a vault
# (i.e. don't impose dtype and shape if we are _loading_ a vault)
# Only specify dtype, shape, and compression if we are creating a vault
# (i.e. don't impose these fields if we are _loading_ a vault)
leaf_shape = (
shape[0], # Batch dim
TIME_AXIS_MAX_LENGTH, # Time dim, which we extend
*shape[2:], # Experience dim(s)
)
leaf_dtype = dtype
spec["metadata"] = {
"compressor": COMPRESSION_DEFAULT
if self._compression is None
else self._compression
}

leaf_ds = ts.open(
spec,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ authors = [
{name="InstaDeep" , email = "hello@instadeep.com"},
]
requires-python = ">=3.9"
version = "0.1.1"
version = "0.1.2"
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Environment :: Console",
Expand Down

0 comments on commit 2b8354d

Please sign in to comment.