Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose RepoCard at top level + few qol improvements #1354

Merged
merged 9 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/package_reference/cards.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@ get a feel for how you would use these utilities in your own projects.

## Repo Card

The `RepoCard` object is the parent class of [`ModelCard`] and [`DatasetCard`].
The `RepoCard` object is the parent class of [`ModelCard`], [`DatasetCard`] and `SpaceCard`.

[[autodoc]] huggingface_hub.repocard.RepoCard
- __init__
- all

## Card Data

The [`CardData`] object is the parent class of [`ModelCardData`] and [`DatasetCardData`].

[[autodoc]] huggingface_hub.repocard_data.CardData

## Model Cards

### ModelCard

[[autodoc]] ModelCard
Expand Down
4 changes: 4 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@
"repocard": [
"DatasetCard",
"ModelCard",
"RepoCard",
"SpaceCard",
"metadata_eval_result",
"metadata_load",
"metadata_save",
Expand Down Expand Up @@ -432,6 +434,8 @@ def __dir__():
from .repocard import (
DatasetCard, # noqa: F401
ModelCard, # noqa: F401
RepoCard, # noqa: F401
SpaceCard, # noqa: F401
metadata_eval_result, # noqa: F401
metadata_load, # noqa: F401
metadata_save, # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3455,7 +3455,7 @@ def create_pull_request(

Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`];

This is a wrapper around [`HfApi.create_discusssion`].
This is a wrapper around [`HfApi.create_discussion`].

Args:
repo_id (`str`):
Expand Down
20 changes: 13 additions & 7 deletions src/huggingface_hub/repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,17 @@ def from_template( # type: ignore # violates Liskov property but easier to use
return super().from_template(card_data, template_path, **template_kwargs)


class SpaceCard:
"""Space card is an alias for [`RepoCard`].

At the moment, it does not implement any specific logic. `SpaceCard` is defined for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it makes sense to have Space-specific attributes in RepoCard under this instead? (e.g. like the one that tells us about duplicate)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!!

Copy link
Contributor Author

@Wauplin Wauplin Feb 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defintely. Not useless to have a separate class in the end 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consistency purposes. It might get extended in the future."""

card_data_class = CardData
default_template_path = TEMPLATE_MODELCARD_PATH
repo_type = "space"


def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # noqa: F722
"""Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines.

Expand Down Expand Up @@ -787,18 +798,13 @@ def metadata_update(
card.data.eval_results.append(new_result)
else:
# Any metadata that is not a result metric
if (
hasattr(card.data, key)
and getattr(card.data, key) is not None
and not overwrite
and getattr(card.data, key) != value
):
if card.data.get(key) is not None and not overwrite and card.data.get(key) != value:
nateraw marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"You passed a new value for the existing meta data field '{key}'."
" Set `overwrite=True` to overwrite existing metadata."
)
else:
setattr(card.data, key, value)
card.data[key] = value

return card.push_to_hub(
repo_id,
Expand Down
29 changes: 29 additions & 0 deletions src/huggingface_hub/repocard_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ def is_equal_except_value(self, other: "EvalResult") -> bool:

@dataclass
class CardData:
"""Structure containing metadata from a RepoCard.

[`CardData`] is the parent class of [`ModelCardData`] and [`DatasetCardData`].

Metadata can be exported as a dictionary or YAML. Export can be customized to alter the representation of the data
(example: flatten evaluation results). `CardData` behaves as a dictionary (can get, pop, set values) but do not
inherit from `dict` to allow this export step.
"""

def __init__(self, **kwargs):
self.__dict__.update(kwargs)

Expand Down Expand Up @@ -187,6 +196,26 @@ def to_yaml(self, line_break=None) -> str:
def __repr__(self):
return self.to_yaml()

def get(self, key: str, default: Any = None) -> Any:
"""Get value for a given metadata key."""
return self.__dict__.get(key, default)

def pop(self, key: str, default: Any = None) -> Any:
"""Pop value for a given metadata key."""
return self.__dict__.pop(key, default)

def __getitem__(self, key: str) -> Any:
"""Get value for a given metadata key."""
return self.__dict__[key]

def __setitem__(self, key: str, value: Any) -> None:
"""Set value for a given metadata key."""
self.__dict__[key] = value

def __contains__(self, key: str) -> bool:
"""Check if a given metadata key is set."""
return key in self.__dict__


class ModelCardData(CardData):
"""Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
Expand Down
3 changes: 2 additions & 1 deletion tests/test_repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
EvalResult,
ModelCard,
ModelCardData,
RepoCard,
metadata_eval_result,
metadata_load,
metadata_save,
Expand All @@ -37,7 +38,7 @@
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from huggingface_hub.repocard import REGEX_YAML_BLOCK, RepoCard
from huggingface_hub.repocard import REGEX_YAML_BLOCK
from huggingface_hub.repocard_data import CardData
from huggingface_hub.repository import Repository
from huggingface_hub.utils import SoftTemporaryDirectory, is_jinja_available, logging
Expand Down
30 changes: 29 additions & 1 deletion tests/test_repocard_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import yaml

from huggingface_hub.repocard_data import (
CardData,
DatasetCardData,
EvalResult,
ModelCardData,
Expand Down Expand Up @@ -37,6 +38,33 @@
"""


class BaseCardDataTest(unittest.TestCase):
def test_metadata_behave_as_dict(self):
metadata = CardData(foo="bar")

# .get and __getitem__
self.assertEqual(metadata.get("foo"), "bar")
self.assertEqual(metadata.get("FOO"), None) # case sensitive
self.assertEqual(metadata["foo"], "bar")
with self.assertRaises(KeyError): # case sensitive
_ = metadata["FOO"]

# __setitem__
metadata["foo"] = "BAR"
self.assertEqual(metadata.get("foo"), "BAR")
self.assertEqual(metadata["foo"], "BAR")

# __contains__
self.assertTrue("foo" in metadata)
self.assertFalse("FOO" in metadata)

# export
self.assertEqual(str(metadata), "foo: BAR")

# .pop
self.assertEqual(metadata.pop("foo"), "BAR")


class ModelCardDataTest(unittest.TestCase):
def test_eval_results_to_model_index(self):
expected_results = yaml.safe_load(DUMMY_METADATA_WITH_MODEL_INDEX)
Expand Down Expand Up @@ -143,7 +171,7 @@ def test_card_data_requires_model_name_for_eval_results(self):
self.assertEqual(model_index[0]["name"], "my-cool-model")
self.assertEqual(model_index[0]["results"][0]["task"]["type"], "image-classification")

def test_abitrary_incoming_card_data(self):
def test_arbitrary_incoming_card_data(self):
data = ModelCardData(
model_name="my-cool-model",
eval_results=[
Expand Down