Skip to content

Commit

Permalink
fix qvalue mask_action error for obs_next (#310)
Browse files Browse the repository at this point in the history
* fix #309
* remove for-loop in dqn expl_noise
  • Loading branch information
Trinkle23897 authored Mar 15, 2021
1 parent 243ab43 commit ec23c7e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 34 deletions.
3 changes: 1 addition & 2 deletions test/discrete/test_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
Expand Down Expand Up @@ -134,7 +134,6 @@ def test_fn(epoch, env_step):
def test_pqrdqn(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_qrdqn(args)


Expand Down
4 changes: 2 additions & 2 deletions tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def close(self) -> None:

def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""Normalize observations by statistics in obs_rms."""
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
if self.obs_rms and self.norm_obs:
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max)
return obs
Expand Down
13 changes: 6 additions & 7 deletions tianshou/policy/modelfree/c51.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np
from typing import Any, Dict
from typing import Any, Dict, Optional

from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer
Expand Down Expand Up @@ -57,14 +57,13 @@ def __init__(
)
self.delta_z = (v_max - v_min) / (num_atoms - 1)

def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]

def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
return (logits * self.support).sum(2)
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
return super().compute_q_value((logits * self.support).sum(2), mask)

def _target_dist(self, batch: Batch) -> torch.Tensor:
if self._target:
Expand Down
40 changes: 21 additions & 19 deletions tianshou/policy/modelfree/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def sync_weight(self) -> None:

def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
result = self(batch, input="obs_next")
if self._target:
a = self(batch, input="obs_next").act
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
target_q = self(batch, model="model_old", input="obs_next").logits
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(batch, input="obs_next").logits.max(dim=1)[0]
target_q = result.logits
target_q = target_q[np.arange(len(result.act)), result.act]
return target_q

def process_fn(
Expand All @@ -95,8 +95,14 @@ def process_fn(
self._gamma, self._n_step, self._rew_norm)
return batch

def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
"""Compute the q value based on the network's raw output and action mask."""
if mask is not None:
# the masked q value should be smaller than logits.min()
min_value = logits.min() - logits.max() - 1.0
logits = logits + to_torch_as(1 - mask, logits) * min_value
return logits

def forward(
Expand Down Expand Up @@ -140,15 +146,10 @@ def forward(
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
logits, h = model(obs_, state=state, info=batch.info)
q = self.compute_q_value(logits)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h)

def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
Expand All @@ -169,10 +170,11 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:

def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if not np.isclose(self.eps, 0.0):
for i in range(len(act)):
if np.random.rand() < self.eps:
q_ = np.random.rand(self.max_action_num)
if hasattr(batch["obs"], "mask"):
q_[~batch["obs"].mask[i]] = -np.inf
act[i] = q_.argmax()
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
if hasattr(batch.obs, "mask"):
q += batch.obs.mask
rand_act = q.argmax(axis=1)
act[rand_mask] = rand_act[rand_mask]
return act
9 changes: 5 additions & 4 deletions tianshou/policy/modelfree/qrdqn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import warnings
import numpy as np
from typing import Any, Dict
import torch.nn.functional as F
from typing import Any, Dict, Optional

from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer
Expand Down Expand Up @@ -61,9 +61,10 @@ def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
next_dist = next_dist[np.arange(len(a)), a, :]
return next_dist # shape: [bsz, num_quantiles]

def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
return logits.mean(2)
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
return super().compute_q_value(logits.mean(2), mask)

def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
Expand Down

0 comments on commit ec23c7e

Please sign in to comment.