diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index e9afe742caa..d0d69b1ee58 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1134,6 +1134,9 @@ def get_prefix_tokens(self, batch: Batch) -> Optional[torch.LongTensor]: """ return None + def _generation_activation(self, score: torch.Tensor) -> torch.float32: + return F.log_softmax(score, dim=-1, dtype=torch.float32) + def _generate( self, batch: Batch, @@ -1214,7 +1217,7 @@ def _generate( if self.temperature != 1.0: score.div_(self.temperature) # force to fp32 to avoid overflow issues during search calculations - score = F.log_softmax(score, dim=-1, dtype=torch.float32) # type: ignore + score = self._generation_activation(score) # type: ignore if prefix_tokens is not None and _ts < prefix_tokens.size(1): # generate prefix_tokens for every timestep that they exist # achieve by setting score of all other tokens to be -inf