Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include cardData in list_models and list_datasets #639

Merged
merged 9 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def list_models(
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
full: Optional[bool] = None,
cardData: Optional[bool] = True,
fetch_config: Optional[bool] = None,
) -> List[ModelInfo]:
"""
Expand Down Expand Up @@ -590,6 +591,9 @@ def list_models(
full (:obj:`bool`, `optional`):
Whether to fetch all model data, including the `lastModified`, the `sha`, the files and the `tags`.
This is set to `True` by default when using a filter.
cardData (:obj:`bool`, `optional`):
Whether to grab the metadata for the model as well. Can contain useful information such as carbon emissions,
metrics, and datasets trained on.
fetch_config (:obj:`bool`, `optional`):
Whether to fetch the model configs as well. This is not included in `full` due to its size.

Expand Down Expand Up @@ -619,6 +623,8 @@ def list_models(
del params["full"]
if fetch_config is not None:
params.update({"config": fetch_config})
if cardData is not None:
params.update({"cardData": cardData})
r = requests.get(path, params=params)
r.raise_for_status()
d = r.json()
Expand Down Expand Up @@ -692,6 +698,7 @@ def list_datasets(
sort: Union[Literal["lastModified"], str, None] = None,
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
cardData: Optional[bool] = True,
full: Optional[bool] = None,
) -> List[DatasetInfo]:
"""
Expand Down Expand Up @@ -762,6 +769,8 @@ def list_datasets(
sort by ascending order.
limit (:obj:`int`, `optional`):
The limit on the number of datasets fetched. Leaving this option to `None` fetches all datasets.
cardData (:obj:`bool`, `optional`):
Whether to grab the metadata for the dataset as well. Can contain useful information such as the PapersWithCode ID.
full (:obj:`bool`, `optional`):
Whether to fetch all dataset data, including the `lastModified` and the `cardData`.

Expand All @@ -786,6 +795,8 @@ def list_datasets(
if full is not None:
if full:
params.update({"full": True})
if cardData is not None:
params.update({"cardData": True})
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
r = requests.get(path, params=params)
r.raise_for_status()
d = r.json()
Expand Down
16 changes: 16 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,14 @@ def test_list_datasets_search(self):
self.assertGreater(len(datasets), 10)
self.assertIsInstance(datasets[0], DatasetInfo)

@with_production_testing
def test_filter_datasets_with_cardData(self):
_api = HfApi()
datasets = _api.list_datasets(cardData=True)
self.assertTrue([hasattr(dataset, "cardData") for dataset in datasets])
datasets = _api.list_datasets(cardData=False)
self.assertTrue([not hasattr(dataset, "cardData") for dataset in datasets])
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

@with_production_testing
def test_dataset_info(self):
_api = HfApi()
Expand Down Expand Up @@ -780,6 +788,14 @@ def test_filter_models_with_complex_query(self):
["pytorch" in model.tags and "tf" in model.tags for model in models]
)

@with_production_testing
def test_filter_models_with_cardData(self):
_api = HfApi()
models = _api.list_models("co2_eq_emissions", cardData=True)
self.assertTrue([hasattr(model, "cardData") for model in models])
models = _api.list_models("co2_eq_emissions", cardData=False)
self.assertTrue([not hasattr(model, "cardData") for model in models])
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


class HfApiPrivateTest(HfApiCommonTestWithLogin):
def setUp(self) -> None:
Expand Down