Skip to content

Commit

Permalink
Pipeline: use tokenizer pad token at generation time if the model pad…
Browse files Browse the repository at this point in the history
… token is unset. (huggingface#29614)
  • Loading branch information
gante authored Mar 15, 2024
1 parent c47fcd0 commit 53d8912
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 32 deletions.
9 changes: 3 additions & 6 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,14 @@ def _sanitize_parameters(

forward_params = defaultdict(dict)
if max_new_tokens is not None:
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
forward_params["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
raise ValueError(
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
" only 1 version"
)
forward_params["generate_kwargs"].update(generate_kwargs)
forward_params.update(generate_kwargs)

postprocess_params = {}
if decoder_kwargs is not None:
Expand Down Expand Up @@ -456,10 +456,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
processed["stride"] = stride
yield {"is_last": True, **processed, **extra}

def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}

def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
attention_mask = model_inputs.pop("attention_mask", None)
stride = model_inputs.pop("stride", None)
is_last = model_inputs.pop("is_last")
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,16 @@ def __init__(
self._num_workers = kwargs.pop("num_workers", None)
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)

# Pipelines calling `generate`: if the tokenizer has a pad token but the model doesn't, set it in the
# forward params so that `generate` is aware of the pad token.
if (
self.tokenizer is not None
and self.model.can_generate()
and self.tokenizer.pad_token_id is not None
and self.model.generation_config.pad_token_id is None
):
self._forward_params["pad_token_id"] = self.tokenizer.pad_token_id

if self.image_processor is None and self.feature_extractor is not None:
if isinstance(self.feature_extractor, BaseImageProcessor):
# Backward compatible change, if users called
Expand Down
12 changes: 3 additions & 9 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ def new_user_input(self):
build_pipeline_init_args(has_tokenizer=True),
r"""
min_length_for_response (`int`, *optional*, defaults to 32):
The minimum length (in number of tokens) for a response.
minimum_tokens (`int`, *optional*, defaults to 10):
The minimum length of tokens to leave for a response.""",
The minimum length (in number of tokens) for a response.""",
)
class ConversationalPipeline(Pipeline):
"""
Expand Down Expand Up @@ -241,17 +239,13 @@ def __init__(self, *args, **kwargs):
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

def _sanitize_parameters(
self, min_length_for_response=None, minimum_tokens=None, clean_up_tokenization_spaces=None, **generate_kwargs
):
def _sanitize_parameters(self, min_length_for_response=None, clean_up_tokenization_spaces=None, **generate_kwargs):
preprocess_params = {}
forward_params = {}
postprocess_params = {}

if min_length_for_response is not None:
preprocess_params["min_length_for_response"] = min_length_for_response
if minimum_tokens is not None:
forward_params["minimum_tokens"] = minimum_tokens

if "max_length" in generate_kwargs:
forward_params["max_length"] = generate_kwargs["max_length"]
Expand Down Expand Up @@ -304,7 +298,7 @@ def preprocess(self, conversation: Conversation, min_length_for_response=32) ->
input_ids = tf.constant([input_ids])
return {"input_ids": input_ids, "conversation": conversation}

def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs):
def _forward(self, model_inputs, **generate_kwargs):
n = model_inputs["input_ids"].shape[1]
conversation = model_inputs.pop("conversation")
if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,14 @@ def preprocess(
"is_last": span_idx == num_spans - 1,
}

def _forward(self, model_inputs):
def _forward(self, model_inputs, **generate_kwargs):
p_mask = model_inputs.pop("p_mask", None)
word_ids = model_inputs.pop("word_ids", None)
words = model_inputs.pop("words", None)
is_last = model_inputs.pop("is_last", False)

if self.model_type == ModelType.VisionEncoderDecoder:
model_outputs = self.model.generate(**model_inputs)
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)

Expand Down
23 changes: 10 additions & 13 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,25 @@ def __init__(self, *args, **kwargs):
)

def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None, timeout=None):
forward_kwargs = {}
forward_params = {}
preprocess_params = {}

if prompt is not None:
preprocess_params["prompt"] = prompt
if timeout is not None:
preprocess_params["timeout"] = timeout

if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
if max_new_tokens is not None:
if "generate_kwargs" not in forward_kwargs:
forward_kwargs["generate_kwargs"] = {}
if "max_new_tokens" in forward_kwargs["generate_kwargs"]:
forward_params["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
raise ValueError(
"'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter,"
" please use only one"
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
" only 1 version"
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
return preprocess_params, forward_kwargs, {}
forward_params.update(generate_kwargs)

return preprocess_params, forward_params, {}

def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
Expand Down Expand Up @@ -164,7 +163,7 @@ def preprocess(self, image, prompt=None, timeout=None):

return model_inputs

def _forward(self, model_inputs, generate_kwargs=None):
def _forward(self, model_inputs, **generate_kwargs):
# Git model sets `model_inputs["input_ids"] = None` in `preprocess` (when `prompt=None`). In batch model, the
# pipeline will group them into a list of `None`, which fail `_forward`. Avoid this by checking it first.
if (
Expand All @@ -174,8 +173,6 @@ def _forward(self, model_inputs, generate_kwargs=None):
):
model_inputs["input_ids"] = None

if generate_kwargs is None:
generate_kwargs = {}
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=N
inputs["table"] = table
return inputs

def _forward(self, model_inputs, sequential=False):
def _forward(self, model_inputs, sequential=False, **generate_kwargs):
table = model_inputs.pop("table")

if self.type == "tapas":
Expand All @@ -385,7 +385,7 @@ def _forward(self, model_inputs, sequential=False):
else:
outputs = self.batch_inference(**model_inputs)
else:
outputs = self.model.generate(**model_inputs)
outputs = self.model.generate(**model_inputs, **generate_kwargs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs

Expand Down

0 comments on commit 53d8912

Please sign in to comment.