Skip to content

Commit

Permalink
Extremely experimental fix!
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Aug 23, 2024
1 parent 0a7af19 commit e9e154b
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def from_config(cls, config, **kwargs):
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
if os.path.isdir(config._name_or_path):
model_class.register_for_auto_class(cls.__name__)
cls.register(config.__class__, model_class, exist_ok=True)
else:
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
Expand Down Expand Up @@ -554,6 +555,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
_ = hub_kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
model_class.register_for_auto_class(cls.__name__)
cls.register(config.__class__, model_class, exist_ok=True)
else:
cls.register(config.__class__, model_class, exist_ok=True)
return model_class.from_pretrained(
Expand Down

0 comments on commit e9e154b

Please sign in to comment.