Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,8 @@ def _parse_file_entrypoint():
import structlog

from airflow.sdk.execution_time import task_runner
from airflow.settings import configure_orm

# Parse DAG file, send JSON back up!

# We need to reconfigure the orm here, as DagFileProcessorManager does db queries for bundles, and
# the session across forks blows things up.
configure_orm()

comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager](
input=sys.stdin,
decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
Expand Down
50 changes: 22 additions & 28 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import pluggy
from packaging.version import Version
from sqlalchemy import create_engine, exc, text
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool
Expand All @@ -46,7 +46,6 @@

if TYPE_CHECKING:
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session as SASession

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,12 +100,12 @@
"""

engine: Engine
Session: Callable[..., SASession]
Session: scoped_session
# NonScopedSession creates global sessions and is not safe to use in multi-threaded environment without
# additional precautions. The only use case is when the session lifecycle needs
# custom handling. Most of the time we only want one unique thread local session object,
# this is achieved by the Session factory above.
NonScopedSession: Callable[..., SASession]
NonScopedSession: sessionmaker
async_engine: AsyncEngine
AsyncSession: Callable[..., SAAsyncSession]

Expand Down Expand Up @@ -389,6 +388,12 @@ def _session_maker(_engine):
NonScopedSession = _session_maker(engine)
Session = scoped_session(NonScopedSession)

from sqlalchemy.orm.session import close_all_sessions

os.register_at_fork(after_in_child=close_all_sessions)
# https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
os.register_at_fork(after_in_child=lambda: engine.dispose(close=False))


DEFAULT_ENGINE_ARGS = {
"postgresql": {
Expand Down Expand Up @@ -479,14 +484,23 @@ def prepare_engine_args(disable_connection_pool=False, pool_class=None):

def dispose_orm():
"""Properly close pooled database connections."""
global Session, engine, NonScopedSession

_globals = globals()
if "engine" not in _globals and "Session" not in _globals:
return

log.debug("Disposing DB connection pool (PID %s)", os.getpid())
global engine
global Session

if Session is not None: # type: ignore[truthy-function]
if "Session" in _globals and Session is not None:
from sqlalchemy.orm.session import close_all_sessions

Session.remove()
Session = None
if engine:
NonScopedSession = None
close_all_sessions()

if "engine" in _globals:
engine.dispose()
engine = None

Expand Down Expand Up @@ -529,26 +543,6 @@ def configure_adapters():
pass


def validate_session():
"""Validate ORM Session."""
global engine

worker_precheck = conf.getboolean("celery", "worker_precheck")
if not worker_precheck:
return True
else:
check_session = sessionmaker(bind=engine)
session = check_session()
try:
session.execute(text("select 1"))
conn_status = True
except exc.DBAPIError as err:
log.error(err)
conn_status = False
session.close()
return conn_status


def configure_action_logging() -> None:
"""Any additional configuration (register callback) for airflow.utils.action_loggers module."""

Expand Down
7 changes: 0 additions & 7 deletions providers/celery/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,6 @@ config:
type: integer
example: ~
default: "3"
worker_precheck:
description: |
Worker initialisation check to validate Metadata Database connection
version_added: ~
type: string
example: ~
default: "False"
extra_celery_config:
description: |
Extra celery configs to include in the celery worker.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,11 @@ def worker(args):
from airflow.sdk.log import configure_logging

configure_logging(output=sys.stdout.buffer)

# Disable connection pool so that celery worker does not hold an unnecessary db connection
settings.reconfigure_orm(disable_connection_pool=True)
if not settings.validate_session():
raise SystemExit("Worker exiting, database connection precheck failed.")
else:
# Disable connection pool so that celery worker does not hold an unnecessary db connection
settings.reconfigure_orm(disable_connection_pool=True)
if not settings.validate_session():
raise SystemExit("Worker exiting, database connection precheck failed.")

autoscale = args.autoscale
skip_serve_logs = args.skip_serve_logs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,6 @@ def get_provider_info():
"example": None,
"default": "3",
},
"worker_precheck": {
"description": "Worker initialisation check to validate Metadata Database connection\n",
"version_added": None,
"type": "string",
"example": None,
"default": "False",
},
"extra_celery_config": {
"description": 'Extra celery configs to include in the celery worker.\nAny of the celery config can be added to this config and it\nwill be applied while starting the celery worker. e.g. {"worker_max_tasks_per_child": 10}\nSee also:\nhttps://docs.celeryq.dev/en/stable/userguide/configuration.html#configuration-and-defaults\n',
"version_added": None,
Expand Down
34 changes: 0 additions & 34 deletions providers/celery/tests/unit/celery/cli/test_celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,11 @@

import importlib
import os
from argparse import Namespace
from unittest import mock
from unittest.mock import patch

import pytest
import sqlalchemy

import airflow
from airflow.cli import cli_parser
from airflow.configuration import conf
from airflow.executors import executor_loader
Expand All @@ -39,37 +36,6 @@
pytestmark = pytest.mark.db_test


@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
class TestWorkerPrecheck:
@mock.patch("airflow.settings.validate_session")
def test_error(self, mock_validate_session):
"""
Test to verify the exit mechanism of airflow-worker cli
by mocking validate_session method
"""
mock_validate_session.return_value = False
with pytest.raises(SystemExit) as ctx, conf_vars({("core", "executor"): "CeleryExecutor"}):
celery_command.worker(Namespace(queues=1, concurrency=1))
assert str(ctx.value) == "Worker exiting, database connection precheck failed."

@conf_vars({("celery", "worker_precheck"): "False"})
def test_worker_precheck_exception(self):
"""
Test to check the behaviour of validate_session method
when worker_precheck is absent in airflow configuration
"""
assert airflow.settings.validate_session()

@mock.patch("sqlalchemy.orm.session.Session.execute")
@conf_vars({("celery", "worker_precheck"): "True"})
def test_validate_session_dbapi_exception(self, mock_session):
"""
Test to validate connection failure scenario on SELECT 1 query
"""
mock_session.side_effect = sqlalchemy.exc.OperationalError("m1", "m2", "m3", "m4")
assert airflow.settings.validate_session() is False


@pytest.mark.backend("mysql", "postgres")
@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
class TestCeleryStopCommand:
Expand Down
58 changes: 58 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,62 @@ def _get_last_chance_stderr() -> TextIO:
return stream


class BlockedDBSession:
""":meta private:""" # noqa: D400

def __init__(self):
raise RuntimeError("Direct database access via the ORM is not allowed in Airflow 3.0")

def remove(*args, **kwargs):
pass

def get_bind(
self,
mapper=None,
clause=None,
bind=None,
_sa_skip_events=None,
_sa_skip_for_implicit_returning=False,
):
pass


def block_orm_access():
"""
Disable direct DB access as best as possible from task code.

While we still don't have 100% code separation between TaskSDK and "core" Airflow, it is still possible to
import the models and use them. This does what it can to disable that if it is not blocked at the network
level
"""
# A fake URL schema that might give users some clue what's going on. Hopefully
conn = "airflow-db-not-allowed:///"
if "airflow.settings" in sys.modules:
from airflow import settings
from airflow.configuration import conf

settings.dispose_orm()

for attr in ("engine", "async_engine", "Session", "AsyncSession", "NonScopedSession"):
if hasattr(settings, attr):
delattr(settings, attr)

def configure_orm(*args, **kwargs):
raise RuntimeError("Database access is disabled from DAGs and Triggers")

settings.configure_orm = configure_orm
settings.Session = BlockedDBSession
if conf.has_section("database"):
conf.set("database", "sql_alchemy_conn", conn)
conf.set("database", "sql_alchemy_conn_cmd", "/bin/false")
conf.set("database", "sql_alchemy_conn_secret", "db-access-blocked")

settings.SQL_ALCHEMY_CONN = conn
settings.SQL_ALCHEMY_CONN_ASYNC = conn

os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = conn


def _fork_main(
child_stdin: socket,
child_stdout: socket,
Expand Down Expand Up @@ -261,6 +317,8 @@ def exit(n: int) -> NoReturn:
base_exit(n)

try:
block_orm_access()

target()
exit(0)
except SystemExit as e:
Expand Down
5 changes: 5 additions & 0 deletions task_sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

# Task SDK does not need access to the Airflow database
os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
os.environ["_AIRFLOW__AS_LIBRARY"] = "true"

if TYPE_CHECKING:
from datetime import datetime
Expand Down Expand Up @@ -56,6 +57,10 @@ def pytest_configure(config: pytest.Config) -> None:
# Always skip looking for tests in these folders!
config.addinivalue_line("norecursedirs", "tests/test_dags")

import airflow.settings

airflow.settings.configure_policy_plugin_manager()


@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(item):
Expand Down
3 changes: 3 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import selectors
import signal
import sys
import time
from io import BytesIO
from operator import attrgetter
from pathlib import Path
Expand Down Expand Up @@ -850,7 +851,9 @@ def _handler(sig, frame):
client=MagicMock(spec=sdk_client.Client),
target=subprocess_main,
)

# Ensure we get one normal run, to give the proc time to register it's custom sighandler
time.sleep(0.1)
proc._service_subprocess(max_wait_time=1)
proc.kill(signal_to_send=signal_to_send, escalation_delay=0.5, force=True)

Expand Down
11 changes: 7 additions & 4 deletions tests/dag_processing/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,18 @@ def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_b
processor_timeout=365 * 86_400,
)

with create_session() as session:
manager.run()
manager.run()

with create_session() as session:
import_errors = session.query(ParseImportError).all()
assert len(import_errors) == 1

path_to_parse.unlink()

# Rerun the parser once the dag file has been removed
manager.run()
# Rerun the parser once the dag file has been removed
manager.run()

with create_session() as session:
import_errors = session.query(ParseImportError).all()

assert len(import_errors) == 0
Expand Down Expand Up @@ -658,6 +660,7 @@ def test_refresh_dags_dir_deactivates_deleted_zipped_dags(
shutil.copy(source_location, zip_dag_path)

with configure_testing_dag_bundle(bundle_path):
session.commit()
manager = DagFileProcessorManager(max_runs=1)
manager.run()

Expand Down
6 changes: 4 additions & 2 deletions tests/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,13 @@ def test_trigger_lifecycle(spy_agency: SpyAgency, session):
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session, trigger)
# Make a TriggererJobRunner and have it retrieve DB tasks
trigger_runner_supervisor = TriggerRunnerSupervisor.start(job=Job(), capacity=10)
trigger_runner_supervisor = TriggerRunnerSupervisor.start(job=Job(id=12345), capacity=10)

try:
# Spy on it so we can see what gets send, but also call the original.
send_spy = spy_agency.spy_on(TriggerRunnerSupervisor._send, owner=TriggerRunnerSupervisor)

trigger_runner_supervisor._service_subprocess(0.1)
trigger_runner_supervisor.load_triggers()
# Make sure it turned up in TriggerRunner's queue
assert trigger_runner_supervisor.running_triggers == {1}
Expand Down Expand Up @@ -431,7 +433,7 @@ def is_alive(self):


@pytest.mark.execution_timeout(5)
def test_trigger_runner_exception_stops_triggerer(session):
def test_trigger_runner_exception_stops_triggerer():
"""
Checks that if an exception occurs when creating triggers, that the triggerer
process stops
Expand Down