Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix list_models bool parameters #1152

Merged
merged 8 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/huggingface_hub/_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
logger = logging.get_logger(__name__)


def login(token: Optional[str] = None, add_to_git_credential: bool = True) -> None:
def login(token: Optional[str] = None, add_to_git_credential: bool = False) -> None:
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
"""Login the machine to access the Hub.

The `token` is persisted in cache and set as a git credential. Once done, the machine
Expand Down
48 changes: 23 additions & 25 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
write_to_credential_store,
)
from .utils._deprecation import (
_deprecate_arguments,
_deprecate_list_output,
_deprecate_method,
_deprecate_positional_args,
Expand Down Expand Up @@ -510,7 +511,7 @@ def __init__(self):
self._process_models()

def _process_models(self):
def clean(s: str):
def clean(s: str) -> str:
return s.replace(" ", "").replace("-", "_").replace(".", "_")

models = self._api.list_models()
Expand Down Expand Up @@ -679,8 +680,8 @@ def list_models(
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
full: Optional[bool] = None,
cardData: Optional[bool] = None,
fetch_config: Optional[bool] = None,
cardData: bool = False,
fetch_config: bool = False,
token: Optional[Union[bool, str]] = None,
) -> List[ModelInfo]:
"""
Expand Down Expand Up @@ -802,10 +803,10 @@ def list_models(
params.update({"full": True})
elif "full" in params:
del params["full"]
if fetch_config is not None:
params.update({"config": fetch_config})
if cardData is not None:
params.update({"cardData": cardData})
if fetch_config:
params.update({"config": True})
if cardData:
params.update({"cardData": True})
r = requests.get(path, params=params, headers=headers)
hf_raise_for_status(r)
d = r.json()
Expand Down Expand Up @@ -882,6 +883,11 @@ def _unpack_model_filter(self, model_filter: ModelFilter):
query_dict["filter"] = tuple(filter_list)
return query_dict

@_deprecate_arguments(
version="0.14",
deprecated_args={"cardData"},
custom_message="Use 'full' instead.",
)
Comment on lines +886 to +890
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cardData and full had exactly the same behavior so I squeezed cardData.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I tried it as well and it returns the same values. I'm not sure if this is intended though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the codebase it was literally:

if full is not None:
    if full:
        params.update({"full": True})
if cardData is not None:
    if cardData:
        params.update({"full": True})

And I tested in on the API, "cardData" doesn't seem to be a valid parameter for the /datasets endpoint. That's why I just removed the cardData param. I tend to think that it comes from a copy-paste of the arguments from list_models and not adapted correctly.

https://huggingface.co/api/datasets?limit=1 => base
https://huggingface.co/api/datasets?limit=1&cardData=1 => base
https://huggingface.co/api/datasets?limit=1&full=1 => full

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah lol indeed!

Copy link
Contributor Author

@Wauplin Wauplin Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from that, are you ok with the changes in the PR @LysandreJik ? 🙂

@_deprecate_list_output(version="0.14")
@validate_hf_hub_args
def list_datasets(
Expand All @@ -893,7 +899,7 @@ def list_datasets(
sort: Union[Literal["lastModified"], str, None] = None,
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
cardData: Optional[bool] = None,
cardData: Optional[bool] = None, # deprecated
full: Optional[bool] = None,
token: Optional[str] = None,
) -> List[DatasetInfo]:
Expand All @@ -917,12 +923,10 @@ def list_datasets(
limit (`int`, *optional*):
The limit on the number of datasets fetched. Leaving this option
to `None` fetches all datasets.
cardData (`bool`, *optional*):
Whether to grab the metadata for the dataset as well. Can
contain useful information such as the PapersWithCode ID.
full (`bool`, *optional*):
Whether to fetch all dataset data, including the `lastModified`
and the `cardData`.
and the `cardData`. Can contain useful information such as the
PapersWithCode ID.
token (`bool` or `str`, *optional*):
A valid authentication token (see https://huggingface.co/settings/token).
If `None` or `True` and machine is logged in (through `huggingface-cli login`
Expand Down Expand Up @@ -1001,12 +1005,8 @@ def list_datasets(
params.update({"direction": direction})
if limit is not None:
params.update({"limit": limit})
if full is not None:
if full:
params.update({"full": True})
if cardData is not None:
if cardData:
params.update({"full": True})
if full or cardData:
params.update({"full": True})
r = requests.get(path, params=params, headers=headers)
hf_raise_for_status(r)
d = r.json()
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def list_spaces(
limit: Optional[int] = None,
datasets: Union[str, Iterable[str], None] = None,
models: Union[str, Iterable[str], None] = None,
linked: Optional[bool] = None,
linked: bool = False,
full: Optional[bool] = None,
token: Optional[str] = None,
) -> List[SpaceInfo]:
Expand Down Expand Up @@ -1139,12 +1139,10 @@ def list_spaces(
params.update({"direction": direction})
if limit is not None:
params.update({"limit": limit})
if full is not None:
if full:
params.update({"full": True})
if linked is not None:
if linked:
params.update({"linked": True})
if full:
params.update({"full": True})
if linked:
params.update({"linked": True})
if datasets is not None:
params.update({"datasets": datasets})
if models is not None:
Expand Down
19 changes: 11 additions & 8 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,15 +1226,17 @@ def test_list_models_complex_query(self):

@with_production_testing
def test_list_models_with_config(self):
_api = HfApi()
models = _api.list_models(
for model in HfApi().list_models(
filter="adapter-transformers", fetch_config=True, limit=20
)
found_configs = 0
for model in models:
if model.config:
found_configs = found_configs + 1
self.assertGreater(found_configs, 0)
):
self.assertIsNotNone(model.config)

@with_production_testing
def test_list_models_without_config(self):
for model in HfApi().list_models(
filter="adapter-transformers", fetch_config=False, limit=20
):
self.assertIsNone(model.config)

@with_production_testing
def test_model_info(self):
Expand Down Expand Up @@ -1412,6 +1414,7 @@ def test_list_datasets_search(self):
self.assertGreater(len(datasets), 10)
self.assertIsInstance(datasets[0], DatasetInfo)

@expect_deprecation("list_datasets")
@with_production_testing
def test_filter_datasets_with_cardData(self):
_api = HfApi()
Expand Down