Skip to content

Commit

Permalink
refactor: Ensure Flask framework leverages the Flask-SQLAlchemy sessi…
Browse files Browse the repository at this point in the history
…on (apache#26200)
  • Loading branch information
john-bodley authored Jan 17, 2024
1 parent f8a24f2 commit 6e2733e
Show file tree
Hide file tree
Showing 26 changed files with 315 additions and 389 deletions.
17 changes: 7 additions & 10 deletions scripts/benchmark_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def main(
filepath: str, limit: int = 1000, force: bool = False, no_auto_cleanup: bool = False
) -> None:
auto_cleanup = not no_auto_cleanup
session = db.session()

print(f"Importing migration script: {filepath}")
module = import_migration_script(Path(filepath))

Expand Down Expand Up @@ -174,10 +172,9 @@ def main(
models = find_models(module)
model_rows: dict[type[Model], int] = {}
for model in models:
rows = session.query(model).count()
rows = db.session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
model_rows[model] = rows
session.close()

print("Benchmarking migration")
results: dict[str, float] = {}
Expand All @@ -199,16 +196,16 @@ def main(
print(f"- Adding {missing} entities to the {model.__name__} model")
bar = ChargingBar("Processing", max=missing)
try:
for entity in add_sample_rows(session, model, missing):
for entity in add_sample_rows(model, missing):
entities.append(entity)
bar.next()
except Exception:
session.rollback()
db.session.rollback()
raise
bar.finish()
model_rows[model] = min_entities
session.add_all(entities)
session.commit()
db.session.add_all(entities)
db.session.commit()

if auto_cleanup:
new_models[model].extend(entities)
Expand All @@ -227,10 +224,10 @@ def main(
print("Cleaning up DB")
# delete in reverse order of creation to handle relationships
for model, entities in list(new_models.items())[::-1]:
session.query(model).filter(
db.session.query(model).filter(
model.id.in_(entity.id for entity in entities)
).delete(synchronize_session=False)
session.commit()
db.session.commit()

if current_revision != revision and not force:
click.confirm(f"\nRevert DB to {revision}?", abort=True)
Expand Down
1 change: 0 additions & 1 deletion superset/cachekeys/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def invalidate(self) -> Response:
datasource_uids = set(datasources.get("datasource_uids", []))
for ds in datasources.get("datasources", []):
ds_obj = SqlaTable.get_datasource_by_name(
session=db.session,
datasource_name=ds.get("datasource_name"),
schema=ds.get("schema"),
database_name=ds.get("database_name"),
Expand Down
32 changes: 15 additions & 17 deletions superset/commands/dashboard/importers/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Any, Optional

from flask_babel import lazy_gettext as _
from sqlalchemy.orm import make_transient, Session
from sqlalchemy.orm import make_transient

from superset import db
from superset.commands.base import BaseCommand
Expand Down Expand Up @@ -55,7 +55,6 @@ def import_chart(
:returns: The resulting id for the imported slice
:rtype: int
"""
session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
Expand All @@ -64,19 +63,18 @@ def import_chart(
slc_to_import.reset_ownership()
params = slc_to_import.params_dict
datasource = SqlaTable.get_datasource_by_name(
session=session,
datasource_name=params["datasource_name"],
database_name=params["database_name"],
schema=params["schema"],
)
slc_to_import.datasource_id = datasource.id # type: ignore
if slc_to_override:
slc_to_override.override(slc_to_import)
session.flush()
db.session.flush()
return slc_to_override.id
session.add(slc_to_import)
db.session.add(slc_to_import)
logger.info("Final slice: %s", str(slc_to_import.to_json()))
session.flush()
db.session.flush()
return slc_to_import.id


Expand Down Expand Up @@ -156,7 +154,6 @@ def alter_native_filters(dashboard: Dashboard) -> None:
dashboard.json_metadata = json.dumps(json_metadata)

logger.info("Started import of the dashboard: %s", dashboard_to_import.to_json())
session = db.session
logger.info("Dashboard has %d slices", len(dashboard_to_import.slices))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
Expand All @@ -173,7 +170,7 @@ def alter_native_filters(dashboard: Dashboard) -> None:
i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = {
slc.params_dict["remote_id"]: slc
for slc in session.query(Slice).all()
for slc in db.session.query(Slice).all()
if "remote_id" in slc.params_dict
}
for slc in slices:
Expand Down Expand Up @@ -224,7 +221,7 @@ def alter_native_filters(dashboard: Dashboard) -> None:

# override the dashboard
existing_dashboard = None
for dash in session.query(Dashboard).all():
for dash in db.session.query(Dashboard).all():
if (
"remote_id" in dash.params_dict
and dash.params_dict["remote_id"] == dashboard_to_import.id
Expand Down Expand Up @@ -253,18 +250,20 @@ def alter_native_filters(dashboard: Dashboard) -> None:
alter_native_filters(dashboard_to_import)

new_slices = (
session.query(Slice).filter(Slice.id.in_(old_to_new_slc_id_dict.values())).all()
db.session.query(Slice)
.filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
.all()
)

if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
session.flush()
db.session.flush()
return existing_dashboard.id

dashboard_to_import.slices = new_slices
session.add(dashboard_to_import)
session.flush()
db.session.add(dashboard_to_import)
db.session.flush()
return dashboard_to_import.id # type: ignore


Expand All @@ -291,7 +290,6 @@ def decode_dashboards(o: dict[str, Any]) -> Any:


def import_dashboards(
session: Session,
content: str,
database_id: Optional[int] = None,
import_time: Optional[int] = None,
Expand All @@ -308,10 +306,10 @@ def import_dashboards(
params = json.loads(table.params)
dataset_id_mapping[params["remote_id"]] = new_dataset_id

session.commit()
db.session.commit()
for dashboard in data["dashboards"]:
import_dashboard(dashboard, dataset_id_mapping, import_time=import_time)
session.commit()
db.session.commit()


class ImportDashboardsCommand(BaseCommand):
Expand All @@ -334,7 +332,7 @@ def run(self) -> None:

for file_name, content in self.contents.items():
logger.info("Importing dashboard from file %s", file_name)
import_dashboards(db.session, content, self.database_id)
import_dashboards(content, self.database_id)

def validate(self) -> None:
# ensure all files are JSON
Expand Down
3 changes: 1 addition & 2 deletions superset/commands/explore/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from flask_babel import lazy_gettext as _
from sqlalchemy.exc import SQLAlchemyError

from superset import db
from superset.commands.base import BaseCommand
from superset.commands.explore.form_data.get import GetFormDataCommand
from superset.commands.explore.form_data.parameters import (
Expand Down Expand Up @@ -114,7 +113,7 @@ def run(self) -> Optional[dict[str, Any]]:
if self._datasource_id is not None:
with contextlib.suppress(DatasourceNotFound):
datasource = DatasourceDAO.get_datasource(
db.session, cast(str, self._datasource_type), self._datasource_id
cast(str, self._datasource_type), self._datasource_id
)
datasource_name = datasource.name if datasource else _("[Missing Dataset]")
viz_type = form_data.get("viz_type")
Expand Down
3 changes: 1 addition & 2 deletions superset/commands/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from superset.daos.datasource import DatasourceDAO
from superset.daos.exceptions import DatasourceNotFound
from superset.extensions import db
from superset.utils.core import DatasourceType, get_user_id

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,7 +79,7 @@ def populate_roles(role_ids: list[int] | None = None) -> list[Role]:
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
try:
return DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
DatasourceType(datasource_type), datasource_id
)
except DatasourceNotFound as ex:
raise DatasourceNotFoundValidationError() from ex
5 changes: 2 additions & 3 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import Any, TYPE_CHECKING

from superset import app, db
from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
Expand All @@ -35,7 +35,7 @@


def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, DatasourceDAO(), db.session)
return QueryObjectFactory(config, DatasourceDAO())


class QueryContextFactory: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -95,7 +95,6 @@ def create(

def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return DatasourceDAO.get_datasource(
session=db.session,
datasource_type=DatasourceType(datasource["type"]),
datasource_id=int(datasource["id"]),
)
Expand Down
6 changes: 0 additions & 6 deletions superset/common/query_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,21 @@
)

if TYPE_CHECKING:
from sqlalchemy.orm import sessionmaker

from superset.connectors.sqla.models import BaseDatasource
from superset.daos.datasource import DatasourceDAO


class QueryObjectFactory: # pylint: disable=too-few-public-methods
_config: dict[str, Any]
_datasource_dao: DatasourceDAO
_session_maker: sessionmaker

def __init__(
self,
app_configurations: dict[str, Any],
_datasource_dao: DatasourceDAO,
session_maker: sessionmaker,
):
self._config = app_configurations
self._datasource_dao = _datasource_dao
self._session_maker = session_maker

def create( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -91,7 +86,6 @@ def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return self._datasource_dao.get_datasource(
datasource_type=DatasourceType(datasource["type"]),
datasource_id=int(datasource["id"]),
session=self._session_maker(),
)

def _process_extras(
Expand Down
14 changes: 5 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def raise_for_access(self) -> None:

@classmethod
def get_datasource_by_name(
cls, session: Session, datasource_name: str, schema: str, database_name: str
cls, datasource_name: str, schema: str, database_name: str
) -> BaseDatasource | None:
raise NotImplementedError()

Expand Down Expand Up @@ -1238,14 +1238,13 @@ def database_name(self) -> str:
@classmethod
def get_datasource_by_name(
cls,
session: Session,
datasource_name: str,
schema: str | None,
database_name: str,
) -> SqlaTable | None:
schema = schema or None
query = (
session.query(cls)
db.session.query(cls)
.join(Database)
.filter(cls.table_name == datasource_name)
.filter(Database.database_name == database_name)
Expand Down Expand Up @@ -1939,12 +1938,10 @@ def query_datasources_by_permissions( # pylint: disable=invalid-name
)

@classmethod
def get_eager_sqlatable_datasource(
cls, session: Session, datasource_id: int
) -> SqlaTable:
def get_eager_sqlatable_datasource(cls, datasource_id: int) -> SqlaTable:
"""Returns SqlaTable with columns and metrics."""
return (
session.query(cls)
db.session.query(cls)
.options(
sa.orm.subqueryload(cls.columns),
sa.orm.subqueryload(cls.metrics),
Expand Down Expand Up @@ -2037,8 +2034,7 @@ def update_column( # pylint: disable=unused-argument
:param connection: Unused.
:param target: The metric or column that was updated.
"""
inspector = inspect(target)
session = inspector.session
session = inspect(target).session

# Forces an update to the table's changed_on value when a metric or column on the
# table is updated. This busts the cache key for all charts that use the table.
Expand Down
7 changes: 4 additions & 3 deletions superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def validate_update_slug_uniqueness(dashboard_id: int, slug: str | None) -> bool
return True

@staticmethod
def set_dash_metadata( # pylint: disable=too-many-locals
def set_dash_metadata(
dashboard: Dashboard,
data: dict[Any, Any],
old_to_new_slice_ids: dict[int, int] | None = None,
Expand All @@ -187,8 +187,9 @@ def set_dash_metadata( # pylint: disable=too-many-locals
if isinstance(value, dict)
]

session = db.session()
current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
current_slices = (
db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
)

dashboard.slices = current_slices

Expand Down
6 changes: 2 additions & 4 deletions superset/daos/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import logging
from typing import Union

from sqlalchemy.orm import Session

from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
Expand All @@ -45,15 +44,14 @@ class DatasourceDAO(BaseDAO[Datasource]):
@classmethod
def get_datasource(
cls,
session: Session,
datasource_type: Union[DatasourceType, str],
datasource_id: int,
) -> Datasource:
if datasource_type not in cls.sources:
raise DatasourceTypeNotSupportedError()

datasource = (
session.query(cls.sources[datasource_type])
db.session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one_or_none()
)
Expand Down
4 changes: 2 additions & 2 deletions superset/datasource/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from flask_appbuilder.api import expose, protect, safe

from superset import app, db, event_logger
from superset import app, event_logger
from superset.daos.datasource import DatasourceDAO
from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.exceptions import SupersetSecurityException
Expand Down Expand Up @@ -100,7 +100,7 @@ def get_column_values(
"""
try:
datasource = DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
DatasourceType(datasource_type), datasource_id
)
datasource.raise_for_access()
except ValueError:
Expand Down
Loading

0 comments on commit 6e2733e

Please sign in to comment.