diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 28e347d434..abe4ce72e8 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -374,6 +374,7 @@ def save_pretrained( config: Optional[Union[dict, "DataclassInstance"]] = None, repo_id: Optional[str] = None, push_to_hub: bool = False, + model_card_kwargs: Optional[Dict[str, Any]] = None, **push_to_hub_kwargs, ) -> Optional[str]: """ @@ -389,7 +390,9 @@ def save_pretrained( repo_id (`str`, *optional*): ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if not provided. - kwargs: + model_card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. + push_to_hub_kwargs: Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. Returns: `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. @@ -418,8 +421,9 @@ def save_pretrained( # save model card model_card_path = save_directory / "README.md" + model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {} if not model_card_path.exists(): # do not overwrite if already exists - self.generate_model_card().save(save_directory / "README.md") + self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md") # push to the Hub if required if push_to_hub: @@ -428,7 +432,7 @@ def save_pretrained( kwargs["config"] = config if repo_id is None: repo_id = save_directory.name # Defaults to `save_directory` name - return self.push_to_hub(repo_id=repo_id, **kwargs) + return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs) return None def _save_pretrained(self, save_directory: Path) -> None: @@ -637,6 +641,7 @@ def push_to_hub( allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, delete_patterns: Optional[Union[List[str], str]] = None, + model_card_kwargs: Optional[Dict[str, Any]] = None, ) -> str: """ Upload model checkpoint to the Hub. @@ -667,6 +672,8 @@ def push_to_hub( If provided, files matching any of the patterns are not pushed. delete_patterns (`List[str]` or `str`, *optional*): If provided, remote files matching any of the patterns will be deleted from the repo. + model_card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. Returns: The url of the commit of your model in the given repository. @@ -677,7 +684,7 @@ def push_to_hub( # Push the files to the repo in a single commit with SoftTemporaryDirectory() as tmp: saved_path = Path(tmp) / repo_id - self.save_pretrained(saved_path, config=config) + self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs) return api.upload_folder( repo_id=repo_id, repo_type="model", @@ -696,6 +703,7 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard: template_str=self._hub_mixin_info.model_card_template, repo_url=self._hub_mixin_info.repo_url, docs_url=self._hub_mixin_info.docs_url, + **kwargs, ) return card diff --git a/tests/test_hub_mixin.py b/tests/test_hub_mixin.py index 9b711adb78..4247443e36 100644 --- a/tests/test_hub_mixin.py +++ b/tests/test_hub_mixin.py @@ -293,11 +293,11 @@ def test_save_pretrained_with_push_to_hub(self): # Push to hub with repo_id (config is pushed) mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID") - mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT) + mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT, model_card_kwargs={}) # Push to hub with default repo_id (based on dir name) mocked_model.save_pretrained(save_directory, push_to_hub=True) - mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT) + mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT, model_card_kwargs={}) @patch.object(DummyModelNoConfig, "_from_pretrained") def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None: diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index 9e208a2a07..816f3e6ffd 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -28,6 +28,17 @@ Arxiv ID: 1234.56789 """ +DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS = """ +--- +{{ card_data }} +--- + +This is a dummy model card with kwargs. +Arxiv ID: 1234.56789 + +{{ custom_data }} +""" + if is_torch_available(): import torch import torch.nn as nn @@ -76,11 +87,20 @@ class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin): def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs): super().__init__() + class DummyModelWithModelCardAndCustomKwargs( + nn.Module, + PyTorchModelHubMixin, + model_card_template=DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS, + ): + def __init__(self, linear_layer: int = 4): + super().__init__() + else: DummyModel = None DummyModelWithModelCard = None DummyModelNoConfig = None DummyModelWithConfigAndKwargs = None + DummyModelWithModelCardAndCustomKwargs = None @requires("torch") @@ -130,11 +150,11 @@ def test_save_pretrained_with_push_to_hub(self): # Push to hub with repo_id mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID", config=config) - mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config) + mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config, model_card_kwargs={}) # Push to hub with default repo_id (based on dir name) mocked_model.save_pretrained(save_directory, push_to_hub=True, config=config) - mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config) + mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config, model_card_kwargs={}) @patch.object(DummyModel, "_from_pretrained") def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None: @@ -386,3 +406,16 @@ def test_save_pretrained_when_config_and_kwargs_are_passed(self): reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir) assert reloaded._hub_mixin_config == model._hub_mixin_config + + def test_model_card_with_custom_kwargs(self): + model_card_kwargs = {"custom_data": "This is a model custom data: 42."} + + # Test creating model with custom kwargs => custom data is saved in model card + model = DummyModelWithModelCardAndCustomKwargs() + card = model.generate_model_card(**model_card_kwargs) + assert model_card_kwargs["custom_data"] in str(card) + + # Test saving card => model card is saved and restored with custom data + model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs) + card_reloaded = ModelCard.load(self.cache_dir / "README.md") + assert str(card) == str(card_reloaded)