diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index fc11433b54..9e44b7defd 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -517,6 +517,7 @@ def list_models( direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, full: Optional[bool] = None, + cardData: Optional[bool] = None, fetch_config: Optional[bool] = None, use_auth_token: Optional[Union[bool, str]] = None, ) -> List[ModelInfo]: @@ -591,6 +592,9 @@ def list_models( full (:obj:`bool`, `optional`): Whether to fetch all model data, including the `lastModified`, the `sha`, the files and the `tags`. This is set to `True` by default when using a filter. + cardData (:obj:`bool`, `optional`): + Whether to grab the metadata for the model as well. Can contain useful information such as carbon emissions, + metrics, and datasets trained on. fetch_config (:obj:`bool`, `optional`): Whether to fetch the model configs as well. This is not included in `full` due to its size. use_auth_token (:obj:`bool` or :obj:`str`, `optional`): @@ -635,7 +639,9 @@ def list_models( del params["full"] if fetch_config is not None: params.update({"config": fetch_config}) - r = requests.get(path, headers=headers, params=params) + if cardData is not None: + params.update({"cardData": cardData}) + r = requests.get(path, params=params, headers=headers) r.raise_for_status() d = r.json() return [ModelInfo(**x) for x in d] @@ -708,6 +714,7 @@ def list_datasets( sort: Union[Literal["lastModified"], str, None] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, + cardData: Optional[bool] = None, full: Optional[bool] = None, use_auth_token: Optional[str] = None, ) -> List[DatasetInfo]: @@ -779,6 +786,8 @@ def list_datasets( sort by ascending order. limit (:obj:`int`, `optional`): The limit on the number of datasets fetched. Leaving this option to `None` fetches all datasets. + cardData (:obj:`bool`, `optional`): + Whether to grab the metadata for the dataset as well. Can contain useful information such as the PapersWithCode ID. full (:obj:`bool`, `optional`): Whether to fetch all dataset data, including the `lastModified` and the `cardData`. use_auth_token (:obj:`bool` or :obj:`str`, `optional`): @@ -818,7 +827,10 @@ def list_datasets( if full is not None: if full: params.update({"full": True}) - r = requests.get(path, headers=headers, params=params) + if cardData is not None: + if cardData: + params.update({"full": True}) + r = requests.get(path, params=params, headers=headers) r.raise_for_status() d = r.json() return [DatasetInfo(**x) for x in d] diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index a4bbec425d..b7ab619b53 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -672,6 +672,21 @@ def test_list_datasets_search(self): self.assertGreater(len(datasets), 10) self.assertIsInstance(datasets[0], DatasetInfo) + @with_production_testing + def test_filter_datasets_with_cardData(self): + _api = HfApi() + datasets = _api.list_datasets(cardData=True) + self.assertGreater( + sum( + [getattr(dataset, "cardData", None) is not None for dataset in datasets] + ), + 0, + ) + datasets = _api.list_datasets() + self.assertTrue( + all([getattr(dataset, "cardData", None) is None for dataset in datasets]) + ) + @with_production_testing def test_dataset_info(self): _api = HfApi() @@ -780,6 +795,14 @@ def test_filter_models_with_complex_query(self): ["pytorch" in model.tags and "tf" in model.tags for model in models] ) + @with_production_testing + def test_filter_models_with_cardData(self): + _api = HfApi() + models = _api.list_models("co2_eq_emissions", cardData=True) + self.assertTrue([hasattr(model, "cardData") for model in models]) + models = _api.list_models("co2_eq_emissions") + self.assertTrue(all([not hasattr(model, "cardData") for model in models])) + class HfApiPrivateTest(HfApiCommonTestWithLogin): def setUp(self) -> None: