From ec23c7efe9a013dde69e7fd4a11538651574c8a7 Mon Sep 17 00:00:00 2001 From: n+e Date: Mon, 15 Mar 2021 08:06:24 +0800 Subject: [PATCH] fix qvalue mask_action error for obs_next (#310) * fix #309 * remove for-loop in dqn expl_noise --- test/discrete/test_qrdqn.py | 3 +-- tianshou/env/venvs.py | 4 +-- tianshou/policy/modelfree/c51.py | 13 +++++----- tianshou/policy/modelfree/dqn.py | 40 ++++++++++++++++-------------- tianshou/policy/modelfree/qrdqn.py | 9 ++++--- 5 files changed, 35 insertions(+), 34 deletions(-) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 2268b63de..59e26b197 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) @@ -134,7 +134,6 @@ def test_fn(epoch, env_step): def test_pqrdqn(args=get_args()): args.prioritized_replay = True args.gamma = .95 - args.seed = 1 test_qrdqn(args) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index adefcf0d8..a15a4e26a 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -287,9 +287,9 @@ def close(self) -> None: def normalize_obs(self, obs: np.ndarray) -> np.ndarray: """Normalize observations by statistics in obs_rms.""" - clip_max = 10.0 # this magic number is from openai baselines - # see baselines/common/vec_env/vec_normalize.py#L10 if self.obs_rms and self.norm_obs: + clip_max = 10.0 # this magic number is from openai baselines + # see baselines/common/vec_env/vec_normalize.py#L10 obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps) obs = np.clip(obs, -clip_max, clip_max) return obs diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 20ef89c1a..fd4decec3 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import Any, Dict +from typing import Any, Dict, Optional from tianshou.policy import DQNPolicy from tianshou.data import Batch, ReplayBuffer @@ -57,14 +57,13 @@ def __init__( ) self.delta_z = (v_max - v_min) / (num_atoms - 1) - def _target_q( - self, buffer: ReplayBuffer, indice: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms] - def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: - """Compute the q value based on the network's raw output logits.""" - return (logits * self.support).sum(2) + def compute_q_value( + self, logits: torch.Tensor, mask: Optional[np.ndarray] + ) -> torch.Tensor: + return super().compute_q_value((logits * self.support).sum(2), mask) def _target_dist(self, batch: Batch) -> torch.Tensor: if self._target: diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 1edeea2fb..5b9f463ed 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -73,13 +73,13 @@ def sync_weight(self) -> None: def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + result = self(batch, input="obs_next") if self._target: - a = self(batch, input="obs_next").act + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) target_q = self(batch, model="model_old", input="obs_next").logits - target_q = target_q[np.arange(len(a)), a] else: - target_q = self(batch, input="obs_next").logits.max(dim=1)[0] + target_q = result.logits + target_q = target_q[np.arange(len(result.act)), result.act] return target_q def process_fn( @@ -95,8 +95,14 @@ def process_fn( self._gamma, self._n_step, self._rew_norm) return batch - def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: - """Compute the q value based on the network's raw output logits.""" + def compute_q_value( + self, logits: torch.Tensor, mask: Optional[np.ndarray] + ) -> torch.Tensor: + """Compute the q value based on the network's raw output and action mask.""" + if mask is not None: + # the masked q value should be smaller than logits.min() + min_value = logits.min() - logits.max() - 1.0 + logits = logits + to_torch_as(1 - mask, logits) * min_value return logits def forward( @@ -140,15 +146,10 @@ def forward( obs = batch[input] obs_ = obs.obs if hasattr(obs, "obs") else obs logits, h = model(obs_, state=state, info=batch.info) - q = self.compute_q_value(logits) + q = self.compute_q_value(logits, getattr(obs, "mask", None)) if not hasattr(self, "max_action_num"): self.max_action_num = q.shape[1] - act: np.ndarray = to_numpy(q.max(dim=1)[1]) - if hasattr(obs, "mask"): - # some of actions are masked, they cannot be selected - q_: np.ndarray = to_numpy(q) - q_[~obs.mask] = -np.inf - act = q_.argmax(axis=1) + act = to_numpy(q.max(dim=1)[1]) return Batch(logits=logits, act=act, state=h) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -169,10 +170,11 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: if not np.isclose(self.eps, 0.0): - for i in range(len(act)): - if np.random.rand() < self.eps: - q_ = np.random.rand(self.max_action_num) - if hasattr(batch["obs"], "mask"): - q_[~batch["obs"].mask[i]] = -np.inf - act[i] = q_.argmax() + bsz = len(act) + rand_mask = np.random.rand(bsz) < self.eps + q = np.random.rand(bsz, self.max_action_num) # [0, 1] + if hasattr(batch.obs, "mask"): + q += batch.obs.mask + rand_act = q.argmax(axis=1) + act[rand_mask] = rand_act[rand_mask] return act diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 7e154e7f7..1a841819c 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -1,8 +1,8 @@ import torch import warnings import numpy as np -from typing import Any, Dict import torch.nn.functional as F +from typing import Any, Dict, Optional from tianshou.policy import DQNPolicy from tianshou.data import Batch, ReplayBuffer @@ -61,9 +61,10 @@ def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: next_dist = next_dist[np.arange(len(a)), a, :] return next_dist # shape: [bsz, num_quantiles] - def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: - """Compute the q value based on the network's raw output logits.""" - return logits.mean(2) + def compute_q_value( + self, logits: torch.Tensor, mask: Optional[np.ndarray] + ) -> torch.Tensor: + return super().compute_q_value(logits.mean(2), mask) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: