Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search by authors and string #531

Merged
merged 14 commits into from
Jan 5, 2022
60 changes: 60 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ def unset_access_token():
def list_models(
self,
filter: Union[str, Iterable[str], None] = None,
author: Optional[str] = None,
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
search: Optional[str] = None,
sort: Union[Literal["lastModified"], str, None] = None,
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
Expand Down Expand Up @@ -451,6 +453,30 @@ def list_models(

>>> # List only the models from the AllenNLP library
>>> api.list_models(filter="allennlp")
author (:obj:`str`, `optional`):
A string which identify the author (user or organization) of the returned models
Example usage:

>>> from huggingface_hub import HfApi
>>> api = HfApi()

>>> # List all models from google
>>> api.list_models(author="google")

>>> # List only the text classification models from google
>>> api.list_models(filter="text-classification", author="google")
search (:obj:`str`, `optional`):
A string that will be contained in the returned models
Example usage:

>>> from huggingface_hub import HfApi
>>> api = HfApi()

>>> # List all models with "bert" in their name
>>> api.list_models(search="bert")

>>> #List all models with "bert" in their name made by google
>>> api.list_models(search="bert", author="google")
sort (:obj:`Literal["lastModified"]` or :obj:`str`, `optional`):
The key with which to sort the resulting models. Possible values are the properties of the `ModelInfo`
class.
Expand All @@ -471,6 +497,10 @@ def list_models(
if filter is not None:
params.update({"filter": filter})
params.update({"full": True})
if author is not None:
params.update({"author": author})
if search is not None:
params.update({"search": search})
if sort is not None:
params.update({"sort": sort})
if direction is not None:
Expand All @@ -492,6 +522,8 @@ def list_models(
def list_datasets(
self,
filter: Union[str, Iterable[str], None] = None,
author: Optional[str] = None,
search: Optional[str] = None,
sort: Union[Literal["lastModified"], str, None] = None,
direction: Optional[Literal[-1]] = None,
limit: Optional[int] = None,
Expand All @@ -516,6 +548,30 @@ def list_datasets(

>>> # List only the datasets in russian for language modeling
>>> api.list_datasets(filter=("languages:ru", "task_ids:language-modeling"))
author (:obj:`str`, `optional`):
A string which identify the author of the returned models
Example usage:

>>> from huggingface_hub import HfApi
>>> api = HfApi()

>>> # List all datasets from google
>>> api.list_datasets(author="google")

>>> # List only the text classification datasets from google
>>> api.list_datasets(filter="text-classification", author="google")
search (:obj:`str`, `optional`):
A string that will be contained in the returned models
Example usage:

>>> from huggingface_hub import HfApi
>>> api = HfApi()

>>> # List all datasets with "text" in their name
>>> api.list_datasets(search="text")

>>> #List all datasets with "text" in their name made by google
>>> api.list_datasets(search="text", author="google")
sort (:obj:`Literal["lastModified"]` or :obj:`str`, `optional`):
The key with which to sort the resulting datasets. Possible values are the properties of the `DatasetInfo`
class.
Expand All @@ -532,6 +588,10 @@ def list_datasets(
params = {}
if filter is not None:
params.update({"filter": filter})
if author is not None:
params.update({"author": author})
if search is not None:
params.update({"search": search})
if sort is not None:
params.update({"sort": sort})
if direction is not None:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,22 @@ def test_list_models(self):
self.assertGreater(len(models), 100)
self.assertIsInstance(models[0], ModelInfo)

@with_production_testing
def test_list_models_author(self):
_api = HfApi()
models = _api.list_models(author="google")
FrancescoSaverioZuppichini marked this conversation as resolved.
Show resolved Hide resolved
self.assertGreater(len(models), 10)
self.assertIsInstance(models[0], ModelInfo)
[self.assertTrue("google" in model.author for model in models)]

@with_production_testing
def test_list_models_search(self):
_api = HfApi()
models = _api.list_models(search="bert")
self.assertGreater(len(models), 10)
self.assertIsInstance(models[0], ModelInfo)
[self.assertTrue("bert" in model.modelId.lower()) for model in models]

@with_production_testing
def test_list_models_complex_query(self):
# Let's list the 10 most recent models
Expand Down Expand Up @@ -549,6 +565,20 @@ def test_list_datasets_full(self):
self.assertIsInstance(dataset, DatasetInfo)
self.assertTrue(any(dataset.cardData for dataset in datasets))

@with_production_testing
def test_list_datasets_author(self):
_api = HfApi()
datasets = _api.list_datasets(author="huggingface")
self.assertGreater(len(datasets), 1)
self.assertIsInstance(datasets[0], DatasetInfo)

@with_production_testing
def test_list_datasets_search(self):
_api = HfApi()
datasets = _api.list_datasets(search="wikipedia")
self.assertGreater(len(datasets), 10)
self.assertIsInstance(datasets[0], DatasetInfo)

@with_production_testing
def test_dataset_info(self):
_api = HfApi()
Expand Down