Skip to content

Commit

Permalink
feat: a working mixer, with comments and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury committed Jul 17, 2024
1 parent 00b1668 commit 890f480
Showing 1 changed file with 128 additions and 21 deletions.
149 changes: 128 additions & 21 deletions flashbax/buffers/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,80 +13,187 @@
# limitations under the License.

import functools
from typing import Callable
from typing import Callable, List, TypeVar

import chex
import jax
import jax.numpy as jnp
from chex import dataclass
from chex import Numeric, dataclass
from jax.tree_util import tree_map

from flashbax.buffers.flat_buffer import TransitionSample
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBuffer,
PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState,
)
from flashbax.buffers.trajectory_buffer import (
TrajectoryBuffer,
TrajectoryBufferSample,
TrajectoryBufferState,
)

# Support for Trajectory, Flat, Item buffers, and prioritised variants
sample_types = [
TrajectoryBufferSample,
PrioritisedTrajectoryBufferSample,
TransitionSample,
]
SampleTypes = TypeVar(
"SampleTypes",
TrajectoryBufferSample,
PrioritisedTrajectoryBufferSample,
TransitionSample,
)

state_types = [TrajectoryBufferState, PrioritisedTrajectoryBufferState]
StateTypes = TypeVar(
"StateTypes", TrajectoryBufferState, PrioritisedTrajectoryBufferState
)

BufferTypes = TypeVar("BufferTypes", TrajectoryBuffer, PrioritisedTrajectoryBuffer)


@dataclass(frozen=True)
class Mixer:
"""Pure functions defining the mixer.
Attributes:
sample (Callable): function to sample proportionally from all buffers,
concatenating along the batch axis
can_sample (Callable): function to check if all buffers can sample
"""

sample: Callable
can_sample: Callable


def _batch_slicer(
sample: SampleTypes,
batch_start: int,
batch_end: int,
) -> SampleTypes:
"""Simple utility function to slice a sample along the batch axis.
Args:
sample (SampleTypes): incoming sample
batch_start (int): batch start index
batch_end (int): batch end index
Returns:
SampleTypes: outgoing sliced sample
"""
return tree_map(lambda x: x[batch_start:batch_end, ...], sample)


def sample_mixer_fn(
states,
key,
prop_batch_sizes,
sample_fns,
):
states: List[StateTypes],
key: chex.PRNGKey,
prop_batch_sizes: List[int],
sample_fns: List[Callable[[StateTypes, chex.PRNGKey], SampleTypes]],
) -> SampleTypes:
"""Perform mixed sampling from provided buffer states, according to provided proportions.
Each buffer sample needs to be of the same pytree structure, and the samples are concatenated
along the first axis i.e. the batch axis. For example, if you are sampling trajectories, then
all samples need to be sequences of the same sequence length but batch sizes can differ.
Args:
states (List[StateTypes]): list of buffer states
key (chex.PRNGKey): random key
prop_batch_sizes (List[Numeric]): list of batch sizes sampled from each buffer, calculated
according to the proportions of joint sample size
sample_fns (List[Callable[[StateTypes, chex.PRNGKey], SampleTypes]]): list of pure sample
functions from each buffer
Returns:
SampleTypes: proportionally concatenated samples from all buffers
"""
keys = jax.random.split(
key, len(states)
) # Split the key for each buffer sampling operation

# We first sample from each buffer, and get a list of samples
samples_array = tree_map(
lambda state, sample, key_in: sample(state, key_in),
states,
sample_fns,
[key] * len(sample_fns), # if key.ndim == 1 else key,
is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState,
list(keys),
is_leaf=lambda leaf: type(leaf) in state_types,
)

def _slicer(sample, batch_slice):
return tree_map(lambda x: x[:batch_slice, ...], sample)

# We then slice the samples according to the proportions
prop_batch_samples_array = tree_map(
lambda x, p: _slicer(x, p),
lambda x, p: _batch_slicer(x, 0, p),
samples_array,
prop_batch_sizes,
is_leaf=lambda leaf: type(leaf) == TrajectoryBufferSample,
is_leaf=lambda leaf: type(leaf) in sample_types,
)

# Concatenate the samples along the batch axis
joint_sample = tree_map(
lambda *x: jnp.concatenate(x, axis=0),
*prop_batch_samples_array,
)

return joint_sample


def can_sample_mixer_fn(
states,
can_sample_fns,
):
states: List[StateTypes], can_sample_fns: List[Callable[[StateTypes], bool]]
) -> bool:
"""Check if all buffers can sample.
Args:
states (List[StateTypes]): list of buffer states
can_sample_fns (List[Callable[[StateTypes], bool]]): list of can_sample functions
from each buffer
Returns:
bool: whether all buffers can sample
"""
each_can_sample = tree_map(
lambda state, can_sample: can_sample(state),
states,
can_sample_fns,
is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState,
is_leaf=lambda leaf: type(leaf) in state_types,
)
return all(each_can_sample)


def make_mixer(
buffers: list,
buffers: List[BufferTypes],
sample_batch_size: int,
proportions: list,
):
proportions: List[Numeric],
) -> Mixer:
"""Create the mixer.
Args:
buffers (List[BufferTypes]): list of buffers (pure functions)
sample_batch_size (int): desired batch size of joint sample
proportions (List[Numeric]):
Proportions of joint sample size to be sampled from each buffer, given as a ratio.
Returns:
Mixer: a mixer
"""
assert len(buffers) == len(
proportions
), "Number of buffers and proportions must match"
assert all(
isinstance(b, type(buffers[0])) for b in buffers
), "All buffers must be of the same type"
assert sample_batch_size > 0, "Sample batch size must be greater than 0"

sample_fns = [b.sample for b in buffers]
can_sample_fns = [b.can_sample for b in buffers]

# Normalize proportions and calculate resulting integer batch sizes
props_sum = sum(proportions)
props_norm = [p / props_sum for p in proportions]
prop_batch_sizes = [int(p * sample_batch_size) for p in props_norm]
if sum(prop_batch_sizes) != sample_batch_size:
# In case of rounding errors, add the remainder to the first buffer's proportion
prop_batch_sizes[0] += sample_batch_size - sum(prop_batch_sizes)

mixer_sample_fn = functools.partial(
Expand Down

0 comments on commit 890f480

Please sign in to comment.