From ea7b0a539a92a79b829cfc7d41d28f33f993e820 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 17 Apr 2023 11:36:29 -0400 Subject: [PATCH] Use code on the Hub from another repo (#22698) * initial work * Add other classes * Refactor code * Move warning and fix dynamic pipeline * Issue warning when necessary * Add test --- src/transformers/configuration_utils.py | 5 ++ src/transformers/dynamic_module_utils.py | 48 ++++++++++++++++--- src/transformers/models/auto/auto_factory.py | 15 +++--- .../models/auto/configuration_auto.py | 11 +---- .../models/auto/feature_extraction_auto.py | 10 +--- .../models/auto/image_processing_auto.py | 10 +--- .../models/auto/processing_auto.py | 9 +--- .../models/auto/tokenization_auto.py | 12 +---- src/transformers/pipelines/__init__.py | 3 +- src/transformers/tokenization_utils_base.py | 12 ++++- src/transformers/utils/__init__.py | 1 + tests/models/auto/test_modeling_auto.py | 28 +++++++++++ 12 files changed, 98 insertions(+), 66 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 718d2d8d0f1de9..ab0df58008f398 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -667,6 +667,11 @@ def _get_config_dict( else: logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") + if "auto_map" in config_dict and not is_local: + config_dict["auto_map"] = { + k: (f"{pretrained_model_name_or_path}--{v}" if "--" not in v else v) + for k, v in config_dict["auto_map"].items() + } return config_dict, kwargs @classmethod diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 62a124f7d38995..8d0ff2c34f2b8c 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -29,6 +29,7 @@ extract_commit_hash, is_offline_mode, logging, + try_to_load_from_cache, ) @@ -222,11 +223,16 @@ def get_cached_module_file( # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: submodule = pretrained_model_name_or_path.split(os.path.sep)[-1] else: submodule = pretrained_model_name_or_path.replace("/", os.path.sep) + cached_module = try_to_load_from_cache( + pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash + ) + new_files = [] try: # Load from URL or cache if already cached resolved_module_file = cached_file( @@ -241,6 +247,8 @@ def get_cached_module_file( revision=revision, _commit_hash=_commit_hash, ) + if not is_local and cached_module != resolved_module_file: + new_files.append(module_file) except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") @@ -284,7 +292,7 @@ def get_cached_module_file( importlib.invalidate_caches() # Make sure we also have every file with relative for module_needed in modules_needed: - if not (submodule_path / module_needed).exists(): + if not (submodule_path / f"{module_needed}.py").exists(): get_cached_module_file( pretrained_model_name_or_path, f"{module_needed}.py", @@ -295,14 +303,24 @@ def get_cached_module_file( use_auth_token=use_auth_token, revision=revision, local_files_only=local_files_only, + _commit_hash=commit_hash, ) + new_files.append(f"{module_needed}.py") + + if len(new_files) > 0: + new_files = "\n".join([f"- {f}" for f in new_files]) + logger.warning( + f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}" + "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new " + "versions of the code file, you can pin a revision." + ) + return os.path.join(full_submodule, module_file) def get_class_from_dynamic_module( + class_reference: str, pretrained_model_name_or_path: Union[str, os.PathLike], - module_file: str, - class_name: str, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, resume_download: bool = False, @@ -323,6 +341,8 @@ def get_class_from_dynamic_module( Args: + class_reference (`str`): + The full name of the class to load, including its module and optionally its repo. pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: @@ -332,6 +352,7 @@ def get_class_from_dynamic_module( - a path to a *directory* containing a configuration file saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + This is used when `class_reference` does not specify another repo. module_file (`str`): The name of the module file containing the class to look for. class_name (`str`): @@ -371,12 +392,25 @@ def get_class_from_dynamic_module( ```python # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this # module. - cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") + cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model") + + # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model") ```""" + # Catch the name of the repo if it's specified in `class_reference` + if "--" in class_reference: + repo_id, class_reference = class_reference.split("--") + # Invalidate revision since it's not relevant for this repo + revision = "main" + else: + repo_id = pretrained_model_name_or_path + module_file, class_name = class_reference.split(".") + # And lastly we get the class inside our newly created module final_module = get_cached_module_file( - pretrained_model_name_or_path, - module_file, + repo_id, + module_file + ".py", cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index eb87bb1ff7dbdd..f8bc266fe8325f 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -403,8 +403,12 @@ def from_config(cls, config, **kwargs): "no malicious code has been contributed in a newer revision." ) class_ref = config.auto_map[cls.__name__] + if "--" in class_ref: + repo_id, class_ref = class_ref.split("--") + else: + repo_id = config.name_or_path module_file, class_name = class_ref.split(".") - model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs) + model_class = get_class_from_dynamic_module(repo_id, module_file + ".py", class_name, **kwargs) return model_class._from_config(config, **kwargs) elif type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) @@ -452,17 +456,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): "on your local machine. Make sure you have read the code there to avoid malicious use, then set " "the option `trust_remote_code=True` to remove this error." ) - if hub_kwargs.get("revision", None) is None: - logger.warning( - "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " - "no malicious code has been contributed in a newer revision." - ) class_ref = config.auto_map[cls.__name__] - module_file, class_name = class_ref.split(".") model_class = get_class_from_dynamic_module( - pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs + class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs ) - model_class.register_for_auto_class(cls.__name__) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 225fc739eda58c..06e562097715ee 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -921,17 +921,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" " set the option `trust_remote_code=True` to remove this error." ) - if kwargs.get("revision", None) is None: - logger.warning( - "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to " - "ensure no malicious code has been contributed in a newer revision." - ) class_ref = config_dict["auto_map"]["AutoConfig"] - module_file, class_name = class_ref.split(".") - config_class = get_class_from_dynamic_module( - pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs - ) - config_class.register_for_auto_class() + config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) elif "model_type" in config_dict: config_class = CONFIG_MAPPING[config_dict["model_type"]] diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 90218d137f8d28..0a527ee151759c 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -333,17 +333,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): "in that repo on your local machine. Make sure you have read the code there to avoid " "malicious use, then set the option `trust_remote_code=True` to remove this error." ) - if kwargs.get("revision", None) is None: - logger.warning( - "Explicitly passing a `revision` is encouraged when loading a feature extractor with custom " - "code to ensure no malicious code has been contributed in a newer revision." - ) - - module_file, class_name = feature_extractor_auto_map.split(".") feature_extractor_class = get_class_from_dynamic_module( - pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs ) - feature_extractor_class.register_for_auto_class() else: feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index c092dbf16f4675..2dae53019f1968 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -355,17 +355,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): "in that repo on your local machine. Make sure you have read the code there to avoid " "malicious use, then set the option `trust_remote_code=True` to remove this error." ) - if kwargs.get("revision", None) is None: - logger.warning( - "Explicitly passing a `revision` is encouraged when loading a image processor with custom " - "code to ensure no malicious code has been contributed in a newer revision." - ) - - module_file, class_name = image_processor_auto_map.split(".") image_processor_class = get_class_from_dynamic_module( - pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + image_processor_auto_map, pretrained_model_name_or_path, **kwargs ) - image_processor_class.register_for_auto_class() else: image_processor_class = image_processor_class_from_name(image_processor_class) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 9e6edc0ae16f79..8c9236130c2b75 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -254,17 +254,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): "in that repo on your local machine. Make sure you have read the code there to avoid " "malicious use, then set the option `trust_remote_code=True` to remove this error." ) - if kwargs.get("revision", None) is None: - logger.warning( - "Explicitly passing a `revision` is encouraged when loading a feature extractor with custom " - "code to ensure no malicious code has been contributed in a newer revision." - ) - module_file, class_name = processor_auto_map.split(".") processor_class = get_class_from_dynamic_module( - pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + processor_auto_map, pretrained_model_name_or_path, **kwargs ) - processor_class.register_for_auto_class() else: processor_class = processor_class_from_name(processor_class) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 4fee20f50b371b..de954e206ae194 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -671,22 +671,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): " repo on your local machine. Make sure you have read the code there to avoid malicious use," " then set the option `trust_remote_code=True` to remove this error." ) - if kwargs.get("revision", None) is None: - logger.warning( - "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure" - " no malicious code has been contributed in a newer revision." - ) if use_fast and tokenizer_auto_map[1] is not None: class_ref = tokenizer_auto_map[1] else: class_ref = tokenizer_auto_map[0] - - module_file, class_name = class_ref.split(".") - tokenizer_class = get_class_from_dynamic_module( - pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs - ) - tokenizer_class.register_for_auto_class() + tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) elif use_fast and not config_tokenizer_class.endswith("Fast"): tokenizer_class_candidate = f"{config_tokenizer_class}Fast" diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index b1d3bc43e8d9d8..b4e696613889e8 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -727,9 +727,8 @@ def pipeline( " set the option `trust_remote_code=True` to remove this error." ) class_ref = targeted_task["impl"] - module_file, class_name = class_ref.split(".") pipeline_class = get_class_from_dynamic_module( - model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token + class_ref, model, revision=revision, use_auth_token=use_auth_token ) else: normalized_task, targeted_task, task_options = check_task(task) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 3045e7f7cb9eef..df132fa7ae78df 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1817,6 +1817,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], cache_dir=cache_dir, local_files_only=local_files_only, _commit_hash=commit_hash, + _is_local=is_local, **kwargs, ) @@ -1831,6 +1832,7 @@ def _from_pretrained( cache_dir=None, local_files_only=False, _commit_hash=None, + _is_local=False, **kwargs, ): # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json @@ -1861,7 +1863,6 @@ def _from_pretrained( # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. config_tokenizer_class = init_kwargs.get("tokenizer_class") init_kwargs.pop("tokenizer_class", None) - init_kwargs.pop("auto_map", None) saved_init_inputs = init_kwargs.pop("init_inputs", ()) if not init_inputs: init_inputs = saved_init_inputs @@ -1869,6 +1870,15 @@ def _from_pretrained( config_tokenizer_class = None init_kwargs = init_configuration + if "auto_map" in init_kwargs and not _is_local: + new_auto_map = {} + for key, value in init_kwargs["auto_map"].items(): + if isinstance(value, (list, tuple)): + new_auto_map[key] = [f"{pretrained_model_name_or_path}--{v}" for v in value] + else: + new_auto_map[key] = f"{pretrained_model_name_or_path}--{value}" + init_kwargs["auto_map"] = new_auto_map + if config_tokenizer_class is None: from .models.auto.configuration_auto import AutoConfig # tests_ignore diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 1f04ca73bfc13a..f91b6c7748c195 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -83,6 +83,7 @@ is_remote_url, move_cache, send_example_telemetry, + try_to_load_from_cache, ) from .import_utils import ( ENV_VARS_TRUE_AND_AUTO_VALUES, diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 9fb982c0f0e14e..26eecd54299c0c 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -298,6 +298,34 @@ def test_from_pretrained_dynamic_model_distant(self): for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + def test_from_pretrained_dynamic_model_distant_with_ref(self): + model = AutoModel.from_pretrained("hf-internal-testing/ref_to_test_dynamic_model", trust_remote_code=True) + self.assertEqual(model.__class__.__name__, "NewModel") + + # Test model can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True) + + self.assertEqual(reloaded_model.__class__.__name__, "NewModel") + for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # This one uses a relative import to a util file, this checks it is downloaded and used properly. + model = AutoModel.from_pretrained( + "hf-internal-testing/ref_to_test_dynamic_model_with_util", trust_remote_code=True + ) + self.assertEqual(model.__class__.__name__, "NewModel") + + # Test model can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True) + + self.assertEqual(reloaded_model.__class__.__name__, "NewModel") + for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + def test_new_model_registration(self): AutoConfig.register("custom", CustomConfig)