diff --git a/UPDATING.md b/UPDATING.md index 41a120f31078d..89fda5580acf8 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -24,6 +24,7 @@ assists people when migrating to a new version. ## Next +- [24488](https://github.com/apache/superset/pull/24488): Augments the foreign key constraints for the `sql_metrics`, `sqlatable_user`, and `table_columns` tables which reference the `tables` table to include an explicit CASCADE ON DELETE to ensure the relevant records are deleted when a dataset is deleted. Scheduled downtime may be advised. - [24335](https://github.com/apache/superset/pull/24335): Removed deprecated API `/superset/filter////` - [24185](https://github.com/apache/superset/pull/24185): `/api/v1/database/test_connection` and `api/v1/database/validate_parameters` permissions changed from `can_read` to `can_write`. Only Admin user's have access. - [24256](https://github.com/apache/superset/pull/24256): `Flask-Login` session validation is now set to `strong` by default. Previous setting was `basic`. diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 4eebec6be7a16..e5d791fc7e5c0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -196,7 +196,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin): __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"),) - table_id = Column(Integer, ForeignKey("tables.id")) + table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE")) table: Mapped[SqlaTable] = relationship( "SqlaTable", back_populates="columns", @@ -400,7 +400,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): __tablename__ = "sql_metrics" __table_args__ = (UniqueConstraint("table_id", "metric_name"),) - table_id = Column(Integer, ForeignKey("tables.id")) + table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE")) table: Mapped[SqlaTable] = relationship( "SqlaTable", back_populates="metrics", @@ -469,8 +469,8 @@ def data(self) -> dict[str, Any]: "sqlatable_user", metadata, Column("id", Integer, primary_key=True), - Column("user_id", Integer, ForeignKey("ab_user.id")), - Column("table_id", Integer, ForeignKey("tables.id")), + Column("user_id", Integer, ForeignKey("ab_user.id", ondelete="CASCADE")), + Column("table_id", Integer, ForeignKey("tables.id", ondelete="CASCADE")), ) @@ -507,11 +507,13 @@ class SqlaTable( TableColumn, back_populates="table", cascade="all, delete-orphan", + passive_deletes=True, ) metrics: Mapped[list[SqlMetric]] = relationship( SqlMetric, back_populates="table", cascade="all, delete-orphan", + passive_deletes=True, ) metric_class = SqlMetric column_class = TableColumn diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 4634a7e46f96e..74b7fa50eb1e7 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -19,6 +19,7 @@ from sqlalchemy.exc import SQLAlchemyError +from superset import security_manager from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.daos.base import BaseDAO from superset.extensions import db @@ -361,25 +362,24 @@ def create_metric( """ return DatasetMetricDAO.create(properties, commit=commit) - @staticmethod - def bulk_delete(models: Optional[list[SqlaTable]], commit: bool = True) -> None: + @classmethod + def bulk_delete( + cls, models: Optional[list[SqlaTable]], commit: bool = True + ) -> None: item_ids = [model.id for model in models] if models else [] - # bulk delete, first delete related data - if models: - for model in models: - model.owners = [] - db.session.merge(model) - db.session.query(SqlMetric).filter(SqlMetric.table_id.in_(item_ids)).delete( - synchronize_session="fetch" - ) - db.session.query(TableColumn).filter( - TableColumn.table_id.in_(item_ids) - ).delete(synchronize_session="fetch") # bulk delete itself try: db.session.query(SqlaTable).filter(SqlaTable.id.in_(item_ids)).delete( synchronize_session="fetch" ) + + if models: + connection = db.session.connection() + mapper = next(iter(cls.model_cls.registry.mappers)) # type: ignore + + for model in models: + security_manager.dataset_after_delete(mapper, connection, model) + if commit: db.session.commit() except SQLAlchemyError as ex: diff --git a/superset/datasets/commands/bulk_delete.py b/superset/datasets/commands/bulk_delete.py index 9733aa21e8c98..6937dd20b75bf 100644 --- a/superset/datasets/commands/bulk_delete.py +++ b/superset/datasets/commands/bulk_delete.py @@ -28,7 +28,6 @@ DatasetNotFoundError, ) from superset.exceptions import SupersetSecurityException -from superset.extensions import db logger = logging.getLogger(__name__) @@ -40,35 +39,10 @@ def __init__(self, model_ids: list[int]): def run(self) -> None: self.validate() - if not self._models: - return None + assert self._models + try: DatasetDAO.bulk_delete(self._models) - for model in self._models: - view_menu = ( - security_manager.find_view_menu(model.get_perm()) if model else None - ) - - if view_menu: - permission_views = ( - db.session.query(security_manager.permissionview_model) - .filter_by(view_menu=view_menu) - .all() - ) - - for permission_view in permission_views: - db.session.delete(permission_view) - if view_menu: - db.session.delete(view_menu) - else: - if not view_menu: - logger.error( - "Could not find the data access permission for the dataset", - exc_info=True, - ) - db.session.commit() - - return None except DeleteFailedError as ex: logger.exception(ex.exception) raise DatasetBulkDeleteFailedError() from ex diff --git a/superset/datasets/commands/delete.py b/superset/datasets/commands/delete.py index 7078f09c37819..1c2147b9594af 100644 --- a/superset/datasets/commands/delete.py +++ b/superset/datasets/commands/delete.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import cast, Optional +from typing import Optional from flask_appbuilder.models.sqla import Model -from sqlalchemy.exc import SQLAlchemyError from superset import security_manager from superset.commands.base import BaseCommand @@ -31,7 +30,6 @@ DatasetNotFoundError, ) from superset.exceptions import SupersetSecurityException -from superset.extensions import db logger = logging.getLogger(__name__) @@ -43,19 +41,13 @@ def __init__(self, model_id: int): def run(self) -> Model: self.validate() - self._model = cast(SqlaTable, self._model) + assert self._model + try: - # Even though SQLAlchemy should in theory delete rows from the association - # table, sporadically Superset will error because the rows are not deleted. - # Let's do it manually here to prevent the error. - self._model.owners = [] - dataset = DatasetDAO.delete(self._model, commit=False) - db.session.commit() - except (SQLAlchemyError, DAODeleteFailedError) as ex: + return DatasetDAO.delete(self._model) + except DAODeleteFailedError as ex: logger.exception(ex) - db.session.rollback() raise DatasetDeleteFailedError() from ex - return dataset def validate(self) -> None: # Validate/populate model exists diff --git a/superset/migrations/versions/2023-06-22_13-39_6fbe660cac39_add_on_delete_cascade_for_tables_references.py b/superset/migrations/versions/2023-06-22_13-39_6fbe660cac39_add_on_delete_cascade_for_tables_references.py new file mode 100644 index 0000000000000..bef12f8cf4015 --- /dev/null +++ b/superset/migrations/versions/2023-06-22_13-39_6fbe660cac39_add_on_delete_cascade_for_tables_references.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add on delete cascade for tables references + +Revision ID: 6fbe660cac39 +Revises: 83e1abbe777f +Create Date: 2023-06-22 13:39:47.989373 + +""" +from __future__ import annotations + +# revision identifiers, used by Alembic. +revision = "6fbe660cac39" +down_revision = "83e1abbe777f" + +import sqlalchemy as sa +from alembic import op + +from superset.utils.core import generic_find_fk_constraint_name + + +def migrate(ondelete: str | None) -> None: + """ + Redefine the foreign key constraints, via a successive DROP and ADD, for all tables + related to the `tables` table to include the ON DELETE construct for cascading + purposes. + + :param ondelete: If set, emit ON DELETE when issuing DDL for this constraint + """ + + bind = op.get_bind() + insp = sa.engine.reflection.Inspector.from_engine(bind) + + conv = { + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + } + + for table in ("sql_metrics", "table_columns"): + with op.batch_alter_table(table, naming_convention=conv) as batch_op: + if constraint := generic_find_fk_constraint_name( + table=table, + columns={"id"}, + referenced="tables", + insp=insp, + ): + batch_op.drop_constraint(constraint, type_="foreignkey") + + batch_op.create_foreign_key( + constraint_name=f"fk_{table}_table_id_tables", + referent_table="tables", + local_cols=["table_id"], + remote_cols=["id"], + ondelete=ondelete, + ) + + with op.batch_alter_table("sqlatable_user", naming_convention=conv) as batch_op: + for table, column in zip(("ab_user", "tables"), ("user_id", "table_id")): + if constraint := generic_find_fk_constraint_name( + table="sqlatable_user", + columns={"id"}, + referenced=table, + insp=insp, + ): + batch_op.drop_constraint(constraint, type_="foreignkey") + + batch_op.create_foreign_key( + constraint_name=f"fk_sqlatable_user_{column}_{table}", + referent_table=table, + local_cols=[column], + remote_cols=["id"], + ondelete=ondelete, + ) + + +def upgrade(): + migrate(ondelete="CASCADE") + + +def downgrade(): + migrate(ondelete=None) diff --git a/superset/utils/core.py b/superset/utils/core.py index 125a406bf5d6a..226daad45e0c7 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -29,6 +29,7 @@ import re import signal import smtplib +import sqlite3 import ssl import tempfile import threading @@ -36,7 +37,7 @@ import uuid import zlib from collections.abc import Iterable, Iterator, Sequence -from contextlib import contextmanager +from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import date, datetime, time, timedelta from email.mime.application import MIMEApplication @@ -849,6 +850,24 @@ def ping_connection(connection: Connection, branch: bool) -> None: # restore 'close with result' connection.should_close_with_result = save_should_close_with_result + if some_engine.dialect.name == "sqlite": + + @event.listens_for(some_engine, "connect") + def set_sqlite_pragma( # pylint: disable=unused-argument + connection: sqlite3.Connection, + *args: Any, + ) -> None: + r""" + Enable foreign key support for SQLite. + + :param connection: The SQLite connection + :param \*args: Additional positional arguments + :see: https://docs.sqlalchemy.org/en/latest/dialects/sqlite.html + """ + + with closing(connection.cursor()) as cursor: + cursor.execute("PRAGMA foreign_keys=ON") + def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many-locals to: str, diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 69e99978e5a29..68b2b55809bac 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1504,7 +1504,6 @@ def test_import_chart(self): assert chart.table == dataset chart.owners = [] - dataset.owners = [] db.session.delete(chart) db.session.commit() db.session.delete(dataset) @@ -1577,7 +1576,6 @@ def test_import_chart_overwrite(self): chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one() chart.owners = [] - dataset.owners = [] db.session.delete(chart) db.session.commit() db.session.delete(dataset) diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py index 217b1655a5f05..d0a59a3100439 100644 --- a/tests/integration_tests/charts/commands_tests.py +++ b/tests/integration_tests/charts/commands_tests.py @@ -283,7 +283,6 @@ def test_import_v1_chart(self, sm_g, utils_g): assert chart.owners == [admin] chart.owners = [] - dataset.owners = [] database.owners = [] db.session.delete(chart) db.session.delete(dataset) diff --git a/tests/integration_tests/commands_test.py b/tests/integration_tests/commands_test.py index 77fbad05f3a39..e34a04072403c 100644 --- a/tests/integration_tests/commands_test.py +++ b/tests/integration_tests/commands_test.py @@ -148,7 +148,6 @@ def test_import_assets(self): dashboard.owners = [] chart.owners = [] - dataset.owners = [] database.owners = [] db.session.delete(dashboard) db.session.delete(chart) @@ -165,6 +164,7 @@ def test_import_v1_dashboard_overwrite(self): "charts/imported_chart.yaml": yaml.safe_dump(chart_config), "dashboards/imported_dashboard.yaml": yaml.safe_dump(dashboard_config), } + command = ImportAssetsCommand(contents) command.run() chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one() @@ -193,7 +193,6 @@ def test_import_v1_dashboard_overwrite(self): dashboard.owners = [] chart.owners = [] - dataset.owners = [] database.owners = [] db.session.delete(dashboard) db.session.delete(chart) diff --git a/tests/integration_tests/dashboards/commands_tests.py b/tests/integration_tests/dashboards/commands_tests.py index ad9152585e90a..f9ff4e7dae170 100644 --- a/tests/integration_tests/dashboards/commands_tests.py +++ b/tests/integration_tests/dashboards/commands_tests.py @@ -575,7 +575,6 @@ def test_import_v1_dashboard(self, sm_g, utils_g): dashboard.owners = [] chart.owners = [] - dataset.owners = [] database.owners = [] db.session.delete(dashboard) db.session.delete(chart) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index e7f33e4a6450a..1f3c6d06502d9 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2135,7 +2135,6 @@ def test_import_database(self): assert dataset.table_name == "imported_dataset" assert str(dataset.uuid) == dataset_config["uuid"] - dataset.owners = [] db.session.delete(dataset) db.session.commit() db.session.delete(database) @@ -2206,7 +2205,6 @@ def test_import_database_overwrite(self): db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) dataset = database.tables[0] - dataset.owners = [] db.session.delete(dataset) db.session.commit() db.session.delete(database) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index c0d86a876abb4..d31b8f0b28ed1 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -2099,7 +2099,6 @@ def test_import_dataset(self): assert dataset.table_name == "imported_dataset" assert str(dataset.uuid) == dataset_config["uuid"] - dataset.owners = [] db.session.delete(dataset) db.session.commit() db.session.delete(database) @@ -2201,7 +2200,6 @@ def test_import_dataset_overwrite(self): ) dataset = database.tables[0] - dataset.owners = [] db.session.delete(dataset) db.session.commit() db.session.delete(database) diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 34a0625b36926..e43e861b59322 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -397,7 +397,6 @@ def test_import_v1_dataset(self, sm_g, utils_g): assert column.description is None assert column.python_date_format is None - dataset.owners = [] dataset.database.owners = [] db.session.delete(dataset) db.session.delete(dataset.database) @@ -526,7 +525,6 @@ def test_import_v1_dataset_existing_database(self, mock_g): ) assert len(database.tables) == 1 - database.tables[0].owners = [] database.owners = [] db.session.delete(database.tables[0]) db.session.delete(database) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 854a0c9be020b..91419010b742d 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -189,11 +189,6 @@ def test_extra_cache_keys(self, flask_g): self.assertTrue(table3.has_extra_cache_key_calls(query_obj)) assert extra_cache_keys == ["abc"] - # Cleanup - for table in [table1, table2, table3]: - db.session.delete(table) - db.session.commit() - @patch("superset.jinja_context.g") def test_jinja_metrics_and_calc_columns(self, flask_g): flask_g.user.username = "abc" @@ -430,7 +425,7 @@ def test_multiple_sql_statements_raises_exception(self): } table = SqlaTable( - table_name="test_has_extra_cache_keys_table", + table_name="test_multiple_sql_statements", sql="SELECT 'foo' as grp, 1 as num; SELECT 'bar' as grp, 2 as num", database=get_example_database(), ) @@ -451,7 +446,7 @@ def test_dml_statement_raises_exception(self): } table = SqlaTable( - table_name="test_has_extra_cache_keys_table", + table_name="test_dml_statement", sql="DELETE FROM foo", database=get_example_database(), )