Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/docs/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7e97c71ee6da77d758087a5d7129ac8d63b52dcc3bf4c46b7faf5977faa53ef4
538840eb3b7cd14f3a9434c10c6b34442a7059ac59100c6f9d4680f7974a1b03
318 changes: 159 additions & 159 deletions airflow-core/docs/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Annotated

from fastapi import Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy import and_, select

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.auth.managers.models.batch_apis import IsAuthorizedDagRequest
Expand Down Expand Up @@ -163,8 +163,10 @@ def get_import_errors(
select(ParseImportError, visible_files_cte.c.dag_id)
.join(
visible_files_cte,
ParseImportError.filename == visible_files_cte.c.relative_fileloc,
ParseImportError.bundle_name == visible_files_cte.c.bundle_name,
and_(
ParseImportError.filename == visible_files_cte.c.relative_fileloc,
# ParseImportError.bundle_name == visible_files_cte.c.bundle_name, # apparently not needed
),
)
.order_by(ParseImportError.id)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,13 @@ def rotate_items_in_batches_v2(session, model_class, filter_condition=None, batc

This function is taking advantage of yield_per available in SQLAlchemy 2.x.
"""
while True:
query = select(model_class)
if filter_condition is not None:
query = query.where(filter_condition)
query = select(model_class)
if filter_condition is not None:
query = query.where(filter_condition)

with session.no_autoflush: # Temporarily disable autoflush while iterating to prevent deadlocks.
items = session.scalars(query).yield_per(batch_size)
for item in items:
item.rotate_fernet_key()

# The dirty items will be flushed later by the session's transaction management.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

_STRING_COLUMN_TYPE = sa.String(length=1500).with_variant(
sa.String(length=1500, collation="latin1_general_cs"),
dialect_name="mysql",
"mysql",
)


Expand Down Expand Up @@ -128,7 +128,7 @@ def downgrade():
"uri",
type_=sa.String(length=3000).with_variant(
sa.String(length=3000, collation="latin1_general_cs"),
dialect_name="mysql",
"mysql",
),
nullable=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

_STRING_COLUMN_TYPE = sa.String(length=1500).with_variant(
sa.String(length=1500, collation="latin1_general_cs"),
dialect_name="mysql",
"mysql",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

_STRING_COLUMN_TYPE = sa.String(length=1500).with_variant(
sa.String(length=1500, collation="latin1_general_cs"),
dialect_name="mysql",
"mysql",
)


Expand All @@ -77,7 +77,7 @@ def downgrade():
"name",
type_=sa.String(length=3000).with_variant(
sa.String(length=3000, collation="latin1_general_cs"),
dialect_name="mysql",
"mysql",
),
nullable=False,
)
5 changes: 5 additions & 0 deletions airflow-core/src/airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sqlalchemy.orm import registry

from airflow.configuration import conf
from airflow.utils.sqlalchemy import is_sqlalchemy_v1

SQL_ALCHEMY_SCHEMA = conf.get("database", "SQL_ALCHEMY_SCHEMA")

Expand Down Expand Up @@ -52,6 +53,10 @@ def _get_schema():
Base = Any
else:
Base = mapper_registry.generate_base()
# TEMPORARY workaround to allow using unmapped (v1.4) models in SQLAlchemy 2.0. It is intended only to
# unblock the development of SQLA2 support.
if not is_sqlalchemy_v1():
Base.__allow_unmapped__ = True

ID_LEN = 250

Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def write_dag(
)
log.debug("Writing DagVersion %s to the DB", dag_version)
session.add(dag_version)
session.commit()
log.debug("DagVersion %s written to the DB", dag_version)
return dag_version

Expand Down
14 changes: 6 additions & 8 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,14 @@ def duration(cls, session: Session = NEW_SESSION) -> Case:
dialect_name = session.bind.dialect.name
if dialect_name == "mysql":
return func.timestampdiff(text("SECOND"), cls.start_date, cls.end_date)
return case(
[
(
(cls.end_date != None) & (cls.start_date != None), # noqa: E711
func.extract("epoch", cls.end_date - cls.start_date),
)
],
else_=None,

when_condition = (
(cls.end_date != None) & (cls.start_date != None), # noqa: E711
func.extract("epoch", cls.end_date - cls.start_date),
)

return case(when_condition, else_=None)

@provide_session
def check_version_id_exists_in_dr(self, dag_version_id: UUIDType, session: Session = NEW_SESSION):
select_stmt = (
Expand Down
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from airflow.utils.timezone import make_naive, utc

if TYPE_CHECKING:
from collections.abc import Iterable

from kubernetes.client.models.v1_pod import V1Pod
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Query, Session
Expand Down Expand Up @@ -448,3 +450,8 @@ def get_orm_mapper():

def is_sqlalchemy_v1() -> bool:
return version.parse(metadata.version("sqlalchemy")).major == 1


def make_dialect_kwarg(dialect: str) -> dict[str, str | Iterable[str]]:
"""Create an SQLAlchemy-version-aware dialect keyword argument."""
return {"dialect_name": dialect} if is_sqlalchemy_v1() else {"dialect_names": (dialect,)}
8 changes: 7 additions & 1 deletion airflow-core/tests/unit/always/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from airflow.hooks.base import BaseHook
from airflow.models import Connection, crypto

from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4

sqlite = pytest.importorskip("airflow.providers.sqlite.hooks.sqlite")

from tests_common.test_utils.config import conf_vars
Expand Down Expand Up @@ -683,8 +685,12 @@ def test_dbapi_get_sqlalchemy_engine(self):
conn = BaseHook.get_connection(conn_id="test_uri")
hook = conn.get_hook()
engine = hook.get_sqlalchemy_engine()
expected = "postgresql://username:password@ec2.compute.com:5432/the_database"
assert isinstance(engine, sqlalchemy.engine.Engine)
assert str(engine.url) == "postgresql://username:password@ec2.compute.com:5432/the_database"
if SQLALCHEMY_V_1_4:
assert str(engine.url) == expected
else:
assert engine.url.render_as_string(hide_password=False) == expected

@mock.patch.dict(
"os.environ",
Expand Down
19 changes: 18 additions & 1 deletion airflow-core/tests/unit/api_fastapi/common/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_connections, clear_db_dags, clear_db_pools, clear_db_runs
from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -182,11 +183,16 @@ def test_handle_single_column_unique_constraint_error(
expected_exception,
) -> None:
# Take Pool and Variable tables as test cases
# Note: SQLA2 uses a more optimized bulk insert strategy when multiple objects are added to the
# session. Instead of individual INSERT statements, a single INSERT with the SELECT FROM VALUES
# pattern is used.
if table == "Pool":
session.add(Pool(pool=TEST_POOL, slots=1, description="test pool", include_deferred=False))
session.flush() # Avoid SQLA2.0 bulk insert optimization
session.add(Pool(pool=TEST_POOL, slots=1, description="test pool", include_deferred=False))
elif table == "Variable":
session.add(Variable(key=TEST_VARIABLE_KEY, val="test_val"))
session.flush()
session.add(Variable(key=TEST_VARIABLE_KEY, val="test_val"))

with pytest.raises(IntegrityError) as exeinfo_integrity_error:
Expand Down Expand Up @@ -264,4 +270,15 @@ def test_handle_multiple_columns_unique_constraint_error(
self.unique_constraint_error_handler.exception_handler(None, exeinfo_integrity_error.value) # type: ignore

assert exeinfo_response_error.value.status_code == expected_exception.status_code
assert exeinfo_response_error.value.detail == expected_exception.detail
if SQLALCHEMY_V_1_4:
assert exeinfo_response_error.value.detail == expected_exception.detail
else:
# The SQL statement is an implementation detail, so we match on the statement pattern (contains
# the table name and is an INSERT) instead of insisting on an exact match.
response_detail = exeinfo_response_error.value.detail
expected_detail = expected_exception.detail
actual_statement = response_detail.pop("statement", None) # type: ignore[attr-defined]
expected_detail.pop("statement", None)

assert response_detail == expected_detail
assert "INSERT INTO dag_run" in actual_statement
21 changes: 15 additions & 6 deletions airflow-core/tests/unit/core/test_sqlalchemy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.exceptions import AirflowConfigException

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4

SQL_ALCHEMY_CONNECT_ARGS = {"test": 43503, "dict": {"is": 1, "supported": "too"}}

Expand Down Expand Up @@ -54,19 +55,23 @@ def test_configure_orm_with_default_values(
self, mock_create_engine, mock_sessionmaker, mock_scoped_session, mock_setup_event_handlers
):
settings.configure_orm()
mock_create_engine.assert_called_once_with(
settings.SQL_ALCHEMY_CONN,
expected_kwargs = dict(
connect_args={}
if not settings.SQL_ALCHEMY_CONN.startswith("sqlite")
else {"check_same_thread": False},
encoding="utf-8",
max_overflow=10,
pool_pre_ping=True,
pool_recycle=1800,
pool_size=5,
isolation_level="READ COMMITTED",
future=True,
)
if SQLALCHEMY_V_1_4:
expected_kwargs["encoding"] = "utf-8"
mock_create_engine.assert_called_once_with(
settings.SQL_ALCHEMY_CONN,
**expected_kwargs,
)

@patch("airflow.settings.setup_event_handlers")
@patch("airflow.settings.scoped_session")
Expand All @@ -88,14 +93,18 @@ def test_sql_alchemy_connect_args(
engine_args = {"arg": 1}
if settings.SQL_ALCHEMY_CONN.startswith("mysql"):
engine_args["isolation_level"] = "READ COMMITTED"
mock_create_engine.assert_called_once_with(
settings.SQL_ALCHEMY_CONN,
expected_kwargs = dict(
connect_args=SQL_ALCHEMY_CONNECT_ARGS,
poolclass=NullPool,
encoding="utf-8",
future=True,
**engine_args,
)
if SQLALCHEMY_V_1_4:
expected_kwargs["encoding"] = "utf-8"
mock_create_engine.assert_called_once_with(
settings.SQL_ALCHEMY_CONN,
**expected_kwargs,
)

@patch("airflow.settings.setup_event_handlers")
@patch("airflow.settings.scoped_session")
Expand Down
4 changes: 4 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def test_serialized_dags_are_written_to_db_on_sync(self, testing_dag_bundle, ses

@patch.object(ParseImportError, "full_file_path")
@patch.object(SerializedDagModel, "write_dag")
@pytest.mark.usefixtures("clean_db")
def test_serialized_dag_errors_are_import_errors(
self, mock_serialize, mock_full_path, caplog, session, dag_import_error_listener, testing_dag_bundle
):
Expand Down Expand Up @@ -492,6 +493,7 @@ def test_serialized_dag_errors_are_import_errors(
assert dag_import_error_listener.new["abc.py"] == import_error.stacktrace

@patch.object(ParseImportError, "full_file_path")
@pytest.mark.usefixtures("clean_db")
def test_new_import_error_replaces_old(
self, mock_full_file_path, session, dag_import_error_listener, testing_dag_bundle
):
Expand Down Expand Up @@ -536,6 +538,7 @@ def test_new_import_error_replaces_old(
assert len(dag_import_error_listener.existing) == 1
assert dag_import_error_listener.existing["abc.py"] == prev_error.stacktrace

@pytest.mark.usefixtures("clean_db")
def test_remove_error_clears_import_error(self, testing_dag_bundle, session):
# Pre-condition: there is an import error for the dag file
bundle_name = "testing"
Expand Down Expand Up @@ -577,6 +580,7 @@ def test_remove_error_clears_import_error(self, testing_dag_bundle, session):

assert import_errors == {("def.py", bundle_name)}

@pytest.mark.usefixtures("clean_db")
def test_remove_error_updates_loaded_dag_model(self, testing_dag_bundle, session):
bundle_name = "testing"
filename = "abc.py"
Expand Down
28 changes: 19 additions & 9 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
)
from tests_common.test_utils.mock_executor import MockExecutor
from tests_common.test_utils.mock_operators import CustomOperator
from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4, SQLALCHEMY_V_2_0
from unit.listeners import dag_listener
from unit.listeners.test_listeners import get_listener_manager
from unit.models import TEST_DAGS_FOLDER
Expand Down Expand Up @@ -3354,7 +3355,7 @@ def test_verify_integrity_if_dag_changed(self, dag_maker):

# Now let's say the DAG got updated (new task got added)
BashOperator(task_id="bash_task_1", dag=dag, bash_command="echo hi")
SerializedDagModel.write_dag(dag=dag, bundle_name="testing")
SerializedDagModel.write_dag(dag=dag, bundle_name="testing", session=session)

dag_version_2 = DagVersion.get_latest_version(dr.dag_id, session=session)
assert dag_version_2 != dag_version_1
Expand All @@ -3368,15 +3369,24 @@ def test_verify_integrity_if_dag_changed(self, dag_maker):
assert dr.dag_versions[-1].id == dag_version_2.id
assert len(self.job_runner.scheduler_dag_bag.get_dag(dr, session).tasks) == 2

tis_count = (
session.query(func.count(TaskInstance.task_id))
.filter(
TaskInstance.dag_id == dr.dag_id,
TaskInstance.logical_date == dr.logical_date,
TaskInstance.state == State.SCHEDULED,
if SQLALCHEMY_V_1_4:
tis_count = (
session.query(func.count(TaskInstance.task_id))
.filter(
TaskInstance.dag_id == dr.dag_id,
TaskInstance.logical_date == dr.logical_date,
TaskInstance.state == State.SCHEDULED,
)
.scalar()
)
if SQLALCHEMY_V_2_0:
tis_count = session.scalar(
select(func.count(TaskInstance.task_id)).where(
TaskInstance.dag_id == dr.dag_id,
TaskInstance.logical_date == dr.logical_date,
TaskInstance.state == State.SCHEDULED,
)
)
.scalar()
)
assert tis_count == 2

latest_dag_version = DagVersion.get_latest_version(dr.dag_id, session=session)
Expand Down
12 changes: 12 additions & 0 deletions devel-common/src/tests_common/test_utils/version_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,15 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_1 = get_base_airflow_version_tuple() == (3, 0, 1)
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)


def get_sqlalchemy_version_tuple() -> tuple[int, int, int]:
import sqlalchemy
from packaging.version import Version

sqlalchemy_version = Version(sqlalchemy.__version__)
return sqlalchemy_version.major, sqlalchemy_version.minor, sqlalchemy_version.micro


SQLALCHEMY_V_1_4 = (1, 4, 0) <= get_sqlalchemy_version_tuple() < (2, 0, 0)
SQLALCHEMY_V_2_0 = (2, 0, 0) <= get_sqlalchemy_version_tuple() < (2, 1, 0)