Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small clean up generate #15611

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,6 @@ def _prepare_decoder_input_ids_for_generation(
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id

def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
return pad_token_id

def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
Expand Down Expand Up @@ -1063,9 +1057,15 @@ def generate(

pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

if eos_token_id is None and hasattr(self.config, "decoder"):
eos_token_id = self.config.decoder.eos_token_id

if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id

output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -1075,11 +1075,6 @@ def generate(
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)

if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id

# 2. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
Expand Down