Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace functools.partial with jax.tree_util.Partial #39

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flashbax/buffers/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from jax.tree_util import Partial as partial
from typing import Callable, Sequence, TypeVar

import chex
Expand Down Expand Up @@ -200,13 +200,13 @@ def make_mixer(
# In case of rounding errors, add the remainder to the first buffer's proportion
prop_batch_sizes[0] += sample_batch_size - sum(prop_batch_sizes)

mixer_sample_fn = functools.partial(
mixer_sample_fn = partial(
sample_mixer_fn,
prop_batch_sizes=prop_batch_sizes,
sample_fns=sample_fns,
)

mixer_can_sample_fn = functools.partial(
mixer_can_sample_fn = partial(
can_sample_mixer_fn,
can_sample_fns=can_sample_fns,
)
Expand Down
12 changes: 6 additions & 6 deletions flashbax/buffers/prioritised_trajectory_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""


import functools
from jax.tree_util import Partial as partial
import warnings
from typing import TYPE_CHECKING, Callable, Generic, Optional, Tuple

Expand Down Expand Up @@ -799,29 +799,29 @@ def make_prioritised_trajectory_buffer(
max_length_time_axis = max_size // add_batch_size

assert max_length_time_axis is not None
init_fn = functools.partial(
init_fn = partial(
prioritised_init,
add_batch_size=add_batch_size,
max_length_time_axis=max_length_time_axis,
period=period,
)
add_fn = functools.partial(
add_fn = partial(
prioritised_add,
sample_sequence_length=sample_sequence_length,
period=period,
device=device,
)
sample_fn = functools.partial(
sample_fn = partial(
prioritised_sample,
batch_size=sample_batch_size,
sequence_length=sample_sequence_length,
period=period,
)
can_sample_fn = functools.partial(
can_sample_fn = partial(
can_sample, min_length_time_axis=min_length_time_axis
)

set_priorities_fn = functools.partial(
set_priorities_fn = partial(
set_priorities, priority_exponent=priority_exponent, device=device
)

Expand Down
10 changes: 5 additions & 5 deletions flashbax/buffers/trajectory_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
This allows for random sampling of the trajectories within the buffer.
"""

import functools
from jax.tree_util import Partial as partial
import warnings
from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar

Expand Down Expand Up @@ -586,21 +586,21 @@ def make_trajectory_buffer(
max_length_time_axis = max_size // add_batch_size

assert max_length_time_axis is not None
init_fn = functools.partial(
init_fn = partial(
init,
add_batch_size=add_batch_size,
max_length_time_axis=max_length_time_axis,
)
add_fn = functools.partial(
add_fn = partial(
add,
)
sample_fn = functools.partial(
sample_fn = partial(
sample,
batch_size=sample_batch_size,
sequence_length=sample_sequence_length,
period=period,
)
can_sample_fn = functools.partial(
can_sample_fn = partial(
can_sample, min_length_time_axis=min_length_time_axis
)

Expand Down
12 changes: 6 additions & 6 deletions flashbax/buffers/trajectory_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


import functools
from jax.tree_util import Partial as partial
import warnings
from typing import TYPE_CHECKING, Callable, Generic, Optional, Tuple, TypeVar

Expand Down Expand Up @@ -386,24 +386,24 @@ def make_trajectory_queue(
if max_size is not None:
max_length_time_axis = max_size // add_batch_size

init_fn = functools.partial(
init_fn = partial(
init,
add_batch_size=add_batch_size,
max_length_time_axis=max_length_time_axis,
)
add_fn = functools.partial(
add_fn = partial(
add,
)
sample_fn = functools.partial(
sample_fn = partial(
sample,
sequence_length=sample_sequence_length,
)
can_sample_fn = functools.partial(
can_sample_fn = partial(
can_sample,
sample_sequence_length=sample_sequence_length,
max_length_time_axis=max_length_time_axis,
)
can_add_fn = functools.partial(
can_add_fn = partial(
can_add,
add_sequence_length=add_sequence_length,
max_length_time_axis=max_length_time_axis,
Expand Down
Loading