Skip to content

Commit

Permalink
Pass max_length_time_axis instead of max_size (#43)
Browse files Browse the repository at this point in the history
* Treat warnings as errors

* Pass max_length_time_axis instead of max_size

Makes it so that the warning:

```
Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`
```

will no longer be triggered by legitimate use of `create_flat_buffer` and
`make_prioritised_flat_buffer`.

---------

Co-authored-by: Simon Du Toit <90381208+SimonDuToit@users.noreply.github.com>
  • Loading branch information
mickvangelderen and SimonDuToit authored Dec 10, 2024
1 parent 1b6078d commit 1352bfa
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 41 deletions.
28 changes: 9 additions & 19 deletions flashbax/buffers/flat_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Generic, Optional

from chex import PRNGKey
Expand Down Expand Up @@ -113,24 +112,15 @@ def create_flat_buffer(
add_batch_size=add_batch_size,
)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Setting max_size dynamically sets the `max_length_time_axis` to "
f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`."
"This allows one to control exactly how many transitions are stored in the buffer."
"Note that this overrides the `max_length_time_axis` argument.",
)

buffer = make_trajectory_buffer(
max_length_time_axis=None, # Unused because max_size is specified
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=max_length,
)
buffer = make_trajectory_buffer(
max_length_time_axis=max_length // add_batch_size,
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=None,
)

add_fn = buffer.add

Expand Down
32 changes: 11 additions & 21 deletions flashbax/buffers/prioritised_flat_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Optional

from chex import PRNGKey
Expand Down Expand Up @@ -100,26 +99,17 @@ def make_prioritised_flat_buffer(
if not validate_device(device):
device = "cpu"

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Setting max_size dynamically sets the `max_length_time_axis` to "
f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`."
"This allows one to control exactly how many transitions are stored in the buffer."
"Note that this overrides the `max_length_time_axis` argument.",
)

buffer = make_prioritised_trajectory_buffer(
max_length_time_axis=None, # Unused because max_size is specified
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=max_length,
priority_exponent=priority_exponent,
device=device,
)
buffer = make_prioritised_trajectory_buffer(
max_length_time_axis=max_length // add_batch_size,
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=None,
priority_exponent=priority_exponent,
device=device,
)

add_fn = buffer.add

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ filterwarnings = [
"error",
"ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax",
"ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax",
"ignore:Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`:UserWarning:flashbax",
"ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax",
]

Expand Down

0 comments on commit 1352bfa

Please sign in to comment.