From a34f0485ec6f3d1030b40312cc5943743012a559 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Tue, 23 Jul 2024 14:25:17 +0200 Subject: [PATCH 1/2] feat: timestep util from discussions in #33. --- flashbax/utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/flashbax/utils.py b/flashbax/utils.py index f8a3428..700485c 100644 --- a/flashbax/utils.py +++ b/flashbax/utils.py @@ -68,3 +68,22 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +def get_timestep_count(buffer_state: chex.ArrayTree) -> int: + """Utility to compute the total number of timesteps currently in the buffer state. + + Args: + buffer_state (BufferStateTypes): the buffer state to compute the total timesteps for. + + Returns: + int: the total number of timesteps in the buffer state. + """ + # Ensure the buffer state is a valid buffer state. + assert hasattr(buffer_state, "experience") + assert hasattr(buffer_state, "current_index") + assert hasattr(buffer_state, "is_full") + + b_size, t_size_max = get_tree_shape_prefix(buffer_state.experience, 2) + t_size = t_size_max if buffer_state.is_full else buffer_state.current_index + return int(b_size * t_size) From 66192004f2f4085c4bbfbcc1c47254b879f73d62 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 24 Jul 2024 09:48:02 +0200 Subject: [PATCH 2/2] fix: use lax.cond for jax control flow. --- flashbax/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/flashbax/utils.py b/flashbax/utils.py index 700485c..88ad680 100644 --- a/flashbax/utils.py +++ b/flashbax/utils.py @@ -85,5 +85,10 @@ def get_timestep_count(buffer_state: chex.ArrayTree) -> int: assert hasattr(buffer_state, "is_full") b_size, t_size_max = get_tree_shape_prefix(buffer_state.experience, 2) - t_size = t_size_max if buffer_state.is_full else buffer_state.current_index - return int(b_size * t_size) + t_size = jax.lax.cond( + buffer_state.is_full, + lambda: t_size_max, + lambda: buffer_state.current_index, + ) + timestep_count: int = b_size * t_size + return timestep_count