Skip to content

Commit

Permalink
Fix list_models bool parameters (#1152)
Browse files Browse the repository at this point in the history
* Fix list_models bool parameter

* remove cardData deprecation

* quality

* expect deprecation is code

* switch back add_to_git_credential default value see #1138
  • Loading branch information
Wauplin authored Nov 7, 2022
1 parent c2dbfd7 commit 9ccdf02
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/huggingface_hub/_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
logger = logging.get_logger(__name__)


def login(token: Optional[str] = None, add_to_git_credential: bool = True) -> None:
def login(token: Optional[str] = None, add_to_git_credential: bool = False) -> None:
"""Login the machine to access the Hub.
The `token` is persisted in cache and set as a git credential. Once done, the machine
Expand Down
48 changes: 23 additions & 25 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
write_to_credential_store,
)
from .utils._deprecation import (
_deprecate_arguments,
_deprecate_list_output,
_deprecate_method,
_deprecate_positional_args,
Expand Down Expand Up @@ -510,7 +511,7 @@ def __init__(self):
self._process_models()

def _process_models(self):
def clean(s: str):
def clean(s: str) -> str:
return s.replace(" ", "").replace("-", "_").replace(".", "_")

models = self._api.list_models()
Expand Down Expand Up @@ -679,8 +680,8 @@ def list_models(
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
full: Optional[bool] = None,
cardData: Optional[bool] = None,
fetch_config: Optional[bool] = None,
cardData: bool = False,
fetch_config: bool = False,
token: Optional[Union[bool, str]] = None,
) -> List[ModelInfo]:
"""
Expand Down Expand Up @@ -802,10 +803,10 @@ def list_models(
params.update({"full": True})
elif "full" in params:
del params["full"]
if fetch_config is not None:
params.update({"config": fetch_config})
if cardData is not None:
params.update({"cardData": cardData})
if fetch_config:
params.update({"config": True})
if cardData:
params.update({"cardData": True})
r = requests.get(path, params=params, headers=headers)
hf_raise_for_status(r)
d = r.json()
Expand Down Expand Up @@ -882,6 +883,11 @@ def _unpack_model_filter(self, model_filter: ModelFilter):
query_dict["filter"] = tuple(filter_list)
return query_dict

@_deprecate_arguments(
version="0.14",
deprecated_args={"cardData"},
custom_message="Use 'full' instead.",
)
@_deprecate_list_output(version="0.14")
@validate_hf_hub_args
def list_datasets(
Expand All @@ -893,7 +899,7 @@ def list_datasets(
sort: Union[Literal["lastModified"], str, None] = None,
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
cardData: Optional[bool] = None,
cardData: Optional[bool] = None, # deprecated
full: Optional[bool] = None,
token: Optional[str] = None,
) -> List[DatasetInfo]:
Expand All @@ -917,12 +923,10 @@ def list_datasets(
limit (`int`, *optional*):
The limit on the number of datasets fetched. Leaving this option
to `None` fetches all datasets.
cardData (`bool`, *optional*):
Whether to grab the metadata for the dataset as well. Can
contain useful information such as the PapersWithCode ID.
full (`bool`, *optional*):
Whether to fetch all dataset data, including the `lastModified`
and the `cardData`.
and the `cardData`. Can contain useful information such as the
PapersWithCode ID.
token (`bool` or `str`, *optional*):
A valid authentication token (see https://huggingface.co/settings/token).
If `None` or `True` and machine is logged in (through `huggingface-cli login`
Expand Down Expand Up @@ -1001,12 +1005,8 @@ def list_datasets(
params.update({"direction": direction})
if limit is not None:
params.update({"limit": limit})
if full is not None:
if full:
params.update({"full": True})
if cardData is not None:
if cardData:
params.update({"full": True})
if full or cardData:
params.update({"full": True})
r = requests.get(path, params=params, headers=headers)
hf_raise_for_status(r)
d = r.json()
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def list_spaces(
limit: Optional[int] = None,
datasets: Union[str, Iterable[str], None] = None,
models: Union[str, Iterable[str], None] = None,
linked: Optional[bool] = None,
linked: bool = False,
full: Optional[bool] = None,
token: Optional[str] = None,
) -> List[SpaceInfo]:
Expand Down Expand Up @@ -1139,12 +1139,10 @@ def list_spaces(
params.update({"direction": direction})
if limit is not None:
params.update({"limit": limit})
if full is not None:
if full:
params.update({"full": True})
if linked is not None:
if linked:
params.update({"linked": True})
if full:
params.update({"full": True})
if linked:
params.update({"linked": True})
if datasets is not None:
params.update({"datasets": datasets})
if models is not None:
Expand Down
19 changes: 11 additions & 8 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,15 +1226,17 @@ def test_list_models_complex_query(self):

@with_production_testing
def test_list_models_with_config(self):
_api = HfApi()
models = _api.list_models(
for model in HfApi().list_models(
filter="adapter-transformers", fetch_config=True, limit=20
)
found_configs = 0
for model in models:
if model.config:
found_configs = found_configs + 1
self.assertGreater(found_configs, 0)
):
self.assertIsNotNone(model.config)

@with_production_testing
def test_list_models_without_config(self):
for model in HfApi().list_models(
filter="adapter-transformers", fetch_config=False, limit=20
):
self.assertIsNone(model.config)

@with_production_testing
def test_model_info(self):
Expand Down Expand Up @@ -1412,6 +1414,7 @@ def test_list_datasets_search(self):
self.assertGreater(len(datasets), 10)
self.assertIsInstance(datasets[0], DatasetInfo)

@expect_deprecation("list_datasets")
@with_production_testing
def test_filter_datasets_with_cardData(self):
_api = HfApi()
Expand Down

0 comments on commit 9ccdf02

Please sign in to comment.