Skip to content

Commit

Permalink
Remove deprecated code + adapt tests (#1450)
Browse files Browse the repository at this point in the history
* Remove deprecated code + adapt tests

* code quality

* fix test
  • Loading branch information
Wauplin authored May 15, 2023
1 parent ab01690 commit 10b1cb2
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 163 deletions.
2 changes: 1 addition & 1 deletion src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from typing import TYPE_CHECKING


__version__ = "0.14.0.dev0"
__version__ = "0.15.0.dev0"

# Alphabetical order of definitions is ensured in tests
# WARNING: any comment added in this dictionary definition will be lost when
Expand Down
69 changes: 27 additions & 42 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@
)
from .utils._deprecation import (
_deprecate_arguments,
_deprecate_list_output,
)
from .utils._typing import Literal, TypedDict
from .utils.endpoint_helpers import (
Expand Down Expand Up @@ -891,7 +890,6 @@ def get_dataset_tags(self) -> DatasetTags:
d = r.json()
return DatasetTags(d)

@_deprecate_list_output(version="0.14")
@validate_hf_hub_args
def list_models(
self,
Expand All @@ -907,9 +905,9 @@ def list_models(
cardData: bool = False,
fetch_config: bool = False,
token: Optional[Union[bool, str]] = None,
) -> List[ModelInfo]:
) -> Iterable[ModelInfo]:
"""
Get the list of all the models on huggingface.co
List models hosted on the Huggingface Hub, given some filters.
Args:
filter ([`ModelFilter`] or `str` or `Iterable`, *optional*):
Expand Down Expand Up @@ -950,9 +948,7 @@ def list_models(
If `False`, token is not sent in the request header.
Returns:
`List[ModelInfo]`: a list of [`huggingface_hub.hf_api.ModelInfo`] objects.
To anticipate future pagination, please consider the return value to be a
simple iterator.
`Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects.
Example usage with the `filter` argument:
Expand Down Expand Up @@ -1002,6 +998,9 @@ def list_models(
>>> api.list_models(search="bert", author="google")
```
"""
if emissions_thresholds is not None and cardData is None:
raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.")

path = f"{self.endpoint}/api/models"
headers = self._build_hf_headers(token=token)
params = {}
Expand Down Expand Up @@ -1031,18 +1030,14 @@ def list_models(
if cardData:
params.update({"cardData": True})

data = paginate(path, params=params, headers=headers)
# `items` is a generator
items = paginate(path, params=params, headers=headers)
if limit is not None:
data = islice(data, limit) # Do not iterate over all pages
items = [ModelInfo(**x) for x in data]

items = islice(items, limit) # Do not iterate over all pages
if emissions_thresholds is not None:
if cardData is None:
raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.")
else:
return _filter_emissions(items, *emissions_thresholds)

return items
items = _filter_emissions(items, *emissions_thresholds)
for item in items:
yield ModelInfo(**item)

def _unpack_model_filter(self, model_filter: ModelFilter):
"""
Expand Down Expand Up @@ -1096,12 +1091,6 @@ def _unpack_model_filter(self, model_filter: ModelFilter):
query_dict["filter"] = tuple(filter_list)
return query_dict

@_deprecate_arguments(
version="0.14",
deprecated_args={"cardData"},
custom_message="Use 'full' instead.",
)
@_deprecate_list_output(version="0.14")
@validate_hf_hub_args
def list_datasets(
self,
Expand All @@ -1112,12 +1101,11 @@ def list_datasets(
sort: Union[Literal["lastModified"], str, None] = None,
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
cardData: Optional[bool] = None, # deprecated
full: Optional[bool] = None,
token: Optional[str] = None,
) -> List[DatasetInfo]:
) -> Iterable[DatasetInfo]:
"""
Get the list of all the datasets on huggingface.co
List datasets hosted on the Huggingface Hub, given some filters.
Args:
filter ([`DatasetFilter`] or `str` or `Iterable`, *optional*):
Expand Down Expand Up @@ -1147,9 +1135,7 @@ def list_datasets(
If `False`, token is not sent in the request header.
Returns:
`List[DatasetInfo]`: a list of [`huggingface_hub.hf_api.DatasetInfo`] objects.
To anticipate future pagination, please consider the return value to be a
simple iterator.
`Iterable[DatasetInfo]`: an iterable of [`huggingface_hub.hf_api.DatasetInfo`] objects.
Example usage with the `filter` argument:
Expand Down Expand Up @@ -1218,13 +1204,14 @@ def list_datasets(
params.update({"direction": direction})
if limit is not None:
params.update({"limit": limit})
if full or cardData:
if full:
params.update({"full": True})

data = paginate(path, params=params, headers=headers)
items = paginate(path, params=params, headers=headers)
if limit is not None:
data = islice(data, limit) # Do not iterate over all pages
return [DatasetInfo(**x) for x in data]
items = islice(items, limit) # Do not iterate over all pages
for item in items:
yield DatasetInfo(**item)

def _unpack_dataset_filter(self, dataset_filter: DatasetFilter):
"""
Expand Down Expand Up @@ -1280,7 +1267,6 @@ def list_metrics(self) -> List[MetricInfo]:
d = r.json()
return [MetricInfo(**x) for x in d]

@_deprecate_list_output(version="0.14")
@validate_hf_hub_args
def list_spaces(
self,
Expand All @@ -1296,9 +1282,9 @@ def list_spaces(
linked: bool = False,
full: Optional[bool] = None,
token: Optional[str] = None,
) -> List[SpaceInfo]:
) -> Iterable[SpaceInfo]:
"""
Get the public list of all Spaces on huggingface.co
List spaces hosted on the Huggingface Hub, given some filters.
Args:
filter (`str` or `Iterable`, *optional*):
Expand Down Expand Up @@ -1334,9 +1320,7 @@ def list_spaces(
If `False`, token is not sent in the request header.
Returns:
`List[SpaceInfo]`: a list of [`huggingface_hub.hf_api.SpaceInfo`] objects.
To anticipate future pagination, please consider the return value to be a
simple iterator.
`Iterable[SpaceInfo]`: an iterable of [`huggingface_hub.hf_api.SpaceInfo`] objects.
"""
path = f"{self.endpoint}/api/spaces"
headers = self._build_hf_headers(token=token)
Expand All @@ -1362,10 +1346,11 @@ def list_spaces(
if models is not None:
params.update({"models": models})

data = paginate(path, params=params, headers=headers)
items = paginate(path, params=params, headers=headers)
if limit is not None:
data = islice(data, limit) # Do not iterate over all pages
return [SpaceInfo(**x) for x in data]
items = islice(items, limit) # Do not iterate over all pages
for item in items:
yield SpaceInfo(**item)

@validate_hf_hub_args
def like(
Expand Down
55 changes: 31 additions & 24 deletions src/huggingface_hub/utils/endpoint_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
import math
import re
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Union


if TYPE_CHECKING:
from ..hf_api import ModelInfo


def _filter_emissions(
models,
models: Iterable["ModelInfo"],
minimum_threshold: Optional[float] = None,
maximum_threshold: Optional[float] = None,
):
"""Filters a list of models for those that include an emission tag
and limit them to between two thresholds
) -> Iterable["ModelInfo"]:
"""Filters a list of models for those that include an emission tag and limit them to between two thresholds
Args:
models (`ModelInfo` or `List`):
A list of `ModelInfo`'s to filter by.
models (Iterable of `ModelInfo`):
A list of models to filter.
minimum_threshold (`float`, *optional*):
A minimum carbon threshold to filter by, such as 1.
maximum_threshold (`float`, *optional*):
Expand All @@ -41,23 +44,27 @@ def _filter_emissions(
minimum_threshold = -1
if maximum_threshold is None:
maximum_threshold = math.inf
emissions = []
for i, model in enumerate(models):
if hasattr(model, "cardData"):
if isinstance(model.cardData, dict):
emission = model.cardData.get("co2_eq_emissions", None)
if isinstance(emission, dict):
emission = emission["emissions"]
if emission:
emission = str(emission)
matched = re.search(r"\d+\.\d+|\d+", emission)
if matched is not None:
emissions.append((i, float(matched.group(0))))
filtered_results = []
for idx, emission in emissions:
if emission >= minimum_threshold and emission <= maximum_threshold:
filtered_results.append(models[idx])
return filtered_results

for model in models:
card_data = getattr(model, "cardData", None)
if card_data is None or not isinstance(card_data, dict):
continue

# Get CO2 emission metadata
emission = card_data.get("co2_eq_emissions", None)
if isinstance(emission, dict):
emission = emission["emissions"]
if not emission:
continue

# Filter out if value is missing or out of range
matched = re.search(r"\d+\.\d+|\d+", str(emission))
if matched is None:
continue

emission_value = float(matched.group(0))
if emission_value >= minimum_threshold and emission_value <= maximum_threshold:
yield model


@dataclass
Expand Down
Loading

0 comments on commit 10b1cb2

Please sign in to comment.