From ec12cccdaa1a52a5cd3af661532b15924579d283 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Wed, 26 Apr 2023 12:00:14 +0200 Subject: [PATCH 1/3] Remove deprecated code + adapt tests --- src/huggingface_hub/__init__.py | 2 +- src/huggingface_hub/hf_api.py | 69 ++++---- src/huggingface_hub/utils/endpoint_helpers.py | 56 ++++--- tests/test_hf_api.py | 147 ++++++------------ 4 files changed, 111 insertions(+), 163 deletions(-) diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 37c2d1fa1b..600fc80cb7 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -46,7 +46,7 @@ from typing import TYPE_CHECKING -__version__ = "0.14.0.dev0" +__version__ = "0.15.0.dev0" # Alphabetical order of definitions is ensured in tests # WARNING: any comment added in this dictionary definition will be lost when diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 5ee78ebca9..d998e9ff0f 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -90,7 +90,6 @@ ) from .utils._deprecation import ( _deprecate_arguments, - _deprecate_list_output, ) from .utils._typing import Literal, TypedDict from .utils.endpoint_helpers import ( @@ -891,7 +890,6 @@ def get_dataset_tags(self) -> DatasetTags: d = r.json() return DatasetTags(d) - @_deprecate_list_output(version="0.14") @validate_hf_hub_args def list_models( self, @@ -907,9 +905,9 @@ def list_models( cardData: bool = False, fetch_config: bool = False, token: Optional[Union[bool, str]] = None, - ) -> List[ModelInfo]: + ) -> Iterable[ModelInfo]: """ - Get the list of all the models on huggingface.co + List models hosted on the Huggingface Hub, given some filters. Args: filter ([`ModelFilter`] or `str` or `Iterable`, *optional*): @@ -950,9 +948,7 @@ def list_models( If `False`, token is not sent in the request header. Returns: - `List[ModelInfo]`: a list of [`huggingface_hub.hf_api.ModelInfo`] objects. - To anticipate future pagination, please consider the return value to be a - simple iterator. + `Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects. Example usage with the `filter` argument: @@ -1002,6 +998,9 @@ def list_models( >>> api.list_models(search="bert", author="google") ``` """ + if emissions_thresholds is not None and cardData is None: + raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.") + path = f"{self.endpoint}/api/models" headers = self._build_hf_headers(token=token) params = {} @@ -1031,18 +1030,14 @@ def list_models( if cardData: params.update({"cardData": True}) - data = paginate(path, params=params, headers=headers) + # `items` is a generator + items = paginate(path, params=params, headers=headers) if limit is not None: - data = islice(data, limit) # Do not iterate over all pages - items = [ModelInfo(**x) for x in data] - + items = islice(items, limit) # Do not iterate over all pages if emissions_thresholds is not None: - if cardData is None: - raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.") - else: - return _filter_emissions(items, *emissions_thresholds) - - return items + items = _filter_emissions(items, *emissions_thresholds) + for item in items: + yield ModelInfo(**item) def _unpack_model_filter(self, model_filter: ModelFilter): """ @@ -1096,12 +1091,6 @@ def _unpack_model_filter(self, model_filter: ModelFilter): query_dict["filter"] = tuple(filter_list) return query_dict - @_deprecate_arguments( - version="0.14", - deprecated_args={"cardData"}, - custom_message="Use 'full' instead.", - ) - @_deprecate_list_output(version="0.14") @validate_hf_hub_args def list_datasets( self, @@ -1112,12 +1101,11 @@ def list_datasets( sort: Union[Literal["lastModified"], str, None] = None, direction: Optional[Literal[-1]] = None, limit: Optional[int] = None, - cardData: Optional[bool] = None, # deprecated full: Optional[bool] = None, token: Optional[str] = None, - ) -> List[DatasetInfo]: + ) -> Iterable[DatasetInfo]: """ - Get the list of all the datasets on huggingface.co + List datasets hosted on the Huggingface Hub, given some filters. Args: filter ([`DatasetFilter`] or `str` or `Iterable`, *optional*): @@ -1147,9 +1135,7 @@ def list_datasets( If `False`, token is not sent in the request header. Returns: - `List[DatasetInfo]`: a list of [`huggingface_hub.hf_api.DatasetInfo`] objects. - To anticipate future pagination, please consider the return value to be a - simple iterator. + `Iterable[DatasetInfo]`: an iterable of [`huggingface_hub.hf_api.DatasetInfo`] objects. Example usage with the `filter` argument: @@ -1218,13 +1204,14 @@ def list_datasets( params.update({"direction": direction}) if limit is not None: params.update({"limit": limit}) - if full or cardData: + if full: params.update({"full": True}) - data = paginate(path, params=params, headers=headers) + items = paginate(path, params=params, headers=headers) if limit is not None: - data = islice(data, limit) # Do not iterate over all pages - return [DatasetInfo(**x) for x in data] + items = islice(items, limit) # Do not iterate over all pages + for item in items: + yield DatasetInfo(**item) def _unpack_dataset_filter(self, dataset_filter: DatasetFilter): """ @@ -1280,7 +1267,6 @@ def list_metrics(self) -> List[MetricInfo]: d = r.json() return [MetricInfo(**x) for x in d] - @_deprecate_list_output(version="0.14") @validate_hf_hub_args def list_spaces( self, @@ -1296,9 +1282,9 @@ def list_spaces( linked: bool = False, full: Optional[bool] = None, token: Optional[str] = None, - ) -> List[SpaceInfo]: + ) -> Iterable[SpaceInfo]: """ - Get the public list of all Spaces on huggingface.co + List spaces hosted on the Huggingface Hub, given some filters. Args: filter (`str` or `Iterable`, *optional*): @@ -1334,9 +1320,7 @@ def list_spaces( If `False`, token is not sent in the request header. Returns: - `List[SpaceInfo]`: a list of [`huggingface_hub.hf_api.SpaceInfo`] objects. - To anticipate future pagination, please consider the return value to be a - simple iterator. + `Iterable[SpaceInfo]`: an iterable of [`huggingface_hub.hf_api.SpaceInfo`] objects. """ path = f"{self.endpoint}/api/spaces" headers = self._build_hf_headers(token=token) @@ -1362,10 +1346,11 @@ def list_spaces( if models is not None: params.update({"models": models}) - data = paginate(path, params=params, headers=headers) + items = paginate(path, params=params, headers=headers) if limit is not None: - data = islice(data, limit) # Do not iterate over all pages - return [SpaceInfo(**x) for x in data] + items = islice(items, limit) # Do not iterate over all pages + for item in items: + yield SpaceInfo(**item) @validate_hf_hub_args def like( diff --git a/src/huggingface_hub/utils/endpoint_helpers.py b/src/huggingface_hub/utils/endpoint_helpers.py index 4286feb9d3..ac614d6a51 100644 --- a/src/huggingface_hub/utils/endpoint_helpers.py +++ b/src/huggingface_hub/utils/endpoint_helpers.py @@ -16,20 +16,23 @@ import math import re from dataclasses import dataclass -from typing import List, Optional, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Union + + +if TYPE_CHECKING: + from ..hf_api import ModelInfo def _filter_emissions( - models, + models: Iterable["ModelInfo"], minimum_threshold: Optional[float] = None, maximum_threshold: Optional[float] = None, -): - """Filters a list of models for those that include an emission tag - and limit them to between two thresholds +) -> Iterable["ModelInfo"]: + """Filters a list of models for those that include an emission tag and limit them to between two thresholds Args: - models (`ModelInfo` or `List`): - A list of `ModelInfo`'s to filter by. + models (Iterable of `ModelInfo`): + A list of models to filter. minimum_threshold (`float`, *optional*): A minimum carbon threshold to filter by, such as 1. maximum_threshold (`float`, *optional*): @@ -41,23 +44,28 @@ def _filter_emissions( minimum_threshold = -1 if maximum_threshold is None: maximum_threshold = math.inf - emissions = [] - for i, model in enumerate(models): - if hasattr(model, "cardData"): - if isinstance(model.cardData, dict): - emission = model.cardData.get("co2_eq_emissions", None) - if isinstance(emission, dict): - emission = emission["emissions"] - if emission: - emission = str(emission) - matched = re.search(r"\d+\.\d+|\d+", emission) - if matched is not None: - emissions.append((i, float(matched.group(0)))) - filtered_results = [] - for idx, emission in emissions: - if emission >= minimum_threshold and emission <= maximum_threshold: - filtered_results.append(models[idx]) - return filtered_results + + for model in models: + # Check ModelInfo format + if not hasattr(model, "cardData"): + continue + if not isinstance(model.cardData, dict): + continue + + # Get CO2 emission metadata + emission = model.cardData.get("co2_eq_emissions", None) + if isinstance(emission, dict): + emission = emission["emissions"] + if not emission: + continue + + matched = re.search(r"\d+\.\d+|\d+", str(emission)) + if matched is None: + continue + + emission_value = float(matched.group(0)) + if emission_value >= minimum_threshold and emission_value <= maximum_threshold: + yield model @dataclass diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index da7c55517c..04d466ebe4 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -49,7 +49,6 @@ from huggingface_hub.hf_api import ( CommitInfo, DatasetInfo, - DatasetSearchArguments, HfApi, MetricInfo, ModelInfo, @@ -91,7 +90,6 @@ DUMMY_MODEL_ID, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, SAMPLE_DATASET_IDENTIFIER, - expect_deprecation, repo_name, require_git_lfs, retry_endpoint, @@ -1384,23 +1382,20 @@ class HfApiPublicProductionTest(unittest.TestCase): def setUp(self) -> None: self._api = HfApi() - @expect_deprecation("list_models") def test_list_models(self): - models = self._api.list_models() + models = list(self._api.list_models(limit=500)) self.assertGreater(len(models), 100) self.assertIsInstance(models[0], ModelInfo) - @expect_deprecation("list_models") def test_list_models_author(self): - models = self._api.list_models(author="google") + models = list(self._api.list_models(author="google")) self.assertGreater(len(models), 10) self.assertIsInstance(models[0], ModelInfo) for model in models: self.assertTrue(model.modelId.startswith("google/")) - @expect_deprecation("list_models") def test_list_models_search(self): - models = self._api.list_models(search="bert") + models = list(self._api.list_models(search="bert")) self.assertGreater(len(models), 10) self.assertIsInstance(models[0], ModelInfo) for model in models[:10]: @@ -1409,12 +1404,11 @@ def test_list_models_search(self): # (and changes it in the future) but for now it should do the trick. self.assertTrue("bert" in model.modelId.lower()) - @expect_deprecation("list_models") def test_list_models_complex_query(self): # Let's list the 10 most recent models # with tags "bert" and "jax", # ordered by last modified date. - models = self._api.list_models(filter=("bert", "jax"), sort="lastModified", direction=-1, limit=10) + models = list(self._api.list_models(filter=("bert", "jax"), sort="lastModified", direction=-1, limit=10)) # we have at least 1 models self.assertGreater(len(models), 1) self.assertLessEqual(len(models), 10) @@ -1479,108 +1473,84 @@ def test_list_repo_files(self): ] self.assertListEqual(files, expected_files) - @expect_deprecation("list_datasets") def test_list_datasets_no_filter(self): - datasets = self._api.list_datasets() + datasets = list(self._api.list_datasets(limit=500)) self.assertGreater(len(datasets), 100) self.assertIsInstance(datasets[0], DatasetInfo) - @expect_deprecation("list_datasets") def test_filter_datasets_by_author_and_name(self): f = DatasetFilter(author="huggingface", dataset_name="DataMeasurementsFiles") - datasets = self._api.list_datasets(filter=f) + datasets = list(self._api.list_datasets(filter=f)) self.assertEqual(len(datasets), 1) self.assertTrue("huggingface" in datasets[0].author) self.assertTrue("DataMeasurementsFiles" in datasets[0].id) - @unittest.skip( - "DatasetFilter is currently broken. See" - " https://github.com/huggingface/huggingface_hub/pull/1250. Skip test until" - " it's fixed." - ) - @expect_deprecation("list_datasets") def test_filter_datasets_by_benchmark(self): f = DatasetFilter(benchmark="raft") - datasets = self._api.list_datasets(filter=f) + datasets = list(self._api.list_datasets(filter=f)) self.assertGreater(len(datasets), 0) self.assertTrue("benchmark:raft" in datasets[0].tags) - @expect_deprecation("list_datasets") def test_filter_datasets_by_language_creator(self): f = DatasetFilter(language_creators="crowdsourced") - datasets = self._api.list_datasets(filter=f) + datasets = list(self._api.list_datasets(filter=f)) self.assertGreater(len(datasets), 0) self.assertTrue("language_creators:crowdsourced" in datasets[0].tags) - @unittest.skip( - "DatasetFilter is currently broken. See" - " https://github.com/huggingface/huggingface_hub/pull/1250. Skip test until" - " it's fixed." - ) - @expect_deprecation("list_datasets") def test_filter_datasets_by_language_only(self): - datasets = self._api.list_datasets(filter=DatasetFilter(language="en")) + datasets = list(self._api.list_datasets(filter=DatasetFilter(language="en"), limit=100)) self.assertGreater(len(datasets), 0) self.assertTrue("language:en" in datasets[0].tags) - args = DatasetSearchArguments(api=self._api) - datasets = self._api.list_datasets(filter=DatasetFilter(language=(args.language.en, args.language.fr))) + datasets = list(self._api.list_datasets(filter=DatasetFilter(language=("en", "fr")), limit=100)) self.assertGreater(len(datasets), 0) self.assertTrue("language:en" in datasets[0].tags) self.assertTrue("language:fr" in datasets[0].tags) - @expect_deprecation("list_datasets") def test_filter_datasets_by_multilinguality(self): - datasets = self._api.list_datasets(filter=DatasetFilter(multilinguality="multilingual")) + datasets = list(self._api.list_datasets(filter=DatasetFilter(multilinguality="multilingual"))) self.assertGreater(len(datasets), 0) self.assertTrue("multilinguality:multilingual" in datasets[0].tags) - @expect_deprecation("list_datasets") def test_filter_datasets_by_size_categories(self): - datasets = self._api.list_datasets(filter=DatasetFilter(size_categories="100K Date: Wed, 26 Apr 2023 12:25:27 +0200 Subject: [PATCH 2/3] code quality --- src/huggingface_hub/utils/endpoint_helpers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/huggingface_hub/utils/endpoint_helpers.py b/src/huggingface_hub/utils/endpoint_helpers.py index ac614d6a51..decc0f03d9 100644 --- a/src/huggingface_hub/utils/endpoint_helpers.py +++ b/src/huggingface_hub/utils/endpoint_helpers.py @@ -46,19 +46,18 @@ def _filter_emissions( maximum_threshold = math.inf for model in models: - # Check ModelInfo format - if not hasattr(model, "cardData"): - continue - if not isinstance(model.cardData, dict): + card_data = getattr(model, "cardData", None) + if card_data is None or not isinstance(card_data, dict): continue # Get CO2 emission metadata - emission = model.cardData.get("co2_eq_emissions", None) + emission = card_data.get("co2_eq_emissions", None) if isinstance(emission, dict): emission = emission["emissions"] if not emission: continue + # Filter out if value is missing or out of range matched = re.search(r"\d+\.\d+|\d+", str(emission)) if matched is None: continue From 9498580e5fce131d1770086fea154c0d94e99645 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Wed, 26 Apr 2023 12:27:16 +0200 Subject: [PATCH 3/3] fix test --- tests/test_hf_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 04d466ebe4..40d54f7890 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -1545,7 +1545,7 @@ def test_list_datasets_search(self): self.assertIsInstance(datasets[0], DatasetInfo) def test_filter_datasets_with_cardData(self): - datasets = list(self._api.list_datasets(cardData=True, limit=500)) + datasets = list(self._api.list_datasets(full=True, limit=500)) self.assertGreater( sum([getattr(dataset, "cardData", None) is not None for dataset in datasets]), 0,