Skip to content

Commit

Permalink
Deprecate output in list_models (#1143)
Browse files Browse the repository at this point in the history
* Deprecate output in list_models

* remove assertions

* fix tests
  • Loading branch information
Wauplin authored Nov 4, 2022
1 parent 91fe43c commit c2dbfd7
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 7 deletions.
27 changes: 22 additions & 5 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@
validate_hf_hub_args,
write_to_credential_store,
)
from .utils._deprecation import _deprecate_method, _deprecate_positional_args
from .utils._deprecation import (
_deprecate_list_output,
_deprecate_method,
_deprecate_positional_args,
)
from .utils._typing import Literal, TypedDict
from .utils.endpoint_helpers import (
AttributeDictionary,
Expand Down Expand Up @@ -662,6 +666,7 @@ def get_dataset_tags(self) -> DatasetTags:
d = r.json()
return DatasetTags(d)

@_deprecate_list_output(version="0.14")
@validate_hf_hub_args
def list_models(
self,
Expand All @@ -679,7 +684,7 @@ def list_models(
token: Optional[Union[bool, str]] = None,
) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface.co
Get the list of all the models on huggingface.co
Args:
filter ([`ModelFilter`] or `str` or `Iterable`, *optional*):
Expand Down Expand Up @@ -720,7 +725,10 @@ def list_models(
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
If `False`, token is not sent in the request header.
Returns: List of [`huggingface_hub.hf_api.ModelInfo`] objects
Returns:
`List[ModelInfo]`: a list of [`huggingface_hub.hf_api.ModelInfo`] objects.
To anticipate future pagination, please consider the return value to be a
simple iterator.
Example usage with the `filter` argument:
Expand Down Expand Up @@ -874,6 +882,7 @@ def _unpack_model_filter(self, model_filter: ModelFilter):
query_dict["filter"] = tuple(filter_list)
return query_dict

@_deprecate_list_output(version="0.14")
@validate_hf_hub_args
def list_datasets(
self,
Expand All @@ -889,7 +898,7 @@ def list_datasets(
token: Optional[str] = None,
) -> List[DatasetInfo]:
"""
Get the public list of all the datasets on huggingface.co
Get the list of all the datasets on huggingface.co
Args:
filter ([`DatasetFilter`] or `str` or `Iterable`, *optional*):
Expand Down Expand Up @@ -920,6 +929,11 @@ def list_datasets(
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
If `False`, token is not sent in the request header.
Returns:
`List[DatasetInfo]`: a list of [`huggingface_hub.hf_api.DatasetInfo`] objects.
To anticipate future pagination, please consider the return value to be a
simple iterator.
Example usage with the `filter` argument:
```python
Expand Down Expand Up @@ -1052,6 +1066,7 @@ def list_metrics(self) -> List[MetricInfo]:
d = r.json()
return [MetricInfo(**x) for x in d]

@_deprecate_list_output(version="0.14")
@validate_hf_hub_args
def list_spaces(
self,
Expand Down Expand Up @@ -1105,7 +1120,9 @@ def list_spaces(
If `False`, token is not sent in the request header.
Returns:
`List[SpaceInfo]`: a list of [`huggingface_hub.hf_api.SpaceInfo`] objects
`List[SpaceInfo]`: a list of [`huggingface_hub.hf_api.SpaceInfo`] objects.
To anticipate future pagination, please consider the return value to be a
simple iterator.
"""
path = f"{self.endpoint}/api/spaces"
headers = self._build_hf_headers(token=token)
Expand Down
102 changes: 101 additions & 1 deletion src/huggingface_hub/utils/_deprecation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from functools import wraps
from inspect import Parameter, signature
from typing import Iterable, Optional
from typing import Generator, Iterable, Optional


def _deprecate_positional_args(*, version: str):
Expand Down Expand Up @@ -131,3 +131,103 @@ def inner_f(*args, **kwargs):
return inner_f

return _inner_deprecate_method


def _deprecate_list_output(*, version: str):
"""Decorator to deprecate the usage as a list of the output of a method.
To be used when a method currently returns a list of objects but is planned to return
an generator instead in the future. Output is still a list but tweaked to issue a
warning message when it is specifically used as a list (e.g. get/set/del item, get
length,...).
Args:
version (`str`):
The version when output will start to be an generator.
"""

def _inner_deprecate_method(f):
@wraps(f)
def inner_f(*args, **kwargs):
list_value = f(*args, **kwargs)
return DeprecatedList(
list_value,
warning_message=(
"'{f.__name__}' currently returns a list of objects but is planned"
" to be a generator starting from version {version} in order to"
" implement pagination. Please avoid to use"
" `{f.__name__}(...).{attr_name}` or explicitly convert the output"
" to a list first with `list(iter({f.__name__})(...))`.".format(
f=f,
version=version,
# Dumb but working workaround to render `attr_name` later
# Taken from https://stackoverflow.com/a/35300723
attr_name="{attr_name}",
)
),
)

return inner_f

return _inner_deprecate_method


def _empty_gen() -> Generator:
# Create an empty generator
# Taken from https://stackoverflow.com/a/13243870
return
yield


# Build the set of attributes that are specific to a List object (and will be deprecated)
_LIST_ONLY_ATTRS = frozenset(set(dir([])) - set(dir(_empty_gen())))


class DeprecateListMetaclass(type):
"""Metaclass that overwrites all list-only methods, including magic ones."""

def __new__(cls, clsname, bases, attrs):
# Check consistency
if "_deprecate" not in attrs:
raise TypeError(
"A `_deprecate` method must be implemented to use"
" `DeprecateListMetaclass`."
)
if list not in bases:
raise TypeError(
"Class must inherit from `list` to use `DeprecateListMetaclass`."
)

# Create decorator to deprecate list-only methods, including magic ones
def _with_deprecation(f, name):
@wraps(f)
def _inner(self, *args, **kwargs):
self._deprecate(name) # Use the `_deprecate`
return f(self, *args, **kwargs)

return _inner

# Deprecate list-only methods
for attr in _LIST_ONLY_ATTRS:
attrs[attr] = _with_deprecation(getattr(list, attr), attr)

return super().__new__(cls, clsname, bases, attrs)


class DeprecatedList(list, metaclass=DeprecateListMetaclass):
"""Custom List class for which all calls to a list-specific method is deprecated.
Methods that are shared with a generator are not deprecated.
See `_deprecate_list_output` for more details.
"""

def __init__(self, iterable, warning_message: str):
"""Initialize the list with a default warning message.
Warning message will be formatted at runtime with a "{attr_name}" value.
"""
super().__init__(iterable)
self._deprecation_msg = warning_message

def _deprecate(self, attr_name: str) -> None:
warnings.warn(self._deprecation_msg.format(attr_name=attr_name), FutureWarning)
Loading

0 comments on commit c2dbfd7

Please sign in to comment.