diff --git a/flashbax/buffers/prioritised_trajectory_buffer.py b/flashbax/buffers/prioritised_trajectory_buffer.py index d93d6b1..88893a3 100644 --- a/flashbax/buffers/prioritised_trajectory_buffer.py +++ b/flashbax/buffers/prioritised_trajectory_buffer.py @@ -801,7 +801,7 @@ def make_prioritised_trajectory_buffer( init_fn = functools.partial( prioritised_init, add_batch_size=add_batch_size, - max_length_time_axis=max_length_time_axis, + max_length_time_axis=max_length_time_axis, # type: ignore period=period, ) add_fn = functools.partial( diff --git a/flashbax/buffers/trajectory_buffer.py b/flashbax/buffers/trajectory_buffer.py index a06ad45..ba005cd 100644 --- a/flashbax/buffers/trajectory_buffer.py +++ b/flashbax/buffers/trajectory_buffer.py @@ -588,7 +588,7 @@ def make_trajectory_buffer( init_fn = functools.partial( init, add_batch_size=add_batch_size, - max_length_time_axis=max_length_time_axis, + max_length_time_axis=max_length_time_axis, # type: ignore ) add_fn = functools.partial( add,