Skip to content

Commit

Permalink
Remote code improvements (huggingface#23959)
Browse files Browse the repository at this point in the history
* Fix model load when it has both code on the Hub and locally

* Add input check with timeout

* Add tests

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

* Some non-saved stuff

* Add feature extractors

* Add image processor

* Add model

* Add processor and tokenizer

* Reduce timeout

---------

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
  • Loading branch information
2 people authored and novice03 committed Jun 23, 2023
1 parent 852c07c commit b516c02
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 103 deletions.
44 changes: 44 additions & 0 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import re
import shutil
import signal
import sys
from pathlib import Path
from typing import Dict, Optional, Union
Expand Down Expand Up @@ -513,3 +514,46 @@ def _set_auto_map_in_config(_config):
result.append(dest_file)

return result


def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute the configuration file in that repo on your local machine. We "
"asked if it was okay but did not get an answer. Make sure you have read the code there to avoid malicious "
"use, then set the option `trust_remote_code=True` to remove this error."
)


TIME_OUT_REMOTE_CODE = 15


def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"Loading {model_name} requires to execute some code in that repo, you can inspect the content of "
f"the repository at https://hf.co/{model_name}. You can dismiss this prompt by passing "
"`trust_remote_code=True`.\nDo you accept? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)

if has_remote_code and not has_local_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)

return trust_remote_code
38 changes: 16 additions & 22 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import OrderedDict

from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import copy_func, logging, requires_backends
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings

Expand Down Expand Up @@ -404,19 +404,14 @@ def __init__(self, *args, **kwargs):

@classmethod
def from_config(cls, config, **kwargs):
trust_remote_code = kwargs.pop("trust_remote_code", False)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
"Loading this model requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
trust_remote_code = kwargs.pop("trust_remote_code", None)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, config._name_or_path, has_local_code, has_remote_code
)

if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
if "--" in class_ref:
repo_id, class_ref = class_ref.split("--")
Expand All @@ -437,7 +432,7 @@ def from_config(cls, config, **kwargs):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
hub_kwargs_names = [
"cache_dir",
Expand Down Expand Up @@ -470,13 +465,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto"

if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
Expand Down
18 changes: 9 additions & 9 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List, Union

from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import CONFIG_NAME, logging


Expand Down Expand Up @@ -943,15 +943,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
```"""
kwargs["_from_auto"] = True
kwargs["name_or_path"] = pretrained_model_name_or_path
trust_remote_code = kwargs.pop("trust_remote_code", False)
trust_remote_code = kwargs.pop("trust_remote_code", None)
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)

if has_remote_code and trust_remote_code:
class_ref = config_dict["auto_map"]["AutoConfig"]
config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
Expand Down
32 changes: 16 additions & 16 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging
from .auto_factory import _LazyAutoMapping
Expand Down Expand Up @@ -307,7 +307,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
```"""
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True

config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
Expand All @@ -326,21 +326,21 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]

if feature_extractor_class is not None:
# If we have custom code for a feature extractor, we get the proper class.
if feature_extractor_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
"in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error."
)
feature_extractor_class = get_class_from_dynamic_module(
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
else:
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)

has_remote_code = feature_extractor_auto_map is not None
has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)

if has_remote_code and trust_remote_code:
feature_extractor_class = get_class_from_dynamic_module(
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
return feature_extractor_class.from_dict(config_dict, **kwargs)
elif feature_extractor_class is not None:
return feature_extractor_class.from_dict(config_dict, **kwargs)
# Last try: we use the FEATURE_EXTRACTOR_MAPPING.
elif type(config) in FEATURE_EXTRACTOR_MAPPING:
Expand Down
32 changes: 16 additions & 16 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Build the list of all image processors
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...image_processing_utils import ImageProcessingMixin
from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, logging
from .auto_factory import _LazyAutoMapping
Expand Down Expand Up @@ -314,7 +314,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
```"""
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True

config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
Expand Down Expand Up @@ -351,21 +351,21 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
image_processor_auto_map = config.auto_map["AutoImageProcessor"]

if image_processor_class is not None:
# If we have custom code for a image processor, we get the proper class.
if image_processor_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the image processor file "
"in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error."
)
image_processor_class = get_class_from_dynamic_module(
image_processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
else:
image_processor_class = image_processor_class_from_name(image_processor_class)
image_processor_class = image_processor_class_from_name(image_processor_class)

has_remote_code = image_processor_auto_map is not None
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)

if has_remote_code and trust_remote_code:
image_processor_class = get_class_from_dynamic_module(
image_processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
return image_processor_class.from_dict(config_dict, **kwargs)
elif image_processor_class is not None:
return image_processor_class.from_dict(config_dict, **kwargs)
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
elif type(config) in IMAGE_PROCESSOR_MAPPING:
Expand Down
38 changes: 19 additions & 19 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin
from ...image_processing_utils import ImageProcessingMixin
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
Expand Down Expand Up @@ -194,7 +194,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
```"""
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True

processor_class = None
Expand Down Expand Up @@ -248,28 +248,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor_auto_map = config.auto_map["AutoProcessor"]

if processor_class is not None:
# If we have custom code for a feature extractor, we get the proper class.
if processor_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
"in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error."
)

processor_class = get_class_from_dynamic_module(
processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
else:
processor_class = processor_class_from_name(processor_class)
processor_class = processor_class_from_name(processor_class)

has_remote_code = processor_auto_map is not None
has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)

if has_remote_code and trust_remote_code:
processor_class = get_class_from_dynamic_module(
processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
return processor_class.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
elif processor_class is not None:
return processor_class.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)

# Last try: we use the PROCESSOR_MAPPING.
if type(config) in PROCESSOR_MAPPING:
elif type(config) in PROCESSOR_MAPPING:
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)

# At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a
Expand Down
Loading

0 comments on commit b516c02

Please sign in to comment.