Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Moved generation activation during generation of TorchGeneratorAgent …
Browse files Browse the repository at this point in the history
…to dedicated method for easier adjusting. (#4700)

Co-authored-by: Leonard Adolphs <ladolphs@devfair0791.h2.fair>
  • Loading branch information
leox1v and Leonard Adolphs authored Aug 18, 2022
1 parent 9c51887 commit 949f4e8
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,9 @@ def get_prefix_tokens(self, batch: Batch) -> Optional[torch.LongTensor]:
"""
return None

def _generation_activation(self, score: torch.Tensor) -> torch.float32:
return F.log_softmax(score, dim=-1, dtype=torch.float32)

def _generate(
self,
batch: Batch,
Expand Down Expand Up @@ -1214,7 +1217,7 @@ def _generate(
if self.temperature != 1.0:
score.div_(self.temperature)
# force to fp32 to avoid overflow issues during search calculations
score = F.log_softmax(score, dim=-1, dtype=torch.float32) # type: ignore
score = self._generation_activation(score) # type: ignore
if prefix_tokens is not None and _ts < prefix_tokens.size(1):
# generate prefix_tokens for every timestep that they exist
# achieve by setting score of all other tokens to be -inf
Expand Down

0 comments on commit 949f4e8

Please sign in to comment.