diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 93944ebcb2..de346a5eec 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -33,6 +33,9 @@ class ModelHubMixin: of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. """ + config: Optional[Union[dict, "DataclassInstance"]] = None + # ^ optional config attribute automatically set in `from_pretrained` (if not already set by the subclass) + def save_pretrained( self, save_directory: Union[str, Path], @@ -160,6 +163,7 @@ def from_pretrained( except HfHubHTTPError as e: logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}") + config = None if config_file is not None: # Read config with open(config_file, "r", encoding="utf-8") as f: @@ -184,7 +188,7 @@ def from_pretrained( # Forward config to model initialization model_kwargs["config"] = config - return cls._from_pretrained( + instance = cls._from_pretrained( model_id=str(model_id), revision=revision, cache_dir=cache_dir, @@ -196,6 +200,13 @@ def from_pretrained( **model_kwargs, ) + # Implicitly set the config as instance attribute if not already set by the class + # This way `config` will be available when calling `save_pretrained` or `push_to_hub`. + if config is not None and instance.config is None: + instance.config = config + + return instance + @classmethod def _from_pretrained( cls: Type[T],