Skip to content

Commit

Permalink
Support custom kwargs for model card in save_pretrained (#2310)
Browse files Browse the repository at this point in the history
* Support custom kwargs for model card in save_pretrained

* Fix failing test

* Fix test for pytorch mixin

* Add test for model_card_kwargs

* Fix style
  • Loading branch information
qubvel authored Jun 4, 2024
1 parent 919ce7d commit 54515da
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
16 changes: 12 additions & 4 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def save_pretrained(
config: Optional[Union[dict, "DataclassInstance"]] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
model_card_kwargs: Optional[Dict[str, Any]] = None,
**push_to_hub_kwargs,
) -> Optional[str]:
"""
Expand All @@ -340,7 +341,9 @@ def save_pretrained(
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
not provided.
kwargs:
model_card_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments passed to the model card template to customize the model card.
push_to_hub_kwargs:
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
Returns:
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
Expand Down Expand Up @@ -369,8 +372,9 @@ def save_pretrained(

# save model card
model_card_path = save_directory / "README.md"
model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
if not model_card_path.exists(): # do not overwrite if already exists
self.generate_model_card().save(save_directory / "README.md")
self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")

# push to the Hub if required
if push_to_hub:
Expand All @@ -379,7 +383,7 @@ def save_pretrained(
kwargs["config"] = config
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, **kwargs)
return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
return None

def _save_pretrained(self, save_directory: Path) -> None:
Expand Down Expand Up @@ -588,6 +592,7 @@ def push_to_hub(
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
delete_patterns: Optional[Union[List[str], str]] = None,
model_card_kwargs: Optional[Dict[str, Any]] = None,
) -> str:
"""
Upload model checkpoint to the Hub.
Expand Down Expand Up @@ -618,6 +623,8 @@ def push_to_hub(
If provided, files matching any of the patterns are not pushed.
delete_patterns (`List[str]` or `str`, *optional*):
If provided, remote files matching any of the patterns will be deleted from the repo.
model_card_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments passed to the model card template to customize the model card.
Returns:
The url of the commit of your model in the given repository.
Expand All @@ -628,7 +635,7 @@ def push_to_hub(
# Push the files to the repo in a single commit
with SoftTemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path, config=config)
self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
Expand All @@ -647,6 +654,7 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
template_str=self._hub_mixin_info.model_card_template,
repo_url=self._hub_mixin_info.repo_url,
docs_url=self._hub_mixin_info.docs_url,
**kwargs,
)
return card

Expand Down
4 changes: 2 additions & 2 deletions tests/test_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ def test_save_pretrained_with_push_to_hub(self):

# Push to hub with repo_id (config is pushed)
mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID")
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT)
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT, model_card_kwargs={})

# Push to hub with default repo_id (based on dir name)
mocked_model.save_pretrained(save_directory, push_to_hub=True)
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT)
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT, model_card_kwargs={})

@patch.object(DummyModelNoConfig, "_from_pretrained")
def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None:
Expand Down
37 changes: 35 additions & 2 deletions tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
Arxiv ID: 1234.56789
"""

DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS = """
---
{{ card_data }}
---
This is a dummy model card with kwargs.
Arxiv ID: 1234.56789
{{ custom_data }}
"""

if is_torch_available():
import torch
import torch.nn as nn
Expand Down Expand Up @@ -76,11 +87,20 @@ class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs):
super().__init__()

class DummyModelWithModelCardAndCustomKwargs(
nn.Module,
PyTorchModelHubMixin,
model_card_template=DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS,
):
def __init__(self, linear_layer: int = 4):
super().__init__()

else:
DummyModel = None
DummyModelWithModelCard = None
DummyModelNoConfig = None
DummyModelWithConfigAndKwargs = None
DummyModelWithModelCardAndCustomKwargs = None


@requires("torch")
Expand Down Expand Up @@ -130,11 +150,11 @@ def test_save_pretrained_with_push_to_hub(self):

# Push to hub with repo_id
mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID", config=config)
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config)
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config, model_card_kwargs={})

# Push to hub with default repo_id (based on dir name)
mocked_model.save_pretrained(save_directory, push_to_hub=True, config=config)
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config)
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config, model_card_kwargs={})

@patch.object(DummyModel, "_from_pretrained")
def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None:
Expand Down Expand Up @@ -386,3 +406,16 @@ def test_save_pretrained_when_config_and_kwargs_are_passed(self):

reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir)
assert reloaded._hub_mixin_config == model._hub_mixin_config

def test_model_card_with_custom_kwargs(self):
model_card_kwargs = {"custom_data": "This is a model custom data: 42."}

# Test creating model with custom kwargs => custom data is saved in model card
model = DummyModelWithModelCardAndCustomKwargs()
card = model.generate_model_card(**model_card_kwargs)
assert model_card_kwargs["custom_data"] in str(card)

# Test saving card => model card is saved and restored with custom data
model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs)
card_reloaded = ModelCard.load(self.cache_dir / "README.md")
assert str(card) == str(card_reloaded)

0 comments on commit 54515da

Please sign in to comment.