Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug about add_special_tokens and so on #31496

Conversation

hiroshi-matsuda-rit
Copy link
Contributor

@hiroshi-matsuda-rit hiroshi-matsuda-rit commented Jun 19, 2024

What does this PR do?

When applying add_special_tokens and/or padding arg(s) to TextGenerationPipeline.__call__(), we get an exception like:

>>> from transformers import pipeline
>>> TARGET = "meta-llama/Llama-2-7b-hf"
>>> pipe = pipeline("text-generation", model=TARGET, max_new_tokens=512, device_map="auto")
>>> pipe("what is 3+4?", add_special_tokens=False)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/model/hng88/llm-jp-eval/llm-jp-eval/lib/python3.10/site-packages/transformers/pipelines/text_generation.py", line 263, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/model/hng88/llm-jp-eval/llm-jp-eval/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1243, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/model/hng88/llm-jp-eval/llm-jp-eval/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1250, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/model/hng88/llm-jp-eval/llm-jp-eval/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1150, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/model/hng88/llm-jp-eval/llm-jp-eval/lib/python3.10/site-packages/transformers/pipelines/text_generation.py", line 350, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/model/hng88/llm-jp-eval/llm-jp-eval/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/model/hng88/llm-jp-eval/llm-jp-eval/lib/python3.10/site-packages/transformers/generation/utils.py", line 1542, in generate
    self._validate_model_kwargs(model_kwargs.copy())
  File "/model/hng88/llm-jp-eval/llm-jp-eval/lib/python3.10/site-packages/transformers/generation/utils.py", line 1157, in _validate_model_kwargs
    raise ValueError(
ValueError: The following `model_kwargs` are not used by the model: ['add_special_tokens'] (note: typos in the generate arguments will also show up in this list)

We need to remove both add_special_tokens and padding fields from generate_kwargs in TextGenerationPipeline.TextGenerationPipeline().

Until transformers v4.40.2, these kwargs were used only for preprocess but used for forward_params from v4.41.0.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts
Copy link
Collaborator

cc @Rocketknight1

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jun 20, 2024

This PR makes sense to me, and sorry for the oversight!

One suggestion, though - we could probably simplify it by including it in the code block above. In other words:

if "add_special_tokens" in generate_kwargs:
    preprocess_params["add_special_tokens"] = generate_kwargs["add_special_tokens"]
    add_special_tokens = generate_kwargs["add_special_tokens"]
if "padding" in generate_kwargs:
    preprocess_params["padding"] = generate_kwargs["padding"]

would become

if "add_special_tokens" in generate_kwargs:
    add_special_tokens = preprocess_params["add_special_tokens"] = generate_kwargs.pop("add_special_tokens")
else:
    add_special_tokens = False
if "padding" in generate_kwargs:
    preprocess_params["padding"] = generate_kwargs.pop("padding")

and then we wouldn't need the extra if statement you added. WDYT?

@hiroshi-matsuda-rit
Copy link
Contributor Author

@Rocketknight1 It's obviously better than adding the extra for-pop statements. I revised as you wrote. Thanks a lot!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now! cc @amyeroberts for core maintainer review, and thanks for this bugfix!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Only request before merge is to add a test which would have caught this bug

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hiroshi-matsuda-rit
Copy link
Contributor Author

I added a test case in test_pipelines_text_generation.py with specifying add_special_tokens and padding args of TextGenerationPipeline.__call__().
Please review it. @amyeroberts

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants