Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix qvalue mask_action error for obs_next #310

Merged
merged 4 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions test/discrete/test_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
Expand Down Expand Up @@ -134,7 +134,6 @@ def test_fn(epoch, env_step):
def test_pqrdqn(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_qrdqn(args)


Expand Down
4 changes: 2 additions & 2 deletions tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def close(self) -> None:

def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""Normalize observations by statistics in obs_rms."""
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
if self.obs_rms and self.norm_obs:
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max)
return obs
Expand Down
13 changes: 6 additions & 7 deletions tianshou/policy/modelfree/c51.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np
from typing import Any, Dict
from typing import Any, Dict, Optional

from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer
Expand Down Expand Up @@ -57,14 +57,13 @@ def __init__(
)
self.delta_z = (v_max - v_min) / (num_atoms - 1)

def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]

def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
return (logits * self.support).sum(2)
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
return super().compute_q_value((logits * self.support).sum(2), mask)

def _target_dist(self, batch: Batch) -> torch.Tensor:
if self._target:
Expand Down
32 changes: 17 additions & 15 deletions tianshou/policy/modelfree/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,14 @@ def process_fn(
self._gamma, self._n_step, self._rew_norm)
return batch

def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
"""Compute the q value based on the network's raw output and action mask."""
if mask is not None:
# the masked q value should be smaller than logits.min()
min_value = logits.min() - logits.max() - 1.0
logits = logits + to_torch_as(1 - mask, logits) * min_value
return logits

def forward(
Expand Down Expand Up @@ -140,15 +146,10 @@ def forward(
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
logits, h = model(obs_, state=state, info=batch.info)
q = self.compute_q_value(logits)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h)

def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
Expand All @@ -169,10 +170,11 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:

def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if not np.isclose(self.eps, 0.0):
for i in range(len(act)):
if np.random.rand() < self.eps:
q_ = np.random.rand(self.max_action_num)
if hasattr(batch["obs"], "mask"):
q_[~batch["obs"].mask[i]] = -np.inf
act[i] = q_.argmax()
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
if hasattr(batch.obs, "mask"):
q += batch.obs.mask
rand_act = q.argmax(axis=1)
act[rand_mask] = rand_act[rand_mask]
return act
9 changes: 5 additions & 4 deletions tianshou/policy/modelfree/qrdqn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import warnings
import numpy as np
from typing import Any, Dict
import torch.nn.functional as F
from typing import Any, Dict, Optional

from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer
Expand Down Expand Up @@ -61,9 +61,10 @@ def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
next_dist = next_dist[np.arange(len(a)), a, :]
return next_dist # shape: [bsz, num_quantiles]

def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
return logits.mean(2)
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
return super().compute_q_value(logits.mean(2), mask)

def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
Expand Down