Skip to content

Commit

Permalink
refactor: Ensure Celery leverages the Flask-SQLAlchemy session (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored and sfirke committed Mar 22, 2024
1 parent e6dd9d3 commit 0a8c430
Show file tree
Hide file tree
Showing 19 changed files with 933 additions and 349 deletions.
688 changes: 688 additions & 0 deletions 1

Large diffs are not rendered by default.

70 changes: 30 additions & 40 deletions superset/commands/report/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@

import pandas as pd
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy.orm import Session

from superset import app, security_manager
from superset import app, db, security_manager
from superset.commands.base import BaseCommand
from superset.commands.dashboard.permalink.create import CreateDashboardPermalinkCommand
from superset.commands.exceptions import CommandException
Expand Down Expand Up @@ -68,7 +67,6 @@
from superset.reports.notifications.base import NotificationContent
from superset.reports.notifications.exceptions import NotificationError
from superset.tasks.utils import get_executor
from superset.utils.celery import session_scope
from superset.utils.core import HeaderDataType, override_user
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.decorators import logs_context
Expand All @@ -85,12 +83,10 @@ class BaseReportState:
@logs_context()
def __init__(
self,
session: Session,
report_schedule: ReportSchedule,
scheduled_dttm: datetime,
execution_id: UUID,
) -> None:
self._session = session
self._report_schedule = report_schedule
self._scheduled_dttm = scheduled_dttm
self._start_dttm = datetime.utcnow()
Expand Down Expand Up @@ -123,7 +119,7 @@ def update_report_schedule(self, state: ReportState) -> None:

self._report_schedule.last_state = state
self._report_schedule.last_eval_dttm = datetime.utcnow()
self._session.commit()
db.session.commit()

def create_log(self, error_message: Optional[str] = None) -> None:
"""
Expand All @@ -140,8 +136,8 @@ def create_log(self, error_message: Optional[str] = None) -> None:
report_schedule=self._report_schedule,
uuid=self._execution_id,
)
self._session.add(log)
self._session.commit()
db.session.add(log)
db.session.commit()

def _get_url(
self,
Expand Down Expand Up @@ -485,9 +481,7 @@ def is_in_grace_period(self) -> bool:
"""
Checks if an alert is in it's grace period
"""
last_success = ReportScheduleDAO.find_last_success_log(
self._report_schedule, session=self._session
)
last_success = ReportScheduleDAO.find_last_success_log(self._report_schedule)
return (
last_success is not None
and self._report_schedule.grace_period
Expand All @@ -501,7 +495,7 @@ def is_in_error_grace_period(self) -> bool:
Checks if an alert/report on error is in it's notification grace period
"""
last_success = ReportScheduleDAO.find_last_error_notification(
self._report_schedule, session=self._session
self._report_schedule
)
if not last_success:
return False
Expand All @@ -518,7 +512,7 @@ def is_on_working_timeout(self) -> bool:
Checks if an alert is in a working timeout
"""
last_working = ReportScheduleDAO.find_last_entered_working_log(
self._report_schedule, session=self._session
self._report_schedule
)
if not last_working:
return False
Expand Down Expand Up @@ -668,12 +662,10 @@ class ReportScheduleStateMachine: # pylint: disable=too-few-public-methods

def __init__(
self,
session: Session,
task_uuid: UUID,
report_schedule: ReportSchedule,
scheduled_dttm: datetime,
):
self._session = session
self._execution_id = task_uuid
self._report_schedule = report_schedule
self._scheduled_dttm = scheduled_dttm
Expand All @@ -684,7 +676,6 @@ def run(self) -> None:
self._report_schedule.last_state in state_cls.current_states
):
state_cls(
self._session,
self._report_schedule,
self._scheduled_dttm,
self._execution_id,
Expand All @@ -708,39 +699,38 @@ def __init__(self, task_id: str, model_id: int, scheduled_dttm: datetime):
self._execution_id = UUID(task_id)

def run(self) -> None:
with session_scope(nullpool=True) as session:
try:
self.validate(session=session)
if not self._model:
raise ReportScheduleExecuteUnexpectedError()
_, username = get_executor(
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
model=self._model,
try:
self.validate()
if not self._model:
raise ReportScheduleExecuteUnexpectedError()
_, username = get_executor(
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
model=self._model,
)
user = security_manager.find_user(username)
with override_user(user):
logger.info(
"Running report schedule %s as user %s",
self._execution_id,
username,
)
user = security_manager.find_user(username)
with override_user(user):
logger.info(
"Running report schedule %s as user %s",
self._execution_id,
username,
)
ReportScheduleStateMachine(
session, self._execution_id, self._model, self._scheduled_dttm
).run()
except CommandException as ex:
raise ex
except Exception as ex:
raise ReportScheduleUnexpectedError(str(ex)) from ex
ReportScheduleStateMachine(
self._execution_id, self._model, self._scheduled_dttm
).run()
except CommandException as ex:
raise ex
except Exception as ex:
raise ReportScheduleUnexpectedError(str(ex)) from ex

def validate(self, session: Session = None) -> None:
def validate(self) -> None:
# Validate/populate model exists
logger.info(
"session is validated: id %s, executionid: %s",
self._model_id,
self._execution_id,
)
self._model = (
session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none()
db.session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none()
)
if not self._model:
raise ReportScheduleNotFoundError()
43 changes: 21 additions & 22 deletions superset/commands/report/log_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import logging
from datetime import datetime, timedelta

from superset import db
from superset.commands.base import BaseCommand
from superset.commands.report.exceptions import ReportSchedulePruneLogError
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.reports.models import ReportSchedule
from superset.utils.celery import session_scope

logger = logging.getLogger(__name__)

Expand All @@ -36,28 +36,27 @@ def __init__(self, worker_context: bool = True):
self._worker_context = worker_context

def run(self) -> None:
with session_scope(nullpool=True) as session:
self.validate()
prune_errors = []

for report_schedule in session.query(ReportSchedule).all():
if report_schedule.log_retention is not None:
from_date = datetime.utcnow() - timedelta(
days=report_schedule.log_retention
self.validate()
prune_errors = []

for report_schedule in db.session.query(ReportSchedule).all():
if report_schedule.log_retention is not None:
from_date = datetime.utcnow() - timedelta(
days=report_schedule.log_retention
)
try:
row_count = ReportScheduleDAO.bulk_delete_logs(
report_schedule, from_date, commit=False
)
try:
row_count = ReportScheduleDAO.bulk_delete_logs(
report_schedule, from_date, session=session, commit=False
)
logger.info(
"Deleted %s logs for report schedule id: %s",
str(row_count),
str(report_schedule.id),
)
except DAODeleteFailedError as ex:
prune_errors.append(str(ex))
if prune_errors:
raise ReportSchedulePruneLogError(";".join(prune_errors))
logger.info(
"Deleted %s logs for report schedule id: %s",
str(row_count),
str(report_schedule.id),
)
except DAODeleteFailedError as ex:
prune_errors.append(str(ex))
if prune_errors:
raise ReportSchedulePruneLogError(";".join(prune_errors))

def validate(self) -> None:
pass
33 changes: 12 additions & 21 deletions superset/daos/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Any

from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session

from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAODeleteFailedError
Expand Down Expand Up @@ -204,27 +203,25 @@ def update(
return super().update(item, attributes, commit)

@staticmethod
def find_active(session: Session | None = None) -> list[ReportSchedule]:
def find_active() -> list[ReportSchedule]:
"""
Find all active reports. If session is passed it will be used instead of the
default `db.session`, this is useful when on a celery worker session context
Find all active reports.
"""
session = session or db.session
return (
session.query(ReportSchedule).filter(ReportSchedule.active.is_(True)).all()
db.session.query(ReportSchedule)
.filter(ReportSchedule.active.is_(True))
.all()
)

@staticmethod
def find_last_success_log(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last success execution log for a given report
"""
session = session or db.session
return (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state == ReportState.SUCCESS,
ReportExecutionLog.report_schedule == report_schedule,
Expand All @@ -236,14 +233,12 @@ def find_last_success_log(
@staticmethod
def find_last_entered_working_log(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last success execution log for a given report
"""
session = session or db.session
return (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state == ReportState.WORKING,
ReportExecutionLog.report_schedule == report_schedule,
Expand All @@ -256,14 +251,12 @@ def find_last_entered_working_log(
@staticmethod
def find_last_error_notification(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last error email sent
"""
session = session or db.session
last_error_email_log = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.error_message
== REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER,
Expand All @@ -276,7 +269,7 @@ def find_last_error_notification(
return None
# Checks that only errors have occurred since the last email
report_from_last_email = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state.notin_(
[ReportState.ERROR, ReportState.WORKING]
Expand All @@ -293,22 +286,20 @@ def find_last_error_notification(
def bulk_delete_logs(
model: ReportSchedule,
from_date: datetime,
session: Session | None = None,
commit: bool = True,
) -> int | None:
session = session or db.session
try:
row_count = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.report_schedule == model,
ReportExecutionLog.end_dttm < from_date,
)
.delete(synchronize_session="fetch")
)
if commit:
session.commit()
db.session.commit()
return row_count
except SQLAlchemyError as ex:
session.rollback()
db.session.rollback()
raise DAODeleteFailedError(str(ex)) from ex
11 changes: 4 additions & 7 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import literal_column, quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
Expand Down Expand Up @@ -1071,7 +1070,7 @@ def convert_dttm( # pylint: disable=unused-argument
return None

@classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
def handle_cursor(cls, cursor: Any, query: Query) -> None:
"""Handle a live cursor between the execute and fetchall calls
The flow works without this method doing anything, but it allows
Expand All @@ -1080,9 +1079,7 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
# TODO: Fix circular import error caused by importing sql_lab.Query

@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
) -> None:
def execute_with_cursor(cls, cursor: Any, sql: str, query: Query) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
Expand All @@ -1095,7 +1092,7 @@ def execute_with_cursor(
logger.debug("Query %d: Running query: %s", query.id, sql)
cls.execute(cursor, sql, async_=True)
logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query, session)
cls.handle_cursor(cursor, query)

@classmethod
def extract_error_message(cls, ex: Exception) -> str:
Expand Down Expand Up @@ -1841,7 +1838,7 @@ def get_sqla_column_type(

# pylint: disable=unused-argument
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
def prepare_cancel_query(cls, query: Query) -> None:
"""
Some databases may acquire the query cancelation id after the query
cancelation request has been received. For those cases, the db engine spec
Expand Down
Loading

0 comments on commit 0a8c430

Please sign in to comment.