From 3efac0163abb78cfdf621c05b026c165fca9b104 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 4 Jan 2021 16:21:31 +0800 Subject: [PATCH] merge categoricalNet to Net --- test/discrete/test_c51.py | 7 ++-- tianshou/utils/net/common.py | 63 ++++++------------------------------ 2 files changed, 12 insertions(+), 58 deletions(-) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index eff2c898f..eb97e56b6 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -8,7 +8,7 @@ from tianshou.policy import C51Policy from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import CategoricalNet +from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer @@ -63,9 +63,8 @@ def test_c51(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = CategoricalNet(args.layer_num, args.state_shape, - args.action_shape, args.device, # dueling=(1, 1) - num_atoms=args.num_atoms).to(args.device) + net = Net(args.layer_num, args.state_shape, args.action_shape, args.device, + softmax=True, num_atoms=args.num_atoms).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = C51Policy(net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, args.n_step, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index ea804385a..40f050af8 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -53,6 +53,8 @@ def __init__( self.device = device self.dueling = dueling self.softmax = softmax + self.num_atoms = num_atoms + self.action_num = np.prod(action_shape) input_size = np.prod(state_shape) if concat: input_size += np.prod(action_shape) @@ -66,7 +68,7 @@ def __init__( if dueling is None: if action_shape and not concat: model += [nn.Linear( - hidden_layer_size, num_atoms * np.prod(action_shape))] + hidden_layer_size, num_atoms * self.action_num)] else: # dueling DQN q_layer_num, v_layer_num = dueling Q, V = [], [] @@ -80,7 +82,7 @@ def __init__( if action_shape and not concat: Q += [nn.Linear( - hidden_layer_size, num_atoms * np.prod(action_shape))] + hidden_layer_size, num_atoms * self.action_num)] V += [nn.Linear(hidden_layer_size, num_atoms)] self.Q = nn.Sequential(*Q) @@ -99,7 +101,12 @@ def forward( logits = self.model(s) if self.dueling is not None: # Dueling DQN q, v = self.Q(logits), self.V(logits) + if self.num_atoms > 1: + v = v.view(-1, 1, self.num_atoms) + q = q.view(-1, self.action_num, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v + elif self.num_atoms > 1: + logits = logits.view(-1, self.action_num, self.num_atoms) if self.softmax: logits = torch.softmax(logits, dim=-1) return logits, state @@ -164,55 +171,3 @@ def forward( # please ensure the first dim is batch size: [bsz, len, ...] return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()} - - -class CategoricalNet(Net): - """Simple MLP backbone. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` for - more detailed explanation. - """ - - def __init__( - self, - layer_num: int, - state_shape: tuple, - action_shape: Optional[Union[tuple, int]] = 0, - device: Union[str, int, torch.device] = "cpu", - concat: bool = False, - hidden_layer_size: int = 128, - dueling: Optional[Tuple[int, int]] = None, - norm_layer: Optional[Callable[[int], nn.modules.Module]] = None, - num_atoms: int = 51, - ) -> None: - super().__init__(layer_num, state_shape, action_shape, - device, True, concat, hidden_layer_size, - dueling, norm_layer, num_atoms) - self.action_shape = action_shape - self.num_atoms = num_atoms - - def forward( - self, - s: Union[np.ndarray, torch.Tensor], - state: Optional[Dict[str, torch.Tensor]] = None, - info: Dict[str, Any] = {}, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """Mapping: s -> flatten -> logits.""" - s = to_torch(s, device=self.device, dtype=torch.float32) - s = s.reshape(s.size(0), -1) - logits = self.model(s) - if self.dueling is not None: # Dueling DQN - q, v = self.Q(logits), self.V(logits) - v = v.view(-1, 1, self.num_atoms), - q = q.view(-1, np.prod(self.action_shape), self.num_atoms) - logits = q - q.mean(dim=1, keepdim=True) + v - else: - logits = logits.view( - -1, np.prod(self.action_shape), self.num_atoms) - logits = torch.softmax(logits, dim=-1) - return logits, state