diff --git a/superset/charts/api.py b/superset/charts/api.py index fb78cd9d739e3..74776c1934a52 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -36,7 +36,7 @@ ChartUpdateFailedError, ) from superset.charts.commands.update import UpdateChartCommand -from superset.charts.filters import ChartFilter +from superset.charts.filters import ChartFilter, ChartNameOrDescriptionFilter from superset.charts.schemas import ( ChartPostSchema, ChartPutSchema, @@ -111,6 +111,7 @@ class ChartRestApi(BaseSupersetModelRestApi): ) base_order = ("changed_on", "desc") base_filters = [["id", ChartFilter, lambda: []]] + search_filters = {"slice_name": [ChartNameOrDescriptionFilter]} # Will just affect _info endpoint edit_columns = ["slice_name"] diff --git a/superset/charts/filters.py b/superset/charts/filters.py index a35ba2912b073..94ae2ad1747ed 100644 --- a/superset/charts/filters.py +++ b/superset/charts/filters.py @@ -16,13 +16,33 @@ # under the License. from typing import Any +from flask_babel import lazy_gettext as _ from sqlalchemy import or_ from sqlalchemy.orm.query import Query from superset import security_manager +from superset.models.slice import Slice from superset.views.base import BaseFilter +class ChartNameOrDescriptionFilter( + BaseFilter +): # pylint: disable=too-few-public-methods + name = _("Name or Description") + arg_name = "name_or_description" + + def apply(self, query: Query, value: Any) -> Query: + if not value: + return query + ilike_value = f"%{value}%" + return query.filter( + or_( + Slice.slice_name.ilike(ilike_value), + Slice.description.ilike(ilike_value), + ) + ) + + class ChartFilter(BaseFilter): # pylint: disable=too-few-public-methods def apply(self, query: Query, value: Any) -> Query: if security_manager.all_datasource_access(): diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 450e99379b0b1..257b89ba37df0 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -566,6 +566,56 @@ def test_get_charts_filter(self): data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data["count"], 5) + def test_get_charts_custom_filter(self): + """ + Chart API: Test get charts custom filter + """ + admin = self.get_user("admin") + chart1 = self.insert_chart("foo", [admin.id], 1, description="ZY_bar") + chart2 = self.insert_chart("zy_foo", [admin.id], 1, description="desc1") + chart3 = self.insert_chart("foo", [admin.id], 1, description="desc1zy_") + chart4 = self.insert_chart("bar", [admin.id], 1, description="foo") + + arguments = { + "filters": [ + {"col": "slice_name", "opr": "name_or_description", "value": "zy_"} + ], + "order_column": "slice_name", + "order_direction": "asc", + } + self.login(username="admin") + uri = f"api/v1/chart/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], 3) + + expected_response = [ + {"description": "ZY_bar", "slice_name": "foo",}, + {"description": "desc1zy_", "slice_name": "foo",}, + {"description": "desc1", "slice_name": "zy_foo",}, + ] + for index, item in enumerate(data["result"]): + self.assertEqual( + item["description"], expected_response[index]["description"] + ) + self.assertEqual(item["slice_name"], expected_response[index]["slice_name"]) + + self.logout() + self.login(username="gamma") + uri = f"api/v1/chart/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], 0) + + # rollback changes + db.session.delete(chart1) + db.session.delete(chart2) + db.session.delete(chart3) + db.session.delete(chart4) + db.session.commit() + def test_get_charts_page(self): """ Chart API: Test get charts filter