-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Include padding mask in generation #2096
Include padding mask in generation #2096
Conversation
@@ -48,7 +48,7 @@ def _prepare_decoder_ids_for_generation( | |||
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx | |||
|
|||
def greedy_search( | |||
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs | |||
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: int, **model_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does changing pas_idx from Optional to required break any call sites?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. Only being called from the entry point method atm.
|
||
# Append the next tokens to the previous tokens | ||
input_ids = torch.cat([input_ids, next_tokens], dim=-1) | ||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does the [:, None]
do here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing as unsqueezing the last dim
tokens_for_single_example = generation_model.generate(inputs, num_beams=1, max_length=30) | ||
generated_text_for_single_example = self.transform.decode(tokens_for_single_example.tolist()) | ||
|
||
self.assertEqual(generated_text[0], generated_text_for_single_example[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we do generated_text_for_single_example[-1]
instead of generated_text_for_single_example[0]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was originally going to pass multiple through the second pass, but did not. Both get the same result though. -1 == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Bug
Expect batched input to match single input e.g.
Before this would not create the same output1. The issue was that the src_key_padding_mask was not being propagated forward.
Fix
Create padding mask and add it to
model_kwargs
and pass it to the forward function.