Skip to content

Commit

Permalink
add exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jul 21, 2022
1 parent 8893aa2 commit 2cdf84b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,11 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)

# Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the
# tests, through)
if "transfoxl" in str(self).lower():
model_kwargs.pop("attention_mask", None)

unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs`` if often used to handle optional forward pass inputs like `attention_mask`. If
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self):

# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)

def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def test_generate_fp16(self):
if torch_device == "cuda":
model.half()
model.generate(**input_dict)
model.generate(**input_dict, do_sample=True, early_stopping=False, num_return_sequences=3)
model.generate(**input_dict, do_sample=True, num_return_sequences=3)

@slow
def test_batched_forward_original_full(self):
Expand Down

0 comments on commit 2cdf84b

Please sign in to comment.