Skip to content

Commit

Permalink
Generate: text generation pipeline no longer emits max_length warni…
Browse files Browse the repository at this point in the history
…ng when it is not set (huggingface#23139)
  • Loading branch information
gante authored and novice03 committed Jun 23, 2023
1 parent 809638e commit e3f058f
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,14 +385,14 @@ def generate(
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length

if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,14 +858,14 @@ def generate(
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length

# If the input length is a tensor (i.e. dynamic length), skip length checks
if not isinstance(input_ids_seq_length, tf.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,14 +1348,14 @@ def generate(
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length

if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
Expand Down
32 changes: 22 additions & 10 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import enum
import warnings

Expand Down Expand Up @@ -105,17 +106,8 @@ def _sanitize_parameters(
prefix_inputs = self.tokenizer(
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
)
prefix_length = prefix_inputs["input_ids"].shape[-1]
generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]

if "max_new_tokens" in generate_kwargs:
pass
elif "max_length" in generate_kwargs:
generate_kwargs["max_length"] += prefix_length
else:
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length

if "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
if handle_long_generation is not None:
if handle_long_generation not in {"hole"}:
raise ValueError(
Expand Down Expand Up @@ -247,6 +239,26 @@ def _forward(self, model_inputs, **generate_kwargs):
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")

# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
generate_kwargs = copy.deepcopy(generate_kwargs)
prefix_length = generate_kwargs.pop("prefix_length", 0)
if prefix_length > 0:
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].max_new_tokens is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].min_new_tokens is not None
)
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length

# BS x SL
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0]
Expand Down
32 changes: 31 additions & 1 deletion tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@

import unittest

from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TextGenerationPipeline,
logging,
pipeline,
)
from transformers.testing_utils import (
CaptureLogger,
is_pipeline_test,
require_accelerate,
require_tf,
Expand Down Expand Up @@ -323,3 +330,26 @@ def test_pipeline_accelerate_top_p(self):

pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
pipe("This is a test", do_sample=True, top_p=0.5)

def test_pipeline_length_setting_warning(self):
prompt = """Hello world"""
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2")
if text_generator.model.framework == "tf":
logger = logging.get_logger("transformers.generation.tf_utils")
else:
logger = logging.get_logger("transformers.generation.utils")
logger_msg = "Both `max_new_tokens`" # The beggining of the message to be checked in this test

# Both are set by the user -> log warning
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10, max_new_tokens=1)
self.assertIn(logger_msg, cl.out)

# The user only sets one -> no warning
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_new_tokens=1)
self.assertNotIn(logger_msg, cl.out)

with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10)
self.assertNotIn(logger_msg, cl.out)

0 comments on commit e3f058f

Please sign in to comment.