diff --git a/src/ai4_cli/cli/modules.py b/src/ai4_cli/cli/modules.py index 85d49e2..ec62aae 100644 --- a/src/ai4_cli/cli/modules.py +++ b/src/ai4_cli/cli/modules.py @@ -1,6 +1,7 @@ """Handle CLI commands for modules.""" from typing_extensions import Annotated +from typing import List, Optional import typer @@ -22,6 +23,43 @@ def list( help="Show more details.", ), ] = False, + tags: Annotated[ + Optional[List[str]], + typer.Option( + "--tags", + help="Filter modules by tags. The given tags must all be present on a " + "module to be included in the results. Boolean expression is " + "t1 AND t2.", + ), + ] = None, + not_tags: Annotated[ + Optional[List[str]], + typer.Option( + "--not-tags", + help="Filter modules by tags. Only the modules that do not have any of the " + "given tags will be included in the results. Boolean expression is " + "NOT (t1 AND t2).", + ), + ] = None, + tags_any: Annotated[ + Optional[List[str]], + typer.Option( + "--tags-any", + help="Filter modules by tags. If any of the given tags is present on a " + "module it will be included in the results. Boolean expression is " + "t1 OR t2.", + ), + ] = None, + not_tags_any: Annotated[ + Optional[List[str]], + typer.Option( + "--not-tags-any", + help="Filter modules by tags. Only the modules that do not have at least " + "any of the given tags will be included in the results. " + "Boolean expression is " + "NOT (t1 OR t2).", + ), + ] = None, ): """List all modules.""" endpoint = ctx.obj.endpoint @@ -29,7 +67,13 @@ def list( debug = ctx.obj.debug cli = client.AI4Client(endpoint, version, http_debug=debug) - _, content = cli.modules.list() + filters = { + "tags": tags, + "not_tags": not_tags, + "tags_any": tags_any, + "not_tags_any": not_tags_any, + } + _, content = cli.modules.list(filters=filters) if long: rows = [ diff --git a/src/ai4_cli/client/modules.py b/src/ai4_cli/client/modules.py index dbfc79c..d9a439a 100644 --- a/src/ai4_cli/client/modules.py +++ b/src/ai4_cli/client/modules.py @@ -11,9 +11,14 @@ def __init__(self, client): """ self.client = client - def list(self): + def list(self, filters=None): """List all modules.""" - return self.client.request("catalog/modules/detail", "GET") + params = {} + for key, value in filters.items(): + if value is None: + continue + params[key] = value + return self.client.request("catalog/modules/detail", "GET", params=params) def show(self, module_id): """Show details of a module.""" diff --git a/src/ai4_cli/tests/test_client.py b/src/ai4_cli/tests/test_client.py index 1e0d4d4..9c9872c 100644 --- a/src/ai4_cli/tests/test_client.py +++ b/src/ai4_cli/tests/test_client.py @@ -49,6 +49,13 @@ def test_client_get(mock_request, ai4_client): mock_request.assert_called_with("foo", "GET") +@mock.patch("ai4_cli.client.client.AI4Client.request") +def test_client_get_with_params(mock_request, ai4_client): + """Test the AI4Client.get method with parameters.""" + ai4_client.get("foo", params={"bar": "baz"}) + mock_request.assert_called_with("foo", "GET", params={"bar": "baz"}) + + @mock.patch("ai4_cli.client.client.AI4Client.request") def test_client_post(mock_request, ai4_client): """Test the AI4Client.post method."""