Skip to content

Commit

Permalink
Merge branch 'main' into feat/mixed_experience_replay
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury authored Aug 29, 2024
2 parents e16bdab + 0dffe1b commit 3fddc3b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
24 changes: 24 additions & 0 deletions flashbax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,27 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


def get_timestep_count(buffer_state: chex.ArrayTree) -> int:
"""Utility to compute the total number of timesteps currently in the buffer state.
Args:
buffer_state (BufferStateTypes): the buffer state to compute the total timesteps for.
Returns:
int: the total number of timesteps in the buffer state.
"""
# Ensure the buffer state is a valid buffer state.
assert hasattr(buffer_state, "experience")
assert hasattr(buffer_state, "current_index")
assert hasattr(buffer_state, "is_full")

b_size, t_size_max = get_tree_shape_prefix(buffer_state.experience, 2)
t_size = jax.lax.cond(
buffer_state.is_full,
lambda: t_size_max,
lambda: buffer_state.current_index,
)
timestep_count: int = b_size * t_size
return timestep_count
3 changes: 2 additions & 1 deletion flashbax/vault/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__( # noqa: CCR001
"""
# Get the base path for the vault and the metadata path
vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S")
self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str)
base_path_unnorm = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str)
self._base_path = os.path.normpath(base_path_unnorm)
metadata_path = epath.Path(os.path.join(self._base_path, METADATA_FILE))

# Check if the vault exists, otherwise create the necessary dirs and files
Expand Down

0 comments on commit 3fddc3b

Please sign in to comment.