Skip to content

Commit

Permalink
Merge pull request #175 from GFNOrg/prioritized_replay_buffer
Browse files Browse the repository at this point in the history
Prioritized replay buffer
  • Loading branch information
saleml authored Apr 3, 2024
2 parents f563ce2 + 5472055 commit 4387e5b
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/gfn/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .replay_buffer import ReplayBuffer
from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from .trajectories import Trajectories
from .transitions import Transitions
143 changes: 137 additions & 6 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ReplayBuffer:
Attributes:
env: the Environment instance.
loss_fn: the Loss instance
capacity: the size of the buffer.
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
Expand Down Expand Up @@ -56,13 +55,12 @@ def __init__(
raise ValueError(f"Unknown objects_type: {objects_type}")

self._is_full = False
self._index = 0

def __repr__(self):
return f"ReplayBuffer(capacity={self.capacity}, containing {len(self)} {self.objects_type})"

def __len__(self):
return self.capacity if self._is_full else self._index
return len(self.training_objects)

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
Expand All @@ -73,8 +71,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):

to_add = len(training_objects)

self._is_full |= self._index + to_add >= self.capacity
self._index = (self._index + to_add) % self.capacity
self._is_full |= len(self) + to_add >= self.capacity

self.training_objects.extend(training_objects)
self.training_objects = self.training_objects[-self.capacity :]
Expand Down Expand Up @@ -102,6 +99,140 @@ def save(self, directory: str):
def load(self, directory: str):
"""Loads the buffer from disk."""
self.training_objects.load(os.path.join(directory, "training_objects"))
self._index = len(self.training_objects)
if self.terminating_states is not None:
self.terminating_states.load(os.path.join(directory, "terminating_states"))


class PrioritizedReplayBuffer(ReplayBuffer):
"""A replay buffer of trajectories or transitions.
Attributes:
env: the Environment instance.
capacity: the size of the buffer.
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are different
enough from those already contained in the buffer.
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
def __init__(
self,
env: Env,
objects_type: Literal["transitions", "trajectories", "states"],
capacity: int = 1000,
cutoff_distance: float = 0.,
p_norm_distance: float = 1.,
):
"""Instantiates a prioritized replay buffer.
Args:
env: the Environment instance.
loss_fn: the Loss instance.
capacity: the size of the buffer.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are
different enough from those already contained in the buffer. If the
cutoff is negative, all diversity caclulations are skipped (since all
norms are >= 0).
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
super().__init__(env, objects_type, capacity)
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance

def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
self.training_objects.extend(training_objects)

# Sort elements by logreward, capping the size at the defined capacity.
ix = torch.argsort(self.training_objects.log_rewards)
self.training_objects = self.training_objects[ix]
self.training_objects = self.training_objects[-self.capacity :]

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert terminating_states is not None
self.terminating_states.extend(terminating_states)

# Sort terminating states by logreward as well.
self.terminating_states = self.terminating_states[ix]
self.terminating_states = self.terminating_states[-self.capacity :]

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects

to_add = len(training_objects)
self._is_full |= len(self) + to_add >= self.capacity

# The buffer isn't full yet.
if len(self.training_objects) < self.capacity:
self._add_objs(training_objects)

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
# Sort the incoming elements by their logrewards.
ix = torch.argsort(training_objects.log_rewards, descending=True)
training_objects = training_objects[ix]

# Filter all batch logrewards lower than the smallest logreward in buffer.
min_reward_in_buffer = self.training_objects.log_rewards.min()
idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer
training_objects = training_objects[idx_bigger_rewards]

# TODO: Concatenate input with final state for conditional GFN.
# if self.is_conditional:
# batch = torch.cat(
# [dict_curr_batch["input"], dict_curr_batch["final_state"]],
# dim=-1,
# )
# buffer = torch.cat(
# [self.storage["input"], self.storage["final_state"]],
# dim=-1,
# )

if self.cutoff_distance >= 0:
# Filter the batch for diverse final_states with high reward.
batch = training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
batch_batch_dist = torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
batch.view(batch_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
).squeeze(0)

# Finds the min distance at each row, and removes rows below the cutoff.
r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag.
batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max
batch_batch_dist = batch_batch_dist.min(-1)[0]
idx_batch_batch = batch_batch_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_batch]

# Compute all pairwise distances between the remaining batch & buffer.
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]
batch_buffer_dist = (
torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
buffer.view(buffer_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
)
.squeeze(0)
.min(-1)[0] # Min calculated over rows - the batch elements.
)

# Filter the batch for diverse final_states w.r.t the buffer.
idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_buffer]

# If any training object remain after filtering, add them.
if len(training_objects):
self._add_objs(training_objects)
21 changes: 17 additions & 4 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def __init__(
if when_is_done is not None
else torch.full(size=(0,), fill_value=-1, dtype=torch.long)
)
self._log_rewards = log_rewards
self._log_rewards = (
log_rewards
if log_rewards is not None
else torch.full(size=(0,), fill_value=0, dtype=torch.float)
)
self.log_probs = (
log_probs
if log_probs is not None
Expand Down Expand Up @@ -232,22 +236,31 @@ def extend(self, other: Trajectories) -> None:
self.states.extend(other.states)
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0)

# For log_probs, we first need to make the first dimensions of self.log_probs and other.log_probs equal
# (i.e. the number of steps in the trajectories), and then concatenate them
# For log_probs, we first need to make the first dimensions of self.log_probs
# and other.log_probs equal (i.e. the number of steps in the trajectories), and
# then concatenate them.
new_max_length = max(self.log_probs.shape[0], other.log_probs.shape[0])
self.log_probs = self.extend_log_probs(self.log_probs, new_max_length)
other.log_probs = self.extend_log_probs(other.log_probs, new_max_length)

self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1)

# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards),
dim=0,
)
# Will not be None if object is initialized as empty.
else:
self._log_rewards = None

# Ensure log_probs/rewards are the correct dimensions. TODO: Remove?
if self.log_probs.numel() > 0:
assert self.log_probs.shape == self.actions.batch_shape

if self.log_rewards is not None:
assert len(self.log_rewards) == self.actions.batch_shape[-1]

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and isinstance(
Expand Down
5 changes: 4 additions & 1 deletion src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
len(self.next_states.batch_shape) == 1
and self.states.batch_shape == self.next_states.batch_shape
)
self._log_rewards = log_rewards
self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0)
self.log_probs = log_probs if log_probs is not None else torch.zeros(0)

@property
Expand Down Expand Up @@ -208,10 +208,13 @@ def extend(self, other: Transitions) -> None:
self.actions.extend(other.actions)
self.is_done = torch.cat((self.is_done, other.is_done), dim=0)
self.next_states.extend(other.next_states)

# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards), dim=0
)
# Will not be None if object is initialized as empty.
else:
self._log_rewards = None
self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=0)
24 changes: 19 additions & 5 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import wandb
from tqdm import tqdm, trange

from gfn.containers import ReplayBuffer
from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer
from gfn.gflownet import (
DBGFlowNet,
FMGFlowNet,
Expand Down Expand Up @@ -185,12 +185,21 @@ def main(args): # noqa: C901
objects_type = "states"
else:
raise NotImplementedError(f"Unknown loss: {args.loss}")
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
if args.replay_buffer_prioritized:
replay_buffer = PrioritizedReplayBuffer(
env,
objects_type=objects_type,
capacity=args.replay_buffer_size,
p_norm_distance=1, # Use L1-norm for diversity estimation.
cutoff_distance=0, # -1 turns off diversity-based filtering.
)
else:
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
# Policy parameters have their own LR.
params = [
{
Expand Down Expand Up @@ -292,6 +301,11 @@ def main(args): # noqa: C901
default=0,
help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.",
)
parser.add_argument(
"--replay_buffer_prioritized",
action="store_true",
help="If set and replay_buffer_size > 0, use a prioritized replay buffer.",
)

parser.add_argument(
"--loss",
Expand Down

0 comments on commit 4387e5b

Please sign in to comment.