Skip to content

Commit

Permalink
always set self.config
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Jan 26, 2024
1 parent b5d561e commit 26e1c57
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class ModelHubMixin:
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
"""

config: Optional[Union[dict, "DataclassInstance"]] = None
# ^ optional config attribute automatically set in `from_pretrained` (if not already set by the subclass)

def save_pretrained(
self,
save_directory: Union[str, Path],
Expand Down Expand Up @@ -160,6 +163,7 @@ def from_pretrained(
except HfHubHTTPError as e:
logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")

config = None
if config_file is not None:
# Read config
with open(config_file, "r", encoding="utf-8") as f:
Expand All @@ -184,7 +188,7 @@ def from_pretrained(
# Forward config to model initialization
model_kwargs["config"] = config

return cls._from_pretrained(
instance = cls._from_pretrained(
model_id=str(model_id),
revision=revision,
cache_dir=cache_dir,
Expand All @@ -196,6 +200,13 @@ def from_pretrained(
**model_kwargs,
)

# Implicitly set the config as instance attribute if not already set by the class
# This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
if config is not None and instance.config is None:
instance.config = config

return instance

@classmethod
def _from_pretrained(
cls: Type[T],
Expand Down

0 comments on commit 26e1c57

Please sign in to comment.