Skip to content

Commit

Permalink
change self.support to nn.Parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Jan 4, 2021
1 parent 3efac01 commit d315052
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
9 changes: 5 additions & 4 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ def test_c51(args=get_args()):
test_envs.seed(args.seed)
# define model
net = C51(*args.state_shape, args.action_shape,
args.num_atoms, args.device).to(args.device)
args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(net, optim, args.gamma, args.num_atoms,
args.v_min, args.v_max, args.n_step,
target_update_freq=args.target_update_freq)
policy = C51Policy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
Expand Down
9 changes: 5 additions & 4 deletions test/discrete/test_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ def test_c51(args=get_args()):
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device,
softmax=True, num_atoms=args.num_atoms).to(args.device)
softmax=True, num_atoms=args.num_atoms)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = C51Policy(net, optim, args.gamma, args.num_atoms,
args.v_min, args.v_max, args.n_step,
target_update_freq=args.target_update_freq)
policy = C51Policy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay > 0:
buf = PrioritizedReplayBuffer(
Expand Down
17 changes: 9 additions & 8 deletions tianshou/policy/modelfree/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, Union, Optional

from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
from tianshou.data import Batch, ReplayBuffer, to_numpy


class C51Policy(DQNPolicy):
Expand Down Expand Up @@ -52,8 +52,10 @@ def __init__(
self._num_atoms = num_atoms
self._v_min = v_min
self._v_max = v_max
self.support = torch.linspace(self._v_min, self._v_max,
self._num_atoms)
self.support = torch.nn.Parameter(
torch.linspace(self._v_min, self._v_max, self._num_atoms),
requires_grad=False,
)
self.delta_z = (v_max - v_min) / (num_atoms - 1)

def _target_q(
Expand Down Expand Up @@ -85,7 +87,7 @@ def forward(
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
dist, h = model(obs_, state=state, info=batch.info)
q = (dist * to_torch_as(self.support, dist)).sum(2)
q = (dist * self.support).sum(2)
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
Expand Down Expand Up @@ -113,13 +115,11 @@ def _target_dist(self, batch: Batch) -> torch.Tensor:
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
support = self.support.to(next_dist.device)
target_support = batch.returns.clamp(
self._v_min, self._v_max).to(next_dist.device)
target_support = batch.returns.clamp(self._v_min, self._v_max)
# An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL
target_dist = (1 - (target_support.unsqueeze(1) -
support.view(1, -1, 1)).abs() / self.delta_z
self.support.view(1, -1, 1)).abs() / self.delta_z
).clamp(0, 1) * next_dist.unsqueeze(1)
return target_dist.sum(-1)

Expand All @@ -135,6 +135,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
curr_dist = curr_dist[np.arange(len(act)), act, :]
cross_entropy = - (target_dist * torch.log(curr_dist + 1e-8)).sum(1)
loss = (cross_entropy * weight).mean()
# ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100
batch.weight = cross_entropy.detach() # prio-buffer
loss.backward()
self.optim.step()
Expand Down

0 comments on commit d315052

Please sign in to comment.