Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation: fix handling of special tokens #31254

Merged
merged 8 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,23 +1431,6 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
self._cache.reset()
return self._cache

def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id

if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
else:
return

def _supports_default_dynamic_cache(self) -> bool:
"""
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
Expand All @@ -1473,25 +1456,32 @@ def _prepare_special_tokens(
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
"""

# Convert special tokens to tensors (if they exist)
def _tensor_or_none(token, device=None):
# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
def _tensor_or_none(token_kwargs, token_self, device=None):
if device is None:
device = self.device

token = token_kwargs if token_kwargs is not None else token_self
if token is None or isinstance(token, torch.Tensor):
return token
return torch.tensor(token, device=device, dtype=torch.long)

# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
if self.config.is_encoder_decoder:
generation_config.decoder_start_token_id = self._get_decoder_start_token_id(
generation_config.decoder_start_token_id, generation_config.bos_token_id
)
bos_token_id = _tensor_or_none(
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
)
eos_token_id = _tensor_or_none(
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
)
pad_token_id = _tensor_or_none(
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
)
decoder_start_token_id = _tensor_or_none(
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
)

bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
if self.config.is_encoder_decoder:
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id

# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_id is not None and eos_token_id.ndim == 0:
Expand All @@ -1507,6 +1497,15 @@ def _tensor_or_none(token, device=None):
pad_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")

# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
"to obtain reliable results."
)

# Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_id is None:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions tests/generation/test_framework_agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_transition_scores_greedy_search(self):
tokenizer.pad_token = tokenizer.eos_token

model = model_cls.from_pretrained("distilbert/distilgpt2")
model.generation_config.eos_token_id = None
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
Expand All @@ -170,7 +171,6 @@ def test_transition_scores_greedy_search(self):
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
Expand All @@ -197,6 +197,7 @@ def test_transition_scores_greedy_search_normalized(self):
tokenizer.pad_token = tokenizer.eos_token

model = model_cls.from_pretrained("distilbert/distilgpt2")
model.generation_config.eos_token_id = None
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
Expand All @@ -206,7 +207,6 @@ def test_transition_scores_greedy_search_normalized(self):
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
Expand Down
Loading