diff --git a/pl_bolts/models/rl/double_dqn_model.py b/pl_bolts/models/rl/double_dqn_model.py index 7b9cb68dd5..a7fcacfa95 100644 --- a/pl_bolts/models/rl/double_dqn_model.py +++ b/pl_bolts/models/rl/double_dqn_model.py @@ -56,7 +56,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD """ # calculates training loss - loss = double_dqn_loss(batch, self.net, self.target_net) + loss = double_dqn_loss(batch, self.net, self.target_net, self.gamma) if self.trainer.use_dp or self.trainer.use_ddp2: loss = loss.unsqueeze(0) diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index 0d43bbe6a2..9f5fb2f709 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -270,7 +270,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD """ # calculates training loss - loss = dqn_loss(batch, self.net, self.target_net) + loss = dqn_loss(batch, self.net, self.target_net, self.gamma) if self.trainer.use_dp or self.trainer.use_ddp2: loss = loss.unsqueeze(0) diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index b26db8e7fb..18afa87d5b 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -114,7 +114,7 @@ def training_step(self, batch, _) -> OrderedDict: indices = indices.cpu().numpy() # calculates training loss - loss, batch_weights = per_dqn_loss(samples, weights, self.net, self.target_net) + loss, batch_weights = per_dqn_loss(samples, weights, self.net, self.target_net, self.gamma) if self.trainer.use_dp or self.trainer.use_ddp2: loss = loss.unsqueeze(0)