diff --git a/superset-frontend/src/features/tags/BulkTagModal.tsx b/superset-frontend/src/features/tags/BulkTagModal.tsx index adacef1f47c7a..3fff056f41329 100644 --- a/superset-frontend/src/features/tags/BulkTagModal.tsx +++ b/superset-frontend/src/features/tags/BulkTagModal.tsx @@ -45,13 +45,19 @@ const BulkTagModal: React.FC = ({ addDangerToast, }) => { useEffect(() => {}, []); + const [tags, setTags] = useState([]); const onSave = async () => { await SupersetClient.post({ endpoint: `/api/v1/tag/bulk_create`, jsonPayload: { - tags: tags.map(tag => tag.value), - objects_to_tag: selected.map(item => [resourceName, +item.original.id]), + tags: tags.map(tag => ({ + name: tag.value, + objects_to_tag: selected.map(item => [ + resourceName, + +item.original.id, + ]), + })), }, }) .then(({ json = {} }) => { @@ -66,8 +72,6 @@ const BulkTagModal: React.FC = ({ setTags([]); }; - const [tags, setTags] = useState([]); - return ( Response: try: for tag in item.get("tags"): tagged_item: dict[str, Any] = self.add_model_schema.load( - {"name": tag, "objects_to_tag": item.get("objects_to_tag")} + { + "name": tag.get("name"), + "objects_to_tag": tag.get("objects_to_tag"), + } ) CreateCustomTagWithRelationshipsCommand( tagged_item, bulk_create=True diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py index 3f05ccd23e93e..e8311ad520be4 100644 --- a/superset/tags/commands/create.py +++ b/superset/tags/commands/create.py @@ -17,12 +17,13 @@ import logging from typing import Any -from superset import db +from superset import db, security_manager from superset.commands.base import BaseCommand, CreateMixin from superset.daos.exceptions import DAOCreateFailedError from superset.daos.tag import TagDAO +from superset.exceptions import SupersetSecurityException from superset.tags.commands.exceptions import TagCreateFailedError, TagInvalidError -from superset.tags.commands.utils import to_object_type +from superset.tags.commands.utils import to_object_model, to_object_type from superset.tags.models import ObjectTypes, TagTypes logger = logging.getLogger(__name__) @@ -73,6 +74,7 @@ def __init__(self, data: dict[str, Any], bulk_create: bool = False): def run(self) -> None: self.validate() + try: tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom) if self._objects_to_tag: @@ -84,7 +86,8 @@ def run(self) -> None: if self._description: tag.description = self._description - db.session.commit() + + db.session.commit() except DAOCreateFailedError as ex: logger.exception(ex.exception) @@ -98,12 +101,25 @@ def validate(self) -> None: exceptions.append(TagInvalidError()) # Validate object type + skipped_tagged_objects: list[tuple[str, int]] = [] for obj_type, obj_id in self._objects_to_tag: + skipped_tagged_objects = [] object_type = to_object_type(obj_type) + if not object_type: exceptions.append( TagInvalidError(f"invalid object type {object_type}") ) + try: + model = to_object_model(object_type, obj_id) # type: ignore + security_manager.raise_for_ownership(model) + except SupersetSecurityException: + # skip the object if the user doesn't have access + skipped_tagged_objects.append((obj_type, obj_id)) + + self._objects_to_tag = set(self._objects_to_tag) - set( + skipped_tagged_objects + ) if exceptions: raise TagInvalidError(exceptions=exceptions) diff --git a/superset/tags/commands/utils.py b/superset/tags/commands/utils.py index 2993365b7ac75..028465d83a4ae 100644 --- a/superset/tags/commands/utils.py +++ b/superset/tags/commands/utils.py @@ -17,6 +17,12 @@ from typing import Optional, Union +from superset.daos.chart import ChartDAO +from superset.daos.dashboard import DashboardDAO +from superset.daos.query import SavedQueryDAO +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.models.sql_lab import SavedQuery from superset.tags.models import ObjectTypes @@ -27,3 +33,15 @@ def to_object_type(object_type: Union[ObjectTypes, int, str]) -> Optional[Object if object_type in [type_.value, type_.name]: return type_ return None + + +def to_object_model( + object_type: ObjectTypes, object_id: int +) -> Optional[Union[Dashboard, SavedQuery, Slice]]: + if ObjectTypes.dashboard == object_type: + return DashboardDAO.find_by_id(object_id) + if ObjectTypes.query == object_type: + return SavedQueryDAO.find_by_id(object_id) + if ObjectTypes.chart == object_type: + return ChartDAO.find_by_id(object_id) + return None diff --git a/superset/tags/schemas.py b/superset/tags/schemas.py index 8aafbb76b59bf..571a2a03c9e74 100644 --- a/superset/tags/schemas.py +++ b/superset/tags/schemas.py @@ -54,27 +54,21 @@ class TagGetResponseSchema(Schema): type = fields.String() -class TagPostSchema(Schema): +class TagObjectSchema(Schema): name = fields.String() description = fields.String(required=False, allow_none=True) - # resource id's to tag with tag objects_to_tag = fields.List( fields.Tuple((fields.String(), fields.Int())), required=False ) class TagPostBulkSchema(Schema): - tags = fields.List(fields.String()) - # resource id's to tag with tag - objects_to_tag = fields.List( - fields.Tuple((fields.String(), fields.Int())), required=False - ) + tags = fields.List(fields.Nested(TagObjectSchema)) -class TagPutSchema(Schema): - name = fields.String() - description = fields.String(required=False, allow_none=True) - # resource id's to tag with tag - objects_to_tag = fields.List( - fields.Tuple((fields.String(), fields.Int())), required=False - ) +class TagPostSchema(TagObjectSchema): + pass + + +class TagPutSchema(TagObjectSchema): + pass diff --git a/tests/integration_tests/tags/api_tests.py b/tests/integration_tests/tags/api_tests.py index 06e4a73e19130..444d52078e7ca 100644 --- a/tests/integration_tests/tags/api_tests.py +++ b/tests/integration_tests/tags/api_tests.py @@ -530,8 +530,23 @@ def test_post_bulk_tag(self): rv = self.client.post( uri, json={ - "tags": ["tag1", "tag2", "tag3"], - "objects_to_tag": [["dashboard", dashboard.id], ["chart", chart.id]], + "tags": [ + { + "name": "tag1", + "objects_to_tag": [ + ["dashboard", dashboard.id], + ["chart", chart.id], + ], + }, + { + "name": "tag2", + "objects_to_tag": [["dashboard", dashboard.id]], + }, + { + "name": "tag3", + "objects_to_tag": [["chart", chart.id]], + }, + ] }, ) @@ -547,11 +562,10 @@ def test_post_bulk_tag(self): TaggedObject.object_id == dashboard.id, TaggedObject.object_type == ObjectTypes.dashboard, ) - assert tagged_objects.count() == 3 + assert tagged_objects.count() == 2 tagged_objects = db.session.query(TaggedObject).filter( - # TaggedObject.tag_id.in_([tag.id for tag in tags]), TaggedObject.object_id == chart.id, TaggedObject.object_type == ObjectTypes.chart, ) - assert tagged_objects.count() == 3 + assert tagged_objects.count() == 2 diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py index 476c51e45db31..065ed756628cc 100644 --- a/tests/unit_tests/dao/tag_test.py +++ b/tests/unit_tests/dao/tag_test.py @@ -169,6 +169,3 @@ def test_create_tag_relationship(mocker): # Verify that the correct number of TaggedObjects are added to the session assert mock_session.add_all.call_count == 1 assert len(mock_session.add_all.call_args[0][0]) == len(objects_to_tag) - - # Verify that commit is called - mock_session.commit.assert_called_once() diff --git a/tests/unit_tests/tags/commands/create_test.py b/tests/unit_tests/tags/commands/create_test.py index a188625b403f5..639372a70fef6 100644 --- a/tests/unit_tests/tags/commands/create_test.py +++ b/tests/unit_tests/tags/commands/create_test.py @@ -1,4 +1,5 @@ import pytest +from pytest_mock import MockFixture from sqlalchemy.orm.session import Session from superset.utils.core import DatasourceType @@ -47,7 +48,7 @@ def session_with_data(session: Session): yield session -def test_create_command_success(session_with_data: Session): +def test_create_command_success(session_with_data: Session, mocker: MockFixture): from superset.connectors.sqla.models import SqlaTable from superset.daos.tag import TagDAO from superset.models.dashboard import Dashboard @@ -61,6 +62,12 @@ def test_create_command_success(session_with_data: Session): chart = session_with_data.query(Slice).first() dashboard = session_with_data.query(Dashboard).first() + mocker.patch( + "superset.security.SupersetSecurityManager.is_admin", return_value=True + ) + mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart) + mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=query) + objects_to_tag = [ (ObjectTypes.query, query.id), (ObjectTypes.chart, chart.id), @@ -84,7 +91,9 @@ def test_create_command_success(session_with_data: Session): ) -def test_create_command_failed_validate(session_with_data: Session): +def test_create_command_failed_validate( + session_with_data: Session, mocker: MockFixture +): from superset.connectors.sqla.models import SqlaTable from superset.daos.tag import TagDAO from superset.models.dashboard import Dashboard @@ -98,6 +107,12 @@ def test_create_command_failed_validate(session_with_data: Session): chart = session_with_data.query(Slice).first() dashboard = session_with_data.query(Dashboard).first() + mocker.patch( + "superset.security.SupersetSecurityManager.is_admin", return_value=True + ) + mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=query) + mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=chart) + objects_to_tag = [ (ObjectTypes.query, query.id), (ObjectTypes.chart, chart.id), diff --git a/tests/unit_tests/tags/commands/update_test.py b/tests/unit_tests/tags/commands/update_test.py index 2c2454547eb17..84007fbb685d2 100644 --- a/tests/unit_tests/tags/commands/update_test.py +++ b/tests/unit_tests/tags/commands/update_test.py @@ -1,4 +1,5 @@ import pytest +from pytest_mock import MockFixture from sqlalchemy.orm.session import Session from superset.utils.core import DatasourceType @@ -56,13 +57,19 @@ def session_with_data(session: Session): yield session -def test_update_command_success(session_with_data: Session): +def test_update_command_success(session_with_data: Session, mocker: MockFixture): from superset.daos.tag import TagDAO from superset.models.dashboard import Dashboard from superset.tags.commands.update import UpdateTagCommand from superset.tags.models import ObjectTypes, TaggedObject dashboard = session_with_data.query(Dashboard).first() + mocker.patch( + "superset.security.SupersetSecurityManager.is_admin", return_value=True + ) + mocker.patch( + "superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard + ) objects_to_tag = [ (ObjectTypes.dashboard, dashboard.id), @@ -84,7 +91,9 @@ def test_update_command_success(session_with_data: Session): assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) -def test_update_command_success_duplicates(session_with_data: Session): +def test_update_command_success_duplicates( + session_with_data: Session, mocker: MockFixture +): from superset.daos.tag import TagDAO from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -95,6 +104,14 @@ def test_update_command_success_duplicates(session_with_data: Session): dashboard = session_with_data.query(Dashboard).first() chart = session_with_data.query(Slice).first() + mocker.patch( + "superset.security.SupersetSecurityManager.is_admin", return_value=True + ) + mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart) + mocker.patch( + "superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard + ) + objects_to_tag = [ (ObjectTypes.dashboard, dashboard.id), ] @@ -124,14 +141,16 @@ def test_update_command_success_duplicates(session_with_data: Session): assert changed_model.objects[0].object_id == chart.id -def test_update_command_failed_validation(session_with_data: Session): +def test_update_command_failed_validation( + session_with_data: Session, mocker: MockFixture +): from superset.daos.tag import TagDAO from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand from superset.tags.commands.exceptions import TagInvalidError from superset.tags.commands.update import UpdateTagCommand - from superset.tags.models import ObjectTypes, TaggedObject + from superset.tags.models import ObjectTypes dashboard = session_with_data.query(Dashboard).first() chart = session_with_data.query(Slice).first() @@ -139,6 +158,14 @@ def test_update_command_failed_validation(session_with_data: Session): (ObjectTypes.chart, chart.id), ] + mocker.patch( + "superset.security.SupersetSecurityManager.is_admin", return_value=True + ) + mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart) + mocker.patch( + "superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard + ) + CreateCustomTagWithRelationshipsCommand( data={"name": "test_tag", "objects_to_tag": objects_to_tag} ).run()