diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index 33f4d6212a97f0..e1e2d4547cc4e2 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -20,10 +20,8 @@ from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig -from ...utils.import_utils import register -@register() class AlbertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used @@ -153,7 +151,6 @@ def __init__( # Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert -@register() class AlbertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index ac624542870d60..ad54e322188af4 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -43,7 +43,6 @@ logging, replace_return_docstrings, ) -from ...utils.import_utils import register from .configuration_albert import AlbertConfig @@ -53,7 +52,6 @@ _CONFIG_FOR_DOC = "AlbertConfig" -@register() def load_tf_weights_in_albert(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" try: @@ -489,7 +487,6 @@ def forward( ) -@register(backends=("torch",)) class AlbertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -621,7 +618,6 @@ class AlbertForPreTrainingOutput(ModelOutput): "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.", ALBERT_START_DOCSTRING, ) -@register(backends=("torch",)) class AlbertModel(AlbertPreTrainedModel): config_class = AlbertConfig base_model_prefix = "albert" @@ -751,7 +747,6 @@ def forward( """, ALBERT_START_DOCSTRING, ) -@register(backends=("torch",)) class AlbertForPreTraining(AlbertPreTrainedModel): _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] @@ -904,7 +899,6 @@ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: "Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ) -@register(backends=("torch",)) class AlbertForMaskedLM(AlbertPreTrainedModel): _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] @@ -1020,7 +1014,6 @@ def forward( """, ALBERT_START_DOCSTRING, ) -@register(backends=("torch",)) class AlbertForSequenceClassification(AlbertPreTrainedModel): def __init__(self, config: AlbertConfig): super().__init__(config) @@ -1122,7 +1115,6 @@ def forward( """, ALBERT_START_DOCSTRING, ) -@register(backends=("torch",)) class AlbertForTokenClassification(AlbertPreTrainedModel): def __init__(self, config: AlbertConfig): super().__init__(config) @@ -1206,7 +1198,6 @@ def forward( """, ALBERT_START_DOCSTRING, ) -@register(backends=("torch",)) class AlbertForQuestionAnswering(AlbertPreTrainedModel): def __init__(self, config: AlbertConfig): super().__init__(config) @@ -1310,7 +1301,6 @@ def forward( """, ALBERT_START_DOCSTRING, ) -@register(backends=("torch",)) class AlbertForMultipleChoice(AlbertPreTrainedModel): def __init__(self, config: AlbertConfig): super().__init__(config) diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py index 1bc5f0f4ff07ca..b5b49219aebf63 100644 --- a/src/transformers/models/albert/modeling_flax_albert.py +++ b/src/transformers/models/albert/modeling_flax_albert.py @@ -42,7 +42,6 @@ overwrite_call_docstring, ) from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from ...utils.import_utils import register from .configuration_albert import AlbertConfig @@ -506,7 +505,6 @@ def __call__(self, pooled_output, deterministic=True): return logits -@register(backends=("flax",)) class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -676,7 +674,6 @@ def __call__( "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.", ALBERT_START_DOCSTRING, ) -@register(backends=("flax",)) class FlaxAlbertModel(FlaxAlbertPreTrainedModel): module_class = FlaxAlbertModule @@ -745,7 +742,6 @@ def __call__( """, ALBERT_START_DOCSTRING, ) -@register(backends=("flax",)) class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel): module_class = FlaxAlbertForPreTrainingModule @@ -829,7 +825,6 @@ def __call__( @add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING) -@register(backends=("flax",)) class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): module_class = FlaxAlbertForMaskedLMModule @@ -900,7 +895,6 @@ def __call__( """, ALBERT_START_DOCSTRING, ) -@register(backends=("flax",)) class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel): module_class = FlaxAlbertForSequenceClassificationModule @@ -974,7 +968,6 @@ def __call__( """, ALBERT_START_DOCSTRING, ) -@register(backends=("flax",)) class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel): module_class = FlaxAlbertForMultipleChoiceModule @@ -1048,7 +1041,6 @@ def __call__( """, ALBERT_START_DOCSTRING, ) -@register(backends=("flax",)) class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel): module_class = FlaxAlbertForTokenClassificationModule @@ -1117,7 +1109,6 @@ def __call__( """, ALBERT_START_DOCSTRING, ) -@register(backends=("flax",)) class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel): module_class = FlaxAlbertForQuestionAnsweringModule diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index c443421b9b9092..24a25658a4d41a 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -56,7 +56,6 @@ logging, replace_return_docstrings, ) -from ...utils.import_utils import register from .configuration_albert import AlbertConfig @@ -511,7 +510,6 @@ def build(self, input_shape=None): layer.build(None) -@register(backends=("tf",)) class TFAlbertPreTrainedModel(TFPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -587,7 +585,6 @@ def call(self, hidden_states: tf.Tensor) -> tf.Tensor: @keras_serializable -@register(backends=("tf",)) class TFAlbertMainLayer(keras.layers.Layer): config_class = AlbertConfig @@ -861,7 +858,6 @@ class TFAlbertForPreTrainingOutput(ModelOutput): "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.", ALBERT_START_DOCSTRING, ) -@register(backends=("tf",)) class TFAlbertModel(TFAlbertPreTrainedModel): def __init__(self, config: AlbertConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -919,7 +915,6 @@ def build(self, input_shape=None): """, ALBERT_START_DOCSTRING, ) -@register(backends=("tf",)) class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss): # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model _keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"] @@ -1051,7 +1046,6 @@ def build(self, input_shape=None): @add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING) -@register(backends=("tf",)) class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss): # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"] @@ -1165,7 +1159,6 @@ def build(self, input_shape=None): """, ALBERT_START_DOCSTRING, ) -@register(backends=("tf",)) class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss): # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model _keys_to_ignore_on_load_unexpected = [r"predictions"] @@ -1260,7 +1253,6 @@ def build(self, input_shape=None): """, ALBERT_START_DOCSTRING, ) -@register(backends=("tf",)) class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss): # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] @@ -1356,7 +1348,6 @@ def build(self, input_shape=None): """, ALBERT_START_DOCSTRING, ) -@register(backends=("tf",)) class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss): # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] @@ -1464,7 +1455,6 @@ def build(self, input_shape=None): """, ALBERT_START_DOCSTRING, ) -@register(backends=("tf",)) class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] diff --git a/src/transformers/models/albert/tokenization_albert.py b/src/transformers/models/albert/tokenization_albert.py index f8d1a38eaee8df..4971d0511f47bd 100644 --- a/src/transformers/models/albert/tokenization_albert.py +++ b/src/transformers/models/albert/tokenization_albert.py @@ -23,7 +23,7 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging -from ...utils.import_utils import register +from ...utils.import_utils import export logger = logging.get_logger(__name__) @@ -33,7 +33,7 @@ SPIECE_UNDERLINE = "▁" -@register(backends=("sentencepiece",)) +@export(backends=("sentencepiece",)) class AlbertTokenizer(PreTrainedTokenizer): """ Construct an ALBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). diff --git a/src/transformers/models/albert/tokenization_albert_fast.py b/src/transformers/models/albert/tokenization_albert_fast.py index 264fe4ebdf166e..6e7b110b0afad7 100644 --- a/src/transformers/models/albert/tokenization_albert_fast.py +++ b/src/transformers/models/albert/tokenization_albert_fast.py @@ -21,7 +21,6 @@ from ...tokenization_utils import AddedToken from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import is_sentencepiece_available, logging -from ...utils.import_utils import register if is_sentencepiece_available(): @@ -36,7 +35,6 @@ SPIECE_UNDERLINE = "▁" -@register(backends=("tokenizers",)) class AlbertTokenizerFast(PreTrainedTokenizerFast): """ Construct a "fast" ALBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on diff --git a/src/transformers/models/align/configuration_align.py b/src/transformers/models/align/configuration_align.py index 55bf04aeb84c8e..99fa81b4a9350d 100644 --- a/src/transformers/models/align/configuration_align.py +++ b/src/transformers/models/align/configuration_align.py @@ -17,8 +17,6 @@ import os from typing import TYPE_CHECKING, List, Union -from ...utils.import_utils import register - if TYPE_CHECKING: pass @@ -30,7 +28,6 @@ logger = logging.get_logger(__name__) -@register() class AlignTextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`AlignTextModel`]. It is used to instantiate a @@ -155,7 +152,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return cls.from_dict(config_dict, **kwargs) -@register() class AlignVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`AlignVisionModel`]. It is used to instantiate a @@ -295,7 +291,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return cls.from_dict(config_dict, **kwargs) -@register() class AlignConfig(PretrainedConfig): r""" [`AlignConfig`] is the configuration class to store the configuration of a [`AlignModel`]. It is used to diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index fee92dbfe7fe60..dcaa38be57501e 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -20,6 +20,7 @@ import torch import torch.utils.checkpoint +from IPython.terminal.pt_inputhooks import backends from torch import nn from ...activations import ACT2FN @@ -38,7 +39,6 @@ logging, replace_return_docstrings, ) -from ...utils.import_utils import register from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig @@ -1166,7 +1166,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output -@register(backends=("torch",)) class AlignPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1200,7 +1199,6 @@ def _init_weights(self, module): """The text model from ALIGN without any head or projection on top.""", ALIGN_START_DOCSTRING, ) -@register(backends=("torch",)) class AlignTextModel(AlignPreTrainedModel): config_class = AlignTextConfig _no_split_modules = ["AlignTextEmbeddings"] @@ -1328,7 +1326,6 @@ def forward( """The vision model from ALIGN without any head or projection on top.""", ALIGN_START_DOCSTRING, ) -@register(backends=("torch",)) class AlignVisionModel(AlignPreTrainedModel): config_class = AlignVisionConfig main_input_name = "pixel_values" @@ -1415,7 +1412,6 @@ def forward( @add_start_docstrings(ALIGN_START_DOCSTRING) -@register(backends=("torch",)) class AlignModel(AlignPreTrainedModel): config_class = AlignConfig diff --git a/src/transformers/models/align/processing_align.py b/src/transformers/models/align/processing_align.py index 546805b30ce163..923daee965fbf9 100644 --- a/src/transformers/models/align/processing_align.py +++ b/src/transformers/models/align/processing_align.py @@ -16,8 +16,6 @@ Image/Text processor class for ALIGN """ -from ...utils.import_utils import register - try: from typing import Unpack @@ -42,7 +40,6 @@ class AlignProcessorKwargs(ProcessingKwargs, total=False): } -@register() class AlignProcessor(ProcessorMixin): r""" Constructs an ALIGN processor which wraps [`EfficientNetImageProcessor`] and diff --git a/src/transformers/models/altclip/configuration_altclip.py b/src/transformers/models/altclip/configuration_altclip.py index 836b5707cf19e6..7333fa63a35280 100755 --- a/src/transformers/models/altclip/configuration_altclip.py +++ b/src/transformers/models/altclip/configuration_altclip.py @@ -19,13 +19,11 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging -from ...utils.import_utils import register logger = logging.get_logger(__name__) -@register() class AltCLIPTextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`AltCLIPTextModel`]. It is used to instantiate a @@ -144,7 +142,6 @@ def __init__( self.project_dim = project_dim -@register() class AltCLIPVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`AltCLIPModel`]. It is used to instantiate an @@ -255,7 +252,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return cls.from_dict(config_dict, **kwargs) -@register() class AltCLIPConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`AltCLIPModel`]. It is used to instantiate an diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 5b73d951faa9e9..2fda84fca02661 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -33,7 +33,6 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from ...utils.import_utils import register from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig @@ -1022,7 +1021,6 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: return embeddings -@register(backends=("torch",)) class AltCLIPPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1138,7 +1136,6 @@ def forward( ) -@register(backends=("torch",)) class AltCLIPVisionModel(AltCLIPPreTrainedModel): config_class = AltCLIPVisionConfig main_input_name = "pixel_values" @@ -1370,7 +1367,6 @@ def forward( ) -@register(backends=("torch",)) class AltCLIPTextModel(AltCLIPPreTrainedModel): config_class = AltCLIPTextConfig @@ -1463,7 +1459,6 @@ def forward( ) -@register(backends=("torch",)) class AltCLIPModel(AltCLIPPreTrainedModel): config_class = AltCLIPConfig diff --git a/src/transformers/models/altclip/processing_altclip.py b/src/transformers/models/altclip/processing_altclip.py index e787217dc77c93..5343498842832c 100644 --- a/src/transformers/models/altclip/processing_altclip.py +++ b/src/transformers/models/altclip/processing_altclip.py @@ -20,10 +20,8 @@ from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding -from ...utils.import_utils import register -@register() class AltCLIPProcessor(ProcessorMixin): r""" Constructs a AltCLIP processor which wraps a CLIP image processor and a XLM-Roberta tokenizer into a single diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a1577a622b4680..b3803e59b86c95 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -27,8 +27,9 @@ from functools import lru_cache from itertools import chain from types import ModuleType -from typing import Any, Tuple, Union +from typing import Any, Tuple, Union, List +from nltk.downloader import update from packaging import version from . import logging @@ -1583,7 +1584,7 @@ def __init__(self, name, module_file, import_structure, module_spec=None, extra_ if key not in _import_structure: _import_structure[key] = values else: - _import_structure[key].extend(values) + _import_structure[key].update(values) # Needed for autocompletion in an IDE self.__all__.extend(list(item.keys()) + list(chain(*item.values()))) @@ -1601,7 +1602,7 @@ def __init__(self, name, module_file, import_structure, module_spec=None, extra_ self._name = name self._import_structure = _import_structure - # This can be removed once every exportable object has a `register()` export. + # This can be removed once every exportable object has a `export()` export. if not PER_BACKEND_SPLIT: self._modules = set(import_structure.keys()) self._class_to_module = {} @@ -1691,12 +1692,12 @@ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType: return module -def register(*, backends=()): +def export(*, backends=()): """ This method enables two things: - Attaching a `__backends` tuple to an object to see what are the necessary backends for it to execute correctly without instantiating it - - The '@register' string is used to dynamically import objects + - The '@export' string is used to dynamically import objects """ if not isinstance(backends, tuple): @@ -1709,6 +1710,51 @@ def inner_fn(fun): return inner_fn +BASE_FILE_REQUIREMENTS = { + lambda e: 'modeling_tf_' in e: ('tf',), + lambda e: 'modeling_flax_' in e: ('flax',), + lambda e: 'modeling_' in e: ('torch',), + lambda e: e.startswith('tokenization_') and e.endswith('_fast'): ('tokenizers',), +} + + +def fetch__all__(file_content): + """ + Returns the content of the __all__ variable in the file content. + Returns None if not defined, otherwise returns a list of strings. + """ + + if '__all__' not in file_content: + return [] + + lines = file_content.splitlines() + for index, line in enumerate(lines): + if line.startswith("__all__"): + start_index = index + + lines = lines[start_index:] + + if not lines[0].startswith('__all__'): + raise ValueError( + "fetch__all__ accepts a list of lines, with the first line being the __all__ variable declaration" + ) + + # __all__ is defined on a single line + if lines[0].endswith("]"): + return [obj.strip("\"' ") for obj in lines[0].split("=")[1].strip(" []").split(",")] + + # __all__ is defined on multiple lines + else: + _all = [] + for __all__line_index in range(1, len(lines)): + if lines[__all__line_index].strip() == "]": + return _all + else: + _all.append(lines[__all__line_index].strip("\"', ")) + + return _all + + @lru_cache() def define_import_structure(module_path): import_structure = {} @@ -1722,6 +1768,9 @@ def define_import_structure(module_path): adjacent_modules = [f for f in os.listdir(directory) if not os.path.isdir(os.path.join(directory, f))] + # We're only taking a look at files different from __init__.py + # We could theoretically export things directly from the __init__.py + # files, but this is not supported at this time. if "__init__.py" in adjacent_modules: adjacent_modules.remove("__init__.py") @@ -1737,78 +1786,108 @@ def define_import_structure(module_path): previous_line = "" previous_index = 0 - lines = file_content.split("\n") - for index, line in enumerate(lines): - # This allows registering items with other decorators. We'll take a look - # at the line that follows at the same indentation level. - if line.startswith((" ", "\t", "@", ")")) and not line.startswith("@register"): - continue + # Some files have some requirements by default. + # For example, any file named `modeling_tf_xxx.py` + # should have TensorFlow as a required backend. + base_requirements = () + for string_check, requirements in BASE_FILE_REQUIREMENTS.items(): + if string_check(module_name): + base_requirements = requirements + break + + # Objects that have a `@export` assigned to them will get exported + # with the backends specified in the decorator as well as the file backends. + registered_objects = set() + if '@export' in file_content: + lines = file_content.split("\n") + for index, line in enumerate(lines): + + # This allows exporting items with other decorators. We'll take a look + # at the line that follows at the same indentation level. + if line.startswith((" ", "\t", "@", ")")) and not line.startswith("@export"): + continue + + # Skipping line enables putting whatever we want between the + # export() call and the actual class/method definition. + # This is what enables having # Copied from statements, docs, etc. + skip_line = False - # Skipping line enables putting whatever we want between the - # register() call and the actuall class/method definition. - # This is what enables having # Copied from statements, docs, etc. - skip_line = False + if "@export" in previous_line: + skip_line = False + + # Backends are defined on the same line as export + if "backends" in previous_line: + backends_string = previous_line.split("backends=")[1].split("(")[1].split(")")[0] + backends = tuple(sorted([b.strip("'\",") for b in backends_string.split(", ") if b])) + + # Backends are defined in the lines following export, for example such as: + # @export( + # backends=( + # "sentencepiece", + # "torch", + # "tf", + # ) + # ) + # + # or + # + # @export( + # backends=( + # "sentencepiece", "tf" + # ) + # ) + elif "backends" in lines[previous_index + 1]: + backends = [] + for backend_line in lines[previous_index:index]: + if "backends" in backend_line: + backend_line = backend_line.split("=")[1] + if '"' in backend_line or "'" in backend_line: + if ", " in backend_line: + backends.extend(backend.strip("()\"', ") for backend in backend_line.split(", ")) + else: + backends.append(backend_line.strip("()\"', ")) + + # If the line is only a ')', then we reached the end of the backends and we break. + if backend_line.strip() == ")": + break + backends = tuple(backends) + + # No backends are registered for export + else: + backends = () - if "@register" in previous_line: - skip_line = False + backends = frozenset(backends + base_requirements) + if backends not in module_requirements: + module_requirements[backends] = {} + if module_name not in module_requirements[backends]: + module_requirements[backends][module_name] = set() - # Backends are defined on the same line as register - if "backends" in previous_line: - backends_string = previous_line.split("backends=")[1].split("(")[1].split(")")[0] - backends = tuple(sorted([b.strip("'\",") for b in backends_string.split(", ")])) - - # Backends are defined in the lines following register, for example such as: - # @register( - # backends=( - # "sentencepiece", - # "torch", - # "tf", - # ) - # ) - # - # or - # - # @register( - # backends=( - # "sentencepiece", "tf" - # ) - # ) - elif "backends" in lines[previous_index + 1]: - backends = [] - for backend_line in lines[previous_index:index]: - if "backends" in backend_line: - backend_line = backend_line.split("=")[1] - if '"' in backend_line or "'" in backend_line: - if ", " in backend_line: - backends.extend(backend.strip("()\"', ") for backend in backend_line.split(", ")) - else: - backends.append(backend_line.strip("()\"', ")) - - # If the line is only a ')', then we reached the end of the backends and we break. - if backend_line.strip() == ")": - break - backends = tuple(backends) - - # No backends are registered - else: - backends = () - - backends = frozenset(backends) - if backends not in module_requirements: - module_requirements[backends] = {} - if module_name not in module_requirements[backends]: - module_requirements[backends][module_name] = [] - - if not line.startswith("class") and not line.startswith("def"): - skip_line = True - else: - start_index = 6 if line.startswith("class") else 4 - object_name = line[start_index:].split("(")[0].strip(":") - module_requirements[backends][module_name].append(object_name) - - if not skip_line: - previous_line = line - previous_index = index + if not line.startswith("class") and not line.startswith("def"): + skip_line = True + else: + start_index = 6 if line.startswith("class") else 4 + object_name = line[start_index:].split("(")[0].strip(":") + module_requirements[backends][module_name].add(object_name) + registered_objects.add(object_name) + + if not skip_line: + previous_line = line + previous_index = index + + # All objects that are in __all__ should be exported by default. + # These objects are exported with the file backends. + if '__all__' in file_content: + _all = fetch__all__(file_content) + + backends = frozenset(base_requirements) + if backends not in module_requirements: + module_requirements[backends] = {} + if module_name not in module_requirements[backends]: + module_requirements[backends][module_name] = set() + + for _all_object in _all: + if _all_object not in registered_objects: + module_requirements[backends][module_name].add(_all_object) import_structure = {**module_requirements, **import_structure} return import_structure diff --git a/tests/utils/import_structures/import_structure_raw_register.py b/tests/utils/import_structures/import_structure_raw_register.py index 0a74438a0173f8..3f838ccdaa6456 100644 --- a/tests/utils/import_structures/import_structure_raw_register.py +++ b/tests/utils/import_structures/import_structure_raw_register.py @@ -1,31 +1,31 @@ # fmt: off -from transformers.utils.import_utils import register +from transformers.utils.import_utils import export -@register() +@export() class A0: def __init__(self): pass -@register() +@export() def a0(): pass -@register(backends=("torch", "tf")) +@export(backends=("torch", "tf")) class A1: def __init__(self): pass -@register(backends=("torch", "tf")) +@export(backends=("torch", "tf")) def a1(): pass -@register( +@export( backends=("torch", "tf") ) class A2: @@ -33,14 +33,14 @@ def __init__(self): pass -@register( +@export( backends=("torch", "tf") ) def a2(): pass -@register( +@export( backends=( "torch", "tf" @@ -51,7 +51,7 @@ def __init__(self): pass -@register( +@export( backends=( "torch", "tf" @@ -59,3 +59,8 @@ def __init__(self): ) def a3(): pass + +@export(backends=()) +class A4: + def __init__(self): + pass diff --git a/tests/utils/import_structures/import_structure_register_with_comments.py b/tests/utils/import_structures/import_structure_register_with_comments.py index 21828d19d79572..e716f0ebca056b 100644 --- a/tests/utils/import_structures/import_structure_register_with_comments.py +++ b/tests/utils/import_structures/import_structure_register_with_comments.py @@ -1,55 +1,51 @@ # fmt: off -from transformers.utils.import_utils import register +from transformers.utils.import_utils import export -@register() +@export() # That's a statement class B0: def __init__(self): pass -@register() +@export() # That's a statement def b0(): pass -@register(backends=("torch", "tf")) +@export(backends=("torch", "tf")) # That's a statement class B1: def __init__(self): pass -@register(backends=("torch", "tf")) +@export(backends=("torch", "tf")) # That's a statement def b1(): pass -@register( - backends=("torch", "tf") -) +@export(backends=("torch", "tf")) # That's a statement class B2: def __init__(self): pass -@register( - backends=("torch", "tf") -) +@export(backends=("torch", "tf")) # That's a statement def b2(): pass -@register( +@export( backends=( - "torch", - "tf" + "torch", + "tf" ) ) # That's a statement @@ -58,10 +54,10 @@ def __init__(self): pass -@register( +@export( backends=( - "torch", - "tf" + "torch", + "tf" ) ) # That's a statement diff --git a/tests/utils/import_structures/import_structure_register_with_duplicates.py b/tests/utils/import_structures/import_structure_register_with_duplicates.py index 13f79d051e1d25..509c8f64052b8c 100644 --- a/tests/utils/import_structures/import_structure_register_with_duplicates.py +++ b/tests/utils/import_structures/import_structure_register_with_duplicates.py @@ -1,50 +1,46 @@ # fmt: off -from transformers.utils.import_utils import register +from transformers.utils.import_utils import export -@register(backends=("torch", "torch")) +@export(backends=("torch", "torch")) class C0: def __init__(self): pass -@register(backends=("torch", "torch")) +@export(backends=("torch", "torch")) def c0(): pass -@register(backends=("torch", "torch")) +@export(backends=("torch", "torch")) # That's a statement class C1: def __init__(self): pass -@register(backends=("torch", "torch")) +@export(backends=("torch", "torch")) # That's a statement def c1(): pass -@register( - backends=("torch", "torch") -) +@export(backends=("torch", "torch")) # That's a statement class C2: def __init__(self): pass -@register( - backends=("torch", "torch") -) +@export(backends=("torch", "torch")) # That's a statement def c2(): pass -@register( +@export( backends=( "torch", "torch" @@ -56,7 +52,7 @@ def __init__(self): pass -@register( +@export( backends=( "torch", "torch" diff --git a/tests/utils/test_import_structure.py b/tests/utils/test_import_structure.py index b1e100dde62641..ae28f31b2559dc 100644 --- a/tests/utils/test_import_structure.py +++ b/tests/utils/test_import_structure.py @@ -40,15 +40,15 @@ def test_definition(self): import_structure = define_import_structure(import_structures) import_structure_definition = { frozenset(()): { - "import_structure_raw_register": ["A0", "a0"], - "import_structure_register_with_comments": ["B0", "b0"], + "import_structure_raw_register": {"A0", "a0", "A4"}, + "import_structure_register_with_comments": {"B0", "b0"}, }, frozenset(("tf", "torch")): { - "import_structure_raw_register": ["A1", "a1", "A2", "a2", "A3", "a3"], - "import_structure_register_with_comments": ["B1", "b1", "B2", "b2", "B3", "b3"], + "import_structure_raw_register": {"A1", "a1", "A2", "a2", "A3", "a3"}, + "import_structure_register_with_comments": {"B1", "b1", "B2", "b2", "B3", "b3"}, }, frozenset(("torch",)): { - "import_structure_register_with_duplicates": ["C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"], + "import_structure_register_with_duplicates": {"C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"}, }, }