Skip to content

Commit

Permalink
Pipeline: no side-effects on model.config and `model.generation_con…
Browse files Browse the repository at this point in the history
…fig` 🔫 (huggingface#33480)
  • Loading branch information
gante authored and itazap committed Sep 20, 2024
1 parent 787b83a commit 7ca9681
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 30 deletions.
4 changes: 4 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,10 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig"
"""
config_dict = model_config.to_dict()
config_dict.pop("_from_model_config", None)

# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
config_dict = {key: value for key, value in config_dict.items() if value is not None}

generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)

# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,23 +1334,26 @@ def _prepare_generation_config(
# the following conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
# 3) there are non-default generation parameters in the model config.
# 4) the user must have set new generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if (
not is_torchdynamo_compiling()
and self.generation_config._from_model_config # 1)
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config: # 3)
if new_generation_config != self.generation_config: # 4)
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" deprecated strategy to control generation and will be removed in v5."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
UserWarning,
)
self.generation_config = new_generation_config
using_model_generation_config = True

generation_config = self.generation_config
using_model_generation_config = True

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
else:
generate_kwargs["num_frames"] = num_frames

# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

tokens = self.model.generate(
inputs=inputs,
attention_mask=attention_mask,
Expand Down
37 changes: 21 additions & 16 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import csv
import importlib
import json
Expand Down Expand Up @@ -899,22 +900,26 @@ def __init__(
):
self.model.to(self.device)

# Update config and generation_config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
if self.model.can_generate():
self.model.generation_config.update(**task_specific_params.get(task))

# Pipelines calling `generate`: if the tokenizer has a pad token but the model doesn't, set it in the
# forward params so that `generate` is aware of the pad token.
if (
self.tokenizer is not None
and self.model.can_generate()
and self.tokenizer.pad_token_id is not None
and self.model.generation_config.pad_token_id is None
):
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
# as we apply local tweaks to the generation config.
if self.model.can_generate():
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
self.generation_config = copy.deepcopy(self.model.generation_config)
# Update the generation config with task specific params if they exist
# NOTE: `prefix` is pipeline-specific and doesn't exist in the generation config.
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
this_task_params = task_specific_params.get(task)
if "prefix" in this_task_params:
self.prefix = this_task_params.pop("prefix")
self.generation_config.update(**this_task_params)
# If the tokenizer has a pad token but the model doesn't, set it so that `generate` is aware of it.
if (
self.tokenizer is not None
and self.tokenizer.pad_token_id is not None
and self.generation_config.pad_token_id is None
):
self.generation_config.pad_token_id = self.tokenizer.pad_token_id

self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ def _forward(self, model_inputs, **generate_kwargs):
is_last = model_inputs.pop("is_last", False)

if self.model_type == ModelType.VisionEncoderDecoder:
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def _forward(self, model_inputs, **generate_kwargs):
):
model_inputs["input_ids"] = None

# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ def _forward(self, model_inputs, sequential=False, **generate_kwargs):
else:
outputs = self.batch_inference(**model_inputs)
else:
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

outputs = self.model.generate(**model_inputs, **generate_kwargs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int):
return True

def _parse_and_tokenize(self, *args, truncation):
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
prefix = self.prefix if self.prefix is not None else ""
if isinstance(args[0], list):
if self.tokenizer.pad_token_id is None:
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
Expand Down Expand Up @@ -185,9 +185,14 @@ def _forward(self, model_inputs, **generate_kwargs):

self.check_inputs(
input_length,
generate_kwargs.get("min_length", self.model.config.min_length),
generate_kwargs.get("max_length", self.model.config.max_length),
generate_kwargs.get("min_length", self.generation_config.min_length),
generate_kwargs.get("max_length", self.generation_config.max_length),
)

# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

output_ids = self.model.generate(**model_inputs, **generate_kwargs)
out_b = output_ids.shape[0]
if self.framework == "pt":
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def __init__(self, *args, **kwargs):
# It also defines both some preprocess_kwargs and generate_kwargs
# which is why we cannot put them in their respective methods.
prefix = None
if self.model.config.prefix is not None:
prefix = self.model.config.prefix
if self.prefix is not None:
prefix = self.prefix
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
Expand Down Expand Up @@ -316,7 +316,7 @@ def preprocess(
if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"]
else:
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
new_tokens = generate_kwargs.get("max_length", self.generation_config.max_length) - cur_len
if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length:
Expand Down Expand Up @@ -354,7 +354,7 @@ def _forward(self, model_inputs, **generate_kwargs):
and generate_kwargs["generation_config"].max_new_tokens is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.generation_config.max_length
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
Expand All @@ -363,7 +363,10 @@ def _forward(self, model_inputs, **generate_kwargs):
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length

# BS x SL
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt":
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def preprocess(self, text, **kwargs):
if self.model.config.model_type == "bark":
# bark Tokenizer is called with BarkProcessor which uses those kwargs
new_kwargs = {
"max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256),
"max_length": self.generation_config.semantic_config.get("max_input_semantic_length", 256),
"add_special_tokens": False,
"return_attention_mask": True,
"return_token_type_ids": False,
Expand All @@ -137,6 +137,10 @@ def _forward(self, model_inputs, **kwargs):
# we expect some kwargs to be additional tensors which need to be on the right device
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)

# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

# generate_kwargs get priority over forward_params
forward_params.update(generate_kwargs)

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/visual_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def preprocess(self, inputs, padding=False, truncation=False, timeout=None):

def _forward(self, model_inputs, **generate_kwargs):
if self.model.can_generate():
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)
Expand Down
26 changes: 26 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AutoTokenizer,
DistilBertForSequenceClassification,
MaskGenerationPipeline,
T5ForConditionalGeneration,
TextClassificationPipeline,
TextGenerationPipeline,
TFAutoModelForSequenceClassification,
Expand Down Expand Up @@ -234,6 +235,31 @@ def test_auto_model_pipeline_registration_from_local_dir(self):

self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load

@require_torch
def test_pipeline_with_task_parameters_no_side_effects(self):
"""
Regression test: certain pipeline flags, like `task`, modified the model configuration, causing unexpected
side-effects
"""
# This checkpoint has task-specific parameters that will modify the behavior of the pipeline
model = T5ForConditionalGeneration.from_pretrained("t5-small")
self.assertTrue(model.config.num_beams == 1)

# The task-specific parameters used to cause side-effects on `model.config` -- not anymore
pipe = pipeline(model=model, tokenizer=AutoTokenizer.from_pretrained("t5-small"), task="translation_en_to_de")
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(model.generation_config.num_beams == 1)

# Under the hood: we now store a generation config in the pipeline. This generation config stores the
# task-specific paremeters.
self.assertTrue(pipe.generation_config.num_beams == 4)

# We can confirm that the task-specific parameters have an effect. (In this case, the default is `num_beams=1`,
# which would crash when `num_return_sequences=4` is passed.)
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4)
with self.assertRaises(ValueError):
pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4, num_beams=1)


@is_pipeline_test
class PipelineScikitCompatTest(unittest.TestCase):
Expand Down
32 changes: 32 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,38 @@ def test_isin_mps_friendly(self):
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)

def test_save_and_load_config_with_custom_generation(self):
"""
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
that gets moved to the generation config and reset on the model config)
"""
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)

# The default for `num_beams` is 1 and `early_stopping` is False
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(model.config.early_stopping is False)

# When we save the model, this custom parameter should be moved to the generation config AND the model
# config should contain `None`
model.config.num_beams = 2
model.config.early_stopping = True
self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
# moved to generation config
self.assertTrue(new_model.generation_config.num_beams == 2)
self.assertTrue(new_model.generation_config.early_stopping is True)
# reset in the model config
self.assertTrue(new_model.config.num_beams is None)
self.assertTrue(new_model.config.early_stopping is None)

# Sanity check: We can run `generate` with the new model without any warnings
random_ids = torch.randint(0, 100, (1, 5))
with warnings.catch_warnings(record=True) as w:
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)


@slow
@require_torch
Expand Down

0 comments on commit 7ca9681

Please sign in to comment.