Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix term trunc #336

Merged
merged 4 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions hive/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,11 @@ def preprocess_update_info(self, update_info):
)
preprocessed_update_info = {
"observation": update_info["observation"],
"next_observation": update_info["next_observation"],
"action": update_info["action"],
"reward": update_info["reward"],
"done": update_info["terminated"] or update_info["truncated"],
"terminated": update_info["terminated"],
"truncated": update_info["truncated"],
}
if "agent_id" in update_info:
preprocessed_update_info["agent_id"] = int(update_info["agent_id"])
Expand Down Expand Up @@ -283,7 +285,7 @@ def update(self, update_info, agent_traj_state=None):
update_info: dictionary containing all the necessary information
from the environment to update the agent. Should contain a full
transition, with keys for "observation", "action", "reward",
"next_observation", and "done".
"next_observation", "terminated", and "truncated".
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.

Expand Down Expand Up @@ -323,7 +325,7 @@ def update(self, update_info, agent_traj_state=None):
next_qvals, _ = torch.max(next_qvals, dim=1)

q_targets = batch["reward"] + self._discount_rate * next_qvals * (
1 - batch["done"]
1 - batch["terminated"]
)

loss = self._loss_fn(pred_qvals, q_targets).mean()
Expand Down
4 changes: 2 additions & 2 deletions hive/agents/drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def update(self, update_info, agent_traj_state=None):
update_info: dictionary containing all the necessary information
from the environment to update the agent. Should contain a full
transition, with keys for "observation", "action", "reward",
"next_observation", and "done".
"next_observation", "terminated", and "truncated".
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.

Expand Down Expand Up @@ -274,7 +274,7 @@ def update(self, update_info, agent_traj_state=None):
next_qvals, _ = torch.max(next_qvals, dim=-1)

q_targets = batch["reward"] + self._discount_rate * next_qvals * (
1 - batch["done"]
1 - batch["terminated"]
)

loss = self._loss_fn(pred_qvals, q_targets).mean()
Expand Down
4 changes: 3 additions & 1 deletion hive/agents/legal_moves_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def create_q_networks(self, representation_net):
def preprocess_update_info(self, update_info):
preprocessed_update_info = {
"observation": update_info["observation"]["observation"],
"next_observation": update_info["next_observation"]["observation"],
"action": update_info["action"],
"reward": update_info["reward"],
"done": update_info["terminated"] or update_info["truncated"],
"terminated": update_info["terminated"],
"truncated": update_info["truncated"],
"action_mask": action_encoding(update_info["observation"]["action_mask"]),
}
if "agent_id" in update_info:
Expand Down
3 changes: 2 additions & 1 deletion hive/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def preprocess_update_info(self, update_info, agent_traj_state):
"observation": update_info["observation"],
"action": update_info["action"],
"reward": update_info["reward"],
"done": done,
"terminated": update_info["terminated"],
"truncated": update_info["truncated"],
"logprob": agent_traj_state["logprob"],
"values": agent_traj_state["value"],
"returns": np.empty(agent_traj_state["value"].shape),
Expand Down
19 changes: 11 additions & 8 deletions hive/agents/rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def update(self, update_info, agent_traj_state=None):
update_info: dictionary containing all the necessary information
from the environment to update the agent. Should contain a full
transition, with keys for "observation", "action", "reward",
"next_observation", and "done".
"next_observation", "terminated", and "truncated".
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.

Expand Down Expand Up @@ -308,7 +308,10 @@ def update(self, update_info, agent_traj_state=None):
log_p = torch.log(probs)
with torch.no_grad():
target_prob = self.target_projection(
next_state_inputs, next_action, batch["reward"], batch["done"]
next_state_inputs,
next_action,
batch["reward"],
batch["terminated"],
)

loss = -(target_prob * log_p).sum(-1)
Expand All @@ -320,7 +323,7 @@ def update(self, update_info, agent_traj_state=None):
next_qvals = next_qvals[torch.arange(next_qvals.size(0)), next_action]

q_targets = batch["reward"] + self._discount_rate * next_qvals * (
1 - batch["done"]
1 - batch["terminated"]
)

loss = self._loss_fn(pred_qvals, q_targets)
Expand Down Expand Up @@ -349,7 +352,7 @@ def update(self, update_info, agent_traj_state=None):
self._update_target()
return agent_traj_state

def target_projection(self, target_net_inputs, next_action, reward, done):
def target_projection(self, target_net_inputs, next_action, reward, terminated):
"""Project distribution of target Q-values.

Args:
Expand All @@ -359,17 +362,17 @@ def target_projection(self, target_net_inputs, next_action, reward, done):
next_action (~torch.Tensor): Tensor containing next actions used to
compute target distribution.
reward (~torch.Tensor): Tensor containing rewards for the current batch.
done (~torch.Tensor): Tensor containing whether the states in the current
batch are terminal.
terminated (~torch.Tensor): Tensor containing whether the states in
the current batch are terminal.

"""
reward = reward.reshape(-1, 1)
not_done = 1 - done.reshape(-1, 1)
not_terminated = 1 - terminated.reshape(-1, 1)
batch_size = reward.size(0)
next_dist = self._target_qnet.dist(*target_net_inputs)
next_dist = next_dist[torch.arange(batch_size), next_action]

dist_supports = reward + not_done * self._discount_rate * self._supports
dist_supports = reward + not_terminated * self._discount_rate * self._supports
dist_supports = dist_supports.clamp(min=self._v_min, max=self._v_max)
dist_supports = dist_supports.unsqueeze(1)
dist_supports = dist_supports.tile([1, self._atoms, 1])
Expand Down
10 changes: 7 additions & 3 deletions hive/agents/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,11 @@ def preprocess_update_info(self, update_info):
)
preprocessed_update_info = {
"observation": update_info["observation"],
"next_observation": update_info["next_observation"],
"action": self.scale_action(update_info["action"]),
"reward": update_info["reward"],
"done": update_info["terminated"] or update_info["truncated"],
"terminated": update_info["terminated"],
"truncated": update_info["truncated"],
}
if "agent_id" in update_info:
preprocessed_update_info["agent_id"] = int(update_info["agent_id"])
Expand Down Expand Up @@ -314,7 +316,7 @@ def update(self, update_info, agent_traj_state=None):
update_info: dictionary containing all the necessary information
from the environment to update the agent. Should contain a full
transition, with keys for "observation", "action", "reward",
"next_observation", and "done".
"next_observation", "terminated", and "truncated
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.

Expand Down Expand Up @@ -360,7 +362,9 @@ def update(self, update_info, agent_traj_state=None):
next_q_vals, _ = torch.min(next_q_vals, dim=1, keepdim=True)
target_q_values = (
batch["reward"][:, None]
+ (1 - batch["done"][:, None]) * self._discount_rate * next_q_vals
+ (1 - batch["terminated"][:, None])
* self._discount_rate
* next_q_vals
)

# Critic losses
Expand Down
1 change: 0 additions & 1 deletion hive/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def __init__(self, env_name, num_players, **kwargs):
self._actions = []
self._obs = None
self._info = None
self._done = False
self._termination = False
self._truncation = False

Expand Down
100 changes: 75 additions & 25 deletions hive/replays/circular_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from hive.replays.replay_buffer import BaseReplayBuffer
from hive.utils.utils import create_folder, seeder
import copy


class CircularReplayBuffer(BaseReplayBuffer):
Expand All @@ -26,6 +27,7 @@ def __init__(
reward_dtype=np.float32,
extra_storage_types=None,
num_players_sharing_buffer: int = None,
optimize_storage: bool = True,
):
"""Constructor for CircularReplayBuffer.

Expand Down Expand Up @@ -56,14 +58,23 @@ def __init__(
(type, shape) tuple.
num_players_sharing_buffer (int): Number of agents that share their
buffers. It is used for self-play.
optimize_storage (bool): If True, the buffer will only store each
observation once. Otherwise, next_observation will be stored for
each transition. Note, if optimize_storage is True, the
next_observation for a transition where terminated OR truncated
is True will not be correct.
"""
self._capacity = capacity
self._optimize_storage = optimize_storage
self._specs = {
"observation": (observation_dtype, observation_shape),
"done": (np.uint8, ()),
"terminated": (np.uint8, ()),
"action": (action_dtype, action_shape),
"reward": (reward_dtype, reward_shape),
}
if not optimize_storage:
self._specs["next_observation"] = (observation_dtype, observation_shape)
if extra_storage_types is not None:
self._specs.update(extra_storage_types)
self._storage = self._create_storage(capacity, self._specs)
Expand Down Expand Up @@ -123,7 +134,16 @@ def _pad_buffer(self, pad_length):
}
self._add_transition(**transition)

def add(self, observation, action, reward, done, **kwargs):
def add(
self,
observation,
next_observation,
action,
reward,
terminated,
truncated,
**kwargs,
):
"""Adds a transition to the buffer.
The required components of a transition are given as positional arguments. The
user can pass additional components to store in the buffer as kwargs as long as
Expand All @@ -133,12 +153,16 @@ def add(self, observation, action, reward, done, **kwargs):
if self._episode_start:
self._pad_buffer(self._stack_size - 1)
self._episode_start = False
done = terminated or truncated
transition = {
"observation": observation,
"action": action,
"reward": reward,
"done": done,
"terminated": terminated,
}
if not self._optimize_storage:
transition["next_observation"] = next_observation
transition.update(kwargs)
for key in self._specs:
obj_type = (
Expand Down Expand Up @@ -238,15 +262,17 @@ def sample(self, batch_size):
indices = self._sample_indices(batch_size)
batch = {}
batch["indices"] = indices
terminals = self._get_from_storage("done", indices, self._n_step)
dones = self._get_from_storage("done", indices, self._n_step)
terminated = self._get_from_storage("terminated", indices, self._n_step)

if self._n_step == 1:
is_terminal = terminals
is_terminal = dones
trajectory_lengths = np.ones(batch_size)
else:
is_terminal = terminals.any(axis=1).astype(int)
is_terminal = dones.any(axis=1).astype(int)
terminated = terminated.any(axis=1).astype(int)
trajectory_lengths = (
np.argmax(terminals.astype(bool), axis=1) + 1
np.argmax(dones.astype(bool), axis=1) + 1
) * is_terminal + self._n_step * (1 - is_terminal)
trajectory_lengths = trajectory_lengths.astype(np.int64)

Expand All @@ -257,8 +283,17 @@ def sample(self, batch_size):
indices - self._stack_size + 1,
num_to_access=self._stack_size,
)
elif key == "next_observation":
batch[key] = self._get_from_storage(
"next_observation",
indices - self._stack_size + 1,
num_to_access=self._stack_size,
)
elif key == "done":
batch["done"] = is_terminal
pass
elif key == "terminated":
batch["terminated"] = terminated
batch["truncated"] = is_terminal - terminated
elif key == "reward":
rewards = self._get_from_storage("reward", indices, self._n_step)
if self._n_step == 1:
Expand All @@ -273,11 +308,12 @@ def sample(self, batch_size):
batch[key] = self._get_from_storage(key, indices)

batch["trajectory_lengths"] = trajectory_lengths
batch["next_observation"] = self._get_from_storage(
"observation",
indices + trajectory_lengths - self._stack_size + 1,
num_to_access=self._stack_size,
)
if "next_observation" not in batch:
batch["next_observation"] = self._get_from_storage(
"observation",
indices + trajectory_lengths - self._stack_size + 1,
num_to_access=self._stack_size,
)
return batch

def save(self, dname):
Expand Down Expand Up @@ -343,7 +379,8 @@ def __init__(self, capacity=1e5, compress=False, seed=42, **kwargs):
"action": "int8",
"reward": "int8" if self._compress else "float32",
"next_observation": "int8" if self._compress else "float32",
"done": "int8" if self._compress else "float32",
"truncated": "int8",
"terminated": "int8",
}

self._data = {}
Expand All @@ -352,32 +389,45 @@ def __init__(self, capacity=1e5, compress=False, seed=42, **kwargs):

self._write_index = -1
self._n = 0
self._previous_transition = None
self.transition = None

def add(self, observation, action, reward, done, **kwargs):
def add(
self,
observation,
next_observation,
action,
reward,
terminated,
truncated,
**kwargs,
):
"""
Adds transition to the buffer

Args:
observation: The current observation
next_observation: The next observation
action: The action taken on the current observation
reward: The reward from taking action at current observation
done: If current observation was the last observation in the episode
terminated: If the trajectory was terminated at the current
transition
truncated: If the trajectory was truncated at the current transition
"""
if self._previous_transition is not None:
self._previous_transition["next_observation"] = observation
self._write_index = (self._write_index + 1) % self._capacity
self._n = int(min(self._capacity, self._n + 1))
for key in self._data:
self._data[key][self._write_index] = np.asarray(
self._previous_transition[key], dtype=self._dtype[key]
)
self._previous_transition = {
# if self._previous_transition is not None:
transition = {
"observation": observation,
"action": action,
"reward": reward,
"done": done,
"terminated": terminated,
"truncated": truncated,
"next_observation": next_observation,
}
self._write_index = (self._write_index + 1) % self._capacity
self._n = int(min(self._capacity, self._n + 1))
for key in self._data:
self._data[key][self._write_index] = np.asarray(
transition[key], dtype=self._dtype[key]
)

def sample(self, batch_size=32):
"""
Expand Down
3 changes: 2 additions & 1 deletion hive/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def get_new_seed(self, group=None):
Args:
group (str): The name of the group to get the seed for.
"""
seed = self._current_seeds[group]
self._current_seeds[group] += 1
return self._current_seeds[group]
return seed


seeder = Seeder()
Expand Down
Loading