Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom kwargs for model card in save_pretrained #2310

Merged
merged 5 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def save_pretrained(
config: Optional[Union[dict, "DataclassInstance"]] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
model_card_kwargs: Optional[Dict[str, Any]] = None,
**push_to_hub_kwargs,
) -> Optional[str]:
"""
Expand All @@ -340,7 +341,9 @@ def save_pretrained(
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
not provided.
kwargs:
model_card_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments passed to the model card template to customize the model card.
push_to_hub_kwargs:
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
Returns:
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
Expand Down Expand Up @@ -369,8 +372,9 @@ def save_pretrained(

# save model card
model_card_path = save_directory / "README.md"
model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
if not model_card_path.exists(): # do not overwrite if already exists
self.generate_model_card().save(save_directory / "README.md")
self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")

# push to the Hub if required
if push_to_hub:
Expand All @@ -379,7 +383,7 @@ def save_pretrained(
kwargs["config"] = config
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, **kwargs)
return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
return None

def _save_pretrained(self, save_directory: Path) -> None:
Expand Down Expand Up @@ -588,6 +592,7 @@ def push_to_hub(
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
delete_patterns: Optional[Union[List[str], str]] = None,
model_card_kwargs: Optional[Dict[str, Any]] = None,
) -> str:
"""
Upload model checkpoint to the Hub.
Expand Down Expand Up @@ -618,6 +623,8 @@ def push_to_hub(
If provided, files matching any of the patterns are not pushed.
delete_patterns (`List[str]` or `str`, *optional*):
If provided, remote files matching any of the patterns will be deleted from the repo.
model_card_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments passed to the model card template to customize the model card.

Returns:
The url of the commit of your model in the given repository.
Expand All @@ -628,7 +635,7 @@ def push_to_hub(
# Push the files to the repo in a single commit
with SoftTemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path, config=config)
self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
Expand All @@ -647,6 +654,7 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
template_str=self._hub_mixin_info.model_card_template,
repo_url=self._hub_mixin_info.repo_url,
docs_url=self._hub_mixin_info.docs_url,
**kwargs,
)
return card

Expand Down
4 changes: 2 additions & 2 deletions tests/test_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ def test_save_pretrained_with_push_to_hub(self):

# Push to hub with repo_id (config is pushed)
mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID")
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT)
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=CONFIG_AS_DICT, model_card_kwargs={})

# Push to hub with default repo_id (based on dir name)
mocked_model.save_pretrained(save_directory, push_to_hub=True)
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT)
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=CONFIG_AS_DICT, model_card_kwargs={})

@patch.object(DummyModelNoConfig, "_from_pretrained")
def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None:
Expand Down
37 changes: 35 additions & 2 deletions tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
Arxiv ID: 1234.56789
"""

DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS = """
---
{{ card_data }}
---

This is a dummy model card with kwargs.
Arxiv ID: 1234.56789

{{ custom_data }}
"""

if is_torch_available():
import torch
import torch.nn as nn
Expand Down Expand Up @@ -76,11 +87,20 @@ class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs):
super().__init__()

class DummyModelWithModelCardAndCustomKwargs(
nn.Module,
PyTorchModelHubMixin,
model_card_template=DUMMY_MODEL_CARD_TEMPLATE_WITH_CUSTOM_KWARGS,
):
def __init__(self, linear_layer: int = 4):
super().__init__()

else:
DummyModel = None
DummyModelWithModelCard = None
DummyModelNoConfig = None
DummyModelWithConfigAndKwargs = None
DummyModelWithModelCardAndCustomKwargs = None


@requires("torch")
Expand Down Expand Up @@ -130,11 +150,11 @@ def test_save_pretrained_with_push_to_hub(self):

# Push to hub with repo_id
mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID", config=config)
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config)
mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config, model_card_kwargs={})

# Push to hub with default repo_id (based on dir name)
mocked_model.save_pretrained(save_directory, push_to_hub=True, config=config)
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config)
mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config, model_card_kwargs={})

@patch.object(DummyModel, "_from_pretrained")
def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None:
Expand Down Expand Up @@ -386,3 +406,16 @@ def test_save_pretrained_when_config_and_kwargs_are_passed(self):

reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir)
assert reloaded._hub_mixin_config == model._hub_mixin_config

def test_model_card_with_custom_kwargs(self):
model_card_kwargs = {"custom_data": "This is a model custom data: 42."}

# Test creating model with custom kwargs => custom data is saved in model card
model = DummyModelWithModelCardAndCustomKwargs()
card = model.generate_model_card(**model_card_kwargs)
assert model_card_kwargs["custom_data"] in str(card)

# Test saving card => model card is saved and restored with custom data
model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs)
card_reloaded = ModelCard.load(self.cache_dir / "README.md")
assert str(card) == str(card_reloaded)
Loading