Skip to content

Commit

Permalink
Merge c8000de into b9d5f5e
Browse files Browse the repository at this point in the history
  • Loading branch information
BartekRoszak authored Jan 30, 2021
2 parents b9d5f5e + c8000de commit cb1aa65
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/double_dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
"""

# calculates training loss
loss = double_dqn_loss(batch, self.net, self.target_net)
loss = double_dqn_loss(batch, self.net, self.target_net, self.gamma)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
"""

# calculates training loss
loss = dqn_loss(batch, self.net, self.target_net)
loss = dqn_loss(batch, self.net, self.target_net, self.gamma)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/per_dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def training_step(self, batch, _) -> OrderedDict:
indices = indices.cpu().numpy()

# calculates training loss
loss, batch_weights = per_dqn_loss(samples, weights, self.net, self.target_net)
loss, batch_weights = per_dqn_loss(samples, weights, self.net, self.target_net, self.gamma)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)
Expand Down

0 comments on commit cb1aa65

Please sign in to comment.