From 6b0eb405d151026ee9b49ef5283b8d32228dc0ba Mon Sep 17 00:00:00 2001 From: John Bodley Date: Mon, 4 Dec 2023 13:20:28 -0800 Subject: [PATCH] chore: Leverage Flask-SQLAlchemy session --- scripts/benchmark_migration.py | 17 +- superset/cachekeys/api.py | 1 - superset/commands/dashboard/importers/v0.py | 32 +- superset/commands/explore/get.py | 3 +- superset/commands/utils.py | 3 +- superset/common/query_context_factory.py | 5 +- superset/common/query_object_factory.py | 6 - superset/connectors/sqla/models.py | 14 +- superset/daos/dashboard.py | 7 +- superset/daos/datasource.py | 6 +- superset/datasource/api.py | 4 +- superset/models/dashboard.py | 57 +-- superset/security/manager.py | 12 +- superset/utils/database.py | 5 +- superset/utils/log.py | 12 +- superset/utils/mock_data.py | 11 +- superset/views/core.py | 10 +- superset/views/datasource/utils.py | 3 +- superset/views/datasource/views.py | 6 +- superset/views/utils.py | 4 +- tests/integration_tests/datasource_tests.py | 4 +- .../integration_tests/query_context_tests.py | 2 - tests/integration_tests/security_tests.py | 461 +++++++++--------- .../common/test_query_object_factory.py | 9 +- tests/unit_tests/conftest.py | 1 - tests/unit_tests/datasource/dao_tests.py | 9 +- 26 files changed, 315 insertions(+), 389 deletions(-) diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py index 466fab6f130e6..90d94853dccb8 100644 --- a/scripts/benchmark_migration.py +++ b/scripts/benchmark_migration.py @@ -142,8 +142,6 @@ def main( filepath: str, limit: int = 1000, force: bool = False, no_auto_cleanup: bool = False ) -> None: auto_cleanup = not no_auto_cleanup - session = db.session() - print(f"Importing migration script: {filepath}") module = import_migration_script(Path(filepath)) @@ -174,10 +172,9 @@ def main( models = find_models(module) model_rows: dict[type[Model], int] = {} for model in models: - rows = session.query(model).count() + rows = db.session.query(model).count() print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})") model_rows[model] = rows - session.close() print("Benchmarking migration") results: dict[str, float] = {} @@ -199,16 +196,16 @@ def main( print(f"- Adding {missing} entities to the {model.__name__} model") bar = ChargingBar("Processing", max=missing) try: - for entity in add_sample_rows(session, model, missing): + for entity in add_sample_rows(model, missing): entities.append(entity) bar.next() except Exception: - session.rollback() + db.session.rollback() raise bar.finish() model_rows[model] = min_entities - session.add_all(entities) - session.commit() + db.session.add_all(entities) + db.session.commit() if auto_cleanup: new_models[model].extend(entities) @@ -227,10 +224,10 @@ def main( print("Cleaning up DB") # delete in reverse order of creation to handle relationships for model, entities in list(new_models.items())[::-1]: - session.query(model).filter( + db.session.query(model).filter( model.id.in_(entity.id for entity in entities) ).delete(synchronize_session=False) - session.commit() + db.session.commit() if current_revision != revision and not force: click.confirm(f"\nRevert DB to {revision}?", abort=True) diff --git a/superset/cachekeys/api.py b/superset/cachekeys/api.py index 9efc3d4c7a27e..40d3830e8bbbb 100644 --- a/superset/cachekeys/api.py +++ b/superset/cachekeys/api.py @@ -84,7 +84,6 @@ def invalidate(self) -> Response: datasource_uids = set(datasources.get("datasource_uids", [])) for ds in datasources.get("datasources", []): ds_obj = SqlaTable.get_datasource_by_name( - session=db.session, datasource_name=ds.get("datasource_name"), schema=ds.get("schema"), database_name=ds.get("database_name"), diff --git a/superset/commands/dashboard/importers/v0.py b/superset/commands/dashboard/importers/v0.py index 4c2a18e5cc694..bd7aaa4c90381 100644 --- a/superset/commands/dashboard/importers/v0.py +++ b/superset/commands/dashboard/importers/v0.py @@ -22,7 +22,7 @@ from typing import Any, Optional from flask_babel import lazy_gettext as _ -from sqlalchemy.orm import make_transient, Session +from sqlalchemy.orm import make_transient from superset import db from superset.commands.base import BaseCommand @@ -55,7 +55,6 @@ def import_chart( :returns: The resulting id for the imported slice :rtype: int """ - session = db.session make_transient(slc_to_import) slc_to_import.dashboards = [] slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time) @@ -64,7 +63,6 @@ def import_chart( slc_to_import.reset_ownership() params = slc_to_import.params_dict datasource = SqlaTable.get_datasource_by_name( - session=session, datasource_name=params["datasource_name"], database_name=params["database_name"], schema=params["schema"], @@ -72,11 +70,11 @@ def import_chart( slc_to_import.datasource_id = datasource.id # type: ignore if slc_to_override: slc_to_override.override(slc_to_import) - session.flush() + db.session.flush() return slc_to_override.id - session.add(slc_to_import) + db.session.add(slc_to_import) logger.info("Final slice: %s", str(slc_to_import.to_json())) - session.flush() + db.session.flush() return slc_to_import.id @@ -156,7 +154,6 @@ def alter_native_filters(dashboard: Dashboard) -> None: dashboard.json_metadata = json.dumps(json_metadata) logger.info("Started import of the dashboard: %s", dashboard_to_import.to_json()) - session = db.session logger.info("Dashboard has %d slices", len(dashboard_to_import.slices)) # copy slices object as Slice.import_slice will mutate the slice # and will remove the existing dashboard - slice association @@ -173,7 +170,7 @@ def alter_native_filters(dashboard: Dashboard) -> None: i_params_dict = dashboard_to_import.params_dict remote_id_slice_map = { slc.params_dict["remote_id"]: slc - for slc in session.query(Slice).all() + for slc in db.session.query(Slice).all() if "remote_id" in slc.params_dict } for slc in slices: @@ -224,7 +221,7 @@ def alter_native_filters(dashboard: Dashboard) -> None: # override the dashboard existing_dashboard = None - for dash in session.query(Dashboard).all(): + for dash in db.session.query(Dashboard).all(): if ( "remote_id" in dash.params_dict and dash.params_dict["remote_id"] == dashboard_to_import.id @@ -253,18 +250,20 @@ def alter_native_filters(dashboard: Dashboard) -> None: alter_native_filters(dashboard_to_import) new_slices = ( - session.query(Slice).filter(Slice.id.in_(old_to_new_slc_id_dict.values())).all() + db.session.query(Slice) + .filter(Slice.id.in_(old_to_new_slc_id_dict.values())) + .all() ) if existing_dashboard: existing_dashboard.override(dashboard_to_import) existing_dashboard.slices = new_slices - session.flush() + db.session.flush() return existing_dashboard.id dashboard_to_import.slices = new_slices - session.add(dashboard_to_import) - session.flush() + db.session.add(dashboard_to_import) + db.session.flush() return dashboard_to_import.id # type: ignore @@ -291,7 +290,6 @@ def decode_dashboards(o: dict[str, Any]) -> Any: def import_dashboards( - session: Session, content: str, database_id: Optional[int] = None, import_time: Optional[int] = None, @@ -308,10 +306,10 @@ def import_dashboards( params = json.loads(table.params) dataset_id_mapping[params["remote_id"]] = new_dataset_id - session.commit() + db.session.commit() for dashboard in data["dashboards"]: import_dashboard(dashboard, dataset_id_mapping, import_time=import_time) - session.commit() + db.session.commit() class ImportDashboardsCommand(BaseCommand): @@ -334,7 +332,7 @@ def run(self) -> None: for file_name, content in self.contents.items(): logger.info("Importing dashboard from file %s", file_name) - import_dashboards(db.session, content, self.database_id) + import_dashboards(content, self.database_id) def validate(self) -> None: # ensure all files are JSON diff --git a/superset/commands/explore/get.py b/superset/commands/explore/get.py index bb8f5a85e9e8a..9d715bd63dc5f 100644 --- a/superset/commands/explore/get.py +++ b/superset/commands/explore/get.py @@ -24,7 +24,6 @@ from flask_babel import lazy_gettext as _ from sqlalchemy.exc import SQLAlchemyError -from superset import db from superset.commands.base import BaseCommand from superset.commands.explore.form_data.get import GetFormDataCommand from superset.commands.explore.form_data.parameters import ( @@ -114,7 +113,7 @@ def run(self) -> Optional[dict[str, Any]]: if self._datasource_id is not None: with contextlib.suppress(DatasourceNotFound): datasource = DatasourceDAO.get_datasource( - db.session, cast(str, self._datasource_type), self._datasource_id + cast(str, self._datasource_type), self._datasource_id ) datasource_name = datasource.name if datasource else _("[Missing Dataset]") viz_type = form_data.get("viz_type") diff --git a/superset/commands/utils.py b/superset/commands/utils.py index 8cfeab3c1148d..b7121ec89f0e7 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -29,7 +29,6 @@ ) from superset.daos.datasource import DatasourceDAO from superset.daos.exceptions import DatasourceNotFound -from superset.extensions import db from superset.utils.core import DatasourceType, get_user_id if TYPE_CHECKING: @@ -80,7 +79,7 @@ def populate_roles(role_ids: list[int] | None = None) -> list[Role]: def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource: try: return DatasourceDAO.get_datasource( - db.session, DatasourceType(datasource_type), datasource_id + DatasourceType(datasource_type), datasource_id ) except DatasourceNotFound as ex: raise DatasourceNotFoundValidationError() from ex diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 708907d4a91ab..fd18b8f90a10f 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -18,7 +18,7 @@ from typing import Any, TYPE_CHECKING -from superset import app, db +from superset import app from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext from superset.common.query_object import QueryObject @@ -35,7 +35,7 @@ def create_query_object_factory() -> QueryObjectFactory: - return QueryObjectFactory(config, DatasourceDAO(), db.session) + return QueryObjectFactory(config, DatasourceDAO()) class QueryContextFactory: # pylint: disable=too-few-public-methods @@ -95,7 +95,6 @@ def create( def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: return DatasourceDAO.get_datasource( - session=db.session, datasource_type=DatasourceType(datasource["type"]), datasource_id=int(datasource["id"]), ) diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index d2aa140dfe933..fe4cca3f4889c 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -33,8 +33,6 @@ ) if TYPE_CHECKING: - from sqlalchemy.orm import sessionmaker - from superset.connectors.sqla.models import BaseDatasource from superset.daos.datasource import DatasourceDAO @@ -42,17 +40,14 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods _config: dict[str, Any] _datasource_dao: DatasourceDAO - _session_maker: sessionmaker def __init__( self, app_configurations: dict[str, Any], _datasource_dao: DatasourceDAO, - session_maker: sessionmaker, ): self._config = app_configurations self._datasource_dao = _datasource_dao - self._session_maker = session_maker def create( # pylint: disable=too-many-arguments self, @@ -91,7 +86,6 @@ def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: return self._datasource_dao.get_datasource( datasource_type=DatasourceType(datasource["type"]), datasource_id=int(datasource["id"]), - session=self._session_maker(), ) def _process_extras( diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 55abaaf68bf07..624eb2ce5a530 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -699,7 +699,7 @@ def raise_for_access(self) -> None: @classmethod def get_datasource_by_name( - cls, session: Session, datasource_name: str, schema: str, database_name: str + cls, datasource_name: str, schema: str, database_name: str ) -> BaseDatasource | None: raise NotImplementedError() @@ -1238,14 +1238,13 @@ def database_name(self) -> str: @classmethod def get_datasource_by_name( cls, - session: Session, datasource_name: str, schema: str | None, database_name: str, ) -> SqlaTable | None: schema = schema or None query = ( - session.query(cls) + db.session.query(cls) .join(Database) .filter(cls.table_name == datasource_name) .filter(Database.database_name == database_name) @@ -1939,12 +1938,10 @@ def query_datasources_by_permissions( # pylint: disable=invalid-name ) @classmethod - def get_eager_sqlatable_datasource( - cls, session: Session, datasource_id: int - ) -> SqlaTable: + def get_eager_sqlatable_datasource(cls, datasource_id: int) -> SqlaTable: """Returns SqlaTable with columns and metrics.""" return ( - session.query(cls) + db.session.query(cls) .options( sa.orm.subqueryload(cls.columns), sa.orm.subqueryload(cls.metrics), @@ -2037,8 +2034,7 @@ def update_column( # pylint: disable=unused-argument :param connection: Unused. :param target: The metric or column that was updated. """ - inspector = inspect(target) - session = inspector.session + session = inspect(target).session # Forces an update to the table's changed_on value when a metric or column on the # table is updated. This busts the cache key for all charts that use the table. diff --git a/superset/daos/dashboard.py b/superset/daos/dashboard.py index e0dffa73c3f64..eef46362e2d9a 100644 --- a/superset/daos/dashboard.py +++ b/superset/daos/dashboard.py @@ -170,7 +170,7 @@ def validate_update_slug_uniqueness(dashboard_id: int, slug: str | None) -> bool return True @staticmethod - def set_dash_metadata( # pylint: disable=too-many-locals + def set_dash_metadata( dashboard: Dashboard, data: dict[Any, Any], old_to_new_slice_ids: dict[int, int] | None = None, @@ -187,8 +187,9 @@ def set_dash_metadata( # pylint: disable=too-many-locals if isinstance(value, dict) ] - session = db.session() - current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + current_slices = ( + db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + ) dashboard.slices = current_slices diff --git a/superset/daos/datasource.py b/superset/daos/datasource.py index 2bdf4ca21fb7a..0e6058d6abc88 100644 --- a/superset/daos/datasource.py +++ b/superset/daos/datasource.py @@ -18,8 +18,7 @@ import logging from typing import Union -from sqlalchemy.orm import Session - +from superset import db from superset.connectors.sqla.models import SqlaTable from superset.daos.base import BaseDAO from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError @@ -45,7 +44,6 @@ class DatasourceDAO(BaseDAO[Datasource]): @classmethod def get_datasource( cls, - session: Session, datasource_type: Union[DatasourceType, str], datasource_id: int, ) -> Datasource: @@ -53,7 +51,7 @@ def get_datasource( raise DatasourceTypeNotSupportedError() datasource = ( - session.query(cls.sources[datasource_type]) + db.session.query(cls.sources[datasource_type]) .filter_by(id=datasource_id) .one_or_none() ) diff --git a/superset/datasource/api.py b/superset/datasource/api.py index 6943d00bc75ec..31e8c503ee0fd 100644 --- a/superset/datasource/api.py +++ b/superset/datasource/api.py @@ -18,7 +18,7 @@ from flask_appbuilder.api import expose, protect, safe -from superset import app, db, event_logger +from superset import app, event_logger from superset.daos.datasource import DatasourceDAO from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError from superset.exceptions import SupersetSecurityException @@ -100,7 +100,7 @@ def get_column_values( """ try: datasource = DatasourceDAO.get_datasource( - db.session, DatasourceType(datasource_type), datasource_id + DatasourceType(datasource_type), datasource_id ) datasource.raise_for_access() except ValueError: diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 48ee403c4841a..1e9b73e0bfe87 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -39,7 +39,7 @@ UniqueConstraint, ) from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import relationship, sessionmaker, subqueryload +from sqlalchemy.orm import relationship, subqueryload from sqlalchemy.orm.mapper import Mapper from sqlalchemy.sql import join, select from sqlalchemy.sql.elements import BinaryExpression @@ -62,38 +62,33 @@ logger = logging.getLogger(__name__) -def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -> None: +def copy_dashboard(_mapper: Mapper, _connection: Connection, target: Dashboard) -> None: dashboard_id = config["DASHBOARD_TEMPLATE_ID"] if dashboard_id is None: return - session_class = sessionmaker(autoflush=False) - session = session_class(bind=connection) - - try: - new_user = session.query(User).filter_by(id=target.id).first() - - # copy template dashboard to user - template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first() - dashboard = Dashboard( - dashboard_title=template.dashboard_title, - position_json=template.position_json, - description=template.description, - css=template.css, - json_metadata=template.json_metadata, - slices=template.slices, - owners=[new_user], - ) - session.add(dashboard) + session = sqla.inspect(target).session + new_user = session.query(User).filter_by(id=target.id).first() + + # copy template dashboard to user + template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first() + dashboard = Dashboard( + dashboard_title=template.dashboard_title, + position_json=template.position_json, + description=template.description, + css=template.css, + json_metadata=template.json_metadata, + slices=template.slices, + owners=[new_user], + ) + session.add(dashboard) - # set dashboard as the welcome dashboard - extra_attributes = UserAttribute( - user_id=target.id, welcome_dashboard_id=dashboard.id - ) - session.add(extra_attributes) - session.commit() - finally: - session.close() + # set dashboard as the welcome dashboard + extra_attributes = UserAttribute( + user_id=target.id, welcome_dashboard_id=dashboard.id + ) + session.add(extra_attributes) + session.commit() sqla.event.listen(User, "after_insert", copy_dashboard) @@ -397,7 +392,7 @@ def export_dashboards( # pylint: disable=too-many-locals if id_ is None: continue datasource = DatasourceDAO.get_datasource( - db.session, utils.DatasourceType.TABLE, id_ + utils.DatasourceType.TABLE, id_ ) datasource_ids.add((datasource.id, datasource.type)) @@ -406,9 +401,7 @@ def export_dashboards( # pylint: disable=too-many-locals eager_datasources = [] for datasource_id, _ in datasource_ids: - eager_datasource = SqlaTable.get_eager_sqlatable_datasource( - db.session, datasource_id - ) + eager_datasource = SqlaTable.get_eager_sqlatable_datasource(datasource_id) copied_datasource = eager_datasource.copy() copied_datasource.alter_params( remote_id=eager_datasource.id, diff --git a/superset/security/manager.py b/superset/security/manager.py index 5eb1afdda99a2..a303c8e90e106 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -48,7 +48,7 @@ from jwt.api_jwt import _jwt_global_obj from sqlalchemy import and_, inspect, or_ from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import eagerload, Session +from sqlalchemy.orm import eagerload from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query as SqlaQuery @@ -545,8 +545,7 @@ def get_user_datasources(self) -> list["BaseDatasource"]: ) # group all datasources by database - session = self.get_session - all_datasources = SqlaTable.get_all_datasources(session) + all_datasources = SqlaTable.get_all_datasources(self.get_session) datasources_by_database: dict["Database", set["SqlaTable"]] = defaultdict(set) for datasource in all_datasources: datasources_by_database[datasource.database].add(datasource) @@ -2001,17 +2000,14 @@ def raise_for_access( self.get_dashboard_access_error_object(dashboard) ) - def get_user_by_username( - self, username: str, session: Session = None - ) -> Optional[User]: + def get_user_by_username(self, username: str) -> Optional[User]: """ Retrieves a user by it's username case sensitive. Optional session parameter utility method normally useful for celery tasks where the session need to be scoped """ - session = session or self.get_session return ( - session.query(self.user_model) + self.get_session.query(self.user_model) .filter(self.user_model.username == username) .one_or_none() ) diff --git a/superset/utils/database.py b/superset/utils/database.py index b34dda1164a45..073e58ffda6fb 100644 --- a/superset/utils/database.py +++ b/superset/utils/database.py @@ -79,6 +79,5 @@ def remove_database(database: Database) -> None: # pylint: disable=import-outside-toplevel from superset import db - session = db.session - session.delete(database) - session.commit() + db.session.delete(database) + db.session.commit() diff --git a/superset/utils/log.py b/superset/utils/log.py index 5430accb43ac2..1de599bf08233 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -27,7 +27,7 @@ from datetime import datetime, timedelta from typing import Any, Callable, cast, Literal, TYPE_CHECKING -from flask import current_app, g, request +from flask import g, request from flask_appbuilder.const import API_URI_RIS_KEY from sqlalchemy.exc import SQLAlchemyError @@ -139,6 +139,7 @@ def log_with_context( # pylint: disable=too-many-locals **payload_override: dict[str, Any] | None, ) -> None: # pylint: disable=import-outside-toplevel + from superset import db from superset.views.core import get_form_data referrer = request.referrer[:1000] if request and request.referrer else None @@ -152,8 +153,7 @@ def log_with_context( # pylint: disable=too-many-locals # need to add them back before logging to capture user_id if user_id is None: try: - session = current_app.appbuilder.get_session - session.add(g.user) + db.session.add(g.user) user_id = get_user_id() except Exception as ex: # pylint: disable=broad-except logging.warning(ex) @@ -332,6 +332,7 @@ def log( # pylint: disable=too-many-arguments,too-many-locals **kwargs: Any, ) -> None: # pylint: disable=import-outside-toplevel + from superset import db from superset.models.core import Log records = kwargs.get("records", []) @@ -353,9 +354,8 @@ def log( # pylint: disable=too-many-arguments,too-many-locals ) logs.append(log) try: - sesh = current_app.appbuilder.get_session - sesh.bulk_save_objects(logs) - sesh.commit() + db.session.bulk_save_objects(logs) + db.session.commit() except SQLAlchemyError as ex: logging.error("DBEventLogger failed to log event(s)") logging.exception(ex) diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index fd4961421585d..67bd9ad73e106 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -31,7 +31,6 @@ from flask_appbuilder import Model from sqlalchemy import Column, inspect, MetaData, Table from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import Session from sqlalchemy.sql import func from sqlalchemy.sql.visitors import VisitableType @@ -231,12 +230,10 @@ def generate_column_data(column: ColumnInfo, num_rows: int) -> list[Any]: return [gen() for _ in range(num_rows)] -def add_sample_rows( - session: Session, model: type[Model], count: int -) -> Iterator[Model]: +def add_sample_rows(model: type[Model], count: int) -> Iterator[Model]: """ Add entities of a given model. - :param Session session: an SQLAlchemy session + :param Model model: a Superset/FAB model :param int count: how many entities to generate and insert """ @@ -244,7 +241,7 @@ def add_sample_rows( # select samples to copy relationship values relationships = inspector.relationships.items() - samples = session.query(model).limit(count).all() if relationships else [] + samples = db.session.query(model).limit(count).all() if relationships else [] max_primary_key: Optional[int] = None for i in range(count): @@ -255,7 +252,7 @@ def add_sample_rows( if column.primary_key: if max_primary_key is None: max_primary_key = ( - session.query(func.max(getattr(model, column.name))).scalar() + db.session.query(func.max(getattr(model, column.name))).scalar() or 0 ) max_primary_key += 1 diff --git a/superset/views/core.py b/superset/views/core.py index 9ad2f63fdc680..5d23164fabb37 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -510,7 +510,6 @@ def explore( if datasource_id is not None: with contextlib.suppress(DatasetNotFoundError): datasource = DatasourceDAO.get_datasource( - db.session, DatasourceType("table"), datasource_id, ) @@ -751,7 +750,6 @@ def warm_up_cache(self) -> FlaskResponse: In terms of the `extra_filters` these can be obtained from records in the JSON encoded `logs.json` column associated with the `explore_json` action. """ - session = db.session() slice_id = request.args.get("slice_id") dashboard_id = request.args.get("dashboard_id") table_name = request.args.get("table_name") @@ -768,14 +766,14 @@ def warm_up_cache(self) -> FlaskResponse: status=400, ) if slice_id: - slices = session.query(Slice).filter_by(id=slice_id).all() + slices = db.session.query(Slice).filter_by(id=slice_id).all() if not slices: return json_error_response( __("Chart %(id)s not found", id=slice_id), status=404 ) elif table_name and db_name: table = ( - session.query(SqlaTable) + db.session.query(SqlaTable) .join(Database) .filter( Database.database_name == db_name @@ -792,7 +790,7 @@ def warm_up_cache(self) -> FlaskResponse: status=404, ) slices = ( - session.query(Slice) + db.session.query(Slice) .filter_by(datasource_id=table.id, datasource_type=table.type) .all() ) @@ -919,7 +917,7 @@ def fetch_datasource_metadata(self) -> FlaskResponse: """ datasource_id, datasource_type = request.args["datasourceKey"].split("__") datasource = DatasourceDAO.get_datasource( - db.session, DatasourceType(datasource_type), int(datasource_id) + DatasourceType(datasource_type), int(datasource_id) ) # Check if datasource exists if not datasource: diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index b08d1ccc1528d..6bad2370c8758 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -16,7 +16,7 @@ # under the License. from typing import Any, Optional -from superset import app, db +from superset import app from superset.commands.dataset.exceptions import DatasetSamplesFailedError from superset.common.chart_data import ChartDataResultType from superset.common.query_context_factory import QueryContextFactory @@ -52,7 +52,6 @@ def get_samples( # pylint: disable=too-many-arguments payload: Optional[SamplesPayloadSchema] = None, ) -> dict[str, Any]: datasource = DatasourceDAO.get_datasource( - session=db.session, datasource_type=datasource_type, datasource_id=datasource_id, ) diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index b911d2ea3f116..2e46faf0af9dc 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -83,7 +83,7 @@ def save(self) -> FlaskResponse: datasource_type = datasource_dict.get("type") database_id = datasource_dict["database"].get("id") orm_datasource = DatasourceDAO.get_datasource( - db.session, DatasourceType(datasource_type), datasource_id + DatasourceType(datasource_type), datasource_id ) orm_datasource.database_id = database_id @@ -126,7 +126,7 @@ def save(self) -> FlaskResponse: @deprecated(new_target="/api/v1/dataset/") def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse: datasource = DatasourceDAO.get_datasource( - db.session, DatasourceType(datasource_type), datasource_id + DatasourceType(datasource_type), datasource_id ) return self.json_response(sanitize_datasource_data(datasource.data)) @@ -139,7 +139,6 @@ def external_metadata( ) -> FlaskResponse: """Gets column info from the source system""" datasource = DatasourceDAO.get_datasource( - db.session, DatasourceType(datasource_type), datasource_id, ) @@ -164,7 +163,6 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse: return json_error_response(str(err), status=400) datasource = SqlaTable.get_datasource_by_name( - session=db.session, database_name=params["database_name"], schema=params["schema_name"], datasource_name=params["table_name"], diff --git a/superset/views/utils.py b/superset/views/utils.py index db5b3b53467f5..574fedb66b7d0 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -129,7 +129,6 @@ def get_viz( ) -> BaseViz: viz_type = form_data.get("viz_type", "table") datasource = DatasourceDAO.get_datasource( - db.session, DatasourceType(datasource_type), datasource_id, ) @@ -312,8 +311,7 @@ def apply_display_max_row_limit( def get_dashboard_extra_filters( slice_id: int, dashboard_id: int ) -> list[dict[str, Any]]: - session = db.session() - dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none() + dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none() # is chart in this dashboard? if ( diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 5ab81b58d12cd..dce33ea2ccaea 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -474,7 +474,7 @@ def test_get_datasource_failed(self): pytest.raises( DatasourceNotFound, - lambda: DatasourceDAO.get_datasource(db.session, "table", 9999999), + lambda: DatasourceDAO.get_datasource("table", 9999999), ) self.login(username="admin") @@ -486,7 +486,7 @@ def test_get_datasource_invalid_datasource_failed(self): pytest.raises( DatasourceTypeNotSupportedError, - lambda: DatasourceDAO.get_datasource(db.session, "druid", 9999999), + lambda: DatasourceDAO.get_datasource("druid", 9999999), ) self.login(username="admin") diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 30cd160d7ee58..0d6d69e4ce9e7 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -145,7 +145,6 @@ def test_query_cache_key_changes_when_datasource_is_updated(self): # make temporary change and revert it to refresh the changed_on property datasource = DatasourceDAO.get_datasource( - session=db.session, datasource_type=DatasourceType(payload["datasource"]["type"]), datasource_id=payload["datasource"]["id"], ) @@ -169,7 +168,6 @@ def test_query_cache_key_changes_when_metric_is_updated(self): # make temporary change and revert it to refresh the changed_on property datasource = DatasourceDAO.get_datasource( - session=db.session, datasource_type=DatasourceType(payload["datasource"]["type"]), datasource_id=payload["datasource"]["id"], ) diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 9eaabf3680ec4..5b51488e02223 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -108,9 +108,8 @@ class TestRolePermission(SupersetTestCase): def setUp(self): schema = get_example_default_schema() - session = db.session security_manager.add_role(SCHEMA_ACCESS_ROLE) - session.commit() + db.session.commit() ds = ( db.session.query(SqlaTable) @@ -121,7 +120,7 @@ def setUp(self): ds.schema_perm = ds.get_schema_perm() ds_slices = ( - session.query(Slice) + db.session.query(Slice) .filter_by(datasource_type=DatasourceType.TABLE) .filter_by(datasource_id=ds.id) .all() @@ -131,12 +130,11 @@ def setUp(self): create_schema_perm("[examples].[temp_schema]") gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE)) - session.commit() + db.session.commit() def tearDown(self): - session = db.session ds = ( - session.query(SqlaTable) + db.session.query(SqlaTable) .filter_by(table_name="wb_health_population", schema="temp_schema") .first() ) @@ -144,7 +142,7 @@ def tearDown(self): ds.schema = get_example_default_schema() ds.schema_perm = None ds_slices = ( - session.query(Slice) + db.session.query(Slice) .filter_by(datasource_type=DatasourceType.TABLE) .filter_by(datasource_id=ds.id) .all() @@ -153,26 +151,25 @@ def tearDown(self): s.schema_perm = None delete_schema_perm(schema_perm) - session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) - session.commit() + db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) + db.session.commit() def test_after_insert_dataset(self): security_manager.on_view_menu_after_insert = Mock() security_manager.on_permission_view_after_insert = Mock() - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) + db.session.add(tmp_db1) table = SqlaTable( schema="tmp_schema", table_name="tmp_perm_table", database=tmp_db1, ) - session.add(table) - session.commit() + db.session.add(table) + db.session.commit() - table = session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() + table = db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() self.assertEqual(table.perm, f"[tmp_db1].[tmp_perm_table](id:{table.id})") pvm_dataset = security_manager.find_permission_view_menu( @@ -200,54 +197,54 @@ def test_after_insert_dataset(self): ) # Cleanup - session.delete(table) - session.delete(tmp_db1) - session.commit() + db.session.delete(table) + db.session.delete(tmp_db1) + db.session.commit() def test_after_insert_dataset_rollback(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() + db.session.add(tmp_db1) + db.session.commit() table = SqlaTable( schema="tmp_schema", table_name="tmp_table", database=tmp_db1, ) - session.add(table) - session.flush() + db.session.add(table) + db.session.flush() pvm_dataset = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table](id:{table.id})" ) self.assertIsNotNone(pvm_dataset) table_id = table.id - session.rollback() + db.session.rollback() - table = session.query(SqlaTable).filter_by(table_name="tmp_table").one_or_none() + table = ( + db.session.query(SqlaTable).filter_by(table_name="tmp_table").one_or_none() + ) self.assertIsNone(table) pvm_dataset = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table](id:{table_id})" ) self.assertIsNone(pvm_dataset) - session.delete(tmp_db1) - session.commit() + db.session.delete(tmp_db1) + db.session.commit() def test_after_insert_dataset_table_none(self): - session = db.session table = SqlaTable( schema="tmp_schema", table_name="tmp_perm_table", # Setting database_id instead of database will skip permission creation database_id=get_example_database().id, ) - session.add(table) - session.commit() + db.session.add(table) + db.session.commit() stored_table = ( - session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() + db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() ) # Assert permission is created self.assertIsNotNone( @@ -263,17 +260,16 @@ def test_after_insert_dataset_table_none(self): ) # Cleanup - session.delete(table) - session.commit() + db.session.delete(table) + db.session.commit() def test_after_insert_database(self): security_manager.on_permission_view_after_insert = Mock() - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) + db.session.add(tmp_db1) - tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() self.assertEqual(tmp_db1.perm, f"[tmp_db1].(id:{tmp_db1.id})") tmp_db1_pvm = security_manager.find_permission_view_menu( "database_access", tmp_db1.perm @@ -288,20 +284,19 @@ def test_after_insert_database(self): ) call_args = security_manager.on_permission_view_after_insert.call_args assert call_args.args[2].id == tmp_db1_pvm.id - session.delete(tmp_db1) - session.commit() + db.session.delete(tmp_db1) + db.session.commit() def test_after_insert_database_rollback(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.flush() + db.session.add(tmp_db1) + db.session.flush() pvm_database = security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) self.assertIsNotNone(pvm_database) - session.rollback() + db.session.rollback() pvm_database = security_manager.find_permission_view_menu( "database_access", f"[tmp_db1](id:{tmp_db1.id})" @@ -311,18 +306,17 @@ def test_after_insert_database_rollback(self): def test_after_update_database__perm_database_access(self): security_manager.on_view_menu_after_update = Mock() - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() - tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + db.session.add(tmp_db1) + db.session.commit() + tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() self.assertIsNotNone( security_manager.find_permission_view_menu("database_access", tmp_db1.perm) ) tmp_db1.database_name = "tmp_db2" - session.commit() + db.session.commit() # Assert that the old permission was updated self.assertIsNone( @@ -347,22 +341,21 @@ def test_after_update_database__perm_database_access(self): ] ) - session.delete(tmp_db1) - session.commit() + db.session.delete(tmp_db1) + db.session.commit() def test_after_update_database_rollback(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() - tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + db.session.add(tmp_db1) + db.session.commit() + tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() self.assertIsNotNone( security_manager.find_permission_view_menu("database_access", tmp_db1.perm) ) tmp_db1.database_name = "tmp_db2" - session.flush() + db.session.flush() # Assert that the old permission was updated self.assertIsNone( @@ -377,7 +370,7 @@ def test_after_update_database_rollback(self): ) ) - session.rollback() + db.session.rollback() self.assertIsNotNone( security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" @@ -390,19 +383,18 @@ def test_after_update_database_rollback(self): ) ) - session.delete(tmp_db1) - session.commit() + db.session.delete(tmp_db1) + db.session.commit() def test_after_update_database__perm_database_access_exists(self): security_manager.on_permission_view_after_delete = Mock() - session = db.session # Add a bogus existing permission before the change tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() - tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + db.session.add(tmp_db1) + db.session.commit() + tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() security_manager.add_permission_view_menu( "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) @@ -412,7 +404,7 @@ def test_after_update_database__perm_database_access_exists(self): ) tmp_db1.database_name = "tmp_db2" - session.commit() + db.session.commit() # Assert that the old permission was updated self.assertIsNone( @@ -433,41 +425,40 @@ def test_after_update_database__perm_database_access_exists(self): ] ) - session.delete(tmp_db1) - session.commit() + db.session.delete(tmp_db1) + db.session.commit() def test_after_update_database__perm_datasource_access(self): security_manager.on_view_menu_after_update = Mock() - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() + db.session.add(tmp_db1) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db1, ) - session.add(table1) + db.session.add(table1) table2 = SqlaTable( schema="tmp_schema", table_name="tmp_table2", database=tmp_db1, ) - session.add(table2) - session.commit() + db.session.add(table2) + db.session.commit() slice1 = Slice( datasource_id=table1.id, datasource_type=DatasourceType.TABLE, datasource_name="tmp_table1", slice_name="tmp_slice1", ) - session.add(slice1) - session.commit() - slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() - table2 = session.query(SqlaTable).filter_by(table_name="tmp_table2").one() + db.session.add(slice1) + db.session.commit() + slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table2 = db.session.query(SqlaTable).filter_by(table_name="tmp_table2").one() # assert initial perms self.assertIsNotNone( @@ -485,9 +476,9 @@ def test_after_update_database__perm_datasource_access(self): self.assertEqual(table2.perm, f"[tmp_db1].[tmp_table2](id:{table2.id})") # Refresh and update the database name - tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() tmp_db1.database_name = "tmp_db2" - session.commit() + db.session.commit() # Assert that the old permissions were updated self.assertIsNone( @@ -534,18 +525,17 @@ def test_after_update_database__perm_datasource_access(self): ] ) - session.delete(slice1) - session.delete(table1) - session.delete(table2) - session.delete(tmp_db1) - session.commit() + db.session.delete(slice1) + db.session.delete(table1) + db.session.delete(table2) + db.session.delete(tmp_db1) + db.session.commit() def test_after_delete_database(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() - tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + db.session.add(tmp_db1) + db.session.commit() + tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() database_pvm = security_manager.find_permission_view_menu( "database_access", tmp_db1.perm @@ -553,11 +543,11 @@ def test_after_delete_database(self): self.assertIsNotNone(database_pvm) role1 = Role(name="tmp_role1") role1.permissions.append(database_pvm) - session.add(role1) - session.commit() + db.session.add(role1) + db.session.commit() - session.delete(tmp_db1) - session.commit() + db.session.delete(tmp_db1) + db.session.commit() # Assert that PVM is removed from Role role1 = security_manager.find_role("tmp_role1") @@ -571,15 +561,14 @@ def test_after_delete_database(self): ) # Cleanup - session.delete(role1) - session.commit() + db.session.delete(role1) + db.session.commit() def test_after_delete_database_rollback(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() - tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + db.session.add(tmp_db1) + db.session.commit() + tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() database_pvm = security_manager.find_permission_view_menu( "database_access", tmp_db1.perm @@ -587,11 +576,11 @@ def test_after_delete_database_rollback(self): self.assertIsNotNone(database_pvm) role1 = Role(name="tmp_role1") role1.permissions.append(database_pvm) - session.add(role1) - session.commit() + db.session.add(role1) + db.session.commit() - session.delete(tmp_db1) - session.flush() + db.session.delete(tmp_db1) + db.session.flush() role1 = security_manager.find_role("tmp_role1") self.assertEqual(role1.permissions, []) @@ -602,7 +591,7 @@ def test_after_delete_database_rollback(self): ) ) - session.rollback() + db.session.rollback() # Test a rollback reverts everything database_pvm = security_manager.find_permission_view_menu( @@ -613,25 +602,24 @@ def test_after_delete_database_rollback(self): self.assertEqual(role1.permissions, [database_pvm]) # Cleanup - session.delete(role1) - session.delete(tmp_db1) - session.commit() + db.session.delete(role1) + db.session.delete(tmp_db1) + db.session.commit() def test_after_delete_dataset(self): security_manager.on_permission_view_after_delete = Mock() - session = db.session tmp_db = Database(database_name="tmp_db", sqlalchemy_uri="sqlite://") - session.add(tmp_db) - session.commit() + db.session.add(tmp_db) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db, ) - session.add(table1) - session.commit() + db.session.add(table1) + db.session.commit() table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" @@ -640,15 +628,15 @@ def test_after_delete_dataset(self): role1 = Role(name="tmp_role1") role1.permissions.append(table1_pvm) - session.add(role1) - session.commit() + db.session.add(role1) + db.session.commit() # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() # Test delete - session.delete(table1) - session.commit() + db.session.delete(table1) + db.session.commit() role1 = security_manager.find_role("tmp_role1") self.assertEqual(role1.permissions, []) @@ -670,23 +658,22 @@ def test_after_delete_dataset(self): ) # cleanup - session.delete(role1) - session.delete(tmp_db) - session.commit() + db.session.delete(role1) + db.session.delete(tmp_db) + db.session.commit() def test_after_delete_dataset_rollback(self): - session = db.session tmp_db = Database(database_name="tmp_db", sqlalchemy_uri="sqlite://") - session.add(tmp_db) - session.commit() + db.session.add(tmp_db) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db, ) - session.add(table1) - session.commit() + db.session.add(table1) + db.session.commit() table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" @@ -695,15 +682,15 @@ def test_after_delete_dataset_rollback(self): role1 = Role(name="tmp_role1") role1.permissions.append(table1_pvm) - session.add(role1) - session.commit() + db.session.add(role1) + db.session.commit() # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() # Test delete, permissions are correctly deleted - session.delete(table1) - session.flush() + db.session.delete(table1) + db.session.flush() role1 = security_manager.find_role("tmp_role1") self.assertEqual(role1.permissions, []) @@ -714,7 +701,7 @@ def test_after_delete_dataset_rollback(self): self.assertIsNone(table1_pvm) # Test rollback, permissions exist everything is correctly rollback - session.rollback() + db.session.rollback() role1 = security_manager.find_role("tmp_role1") table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" @@ -723,26 +710,25 @@ def test_after_delete_dataset_rollback(self): self.assertEqual(role1.permissions, [table1_pvm]) # cleanup - session.delete(table1) - session.delete(role1) - session.delete(tmp_db) - session.commit() + db.session.delete(table1) + db.session.delete(role1) + db.session.delete(tmp_db) + db.session.commit() def test_after_update_dataset__name_changes(self): security_manager.on_view_menu_after_update = Mock() - session = db.session tmp_db = Database(database_name="tmp_db", sqlalchemy_uri="sqlite://") - session.add(tmp_db) - session.commit() + db.session.add(tmp_db) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db, ) - session.add(table1) - session.commit() + db.session.add(table1) + db.session.commit() slice1 = Slice( datasource_id=table1.id, @@ -750,8 +736,8 @@ def test_after_update_dataset__name_changes(self): datasource_name="tmp_table1", slice_name="tmp_slice1", ) - session.add(slice1) - session.commit() + db.session.add(slice1) + db.session.commit() table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" @@ -759,10 +745,10 @@ def test_after_update_dataset__name_changes(self): self.assertIsNotNone(table1_pvm) # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() # Test update table1.table_name = "tmp_table1_changed" - session.commit() + db.session.commit() # Test old permission does not exist old_table1_pvm = security_manager.find_permission_view_menu( @@ -778,14 +764,14 @@ def test_after_update_dataset__name_changes(self): # test dataset permission changed changed_table1 = ( - session.query(SqlaTable).filter_by(table_name="tmp_table1_changed").one() + db.session.query(SqlaTable).filter_by(table_name="tmp_table1_changed").one() ) self.assertEqual( changed_table1.perm, f"[tmp_db].[tmp_table1_changed](id:{table1.id})" ) # Test Chart permission changed - slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() self.assertEqual(slice1.perm, f"[tmp_db].[tmp_table1_changed](id:{table1.id})") # Assert hook is called @@ -798,24 +784,23 @@ def test_after_update_dataset__name_changes(self): ] ) # cleanup - session.delete(slice1) - session.delete(table1) - session.delete(tmp_db) - session.commit() + db.session.delete(slice1) + db.session.delete(table1) + db.session.delete(tmp_db) + db.session.commit() def test_after_update_dataset_rollback(self): - session = db.session tmp_db = Database(database_name="tmp_db", sqlalchemy_uri="sqlite://") - session.add(tmp_db) - session.commit() + db.session.add(tmp_db) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db, ) - session.add(table1) - session.commit() + db.session.add(table1) + db.session.commit() slice1 = Slice( datasource_id=table1.id, @@ -823,14 +808,14 @@ def test_after_update_dataset_rollback(self): datasource_name="tmp_table1", slice_name="tmp_slice1", ) - session.add(slice1) - session.commit() + db.session.add(slice1) + db.session.commit() # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() # Test update table1.table_name = "tmp_table1_changed" - session.flush() + db.session.flush() # Test old permission does not exist old_table1_pvm = security_manager.find_permission_view_menu( @@ -845,7 +830,7 @@ def test_after_update_dataset_rollback(self): self.assertIsNotNone(new_table1_pvm) # Test rollback - session.rollback() + db.session.rollback() old_table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" @@ -853,26 +838,25 @@ def test_after_update_dataset_rollback(self): self.assertIsNotNone(old_table1_pvm) # cleanup - session.delete(slice1) - session.delete(table1) - session.delete(tmp_db) - session.commit() + db.session.delete(slice1) + db.session.delete(table1) + db.session.delete(tmp_db) + db.session.commit() def test_after_update_dataset__db_changes(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") tmp_db2 = Database(database_name="tmp_db2", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.add(tmp_db2) - session.commit() + db.session.add(tmp_db1) + db.session.add(tmp_db2) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db1, ) - session.add(table1) - session.commit() + db.session.add(table1) + db.session.commit() slice1 = Slice( datasource_id=table1.id, @@ -880,8 +864,8 @@ def test_after_update_dataset__db_changes(self): datasource_name="tmp_table1", slice_name="tmp_slice1", ) - session.add(slice1) - session.commit() + db.session.add(slice1) + db.session.commit() table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" @@ -889,10 +873,10 @@ def test_after_update_dataset__db_changes(self): self.assertIsNotNone(table1_pvm) # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() # Test update table1.database = tmp_db2 - session.commit() + db.session.commit() # Test old permission does not exist table1_pvm = security_manager.find_permission_view_menu( @@ -908,36 +892,35 @@ def test_after_update_dataset__db_changes(self): # test dataset permission and schema permission changed changed_table1 = ( - session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # Test Chart permission changed - slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() self.assertEqual(slice1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") self.assertEqual(slice1.schema_perm, f"[tmp_db2].[tmp_schema]") # cleanup - session.delete(slice1) - session.delete(table1) - session.delete(tmp_db1) - session.delete(tmp_db2) - session.commit() + db.session.delete(slice1) + db.session.delete(table1) + db.session.delete(tmp_db1) + db.session.delete(tmp_db2) + db.session.commit() def test_after_update_dataset__schema_changes(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() + db.session.add(tmp_db1) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db1, ) - session.add(table1) - session.commit() + db.session.add(table1) + db.session.commit() slice1 = Slice( datasource_id=table1.id, @@ -945,8 +928,8 @@ def test_after_update_dataset__schema_changes(self): datasource_name="tmp_table1", slice_name="tmp_slice1", ) - session.add(slice1) - session.commit() + db.session.add(slice1) + db.session.commit() table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" @@ -954,10 +937,10 @@ def test_after_update_dataset__schema_changes(self): self.assertIsNotNone(table1_pvm) # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() # Test update table1.schema = "tmp_schema_changed" - session.commit() + db.session.commit() # Test permission still exists table1_pvm = security_manager.find_permission_view_menu( @@ -967,35 +950,34 @@ def test_after_update_dataset__schema_changes(self): # test dataset schema permission changed changed_table1 = ( - session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") self.assertEqual(changed_table1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # Test Chart schema permission changed - slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") self.assertEqual(slice1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # cleanup - session.delete(slice1) - session.delete(table1) - session.delete(tmp_db1) - session.commit() + db.session.delete(slice1) + db.session.delete(table1) + db.session.delete(tmp_db1) + db.session.commit() def test_after_update_dataset__schema_none(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.commit() + db.session.add(tmp_db1) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db1, ) - session.add(table1) - session.commit() + db.session.add(table1) + db.session.commit() slice1 = Slice( datasource_id=table1.id, @@ -1003,8 +985,8 @@ def test_after_update_dataset__schema_none(self): datasource_name="tmp_table1", slice_name="tmp_slice1", ) - session.add(slice1) - session.commit() + db.session.add(slice1) + db.session.commit() table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" @@ -1012,38 +994,37 @@ def test_after_update_dataset__schema_none(self): self.assertIsNotNone(table1_pvm) # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() # Test update table1.schema = None - session.commit() + db.session.commit() # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() self.assertEqual(table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") self.assertIsNone(table1.schema_perm) # cleanup - session.delete(slice1) - session.delete(table1) - session.delete(tmp_db1) - session.commit() + db.session.delete(slice1) + db.session.delete(table1) + db.session.delete(tmp_db1) + db.session.commit() def test_after_update_dataset__name_db_changes(self): - session = db.session tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") tmp_db2 = Database(database_name="tmp_db2", sqlalchemy_uri="sqlite://") - session.add(tmp_db1) - session.add(tmp_db2) - session.commit() + db.session.add(tmp_db1) + db.session.add(tmp_db2) + db.session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", database=tmp_db1, ) - session.add(table1) - session.commit() + db.session.add(table1) + db.session.commit() slice1 = Slice( datasource_id=table1.id, @@ -1051,8 +1032,8 @@ def test_after_update_dataset__name_db_changes(self): datasource_name="tmp_table1", slice_name="tmp_slice1", ) - session.add(slice1) - session.commit() + db.session.add(slice1) + db.session.commit() table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" @@ -1060,11 +1041,11 @@ def test_after_update_dataset__name_db_changes(self): self.assertIsNotNone(table1_pvm) # refresh - table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() # Test update table1.table_name = "tmp_table1_changed" table1.database = tmp_db2 - session.commit() + db.session.commit() # Test old permission does not exist table1_pvm = security_manager.find_permission_view_menu( @@ -1080,7 +1061,7 @@ def test_after_update_dataset__name_db_changes(self): # test dataset permission and schema permission changed changed_table1 = ( - session.query(SqlaTable).filter_by(table_name="tmp_table1_changed").one() + db.session.query(SqlaTable).filter_by(table_name="tmp_table1_changed").one() ) self.assertEqual( changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" @@ -1088,16 +1069,16 @@ def test_after_update_dataset__name_db_changes(self): self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # Test Chart permission changed - slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() self.assertEqual(slice1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})") self.assertEqual(slice1.schema_perm, f"[tmp_db2].[tmp_schema]") # cleanup - session.delete(slice1) - session.delete(table1) - session.delete(tmp_db1) - session.delete(tmp_db2) - session.commit() + db.session.delete(slice1) + db.session.delete(table1) + db.session.delete(tmp_db1) + db.session.delete(tmp_db2) + db.session.commit() def test_hybrid_perm_database(self): database = Database(database_name="tmp_database3", sqlalchemy_uri="sqlite://") @@ -1123,12 +1104,11 @@ def test_hybrid_perm_database(self): db.session.commit() def test_set_perm_slice(self): - session = db.session database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") table = SqlaTable(table_name="tmp_perm_table", database=database) - session.add(database) - session.add(table) - session.commit() + db.session.add(database) + db.session.add(table) + db.session.commit() # no schema permission slice = Slice( @@ -1137,10 +1117,10 @@ def test_set_perm_slice(self): datasource_name="tmp_perm_table", slice_name="slice_name", ) - session.add(slice) - session.commit() + db.session.add(slice) + db.session.commit() - slice = session.query(Slice).filter_by(slice_name="slice_name").one() + slice = db.session.query(Slice).filter_by(slice_name="slice_name").one() self.assertEqual(slice.perm, table.perm) self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") self.assertEqual(slice.schema_perm, table.schema_perm) @@ -1148,8 +1128,10 @@ def test_set_perm_slice(self): table.schema = "tmp_perm_schema" table.table_name = "tmp_perm_table_v2" - session.commit() - table = session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + db.session.commit() + table = ( + db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + ) self.assertEqual(slice.perm, table.perm) self.assertEqual( slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" @@ -1160,11 +1142,11 @@ def test_set_perm_slice(self): self.assertEqual(slice.schema_perm, table.schema_perm) self.assertEqual(slice.schema_perm, "[tmp_database].[tmp_perm_schema]") - session.delete(slice) - session.delete(table) - session.delete(database) + db.session.delete(slice) + db.session.delete(table) + db.session.delete(database) - session.commit() + db.session.commit() @patch("superset.utils.core.g") @patch("superset.security.manager.g") @@ -1279,10 +1261,11 @@ def test_public_sync_role_builtin_perms(self): @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_sqllab_gamma_user_schema_access_to_sqllab(self): - session = db.session - example_db = session.query(Database).filter_by(database_name="examples").one() + example_db = ( + db.session.query(Database).filter_by(database_name="examples").one() + ) example_db.expose_in_sqllab = True - session.commit() + db.session.commit() arguments = { "keys": ["none"], diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py index 02304828dca82..590ace3f10fb6 100644 --- a/tests/unit_tests/common/test_query_object_factory.py +++ b/tests/unit_tests/common/test_query_object_factory.py @@ -38,11 +38,6 @@ def app_config() -> dict[str, Any]: return create_app_config().copy() -@fixture -def session_factory() -> Mock: - return Mock() - - @fixture def connector_registry() -> Mock: return Mock(spec=["get_datasource"]) @@ -58,12 +53,12 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int: @fixture def query_object_factory( - app_config: dict[str, Any], connector_registry: Mock, session_factory: Mock + app_config: dict[str, Any], connector_registry: Mock ) -> QueryObjectFactory: import superset.common.query_object_factory as mod mod.apply_max_row_limit = apply_max_row_limit - return QueryObjectFactory(app_config, connector_registry, session_factory) + return QueryObjectFactory(app_config, connector_registry) @fixture diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 4444fdc8c7564..beb4e99472c19 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -172,7 +172,6 @@ def dummy_query_object(request, app_context): "ROW_LIMIT": 100, }, _datasource_dao=unittest.mock.Mock(), - session_maker=unittest.mock.Mock(), ).create(parent_result_type=result_type, **query_object) diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py index 0af2cbf0200bf..b4ce162c0c0c9 100644 --- a/tests/unit_tests/datasource/dao_tests.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -106,7 +106,6 @@ def test_get_datasource_sqlatable(session_with_data: Session) -> None: result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.TABLE, datasource_id=1, - session=session_with_data, ) assert 1 == result.id @@ -119,7 +118,7 @@ def test_get_datasource_query(session_with_data: Session) -> None: from superset.models.sql_lab import Query result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.QUERY, datasource_id=1, session=session_with_data + datasource_type=DatasourceType.QUERY, datasource_id=1 ) assert result.id == 1 @@ -133,7 +132,6 @@ def test_get_datasource_saved_query(session_with_data: Session) -> None: result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.SAVEDQUERY, datasource_id=1, - session=session_with_data, ) assert result.id == 1 @@ -147,7 +145,6 @@ def test_get_datasource_sl_table(session_with_data: Session) -> None: result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.SLTABLE, datasource_id=1, - session=session_with_data, ) assert result.id == 1 @@ -161,7 +158,6 @@ def test_get_datasource_sl_dataset(session_with_data: Session) -> None: result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.DATASET, datasource_id=1, - session=session_with_data, ) assert result.id == 1 @@ -178,7 +174,6 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None: DatasourceDAO.get_datasource( datasource_type="table", datasource_id=1, - session=session_with_data, ), SqlaTable, ) @@ -187,7 +182,6 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None: DatasourceDAO.get_datasource( datasource_type="sl_table", datasource_id=1, - session=session_with_data, ), Table, ) @@ -208,5 +202,4 @@ def test_not_found_datasource(session_with_data: Session) -> None: DatasourceDAO.get_datasource( datasource_type="table", datasource_id=500000, - session=session_with_data, )