From 1baa1b739ac4a35501cd7fe9f2e47e92eac41795 Mon Sep 17 00:00:00 2001 From: Mick van Gelderen Date: Fri, 8 Nov 2024 13:11:34 -0800 Subject: [PATCH 1/2] Treat warnings as errors --- pyproject.toml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index daf68a5..29fefe7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,15 @@ build-backend = "hatchling.build" [tool.hatch.build] include = ["flashbax/*"] +[tool.pytest.ini_options] +filterwarnings = [ + "error", + "ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax", + "ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax", + "ignore:Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`:UserWarning:flashbax", + "ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax", +] + [project] name = "flashbax" description = "Flashbax is an experience replay library oriented around JAX. Tailored to integrate seamlessly with JAX's Just-In-Time (JIT) compilation." From d085f0967507bd6c059a8965f5bffa39878e82cd Mon Sep 17 00:00:00 2001 From: Mick van Gelderen Date: Fri, 8 Nov 2024 13:13:50 -0800 Subject: [PATCH 2/2] Pass max_length_time_axis instead of max_size Makes it so that the warning: ``` Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*` ``` will no longer be triggered by legitimate use of `create_flat_buffer` and `make_prioritised_flat_buffer`. --- flashbax/buffers/flat_buffer.py | 28 ++++++------------ flashbax/buffers/prioritised_flat_buffer.py | 32 +++++++-------------- pyproject.toml | 1 - 3 files changed, 20 insertions(+), 41 deletions(-) diff --git a/flashbax/buffers/flat_buffer.py b/flashbax/buffers/flat_buffer.py index 41305a7..409f022 100644 --- a/flashbax/buffers/flat_buffer.py +++ b/flashbax/buffers/flat_buffer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import TYPE_CHECKING, Generic, Optional from chex import PRNGKey @@ -113,24 +112,15 @@ def create_flat_buffer( add_batch_size=add_batch_size, ) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message="Setting max_size dynamically sets the `max_length_time_axis` to " - f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`." - "This allows one to control exactly how many transitions are stored in the buffer." - "Note that this overrides the `max_length_time_axis` argument.", - ) - - buffer = make_trajectory_buffer( - max_length_time_axis=None, # Unused because max_size is specified - min_length_time_axis=min_length // add_batch_size + 1, - add_batch_size=add_batch_size, - sample_batch_size=sample_batch_size, - sample_sequence_length=2, - period=1, - max_size=max_length, - ) + buffer = make_trajectory_buffer( + max_length_time_axis=max_length // add_batch_size, + min_length_time_axis=min_length // add_batch_size + 1, + add_batch_size=add_batch_size, + sample_batch_size=sample_batch_size, + sample_sequence_length=2, + period=1, + max_size=None, + ) add_fn = buffer.add diff --git a/flashbax/buffers/prioritised_flat_buffer.py b/flashbax/buffers/prioritised_flat_buffer.py index a4f6a1d..274d69c 100644 --- a/flashbax/buffers/prioritised_flat_buffer.py +++ b/flashbax/buffers/prioritised_flat_buffer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import TYPE_CHECKING, Optional from chex import PRNGKey @@ -100,26 +99,17 @@ def make_prioritised_flat_buffer( if not validate_device(device): device = "cpu" - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message="Setting max_size dynamically sets the `max_length_time_axis` to " - f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`." - "This allows one to control exactly how many transitions are stored in the buffer." - "Note that this overrides the `max_length_time_axis` argument.", - ) - - buffer = make_prioritised_trajectory_buffer( - max_length_time_axis=None, # Unused because max_size is specified - min_length_time_axis=min_length // add_batch_size + 1, - add_batch_size=add_batch_size, - sample_batch_size=sample_batch_size, - sample_sequence_length=2, - period=1, - max_size=max_length, - priority_exponent=priority_exponent, - device=device, - ) + buffer = make_prioritised_trajectory_buffer( + max_length_time_axis=max_length // add_batch_size, + min_length_time_axis=min_length // add_batch_size + 1, + add_batch_size=add_batch_size, + sample_batch_size=sample_batch_size, + sample_sequence_length=2, + period=1, + max_size=None, + priority_exponent=priority_exponent, + device=device, + ) add_fn = buffer.add diff --git a/pyproject.toml b/pyproject.toml index 29fefe7..5a98cb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ filterwarnings = [ "error", "ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax", "ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax", - "ignore:Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`:UserWarning:flashbax", "ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax", ]