Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline: use tokenizer pad token at generation time if the model pad token is unset. #29614

Merged
merged 6 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,14 @@ def _sanitize_parameters(

forward_params = defaultdict(dict)
if max_new_tokens is not None:
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
forward_params["max_new_tokens"] = max_new_tokens
Copy link
Member Author

@gante gante Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note regarding this file's diff, also applicable to the diff in src/transformers/pipelines/image_to_text.py:

The conventional strategy to pass kwargs to generate is through **forward_params. Previously in this file, the generation kwargs were held as forward_params["generate_kwargs"], which prevented the use of the conventional strategy. There isn't really a reason to hold these kwargs separately, generate is the only sink for kwargs in models that can generate. Models that can't generate will should throw an exception regardless of the container for kwargs. As such, this diff aims at minimizing the difference for generate parameterization across pipelines :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this - it's far cleaner to clearly outline what are generate kwargs and what are not. In the current pipelines the models might be the only sink, but that's not guaranteed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually - I realise what I've said about the forward kwargs is wrong here - we can just assume they're passed to the model. In this case, my preference is to still have "generate_kwargs" explicitly in the forward_kwargs, but I don't feel strongly and don't mind if you leave as-is

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can agree that regardless of the pattern we choose here, it should be applied to all pipelines with generative capabilities for consistency. Based on this premise, enforcing a separation of generate_kwargs this exact way will break backward compatibility, i.e. the following would not be possible

from transformers import pipeline

llm = pipeline(task='text-generation', model="openai-community/gpt2")
response = llm('The capital of France ', max_length=50)

Nevertheless, I am aligned with you -- we should separate them! We can do it through generation_config.update(**kwargs), and perform the required validation with the aid of generation_config.validate(). One of the requirements to do so is to have a single big blob of keyword arguments to untangle, and thus these changes go in this direction.

Let me know if you agree, in which case I'll merge the PR and prepare this follow-up. [My instinct was to merge this PR now, but I've held it back -- I've merged too many not-100%-approved PRs recently 😉 ]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's merge atm so this is unblocked and then we can iterate on something different :)

if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
raise ValueError(
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
" only 1 version"
)
forward_params["generate_kwargs"].update(generate_kwargs)
forward_params.update(generate_kwargs)

postprocess_params = {}
if decoder_kwargs is not None:
Expand Down Expand Up @@ -456,10 +456,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
processed["stride"] = stride
yield {"is_last": True, **processed, **extra}

def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}

def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
attention_mask = model_inputs.pop("attention_mask", None)
stride = model_inputs.pop("stride", None)
is_last = model_inputs.pop("is_last")
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,16 @@ def __init__(
self._num_workers = kwargs.pop("num_workers", None)
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)

# Pipelines calling `generate`: if the tokenizer has a pad token but the model doesn't, set it in the
# forward params so that `generate` is aware of the pad token.
if (
self.tokenizer is not None
and self.model.can_generate()
and self.tokenizer.pad_token_id is not None
and self.model.generation_config.pad_token_id is None
):
self._forward_params["pad_token_id"] = self.tokenizer.pad_token_id

if self.image_processor is None and self.feature_extractor is not None:
if isinstance(self.feature_extractor, BaseImageProcessor):
# Backward compatible change, if users called
Expand Down
12 changes: 3 additions & 9 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ def new_user_input(self):
build_pipeline_init_args(has_tokenizer=True),
r"""
min_length_for_response (`int`, *optional*, defaults to 32):
The minimum length (in number of tokens) for a response.
minimum_tokens (`int`, *optional*, defaults to 10):
Copy link
Member Author

@gante gante Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minimum_tokens is an unused internal variable, probably a legacy version of min_length.

Initially, I removed it from the signature of the private _forward, as I was touching it. Then, I realized we could remove all traces since it is unused :)

The minimum length of tokens to leave for a response.""",
The minimum length (in number of tokens) for a response.""",
)
class ConversationalPipeline(Pipeline):
"""
Expand Down Expand Up @@ -241,17 +239,13 @@ def __init__(self, *args, **kwargs):
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

def _sanitize_parameters(
self, min_length_for_response=None, minimum_tokens=None, clean_up_tokenization_spaces=None, **generate_kwargs
):
def _sanitize_parameters(self, min_length_for_response=None, clean_up_tokenization_spaces=None, **generate_kwargs):
preprocess_params = {}
forward_params = {}
postprocess_params = {}

if min_length_for_response is not None:
preprocess_params["min_length_for_response"] = min_length_for_response
if minimum_tokens is not None:
forward_params["minimum_tokens"] = minimum_tokens

if "max_length" in generate_kwargs:
forward_params["max_length"] = generate_kwargs["max_length"]
Expand Down Expand Up @@ -304,7 +298,7 @@ def preprocess(self, conversation: Conversation, min_length_for_response=32) ->
input_ids = tf.constant([input_ids])
return {"input_ids": input_ids, "conversation": conversation}

def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs):
def _forward(self, model_inputs, **generate_kwargs):
n = model_inputs["input_ids"].shape[1]
conversation = model_inputs.pop("conversation")
if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,14 @@ def preprocess(
"is_last": span_idx == num_spans - 1,
}

def _forward(self, model_inputs):
def _forward(self, model_inputs, **generate_kwargs):
p_mask = model_inputs.pop("p_mask", None)
word_ids = model_inputs.pop("word_ids", None)
words = model_inputs.pop("words", None)
is_last = model_inputs.pop("is_last", False)

if self.model_type == ModelType.VisionEncoderDecoder:
model_outputs = self.model.generate(**model_inputs)
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)

Expand Down
23 changes: 10 additions & 13 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,25 @@ def __init__(self, *args, **kwargs):
)

def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None, timeout=None):
forward_kwargs = {}
forward_params = {}
preprocess_params = {}

if prompt is not None:
preprocess_params["prompt"] = prompt
if timeout is not None:
preprocess_params["timeout"] = timeout

if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
if max_new_tokens is not None:
if "generate_kwargs" not in forward_kwargs:
forward_kwargs["generate_kwargs"] = {}
if "max_new_tokens" in forward_kwargs["generate_kwargs"]:
forward_params["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
raise ValueError(
"'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter,"
" please use only one"
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
" only 1 version"
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
return preprocess_params, forward_kwargs, {}
forward_params.update(generate_kwargs)

return preprocess_params, forward_params, {}

def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
Expand Down Expand Up @@ -164,7 +163,7 @@ def preprocess(self, image, prompt=None, timeout=None):

return model_inputs

def _forward(self, model_inputs, generate_kwargs=None):
def _forward(self, model_inputs, **generate_kwargs):
# Git model sets `model_inputs["input_ids"] = None` in `preprocess` (when `prompt=None`). In batch model, the
# pipeline will group them into a list of `None`, which fail `_forward`. Avoid this by checking it first.
if (
Expand All @@ -174,8 +173,6 @@ def _forward(self, model_inputs, generate_kwargs=None):
):
model_inputs["input_ids"] = None

if generate_kwargs is None:
generate_kwargs = {}
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=N
inputs["table"] = table
return inputs

def _forward(self, model_inputs, sequential=False):
def _forward(self, model_inputs, sequential=False, **generate_kwargs):
table = model_inputs.pop("table")

if self.type == "tapas":
Expand All @@ -386,7 +386,7 @@ def _forward(self, model_inputs, sequential=False):
else:
outputs = self.batch_inference(**model_inputs)
else:
outputs = self.model.generate(**model_inputs)
outputs = self.model.generate(**model_inputs, **generate_kwargs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs

Expand Down
Loading