Skip to content

Commit

Permalink
Support filtering datasets by tags (#2266)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored May 3, 2024
1 parent 2a98fb5 commit 5ff2d15
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,13 +1540,13 @@ def list_models(
>>> api = HfApi()
>>> # List all models
# List all models
>>> api.list_models()
>>> # List only the text classification models
# List only the text classification models
>>> api.list_models(filter="text-classification")
>>> # List only models from the AllenNLP library
# List only models from the AllenNLP library
>>> api.list_models(filter="allennlp")
```
Expand All @@ -1557,10 +1557,10 @@ def list_models(
>>> api = HfApi()
>>> # List all models with "bert" in their name
# List all models with "bert" in their name
>>> api.list_models(search="bert")
>>> # List all models with "bert" in their name made by google
# List all models with "bert" in their name made by google
>>> api.list_models(search="bert", author="google")
```
"""
Expand Down Expand Up @@ -1698,6 +1698,7 @@ def list_datasets(
language: Optional[Union[str, List[str]]] = None,
multilinguality: Optional[Union[str, List[str]]] = None,
size_categories: Optional[Union[str, List[str]]] = None,
tags: Optional[Union[str, List[str]]] = None,
task_categories: Optional[Union[str, List[str]]] = None,
task_ids: Optional[Union[str, List[str]]] = None,
search: Optional[str] = None,
Expand Down Expand Up @@ -1736,6 +1737,8 @@ def list_datasets(
A string or list of strings that can be used to identify datasets on
the Hub by the size of the dataset such as `100K<n<1M` or
`1M<n<10M`.
tags (`str` or `List`, *optional*):
A string tag or a list of tags to filter datasets on the Hub.
task_categories (`str` or `List`, *optional*):
A string or list of strings that can be used to identify datasets on
the Hub by the designed task, such as `audio_classification` or
Expand Down Expand Up @@ -1775,20 +1778,21 @@ def list_datasets(
>>> api = HfApi()
>>> # List all datasets
# List all datasets
>>> api.list_datasets()
>>> # List only the text classification datasets
# List only the text classification datasets
>>> api.list_datasets(filter="task_categories:text-classification")
>>> # List only the datasets in russian for language modeling
# List only the datasets in russian for language modeling
>>> api.list_datasets(
... filter=("language:ru", "task_ids:language-modeling")
... )
>>> api.list_datasets(filter=filt)
# List FiftyOne datasets (identified by the tag "fiftyone" in dataset card)
>>> api.list_datasets(tags="fiftyone")
```
Example usage with the `search` argument:
Expand All @@ -1798,10 +1802,10 @@ def list_datasets(
>>> api = HfApi()
>>> # List all datasets with "text" in their name
# List all datasets with "text" in their name
>>> api.list_datasets(search="text")
>>> # List all datasets with "text" in their name made by google
# List all datasets with "text" in their name made by google
>>> api.list_datasets(search="text", author="google")
```
"""
Expand Down Expand Up @@ -1839,6 +1843,8 @@ def list_datasets(
data = f"{attr}:{data}"
filter_list.append(data)

if tags is not None:
filter_list.extend([tags] if isinstance(tags, str) else tags)
if search:
params.update({"search": search})
if sort is not None:
Expand Down
5 changes: 5 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,11 @@ def test_filter_datasets_with_card_data(self):
datasets = list(self._api.list_datasets(limit=500))
self.assertTrue(all([getattr(dataset, "card_data", None) is None for dataset in datasets]))

def test_filter_datasets_by_tag(self):
datasets = list(self._api.list_datasets(tags="fiftyone", limit=5))
for dataset in datasets:
assert "fiftyone" in dataset.tags

def test_dataset_info(self):
dataset = self._api.dataset_info(repo_id=DUMMY_DATASET_ID)
self.assertTrue(isinstance(dataset.card_data, DatasetCardData) and len(dataset.card_data) > 0)
Expand Down

0 comments on commit 5ff2d15

Please sign in to comment.