Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline: no side-effects on model.config and model.generation_config 🔫 #33480

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,10 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig"
"""
config_dict = model_config.to_dict()
config_dict.pop("_from_model_config", None)

# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
config_dict = {key: value for key, value in config_dict.items() if value is not None}

generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)

# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,23 +1334,26 @@ def _prepare_generation_config(
# the following conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
# 3) there are non-default generation parameters in the model config.
# 4) the user must have set new generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if (
not is_torchdynamo_compiling()
and self.generation_config._from_model_config # 1)
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config: # 3)
if new_generation_config != self.generation_config: # 4)
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" deprecated strategy to control generation and will be removed in v5."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
UserWarning,
)
self.generation_config = new_generation_config
using_model_generation_config = True

generation_config = self.generation_config
using_model_generation_config = True

Expand Down
13 changes: 12 additions & 1 deletion src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import csv
import importlib
import json
Expand Down Expand Up @@ -825,13 +826,23 @@ def __init__(
framework, model = infer_framework_load_model(model, config=model.config)

self.task = task
self.model = model
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.image_processor = image_processor
self.modelcard = modelcard
self.framework = framework

# TODO (joao): Keras models don't support `copy(model)` as of writing, fix me
if framework == "pt":
# Create shallow copy of the model with a deep copies of the configs. A pipeline may change the config of
# the model and we don't want side-effects on the original object.
self.model = copy.copy(model)
self.model.config = copy.deepcopy(model.config)
if self.model.can_generate():
self.model.generation_config = copy.deepcopy(model.generation_config)
else:
self.model = model

# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)

Expand Down
19 changes: 19 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AutoTokenizer,
DistilBertForSequenceClassification,
MaskGenerationPipeline,
T5ForConditionalGeneration,
TextClassificationPipeline,
TextGenerationPipeline,
TFAutoModelForSequenceClassification,
Expand Down Expand Up @@ -234,6 +235,24 @@ def test_auto_model_pipeline_registration_from_local_dir(self):

self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load

@require_torch
def test_pipeline_with_task_parameters_no_side_effects(self):
"""
Regression test: certain pipeline flags, like `task`, modified the model configuration, causing unexpected
side-effects
"""
model = T5ForConditionalGeneration.from_pretrained("t5-small")
self.assertTrue(model.config.num_beams == 1)

# The `task` parameter used to cause side-effects on `model.config` -- not anymore
pipe = pipeline(model=model, tokenizer=AutoTokenizer.from_pretrained("t5-small"), task="translation_en_to_de")
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(pipe.model.config.num_beams == 4)

# Under the hood: we hold a shallow copy of the model with a deep copy of the configs
self.assertTrue(id(model._parameters) == id(pipe.model._parameters)) # same model reference
self.assertTrue(id(model.config) != id(pipe.model.config)) # different config reference


@is_pipeline_test
class PipelineScikitCompatTest(unittest.TestCase):
Expand Down
32 changes: 32 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,38 @@ def test_isin_mps_friendly(self):
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)

def test_save_and_load_config_with_custom_generation(self):
"""
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
that gets moved to the generation config and reset on the model config)
"""
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)

# The default for `num_beams` is 1 and `early_stopping` is False
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(model.config.early_stopping is False)

# When we save the model, this custom parameter should be moved to the generation config AND the model
# config should contain `None`
model.config.num_beams = 2
model.config.early_stopping = True
self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
# moved to generation config
self.assertTrue(new_model.generation_config.num_beams == 2)
self.assertTrue(new_model.generation_config.early_stopping is True)
# reset in the model config
self.assertTrue(new_model.config.num_beams is None)
self.assertTrue(new_model.config.early_stopping is None)

# Sanity check: We can run `generate` with the new model without any warnings
random_ids = torch.randint(0, 100, (1, 5))
with warnings.catch_warnings(record=True) as w:
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)


@slow
@require_torch
Expand Down
Loading