Skip to content

Commit

Permalink
Add SpaceCardData with attributes + docstring + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Feb 21, 2023
1 parent b7a6027 commit d5e71bb
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 7 deletions.
10 changes: 10 additions & 0 deletions docs/source/package_reference/cards.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ Dataset cards are also known as Data Cards in the ML Community.

[[autodoc]] DatasetCardData

## Space Cards

### SpaceCard

[[autodoc]] SpaceCard

### SpaceCardData

[[autodoc]] SpaceCardData

## Utilities

### EvalResult
Expand Down
2 changes: 2 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
"DatasetCardData",
"EvalResult",
"ModelCardData",
"SpaceCardData",
],
"repository": [
"Repository",
Expand Down Expand Up @@ -446,6 +447,7 @@ def __dir__():
DatasetCardData, # noqa: F401
EvalResult, # noqa: F401
ModelCardData, # noqa: F401
SpaceCardData, # noqa: F401
)
from .repository import Repository # noqa: F401
from .utils import (
Expand Down
10 changes: 3 additions & 7 deletions src/huggingface_hub/repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DatasetCardData,
EvalResult,
ModelCardData,
SpaceCardData,
eval_results_to_model_index,
model_index_to_eval_results,
)
Expand Down Expand Up @@ -463,13 +464,8 @@ def from_template( # type: ignore # violates Liskov property but easier to use
return super().from_template(card_data, template_path, **template_kwargs)


class SpaceCard:
"""Space card is an alias for [`RepoCard`].
At the moment, it does not implement any specific logic. `SpaceCard` is defined for
consistency purposes. It might get extended in the future."""

card_data_class = CardData
class SpaceCard(RepoCard):
card_data_class = SpaceCardData
default_template_path = TEMPLATE_MODELCARD_PATH
repo_type = "space"

Expand Down
77 changes: 77 additions & 0 deletions src/huggingface_hub/repocard_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,83 @@ def _to_dict(self, data_dict):
data_dict["train-eval-index"] = data_dict.pop("train_eval_index")


class SpaceCardData(CardData):
"""Space Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
To get an exhaustive reference of Spaces configuration, please visit https://huggingface.co/docs/hub/spaces-config-reference#spaces-configuration-reference.
Args:
title (`str`, *optional*)
Title of the Space.
sdk (`str`, *optional*)
SDK of the Space (one of `gradio`, `streamlit`, `docker`, or `static`).
sdk_version (`str`, *optional*)
Version of the used SDK (if Gradio/Streamlit sdk).
python_version (`str`, *optional*)
Python version used in the Space (if Gradio/Streamlit sdk).
app_file (`str`, *optional*)
Path to your main application file (which contains either gradio or streamlit Python code, or static html code).
Path is relative to the root of the repository.
app_port (`str`, *optional*)
Port on which your application is running. Used only if sdk is `docker`.
license (`str`, *optional*)
License of this model. Example: apache-2.0 or any license from
https://huggingface.co/docs/hub/repositories-licenses.
duplicated_from (`str`, *optional*)
ID of the original Space if this is a duplicated Space.
models (`str`, *optional*)
List of models related to this Space. Should be a dataset ID found on https://hf.co/models.
datasets (`str`, *optional*)
List of datasets related to this Space. Should be a dataset ID found on https://hf.co/datasets.
tags (`str`, *optional*)
List of tags to add to your model that can be used when filtering on the Hub.
kwargs (`dict`, *optional*):
Additional metadata that will be added to the space card.
Example:
```python
>>> from huggingface_hub import SpaceCardData
>>> card_data = SpaceCardData(
... title="Dreambooth Training",
... license="mit",
... sdk="gradio",
... duplicated_from="multimodalart/dreambooth-training"
... )
>>> card_data.to_dict()
{'title': 'Dreambooth Training', 'sdk': 'gradio', 'license': 'mit', 'duplicated_from': 'multimodalart/dreambooth-training'}
```
"""

def __init__(
self,
*,
title: Optional[str] = None,
sdk: Optional[str] = None,
sdk_version: Optional[str] = None,
python_version: Optional[str] = None,
app_file: Optional[str] = None,
app_port: Optional[int] = None,
license: Optional[str] = None,
duplicated_from: Optional[str] = None,
models: Optional[List[str]] = None,
datasets: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
**kwargs,
):
self.title = title
self.sdk = sdk
self.sdk_version = sdk_version
self.python_version = python_version
self.app_file = app_file
self.app_port = app_port
self.license = license
self.duplicated_from = duplicated_from
self.models = models
self.datasets = datasets
self.tags = tags
super().__init__(**kwargs)


def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]:
"""Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects.
Expand Down
13 changes: 13 additions & 0 deletions tests/test_repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
ModelCard,
ModelCardData,
RepoCard,
SpaceCard,
SpaceCardData,
metadata_eval_result,
metadata_load,
metadata_save,
Expand All @@ -54,6 +56,7 @@
repo_name,
retry_endpoint,
rmtree_with_retry,
with_production_testing,
)


Expand Down Expand Up @@ -946,3 +949,13 @@ def test_dataset_card_from_custom_template(self):

# some_data is at the bottom of the template, so should end with whatever we passed to it
self.assertTrue(card.text.strip().endswith("asdf"))


@with_production_testing
class SpaceCardTest(TestCaseWithCapLog):
def test_load_spacecard_from_hub(self) -> None:
card = SpaceCard.load("multimodalart/dreambooth-training")
self.assertIsInstance(card, SpaceCard)
self.assertIsInstance(card.data, SpaceCardData)
self.assertEqual(card.data.title, "Dreambooth Training")
self.assertIsNone(card.data.app_port)
21 changes: 21 additions & 0 deletions tests/test_repocard_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import yaml

from huggingface_hub import SpaceCardData
from huggingface_hub.repocard_data import (
CardData,
DatasetCardData,
Expand Down Expand Up @@ -223,3 +224,23 @@ def test_train_eval_index_keys_updated(self):
self.assertTrue(card_data.to_dict().get("train_eval_index") is None)
# And train-eval-index should be in the dict
self.assertEqual(card_data.to_dict()["train-eval-index"], train_eval_index)


class SpaceCardDataTest(unittest.TestCase):
def test_space_card_data(self) -> None:
card_data = SpaceCardData(
title="Dreambooth Training",
license="mit",
sdk="gradio",
duplicated_from="multimodalart/dreambooth-training",
)
self.assertEqual(
card_data.to_dict(),
{
"title": "Dreambooth Training",
"sdk": "gradio",
"license": "mit",
"duplicated_from": "multimodalart/dreambooth-training",
},
)
self.assertIsNone(card_data.tags) # SpaceCardData has some default attributes

0 comments on commit d5e71bb

Please sign in to comment.