Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'gated' search parameter #2448

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,7 @@ def list_models(
# Search-query parameter
filter: Union[str, Iterable[str], None] = None,
author: Optional[str] = None,
gated: Optional[bool] = None,
library: Optional[Union[str, List[str]]] = None,
language: Optional[Union[str, List[str]]] = None,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -1624,6 +1625,10 @@ def list_models(
author (`str`, *optional*):
A string which identify the author (user or organization) of the
returned models
gated (`bool`, *optional*):
A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
If `gated=True` is passed, only gated models are returned.
If `gated=False` is passed, only non-gated models are returned.
library (`str` or `List`, *optional*):
A string or list of strings of foundational libraries models were
originally trained from, such as pytorch, tensorflow, or allennlp.
Expand Down Expand Up @@ -1749,6 +1754,8 @@ def list_models(
# Handle other query params
if author:
params["author"] = author
if gated is not None:
params["gated"] = gated
if pipeline_tag:
params["pipeline_tag"] = pipeline_tag
search_list = []
Expand Down Expand Up @@ -1795,6 +1802,7 @@ def list_datasets(
author: Optional[str] = None,
benchmark: Optional[Union[str, List[str]]] = None,
dataset_name: Optional[str] = None,
gated: Optional[bool] = None,
language_creators: Optional[Union[str, List[str]]] = None,
language: Optional[Union[str, List[str]]] = None,
multilinguality: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -1826,6 +1834,10 @@ def list_datasets(
dataset_name (`str`, *optional*):
A string or list of strings that can be used to identify datasets on
the Hub by its name, such as `SQAC` or `wikineural`
gated (`bool`, *optional*):
A boolean to filter datasets on the Hub that are gated or not. By default, all datasets are returned.
If `gated=True` is passed, only gated datasets are returned.
If `gated=False` is passed, only non-gated datasets are returned.
language_creators (`str` or `List`, *optional*):
A string or list of strings that can be used to identify datasets on
the Hub with how the data was curated, such as `crowdsourced` or
Expand Down Expand Up @@ -1954,6 +1966,8 @@ def list_datasets(
# Handle other query params
if author:
params["author"] = author
if gated is not None:
params["gated"] = gated
search_list = []
if dataset_name:
search_list.append(dataset_name)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,14 @@ def test_list_models_expand_cannot_be_used_with_other_params(self):
with self.assertRaises(ValueError):
next(self._api.list_models(expand=["author"], cardData=True))

def test_list_models_gated_only(self):
for model in self._api.list_models(expand=["gated"], gated=True, limit=5):
assert model.gated in ("auto", "manual")

def test_list_models_non_gated_only(self):
for model in self._api.list_models(expand=["gated"], gated=False, limit=5):
assert model.gated is False

def test_model_info(self):
model = self._api.model_info(repo_id=DUMMY_MODEL_ID)
self.assertIsInstance(model, ModelInfo)
Expand Down Expand Up @@ -2009,6 +2017,14 @@ def test_list_datasets_expand_cannot_be_used_with_full(self):
with self.assertRaises(ValueError):
next(self._api.list_datasets(expand=["author"], full=True))

def test_list_datasets_gated_only(self):
for dataset in self._api.list_datasets(expand=["gated"], gated=True, limit=5):
assert dataset.gated in ("auto", "manual")

def test_list_datasets_non_gated_only(self):
for dataset in self._api.list_datasets(expand=["gated"], gated=False, limit=5):
assert dataset.gated is False

def test_filter_datasets_with_card_data(self):
assert any(dataset.card_data is not None for dataset in self._api.list_datasets(full=True, limit=50))
assert all(dataset.card_data is None for dataset in self._api.list_datasets(full=False, limit=50))
Expand Down
Loading