Skip to content

Commit

Permalink
merge categoricalNet to Net
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Jan 4, 2021
1 parent 22fa78a commit 3efac01
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 58 deletions.
7 changes: 3 additions & 4 deletions test/discrete/test_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tianshou.policy import C51Policy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import CategoricalNet
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer

Expand Down Expand Up @@ -63,9 +63,8 @@ def test_c51(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = CategoricalNet(args.layer_num, args.state_shape,
args.action_shape, args.device, # dueling=(1, 1)
num_atoms=args.num_atoms).to(args.device)
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device,
softmax=True, num_atoms=args.num_atoms).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = C51Policy(net, optim, args.gamma, args.num_atoms,
args.v_min, args.v_max, args.n_step,
Expand Down
63 changes: 9 additions & 54 deletions tianshou/utils/net/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(
self.device = device
self.dueling = dueling
self.softmax = softmax
self.num_atoms = num_atoms
self.action_num = np.prod(action_shape)
input_size = np.prod(state_shape)
if concat:
input_size += np.prod(action_shape)
Expand All @@ -66,7 +68,7 @@ def __init__(
if dueling is None:
if action_shape and not concat:
model += [nn.Linear(
hidden_layer_size, num_atoms * np.prod(action_shape))]
hidden_layer_size, num_atoms * self.action_num)]
else: # dueling DQN
q_layer_num, v_layer_num = dueling
Q, V = [], []
Expand All @@ -80,7 +82,7 @@ def __init__(

if action_shape and not concat:
Q += [nn.Linear(
hidden_layer_size, num_atoms * np.prod(action_shape))]
hidden_layer_size, num_atoms * self.action_num)]
V += [nn.Linear(hidden_layer_size, num_atoms)]

self.Q = nn.Sequential(*Q)
Expand All @@ -99,7 +101,12 @@ def forward(
logits = self.model(s)
if self.dueling is not None: # Dueling DQN
q, v = self.Q(logits), self.V(logits)
if self.num_atoms > 1:
v = v.view(-1, 1, self.num_atoms)
q = q.view(-1, self.action_num, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
elif self.num_atoms > 1:
logits = logits.view(-1, self.action_num, self.num_atoms)
if self.softmax:
logits = torch.softmax(logits, dim=-1)
return logits, state
Expand Down Expand Up @@ -164,55 +171,3 @@ def forward(
# please ensure the first dim is batch size: [bsz, len, ...]
return s, {"h": h.transpose(0, 1).detach(),
"c": c.transpose(0, 1).detach()}


class CategoricalNet(Net):
"""Simple MLP backbone.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` for
more detailed explanation.
"""

def __init__(
self,
layer_num: int,
state_shape: tuple,
action_shape: Optional[Union[tuple, int]] = 0,
device: Union[str, int, torch.device] = "cpu",
concat: bool = False,
hidden_layer_size: int = 128,
dueling: Optional[Tuple[int, int]] = None,
norm_layer: Optional[Callable[[int], nn.modules.Module]] = None,
num_atoms: int = 51,
) -> None:
super().__init__(layer_num, state_shape, action_shape,
device, True, concat, hidden_layer_size,
dueling, norm_layer, num_atoms)
self.action_shape = action_shape
self.num_atoms = num_atoms

def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Mapping: s -> flatten -> logits."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.reshape(s.size(0), -1)
logits = self.model(s)
if self.dueling is not None: # Dueling DQN
q, v = self.Q(logits), self.V(logits)
v = v.view(-1, 1, self.num_atoms),
q = q.view(-1, np.prod(self.action_shape), self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
else:
logits = logits.view(
-1, np.prod(self.action_shape), self.num_atoms)
logits = torch.softmax(logits, dim=-1)
return logits, state

0 comments on commit 3efac01

Please sign in to comment.