From 312f91862e2bb072d3e5a9a79f79cf4e9171e89f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 18 Sep 2024 15:43:06 +0100 Subject: [PATCH] =?UTF-8?q?Pipeline:=20no=20side-effects=20on=20`model.con?= =?UTF-8?q?fig`=20and=20`model.generation=5Fconfig`=20=F0=9F=94=AB=20=20(#?= =?UTF-8?q?33480)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generation/configuration_utils.py | 4 ++ src/transformers/generation/utils.py | 13 ++++--- .../pipelines/automatic_speech_recognition.py | 4 ++ src/transformers/pipelines/base.py | 37 +++++++++++-------- .../pipelines/document_question_answering.py | 4 ++ src/transformers/pipelines/image_to_text.py | 4 ++ .../pipelines/table_question_answering.py | 4 ++ .../pipelines/text2text_generation.py | 11 ++++-- src/transformers/pipelines/text_generation.py | 13 ++++--- src/transformers/pipelines/text_to_audio.py | 6 ++- .../pipelines/visual_question_answering.py | 4 ++ tests/pipelines/test_pipelines_common.py | 26 +++++++++++++ tests/utils/test_modeling_utils.py | 32 ++++++++++++++++ 13 files changed, 132 insertions(+), 30 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index e2585b1b9ed49c..5e9ac835c19d6d 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -1229,6 +1229,10 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" """ config_dict = model_config.to_dict() config_dict.pop("_from_model_config", None) + + # Removes all `None` from the model config dict -- this lets the generation config defaults to take hold + config_dict = {key: value for key, value in config_dict.items() if value is not None} + generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True) # Special case: some models have generation attributes set in the decoder. Use them if still unset in the diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 019eb6c27f18cc..d8896f91267d7b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1334,23 +1334,26 @@ def _prepare_generation_config( # the following conditions must be met # 1) the generation config must have been created from the model config (`_from_model_config` field); # 2) the generation config must have seen no modification since its creation (the hash is the same); - # 3) the user must have set generation parameters in the model config. + # 3) there are non-default generation parameters in the model config. + # 4) the user must have set new generation parameters in the model config. # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. if ( not is_torchdynamo_compiling() and self.generation_config._from_model_config # 1) and self.generation_config._original_object_hash == hash(self.generation_config) # 2) + and len(self.config._get_non_default_generation_parameters()) > 0 # 3) ): new_generation_config = GenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: # 3) + if new_generation_config != self.generation_config: # 4) warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed soon, in a future version." + " deprecated strategy to control generation and will be removed in v5." " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", + UserWarning, ) self.generation_config = new_generation_config - using_model_generation_config = True + generation_config = self.generation_config using_model_generation_config = True diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f3de341d88954c..7c122bed5437cc 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -501,6 +501,10 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): else: generate_kwargs["num_frames"] = num_frames + # User-defined `generation_config` passed to the pipeline call take precedence + if "generation_config" not in generate_kwargs: + generate_kwargs["generation_config"] = self.generation_config + tokens = self.model.generate( inputs=inputs, attention_mask=attention_mask, diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 7db33ab5bd1a01..40a91a0d484b8e 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import copy import csv import importlib import json @@ -899,22 +900,26 @@ def __init__( ): self.model.to(self.device) - # Update config and generation_config with task specific parameters - task_specific_params = self.model.config.task_specific_params - if task_specific_params is not None and task in task_specific_params: - self.model.config.update(task_specific_params.get(task)) - if self.model.can_generate(): - self.model.generation_config.update(**task_specific_params.get(task)) - - # Pipelines calling `generate`: if the tokenizer has a pad token but the model doesn't, set it in the - # forward params so that `generate` is aware of the pad token. - if ( - self.tokenizer is not None - and self.model.can_generate() - and self.tokenizer.pad_token_id is not None - and self.model.generation_config.pad_token_id is None - ): - self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id + # If the model can generate, create a local generation config. This is done to avoid side-effects on the model + # as we apply local tweaks to the generation config. + if self.model.can_generate(): + self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None + self.generation_config = copy.deepcopy(self.model.generation_config) + # Update the generation config with task specific params if they exist + # NOTE: `prefix` is pipeline-specific and doesn't exist in the generation config. + task_specific_params = self.model.config.task_specific_params + if task_specific_params is not None and task in task_specific_params: + this_task_params = task_specific_params.get(task) + if "prefix" in this_task_params: + self.prefix = this_task_params.pop("prefix") + self.generation_config.update(**this_task_params) + # If the tokenizer has a pad token but the model doesn't, set it so that `generate` is aware of it. + if ( + self.tokenizer is not None + and self.tokenizer.pad_token_id is not None + and self.generation_config.pad_token_id is None + ): + self.generation_config.pad_token_id = self.tokenizer.pad_token_id self.call_count = 0 self._batch_size = kwargs.pop("batch_size", None) diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index aa4fb48aae6a40..9198f432263822 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -429,6 +429,10 @@ def _forward(self, model_inputs, **generate_kwargs): is_last = model_inputs.pop("is_last", False) if self.model_type == ModelType.VisionEncoderDecoder: + # User-defined `generation_config` passed to the pipeline call take precedence + if "generation_config" not in generate_kwargs: + generate_kwargs["generation_config"] = self.generation_config + model_outputs = self.model.generate(**model_inputs, **generate_kwargs) else: model_outputs = self.model(**model_inputs) diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index 88dce8e591ae41..91d44c46d25c10 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -181,6 +181,10 @@ def _forward(self, model_inputs, **generate_kwargs): ): model_inputs["input_ids"] = None + # User-defined `generation_config` passed to the pipeline call take precedence + if "generation_config" not in generate_kwargs: + generate_kwargs["generation_config"] = self.generation_config + # FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py` # parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas # the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name` diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py index 702a47b7c3cbed..77c95432c7218f 100644 --- a/src/transformers/pipelines/table_question_answering.py +++ b/src/transformers/pipelines/table_question_answering.py @@ -385,6 +385,10 @@ def _forward(self, model_inputs, sequential=False, **generate_kwargs): else: outputs = self.batch_inference(**model_inputs) else: + # User-defined `generation_config` passed to the pipeline call take precedence + if "generation_config" not in generate_kwargs: + generate_kwargs["generation_config"] = self.generation_config + outputs = self.model.generate(**model_inputs, **generate_kwargs) model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs} return model_outputs diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 42d97f4d11b919..75ded8ac085ca5 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -115,7 +115,7 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int): return True def _parse_and_tokenize(self, *args, truncation): - prefix = self.model.config.prefix if self.model.config.prefix is not None else "" + prefix = self.prefix if self.prefix is not None else "" if isinstance(args[0], list): if self.tokenizer.pad_token_id is None: raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input") @@ -185,9 +185,14 @@ def _forward(self, model_inputs, **generate_kwargs): self.check_inputs( input_length, - generate_kwargs.get("min_length", self.model.config.min_length), - generate_kwargs.get("max_length", self.model.config.max_length), + generate_kwargs.get("min_length", self.generation_config.min_length), + generate_kwargs.get("max_length", self.generation_config.max_length), ) + + # User-defined `generation_config` passed to the pipeline call take precedence + if "generation_config" not in generate_kwargs: + generate_kwargs["generation_config"] = self.generation_config + output_ids = self.model.generate(**model_inputs, **generate_kwargs) out_b = output_ids.shape[0] if self.framework == "pt": diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 8bd1017ffc6696..9bffca522d5f2e 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -103,8 +103,8 @@ def __init__(self, *args, **kwargs): # It also defines both some preprocess_kwargs and generate_kwargs # which is why we cannot put them in their respective methods. prefix = None - if self.model.config.prefix is not None: - prefix = self.model.config.prefix + if self.prefix is not None: + prefix = self.prefix if prefix is None and self.model.__class__.__name__ in [ "XLNetLMHeadModel", "TransfoXLLMHeadModel", @@ -316,7 +316,7 @@ def preprocess( if "max_new_tokens" in generate_kwargs: new_tokens = generate_kwargs["max_new_tokens"] else: - new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len + new_tokens = generate_kwargs.get("max_length", self.generation_config.max_length) - cur_len if new_tokens < 0: raise ValueError("We cannot infer how many new tokens are expected") if cur_len + new_tokens > self.tokenizer.model_max_length: @@ -354,7 +354,7 @@ def _forward(self, model_inputs, **generate_kwargs): and generate_kwargs["generation_config"].max_new_tokens is not None ) if not has_max_new_tokens: - generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length + generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.generation_config.max_length generate_kwargs["max_length"] += prefix_length has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( "generation_config" in generate_kwargs @@ -363,7 +363,10 @@ def _forward(self, model_inputs, **generate_kwargs): if not has_min_new_tokens and "min_length" in generate_kwargs: generate_kwargs["min_length"] += prefix_length - # BS x SL + # User-defined `generation_config` passed to the pipeline call take precedence + if "generation_config" not in generate_kwargs: + generate_kwargs["generation_config"] = self.generation_config + generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) out_b = generated_sequence.shape[0] if self.framework == "pt": diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 81653f14d6d878..d17d18205920b0 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -111,7 +111,7 @@ def preprocess(self, text, **kwargs): if self.model.config.model_type == "bark": # bark Tokenizer is called with BarkProcessor which uses those kwargs new_kwargs = { - "max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256), + "max_length": self.generation_config.semantic_config.get("max_input_semantic_length", 256), "add_special_tokens": False, "return_attention_mask": True, "return_token_type_ids": False, @@ -137,6 +137,10 @@ def _forward(self, model_inputs, **kwargs): # we expect some kwargs to be additional tensors which need to be on the right device generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device) + # User-defined `generation_config` passed to the pipeline call take precedence + if "generation_config" not in generate_kwargs: + generate_kwargs["generation_config"] = self.generation_config + # generate_kwargs get priority over forward_params forward_params.update(generate_kwargs) diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py index e5849cbdec1955..89988c0cba2b1b 100644 --- a/src/transformers/pipelines/visual_question_answering.py +++ b/src/transformers/pipelines/visual_question_answering.py @@ -162,6 +162,10 @@ def preprocess(self, inputs, padding=False, truncation=False, timeout=None): def _forward(self, model_inputs, **generate_kwargs): if self.model.can_generate(): + # User-defined `generation_config` passed to the pipeline call take precedence + if "generation_config" not in generate_kwargs: + generate_kwargs["generation_config"] = self.generation_config + model_outputs = self.model.generate(**model_inputs, **generate_kwargs) else: model_outputs = self.model(**model_inputs) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index ea36ae5728d161..1fec4be3d95ca0 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -31,6 +31,7 @@ AutoTokenizer, DistilBertForSequenceClassification, MaskGenerationPipeline, + T5ForConditionalGeneration, TextClassificationPipeline, TextGenerationPipeline, TFAutoModelForSequenceClassification, @@ -234,6 +235,31 @@ def test_auto_model_pipeline_registration_from_local_dir(self): self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load + @require_torch + def test_pipeline_with_task_parameters_no_side_effects(self): + """ + Regression test: certain pipeline flags, like `task`, modified the model configuration, causing unexpected + side-effects + """ + # This checkpoint has task-specific parameters that will modify the behavior of the pipeline + model = T5ForConditionalGeneration.from_pretrained("t5-small") + self.assertTrue(model.config.num_beams == 1) + + # The task-specific parameters used to cause side-effects on `model.config` -- not anymore + pipe = pipeline(model=model, tokenizer=AutoTokenizer.from_pretrained("t5-small"), task="translation_en_to_de") + self.assertTrue(model.config.num_beams == 1) + self.assertTrue(model.generation_config.num_beams == 1) + + # Under the hood: we now store a generation config in the pipeline. This generation config stores the + # task-specific paremeters. + self.assertTrue(pipe.generation_config.num_beams == 4) + + # We can confirm that the task-specific parameters have an effect. (In this case, the default is `num_beams=1`, + # which would crash when `num_return_sequences=4` is passed.) + pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4) + with self.assertRaises(ValueError): + pipe("Hugging Face doesn't sell hugs.", num_return_sequences=4, num_beams=1) + @is_pipeline_test class PipelineScikitCompatTest(unittest.TestCase): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index f78285fdb90d90..2130ed4b7c887f 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1715,6 +1715,38 @@ def test_isin_mps_friendly(self): torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) ) + def test_save_and_load_config_with_custom_generation(self): + """ + Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter + that gets moved to the generation config and reset on the model config) + """ + model = T5ForConditionalGeneration.from_pretrained(TINY_T5) + + # The default for `num_beams` is 1 and `early_stopping` is False + self.assertTrue(model.config.num_beams == 1) + self.assertTrue(model.config.early_stopping is False) + + # When we save the model, this custom parameter should be moved to the generation config AND the model + # config should contain `None` + model.config.num_beams = 2 + model.config.early_stopping = True + self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir) + # moved to generation config + self.assertTrue(new_model.generation_config.num_beams == 2) + self.assertTrue(new_model.generation_config.early_stopping is True) + # reset in the model config + self.assertTrue(new_model.config.num_beams is None) + self.assertTrue(new_model.config.early_stopping is None) + + # Sanity check: We can run `generate` with the new model without any warnings + random_ids = torch.randint(0, 100, (1, 5)) + with warnings.catch_warnings(record=True) as w: + new_model.generate(random_ids, max_new_tokens=3) + self.assertTrue(len(w) == 0) + @slow @require_torch