Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline: use tokenizer pad token at generation time if the model pad token is unset. #29614

Merged
merged 6 commits into from
Mar 15, 2024

Conversation

gante
Copy link
Member

@gante gante commented Mar 12, 2024

What does this PR do?

Fixes #29378

The tagged issue describes the problem, the title describes the fix :D

Example of a script that no longer emits a warning, after this PR:

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id

llm = pipeline(task='text-generation', model=model, tokenizer=tokenizer, framework='pt')
response = llm('The capital of France ')

@gante gante requested a review from amyeroberts March 12, 2024 15:27
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Just some qs about the implementation

src/transformers/pipelines/automatic_speech_recognition.py Outdated Show resolved Hide resolved
src/transformers/pipelines/automatic_speech_recognition.py Outdated Show resolved Hide resolved
tests/pipelines/test_pipelines_text_generation.py Outdated Show resolved Hide resolved
@@ -196,9 +196,7 @@ def new_user_input(self):
build_pipeline_init_args(has_tokenizer=True),
r"""
min_length_for_response (`int`, *optional*, defaults to 32):
The minimum length (in number of tokens) for a response.
minimum_tokens (`int`, *optional*, defaults to 10):
Copy link
Member Author

@gante gante Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minimum_tokens is an unused internal variable, probably a legacy version of min_length.

Initially, I removed it from the signature of the private _forward, as I was touching it. Then, I realized we could remove all traces since it is unused :)

@@ -311,14 +311,14 @@ def _sanitize_parameters(

forward_params = defaultdict(dict)
if max_new_tokens is not None:
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
forward_params["max_new_tokens"] = max_new_tokens
Copy link
Member Author

@gante gante Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note regarding this file's diff, also applicable to the diff in src/transformers/pipelines/image_to_text.py:

The conventional strategy to pass kwargs to generate is through **forward_params. Previously in this file, the generation kwargs were held as forward_params["generate_kwargs"], which prevented the use of the conventional strategy. There isn't really a reason to hold these kwargs separately, generate is the only sink for kwargs in models that can generate. Models that can't generate will should throw an exception regardless of the container for kwargs. As such, this diff aims at minimizing the difference for generate parameterization across pipelines :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this - it's far cleaner to clearly outline what are generate kwargs and what are not. In the current pipelines the models might be the only sink, but that's not guaranteed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually - I realise what I've said about the forward kwargs is wrong here - we can just assume they're passed to the model. In this case, my preference is to still have "generate_kwargs" explicitly in the forward_kwargs, but I don't feel strongly and don't mind if you leave as-is

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can agree that regardless of the pattern we choose here, it should be applied to all pipelines with generative capabilities for consistency. Based on this premise, enforcing a separation of generate_kwargs this exact way will break backward compatibility, i.e. the following would not be possible

from transformers import pipeline

llm = pipeline(task='text-generation', model="openai-community/gpt2")
response = llm('The capital of France ', max_length=50)

Nevertheless, I am aligned with you -- we should separate them! We can do it through generation_config.update(**kwargs), and perform the required validation with the aid of generation_config.validate(). One of the requirements to do so is to have a single big blob of keyword arguments to untangle, and thus these changes go in this direction.

Let me know if you agree, in which case I'll merge the PR and prepare this follow-up. [My instinct was to merge this PR now, but I've held it back -- I've merged too many not-100%-approved PRs recently 😉 ]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's merge atm so this is unblocked and then we can iterate on something different :)

@gante gante requested a review from amyeroberts March 14, 2024 12:48
@gante
Copy link
Member Author

gante commented Mar 14, 2024

@amyeroberts ready for a re-review :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this!

Happy with the changes in general but I think we should maintain the clear separation between "generate_kwargs" and forward_kwargs. It makes it easier to understand what controls what in the pipeline

@gante gante merged commit 53d8912 into huggingface:main Mar 15, 2024
21 checks passed
@gante gante deleted the fix_29378 branch March 15, 2024 13:00
itazap pushed a commit that referenced this pull request May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Misleading warning message about pad_token_id when passing tokenizer instance to pipeline
3 participants