diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 0789b6428e4ab1..220ae97f5073c6 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -17,7 +17,6 @@ import copy import importlib import json -import os import warnings from collections import OrderedDict @@ -427,11 +426,7 @@ def from_config(cls, config, **kwargs): else: repo_id = config.name_or_path model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) - if os.path.isdir(config._name_or_path): - model_class.register_for_auto_class(cls.__name__) - cls.register(config.__class__, model_class, exist_ok=True) - else: - cls.register(config.__class__, model_class, exist_ok=True) + cls.register(config.__class__, model_class, exist_ok=True) _ = kwargs.pop("code_revision", None) return model_class._from_config(config, **kwargs) elif type(config) in cls._model_mapping.keys(): @@ -553,11 +548,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs ) _ = hub_kwargs.pop("code_revision", None) - if os.path.isdir(pretrained_model_name_or_path): - model_class.register_for_auto_class(cls.__name__) - cls.register(config.__class__, model_class, exist_ok=True) - else: - cls.register(config.__class__, model_class, exist_ok=True) + cls.register(config.__class__, model_class, exist_ok=True) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs )