Skip to content

Commit

Permalink
Fixes filtering by tags with list_models and adds test case (#1673)
Browse files Browse the repository at this point in the history
* Fix tags filter in list_models and add test case

* Remove key tags in query dict

---------

Co-authored-by: Lucain <lucainp@gmail.com>
  • Loading branch information
martinbrose and Wauplin authored Sep 19, 2023
1 parent 7f25246 commit adef26d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
5 changes: 1 addition & 4 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,6 @@ def _unpack_model_filter(self, model_filter: ModelFilter):
Unpacks a [`ModelFilter`] into something readable for `list_models`
"""
model_str = ""
tags = []

# Handling author
if model_filter.author is not None:
Expand Down Expand Up @@ -1186,13 +1185,11 @@ def _unpack_model_filter(self, model_filter: ModelFilter):

# Handling tags
if model_filter.tags:
tags.extend([model_filter.tags] if isinstance(model_filter.tags, str) else model_filter.tags)
filter_list.extend([model_filter.tags] if isinstance(model_filter.tags, str) else model_filter.tags)

query_dict: Dict[str, Any] = {}
if model_str is not None:
query_dict["search"] = model_str
if len(tags) > 0:
query_dict["tags"] = tags
if isinstance(model_filter.language, list):
filter_list.extend(model_filter.language)
elif isinstance(model_filter.language, str):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,14 @@ def test_filter_models_by_language(self):
res_en = list(self._api.list_models(filter=ModelFilter(language="en")))
self.assertGreater(len(res_en), len(res_fr))

def test_filter_models_with_tag(self):
models = list(self._api.list_models(filter=ModelFilter(author="HuggingFaceBR4", tags=["tensorboard"])))
self.assertTrue("HuggingFaceBR4" == models[0].author)
self.assertTrue("tensorboard" in models[0].tags)

models = list(self._api.list_models(filter=ModelFilter(tags="dummytag")))
self.assertEqual(len(models), 0)

def test_filter_models_with_complex_query(self):
args = ModelSearchArguments(api=self._api)
f = ModelFilter(
Expand Down

0 comments on commit adef26d

Please sign in to comment.