Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rnn efficient padding #339

Merged
merged 22 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions hive/agents/drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,15 @@ def update(self, update_info, agent_traj_state=None):
dtype=torch.float,
)
mask[self._burn_frames :] = 1.0
mask = mask.view(1, -1)
mask = mask.unsqueeze(0).repeat(len(batch["reward"]), 1)
mask = mask & batch["mask"]
interm_loss *= mask
loss = interm_loss.mean()
loss = interm_loss.sum() / mask.sum()

else:
loss = self._loss_fn(pred_qvals, q_targets).mean()
interm_loss = self._loss_fn(pred_qvals, q_targets)
interm_loss *= batch["mask"]
loss = interm_loss.sum() / batch["mask"].sum()

if self._logger.should_log(self._timescale):
self._logger.log_scalar("train_loss", loss, self._timescale)
Expand Down
49 changes: 49 additions & 0 deletions hive/replays/recurrent_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,53 @@ def __init__(
)
self._max_seq_len = max_seq_len

def add(
self,
observation,
next_observation,
action,
reward,
terminated,
truncated,
**kwargs,
):
"""Adds a transition to the buffer.
The required components of a transition are given as positional arguments. The
user can pass additional components to store in the buffer as kwargs as long as
they were defined in the specification in the constructor.
"""

done = terminated or truncated
transition = {
"observation": observation,
"action": action,
"reward": reward,
"done": done,
"terminated": terminated,
}
if not self._optimize_storage:
transition["next_observation"] = next_observation
transition.update(kwargs)
for key in self._specs:
obj_type = (
transition[key].dtype
if hasattr(transition[key], "dtype")
else type(transition[key])
)
if not np.can_cast(obj_type, self._specs[key][0], casting="same_kind"):
raise ValueError(
f"Key {key} has wrong dtype. Expected {self._specs[key][0]},"
f"received {type(transition[key])}."
)
if self._num_players_sharing_buffer is None:
self._add_transition(**transition)
else:
self._episode_storage[kwargs["agent_id"]].append(transition)
if done:
for transition in self._episode_storage[kwargs["agent_id"]]:
self._add_transition(**transition)
self._episode_storage[kwargs["agent_id"]] = []

def _get_from_array(self, array, indices, num_to_access=1):
"""Retrieves consecutive elements in the array, wrapping around if necessary.
If more than 1 element is being accessed, the elements are concatenated along
Expand Down Expand Up @@ -196,4 +243,6 @@ def sample(self, batch_size):
num_to_access=self._max_seq_len,
)

mask = np.cumsum(batch["done"], axis=1, dtype=bool)
batch["mask"] = mask
return batch