From bf81df939e2d51f2388e43302fc922295122137b Mon Sep 17 00:00:00 2001 From: muellerzr Date: Tue, 1 Feb 2022 19:54:32 +0000 Subject: [PATCH 1/8] Include cardData --- src/huggingface_hub/hf_api.py | 11 +++++++++++ tests/test_hf_api.py | 20 ++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index a2054a7f01..a230418d64 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -517,6 +517,7 @@ def list_models( direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, full: Optional[bool] = None, + cardData: Optional[bool] = True, fetch_config: Optional[bool] = None, ) -> List[ModelInfo]: """ @@ -590,6 +591,9 @@ def list_models( full (:obj:`bool`, `optional`): Whether to fetch all model data, including the `lastModified`, the `sha`, the files and the `tags`. This is set to `True` by default when using a filter. + cardData (:obj:`bool`, `optional`): + Whether to grab the metadata for the model as well. Can contain useful information such as carbon emissions, + metrics, and datasets trained on. fetch_config (:obj:`bool`, `optional`): Whether to fetch the model configs as well. This is not included in `full` due to its size. @@ -619,6 +623,8 @@ def list_models( del params["full"] if fetch_config is not None: params.update({"config": fetch_config}) + if cardData is not None: + params.update({"cardData": cardData}) r = requests.get(path, params=params) r.raise_for_status() d = r.json() @@ -692,6 +698,7 @@ def list_datasets( sort: Union[Literal["lastModified"], str, None] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, + cardData: Optional[bool] = True, full: Optional[bool] = None, ) -> List[DatasetInfo]: """ @@ -762,6 +769,8 @@ def list_datasets( sort by ascending order. limit (:obj:`int`, `optional`): The limit on the number of datasets fetched. Leaving this option to `None` fetches all datasets. + cardData (:obj:`bool`, `optional`): + Whether to grab the metadata for the dataset as well. Can contain useful information such as the PapersWithCode ID. full (:obj:`bool`, `optional`): Whether to fetch all dataset data, including the `lastModified` and the `cardData`. @@ -786,6 +795,8 @@ def list_datasets( if full is not None: if full: params.update({"full": True}) + if cardData is not None: + params.update({"cardData": True}) r = requests.get(path, params=params) r.raise_for_status() d = r.json() diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index b2ac9fd85b..14c9c49e45 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -672,6 +672,14 @@ def test_list_datasets_search(self): self.assertGreater(len(datasets), 10) self.assertIsInstance(datasets[0], DatasetInfo) + @with_production_testing + def test_filter_datasets_with_cardData(self): + _api = HfApi() + datasets = _api.list_datasets(cardData=True) + self.assertTrue(["cardData" in dataset for dataset in datasets]) + datasets = _api.list_datasets(cardData=False) + self.assertTrue(["cardData" not in dataset for dataset in datasets]) + @with_production_testing def test_dataset_info(self): _api = HfApi() @@ -780,6 +788,14 @@ def test_filter_models_with_complex_query(self): ["pytorch" in model.tags and "tf" in model.tags for model in models] ) + @with_production_testing + def test_filter_models_with_cardData(self): + _api = HfApi() + models = _api.list_models(search="albert", cardData=True) + self.assertTrue(["cardData" in model for model in models]) + models = _api.list_models(search="albert", cardData=False) + self.assertTrue(["cardData" not in model for model in models]) + class HfApiPrivateTest(HfApiCommonTestWithLogin): def setUp(self) -> None: @@ -858,7 +874,7 @@ def test_end_to_end_thresh_6M(self): REMOTE_URL = self._api.create_repo( name=self.REPO_NAME_LARGE_FILE, token=self._token, - lfsmultipartthresh=6 * 10 ** 6, + lfsmultipartthresh=6 * 10**6, ) self.setup_local_clone(REMOTE_URL) @@ -911,7 +927,7 @@ def test_end_to_end_thresh_16M(self): REMOTE_URL = self._api.create_repo( name=self.REPO_NAME_LARGE_FILE, token=self._token, - lfsmultipartthresh=16 * 10 ** 6, + lfsmultipartthresh=16 * 10**6, ) self.setup_local_clone(REMOTE_URL) From 219e36ba701a956a3a1b43021559468f255486d5 Mon Sep 17 00:00:00 2001 From: muellerzr Date: Tue, 1 Feb 2022 20:14:32 +0000 Subject: [PATCH 2/8] Fix tests --- tests/test_hf_api.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 14c9c49e45..7104091cd6 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -676,9 +676,9 @@ def test_list_datasets_search(self): def test_filter_datasets_with_cardData(self): _api = HfApi() datasets = _api.list_datasets(cardData=True) - self.assertTrue(["cardData" in dataset for dataset in datasets]) + self.assertTrue([hasattr(dataset, "cardData") for dataset in datasets]) datasets = _api.list_datasets(cardData=False) - self.assertTrue(["cardData" not in dataset for dataset in datasets]) + self.assertTrue([not hasattr(dataset, "cardData") for dataset in datasets]) @with_production_testing def test_dataset_info(self): @@ -791,10 +791,10 @@ def test_filter_models_with_complex_query(self): @with_production_testing def test_filter_models_with_cardData(self): _api = HfApi() - models = _api.list_models(search="albert", cardData=True) - self.assertTrue(["cardData" in model for model in models]) - models = _api.list_models(search="albert", cardData=False) - self.assertTrue(["cardData" not in model for model in models]) + models = _api.list_models("co2_eq_emissions", cardData=True) + self.assertTrue([hasattr(model, "cardData") for model in models]) + models = _api.list_models("co2_eq_emissions", cardData=False) + self.assertTrue([not hasattr(model, "cardData") for model in models]) class HfApiPrivateTest(HfApiCommonTestWithLogin): From 58857ca9f0d6918d727750e333d6531735c7a043 Mon Sep 17 00:00:00 2001 From: muellerzr Date: Wed, 2 Feb 2022 11:54:41 -0500 Subject: [PATCH 3/8] Black version --- tests/test_hf_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 7104091cd6..d148197378 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -874,7 +874,7 @@ def test_end_to_end_thresh_6M(self): REMOTE_URL = self._api.create_repo( name=self.REPO_NAME_LARGE_FILE, token=self._token, - lfsmultipartthresh=6 * 10**6, + lfsmultipartthresh=6 * 10 ** 6, ) self.setup_local_clone(REMOTE_URL) @@ -927,7 +927,7 @@ def test_end_to_end_thresh_16M(self): REMOTE_URL = self._api.create_repo( name=self.REPO_NAME_LARGE_FILE, token=self._token, - lfsmultipartthresh=16 * 10**6, + lfsmultipartthresh=16 * 10 ** 6, ) self.setup_local_clone(REMOTE_URL) From 5ffdbc9e96ed5d287c06b7b73bb9cc3f0c697126 Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Fri, 4 Feb 2022 12:24:57 -0500 Subject: [PATCH 4/8] Make it opt-in Co-authored-by: Lysandre Debut --- src/huggingface_hub/hf_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index a230418d64..18c673bd5b 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -796,7 +796,8 @@ def list_datasets( if full: params.update({"full": True}) if cardData is not None: - params.update({"cardData": True}) + if cardData: + params.update({"cardData": True}) r = requests.get(path, params=params) r.raise_for_status() d = r.json() From b5e6c5219a6cf3acc797738688f12fe5820e1b56 Mon Sep 17 00:00:00 2001 From: muellerzr Date: Fri, 4 Feb 2022 12:31:52 -0500 Subject: [PATCH 5/8] Keep as optional, improve datasets --- src/huggingface_hub/hf_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 18c673bd5b..487547a867 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -517,7 +517,7 @@ def list_models( direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, full: Optional[bool] = None, - cardData: Optional[bool] = True, + cardData: Optional[bool] = None, fetch_config: Optional[bool] = None, ) -> List[ModelInfo]: """ From 8160410fd7cde8e0ccac6199b038a2cb58be5ea8 Mon Sep 17 00:00:00 2001 From: muellerzr Date: Fri, 4 Feb 2022 12:39:48 -0500 Subject: [PATCH 6/8] Add headers --- src/huggingface_hub/hf_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 26b1423c22..cfb493b754 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -641,7 +641,7 @@ def list_models( params.update({"config": fetch_config}) if cardData is not None: params.update({"cardData": cardData}) - r = requests.get(path, params=params) + r = requests.get(path, params=params, headers=headers) r.raise_for_status() d = r.json() return [ModelInfo(**x) for x in d] @@ -830,7 +830,7 @@ def list_datasets( if cardData is not None: if cardData: params.update({"cardData": True}) - r = requests.get(path, params=params) + r = requests.get(path, params=params, headers=headers) r.raise_for_status() d = r.json() return [DatasetInfo(**x) for x in d] From 22e0b087d846bd5483f8e03afbd6bd390ef5728b Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Fri, 4 Feb 2022 16:35:36 -0500 Subject: [PATCH 7/8] Adjust tests Co-authored-by: Lysandre Debut --- tests/test_hf_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 4c12c32c55..302bf92d58 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -677,8 +677,8 @@ def test_filter_datasets_with_cardData(self): _api = HfApi() datasets = _api.list_datasets(cardData=True) self.assertTrue([hasattr(dataset, "cardData") for dataset in datasets]) - datasets = _api.list_datasets(cardData=False) - self.assertTrue([not hasattr(dataset, "cardData") for dataset in datasets]) + datasets = _api.list_datasets() + self.assertTrue(all([not hasattr(dataset, "cardData") for dataset in datasets])) @with_production_testing def test_dataset_info(self): @@ -793,8 +793,8 @@ def test_filter_models_with_cardData(self): _api = HfApi() models = _api.list_models("co2_eq_emissions", cardData=True) self.assertTrue([hasattr(model, "cardData") for model in models]) - models = _api.list_models("co2_eq_emissions", cardData=False) - self.assertTrue([not hasattr(model, "cardData") for model in models]) + models = _api.list_models("co2_eq_emissions") + self.assertTrue(all([not hasattr(model, "cardData") for model in models])) class HfApiPrivateTest(HfApiCommonTestWithLogin): From 1c6a2f0ab64f8a172b40e3f235582a4e0b327ee6 Mon Sep 17 00:00:00 2001 From: muellerzr Date: Fri, 4 Feb 2022 17:48:40 -0500 Subject: [PATCH 8/8] Fix tests and functionality --- src/huggingface_hub/hf_api.py | 4 ++-- tests/test_hf_api.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index cfb493b754..9e44b7defd 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -714,7 +714,7 @@ def list_datasets( sort: Union[Literal["lastModified"], str, None] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, - cardData: Optional[bool] = True, + cardData: Optional[bool] = None, full: Optional[bool] = None, use_auth_token: Optional[str] = None, ) -> List[DatasetInfo]: @@ -829,7 +829,7 @@ def list_datasets( params.update({"full": True}) if cardData is not None: if cardData: - params.update({"cardData": True}) + params.update({"full": True}) r = requests.get(path, params=params, headers=headers) r.raise_for_status() d = r.json() diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 302bf92d58..b7ab619b53 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -676,9 +676,16 @@ def test_list_datasets_search(self): def test_filter_datasets_with_cardData(self): _api = HfApi() datasets = _api.list_datasets(cardData=True) - self.assertTrue([hasattr(dataset, "cardData") for dataset in datasets]) + self.assertGreater( + sum( + [getattr(dataset, "cardData", None) is not None for dataset in datasets] + ), + 0, + ) datasets = _api.list_datasets() - self.assertTrue(all([not hasattr(dataset, "cardData") for dataset in datasets])) + self.assertTrue( + all([getattr(dataset, "cardData", None) is None for dataset in datasets]) + ) @with_production_testing def test_dataset_info(self):