Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing list of tags to Keras model card #806

Merged
merged 17 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions src/huggingface_hub/keras_mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import os
import warnings
from pathlib import Path
from shutil import copytree, rmtree
from typing import Any, Dict, Optional, Union

import yaml
from huggingface_hub import ModelHubMixin
from huggingface_hub.file_download import (
is_graphviz_available,
Expand Down Expand Up @@ -100,7 +102,7 @@ def _create_model_card(
model,
repo_dir: Path,
plot_model: Optional[bool] = True,
task_name: Optional[str] = None,
metadata: Optional[dict] = None,
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Creates a model card for the repository.
Expand All @@ -109,10 +111,10 @@ def _create_model_card(
if plot_model and is_graphviz_available() and is_pydot_available():
_plot_network(model, repo_dir)
readme_path = f"{repo_dir}/README.md"
metadata["library_name"] = "keras"
model_card = "---\n"
if task_name is not None:
model_card += f"tags:\n- {task_name}\n"
model_card += "library_name: keras\n---\n"
model_card += yaml.dump(metadata, default_flow_style=False)
model_card += "---\n"
model_card += "\n## Model description\n\nMore information needed\n"
model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
model_card += "\n## Training and evaluation data\n\nMore information needed\n"
Expand Down Expand Up @@ -149,6 +151,7 @@ def save_pretrained_keras(
config: Optional[Dict[str, Any]] = None,
include_optimizer: Optional[bool] = False,
plot_model: Optional[bool] = True,
tags: Optional[Union[list, str]] = None,
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved
task_name: Optional[str] = None,
**model_save_kwargs,
):
Expand All @@ -170,9 +173,13 @@ def save_pretrained_keras(
plot_model (`bool`, *optional*, defaults to `True`):
Setting this to `True` will plot the model and put it in the model
card. Requires graphviz and pydot to be installed.
tags (`dict`, *optional*):
osanseviero marked this conversation as resolved.
Show resolved Hide resolved
List of tags that are related to model or string of a single tag. See example tags
[here](https://github.com/huggingface/hub-docs/blame/main/modelcard.md).
task_name (`str`, *optional*):
Name of the task the model was trained on. Available tasks
[here](https://github.com/huggingface/hub-docs/blob/main/js/src/lib/interfaces/Types.ts).
This is deprecated in favor of `tags` and will be removed in v0.7.
model_save_kwargs(`dict`, *optional*):
model_save_kwargs will be passed to
[`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
Expand All @@ -189,6 +196,12 @@ def save_pretrained_keras(

os.makedirs(save_directory, exist_ok=True)

if task_name:
warnings.warn(
"`task_name` input argument is deprecated and "
"will be removed in v0.7. Pass `tags` instead.",
FutureWarning,
)
# saving config
if config:
if not isinstance(config, dict):
Expand All @@ -199,7 +212,17 @@ def save_pretrained_keras(
with open(path, "w") as f:
json.dump(config, f)

_create_model_card(model, save_directory, plot_model, task_name)
metadata = {}
if isinstance(tags, list):
metadata["tags"] = tags
elif isinstance(tags, str):
metadata["tags"] = [tags]
if task_name is not None and "tags" in metadata:
metadata["tags"].append(task_name)
elif task_name is not None and metadata == {}:
metadata["tags"] = [task_name]

_create_model_card(model, save_directory, plot_model, metadata)
tf.keras.models.save_model(
model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs
)
Expand Down Expand Up @@ -278,6 +301,7 @@ def push_to_hub_keras(
git_email: Optional[str] = None,
config: Optional[dict] = None,
include_optimizer: Optional[bool] = False,
tags: Optional[Union[list, str]] = None,
osanseviero marked this conversation as resolved.
Show resolved Hide resolved
task_name: Optional[str] = None,
plot_model: Optional[bool] = True,
**model_save_kwargs,
Expand Down Expand Up @@ -328,6 +352,9 @@ def push_to_hub_keras(
Configuration object to be saved alongside the model weights.
include_optimizer (`bool`, *optional*, defaults to `False`):
Whether or not to include optimizer during serialization.
tags (`dict`, *optional*):
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved
List of tags that are related to model or string of a single tag. See example tags
[here](https://github.com/huggingface/hub-docs/blame/main/modelcard.md).
task_name (`str`, *optional*):
Name of the task the model was trained on. Available tasks
[here](https://github.com/huggingface/huggingface_hub/blob/main/js/src/lib/interfaces/Types.ts).
Expand Down Expand Up @@ -390,6 +417,7 @@ def push_to_hub_keras(
config=config,
include_optimizer=include_optimizer,
plot_model=plot_model,
tags=tags,
task_name=task_name,
**model_save_kwargs,
)
Expand Down
14 changes: 13 additions & 1 deletion tests/test_keras_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,18 @@ def test_from_pretrained_weights(self):
.item()
)

def test_save_pretrained_task_name_deprecation(self):
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved
REPO_NAME = repo_name("save")
model = self.model_init()
model.build((None, 2))

with pytest.warns(
FutureWarning, match="`task_name` input argument is deprecated*"
):
save_pretrained_keras(
model, f"{WORKING_REPO_DIR}/{REPO_NAME}", task_name="test"
)

def test_rel_path_from_pretrained(self):
model = self.model_init()
model.build((None, 2))
Expand Down Expand Up @@ -300,7 +312,7 @@ def test_abs_path_from_pretrained(self):
f"{WORKING_REPO_DIR}/{REPO_NAME}",
config={"num": 10, "act": "gelu_fast"},
plot_model=True,
task_name=None,
tags=None,
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved
)

new_model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}")
Expand Down