Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline: use tokenizer pad token at generation time if the model pad token is unset. #29614

Merged
merged 6 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
else:
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)

# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
gante marked this conversation as resolved.
Show resolved Hide resolved
gante marked this conversation as resolved.
Show resolved Hide resolved
generate_kwargs["pad_token_id"] = self.tokenizer.pad_token_id

tokens = self.model.generate(
attention_mask=attention_mask,
**generate_kwargs,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs):
conversation = model_inputs.pop("conversation")
if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs:
generate_kwargs["max_new_tokens"] = 256

# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
generate_kwargs["pad_token_id"] = self.tokenizer.pad_token_id

output_ids = self.model.generate(**model_inputs, **generate_kwargs)
if self.model.config.is_encoder_decoder:
start_position = 1
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,13 @@ def _forward(self, model_inputs):
is_last = model_inputs.pop("is_last", False)

if self.model_type == ModelType.VisionEncoderDecoder:
model_outputs = self.model.generate(**model_inputs)
generate_kwargs = {}

# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
generate_kwargs["pad_token_id"] = self.tokenizer.pad_token_id

model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ def _forward(self, model_inputs, generate_kwargs=None):
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
# in the `_prepare_model_inputs` method.
inputs = model_inputs.pop(self.model.main_input_name)

# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
generate_kwargs["pad_token_id"] = self.tokenizer.pad_token_id

model_outputs = self.model.generate(inputs, **model_inputs, **generate_kwargs)
return model_outputs

Expand Down
8 changes: 7 additions & 1 deletion src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,13 @@ def _forward(self, model_inputs, sequential=False):
else:
outputs = self.batch_inference(**model_inputs)
else:
outputs = self.model.generate(**model_inputs)
generate_kwargs = {}

# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
generate_kwargs["pad_token_id"] = self.tokenizer.pad_token_id

outputs = self.model.generate(**model_inputs, **generate_kwargs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def _forward(self, model_inputs, **generate_kwargs):
generate_kwargs.get("min_length", self.model.config.min_length),
generate_kwargs.get("max_length", self.model.config.max_length),
)

# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
generate_kwargs["pad_token_id"] = self.tokenizer.pad_token_id

output_ids = self.model.generate(**model_inputs, **generate_kwargs)
out_b = output_ids.shape[0]
if self.framework == "pt":
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def _forward(self, model_inputs, **generate_kwargs):
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length

# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
generate_kwargs["pad_token_id"] = self.tokenizer.pad_token_id

# BS x SL
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0]
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def _forward(self, model_inputs, **kwargs):
# generate_kwargs get priority over forward_params
forward_params.update(generate_kwargs)

# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
forward_params["pad_token_id"] = self.tokenizer.pad_token_id

output = self.model.generate(**model_inputs, **forward_params)
else:
if len(generate_kwargs):
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/visual_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def preprocess(self, inputs, padding=False, truncation=False, timeout=None):

def _forward(self, model_inputs, **generate_kwargs):
if self.model.can_generate():
# If the tokenizer has a pad token but the model doesn't, we add it to the `generate` call
if self.tokenizer.pad_token_id is not None and self.model.generation_config.pad_token_id is None:
generate_kwargs["pad_token_id"] = self.tokenizer.pad_token_id

model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)
Expand Down
30 changes: 24 additions & 6 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
AutoTokenizer,
TextGenerationPipeline,
is_torch_available,
logging,
pipeline,
)
Expand All @@ -36,6 +38,12 @@
from .test_pipelines_common import ANY


if is_torch_available():
import torch

from transformers import AutoModelForCausalLM


@is_pipeline_test
@require_torch_or_tf
class TextGenerationPipelineTests(unittest.TestCase):
Expand Down Expand Up @@ -379,8 +387,6 @@ def run_pipeline_test(self, text_generator, _):
@require_accelerate
@require_torch_gpu
def test_small_model_pt_bloom_accelerate(self):
import torch

# Classic `model_kwargs`
pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom",
Expand Down Expand Up @@ -435,8 +441,6 @@ def test_small_model_pt_bloom_accelerate(self):
@require_torch
@require_torch_accelerator
def test_small_model_fp16(self):
import torch

pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom",
device=torch_device,
Expand All @@ -448,8 +452,6 @@ def test_small_model_fp16(self):
@require_accelerate
@require_torch_accelerator
def test_pipeline_accelerate_top_p(self):
import torch

pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom", device_map=torch_device, torch_dtype=torch.float16
)
Expand Down Expand Up @@ -477,3 +479,19 @@ def test_pipeline_length_setting_warning(self):
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10)
self.assertNotIn(logger_msg, cl.out)

@require_torch
def test_pipeline_tokenizer_has_pad_but_model_doesnt(self):
# When the tokenizer pad_token_id is set but the model pad_token_id is not, we pass the pad_token_id to
# `generate`. This prevents a warning from being raised, which this test checks.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = None
model.generation_config.pad_token_id = None
tokenizer.pad_token_id = tokenizer.eos_token_id

llm = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="pt")
with self.assertRaises(AssertionError) as exc:
with self.assertLogs("transformers", level="WARNING"):
llm("The capital of France ")
self.assertIn("no logs of level WARNING or higher triggered", str(exc.exception))
gante marked this conversation as resolved.
Show resolved Hide resolved
Loading