Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/docs/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d79700d79f51a3f70709131183df2e80e6be0f0e73ffdbcc21731890a0a030fd
da000ad784f974dad63f6db08942d8e968242380f468bc43e35de5634960dcfc
3,757 changes: 1,884 additions & 1,873 deletions airflow-core/docs/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
95 changes: 78 additions & 17 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@
from airflow.models.backfill import Backfill
from airflow.models.dag import DAG, DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
from airflow.stats import Stats
Expand Down Expand Up @@ -102,6 +102,55 @@
""":meta private:"""


class SchedulerDagBag:
"""
Internal class for retrieving and caching dags in the scheduler.

:meta private:
"""

def __init__(self):
self._dags: dict[str, DAG] = {} # dag_version_id to dag

def _get_dag(self, version_id: str, session: Session) -> DAG | None:
if dag := self._dags.get(version_id):
return dag
dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)])
if not dag_version:
return None
serdag = dag_version.serialized_dag
if not serdag:
return None
serdag.load_op_links = False
dag = serdag.dag
if not dag:
return None
self._dags[version_id] = dag
return dag

@staticmethod
def _version_from_dag_run(dag_run, session):
if dag_run.bundle_version:
dag_version = dag_run.created_dag_version
else:
dag_version = DagVersion.get_latest_version(dag_id=dag_run.dag_id, session=session)
return dag_version

def get_dag(self, dag_run: DagRun, session: Session) -> DAG | None:
version = self._version_from_dag_run(dag_run=dag_run, session=session)
if not version:
return None
return self._get_dag(version_id=version.id, session=session)


def _get_current_dag(dag_id: str, session: Session) -> DAG | None:
serdag = SerializedDagModel.get(dag_id=dag_id, session=session) # grabs the latest version
if not serdag:
return None
serdag.load_op_links = False
return serdag.dag


class ConcurrencyMap:
"""
Dataclass to represent concurrency maps.
Expand Down Expand Up @@ -199,7 +248,7 @@ def __init__(
if log:
self._log = log

self.dagbag = DagBag(read_dags_from_db=True, load_op_links=False)
self.scheduler_dag_bag = SchedulerDagBag()

@provide_session
def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
Expand Down Expand Up @@ -490,7 +539,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
if task_instance.dag_model.has_task_concurrency_limits:
# Many dags don't have a task_concurrency, so where we can avoid loading the full
# serialized DAG the better.
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
serialized_dag = self.scheduler_dag_bag.get_dag(
dag_run=task_instance.dag_run, session=session
)
# If the dag is missing, fail the task and continue to the next task.
if not serialized_dag:
self.log.error(
Expand Down Expand Up @@ -730,12 +781,15 @@ def _process_task_event_logs(log_records: deque[Log], session: Session):

def _process_executor_events(self, executor: BaseExecutor, session: Session) -> int:
return SchedulerJobRunner.process_executor_events(
executor=executor, dag_bag=self.dagbag, job_id=self.job.id, session=session
executor=executor,
job_id=self.job.id,
scheduler_dag_bag=self.scheduler_dag_bag,
session=session,
)

@classmethod
def process_executor_events(
cls, executor: BaseExecutor, dag_bag: DagBag, job_id: str | None, session: Session
cls, executor: BaseExecutor, job_id: str | None, scheduler_dag_bag: SchedulerDagBag, session: Session
) -> int:
"""
Respond to executor events.
Expand Down Expand Up @@ -867,7 +921,14 @@ def process_executor_events(

# Get task from the Serialized DAG
try:
dag = dag_bag.get_dag(ti.dag_id)
dag = scheduler_dag_bag.get_dag(dag_run=ti.dag_run, session=session)
cls.logger().error(
"DAG '%s' for task instance %s not found in serialized_dag table",
ti.dag_id,
ti,
)
if TYPE_CHECKING:
assert dag
task = dag.get_task(ti.task_id)
except Exception:
cls.logger().exception("Marking task instance %s as %s", ti, state)
Expand Down Expand Up @@ -985,7 +1046,7 @@ def _update_dag_run_state_for_paused_dags(self, session: Session = NEW_SESSION)
.group_by(DagRun)
)
for dag_run in paused_runs:
dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag = self.scheduler_dag_bag.get_dag(dag_run=dag_run, session=session)
if dag is not None:
dag_run.dag = dag
_, callback_to_run = dag_run.update_state(execute_callbacks=False, session=session)
Expand Down Expand Up @@ -1336,11 +1397,11 @@ def _do_scheduling(self, session: Session) -> int:

# Send the callbacks after we commit to ensure the context is up to date when it gets run
# cache saves time during scheduling of many dag_runs for same dag
cached_get_dag: Callable[[str], DAG | None] = lru_cache()(
partial(self.dagbag.get_dag, session=session)
cached_get_dag: Callable[[DagRun], DAG | None] = lru_cache()(
partial(self.scheduler_dag_bag.get_dag, session=session)
)
for dag_run, callback_to_run in callback_tuples:
dag = cached_get_dag(dag_run.dag_id)
dag = cached_get_dag(dag_run)
if dag:
# Sending callbacks to the database, so it must be done outside of prohibit_commit.
self._send_dag_callbacks_to_processor(dag, callback_to_run)
Expand Down Expand Up @@ -1457,7 +1518,7 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
)

for dag_model in dag_models:
dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
dag = _get_current_dag(dag_id=dag_model.dag_id, session=session)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
continue
Expand Down Expand Up @@ -1520,7 +1581,7 @@ def _create_dag_runs_asset_triggered(
}

for dag_model in dag_models:
dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
dag = _get_current_dag(dag_id=dag_model.dag_id, session=session)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
continue
Expand Down Expand Up @@ -1671,8 +1732,8 @@ def _update_state(dag: DAG, dag_run: DagRun):
)

# cache saves time during scheduling of many dag_runs for same dag
cached_get_dag: Callable[[str], DAG | None] = lru_cache()(
partial(self.dagbag.get_dag, session=session)
cached_get_dag: Callable[[DagRun], DAG | None] = lru_cache()(
partial(self.scheduler_dag_bag.get_dag, session=session)
)

span = Trace.get_current_span()
Expand All @@ -1681,7 +1742,7 @@ def _update_state(dag: DAG, dag_run: DagRun):
run_id = dag_run.run_id
backfill_id = dag_run.backfill_id
backfill = dag_run.backfill
dag = dag_run.dag = cached_get_dag(dag_id)
dag = dag_run.dag = cached_get_dag(dag_run)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
Expand Down Expand Up @@ -1762,7 +1823,7 @@ def _schedule_dag_run(
)
callback: DagCallbackRequest | None = None

dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag = dag_run.dag = self.scheduler_dag_bag.get_dag(dag_run=dag_run, session=session)
dag_model = DM.get_dagmodel(dag_run.dag_id, session)

if not dag or not dag_model:
Expand Down Expand Up @@ -1876,7 +1937,7 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) ->
self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
return True
# Refresh the DAG
dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session)
dag_run.dag = self.scheduler_dag_bag.get_dag(dag_run=dag_run, session=session)
if not dag_run.dag:
return False
# Select all TIs in State.unfinished and update the dag_version_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,14 @@ def upgrade():

with op.batch_alter_table("dag_run", schema=None) as batch_op:
batch_op.drop_column("dag_hash")
batch_op.add_column(sa.Column("created_dag_version_id", UUIDType(binary=False), nullable=True))
batch_op.create_foreign_key(
"created_dag_version_id_fkey",
"dag_version",
["created_dag_version_id"],
["id"],
ondelete="SET NULL",
)


def downgrade():
Expand Down Expand Up @@ -375,6 +383,8 @@ def downgrade():

with op.batch_alter_table("dag_run", schema=None) as batch_op:
batch_op.add_column(sa.Column("dag_hash", sa.String(length=32), autoincrement=False, nullable=True))
batch_op.drop_constraint("created_dag_version_id_fkey", type_="foreignkey")
batch_op.drop_column("created_dag_version_id")

# Update dag_run dag_hash with dag_hash from serialized_dag where dag_id matches
if conn.dialect.name == "mysql":
Expand Down
7 changes: 4 additions & 3 deletions airflow-core/src/airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _create_orm_dagrun(
bundle_version = session.scalar(
select(DagModel.bundle_version).where(DagModel.dag_id == dag.dag_id),
)
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
run = DagRun(
dag_id=dag.dag_id,
run_id=run_id,
Expand All @@ -283,13 +284,13 @@ def _create_orm_dagrun(
)
# Load defaults into the following two fields to ensure result can be serialized detached
run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id))))
run.created_dag_version = dag_version
run.consumed_asset_events = []
session.add(run)
session.flush()
run.dag = dag
# create the associated task instances
# state is None at the moment of creation
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
run.verify_integrity(session=session, dag_version_id=dag_version.id if dag_version else None)
return run

Expand Down Expand Up @@ -1771,10 +1772,10 @@ def add_logger_if_needed(ti: TaskInstance):
self.log.exception("Task failed; ti=%s", ti)
if use_executor:
executor.heartbeat()
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
from airflow.jobs.scheduler_job_runner import SchedulerDagBag, SchedulerJobRunner

SchedulerJobRunner.process_executor_events(
executor=executor, dag_bag=dag_bag, job_id=None, session=session
executor=executor, job_id=None, scheduler_dag_bag=SchedulerDagBag(), session=session
)
if use_executor:
executor.end()
Expand Down
22 changes: 21 additions & 1 deletion airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates
from sqlalchemy.sql.expression import case, false, select
from sqlalchemy.sql.functions import coalesce
from sqlalchemy_utils import UUIDType

from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
Expand Down Expand Up @@ -93,7 +94,6 @@

from opentelemetry.sdk.trace import Span
from sqlalchemy.orm import Query, Session
from sqlalchemy_utils import UUIDType

from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
Expand Down Expand Up @@ -189,6 +189,15 @@ class DagRun(Base, LoggingMixin):
# Span context carrier, used for context propagation.
context_carrier = Column(MutableDict.as_mutable(ExtendedJSON))
span_status = Column(String(250), server_default=SpanStatus.NOT_STARTED, nullable=False)
created_dag_version_id = Column(
UUIDType(binary=False),
ForeignKey("dag_version.id", name="created_dag_version_id_fkey", ondelete="set null"),
nullable=True,
)
"""The id of the dag version column that was in effect at dag run creation time.

:meta private:
"""

# Remove this `if` after upgrading Sphinx-AutoAPI
if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
Expand Down Expand Up @@ -244,6 +253,14 @@ class DagRun(Base, LoggingMixin):
uselist=False,
cascade="all, delete, delete-orphan",
)

created_dag_version = relationship("DagVersion", uselist=False, passive_deletes=True)
"""
The dag version that was active when the dag run was created, if available.

:meta private:
"""

backfill = relationship(Backfill, uselist=False)
backfill_max_active_runs = association_proxy("backfill", "max_active_runs")
max_active_runs = association_proxy("dag_model", "max_active_runs")
Expand Down Expand Up @@ -329,6 +346,9 @@ def validate_run_id(self, key: str, run_id: str) -> str | None:
@property
def dag_versions(self) -> list[DagVersion]:
"""Return the DAG versions associated with the TIs of this DagRun."""
# when the dag is in a versioned bundle, we keep the dag version fixed
if self.bundle_version:
return [self.created_dag_version]
dag_versions = [
dv
for dv in dict.fromkeys(list(self._tih_dag_versions) + list(self._ti_dag_versions))
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/tests/unit/api_fastapi/common/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,23 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "Unique constraint violation",
"statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, run_type, triggered_by, conf, data_interval_start, data_interval_end, run_after, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, bundle_version, scheduled_by_job_id, context_carrier) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT max(log_template.id) AS max_1 \nFROM log_template), ?, ?, ?, ?, ?, ?)",
"statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, run_type, triggered_by, conf, data_interval_start, data_interval_end, run_after, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, bundle_version, scheduled_by_job_id, context_carrier, created_dag_version_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT max(log_template.id) AS max_1 \nFROM log_template), ?, ?, ?, ?, ?, ?, ?)",
"orig_error": "UNIQUE constraint failed: dag_run.dag_id, dag_run.run_id",
},
),
HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "Unique constraint violation",
"statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, run_type, triggered_by, conf, data_interval_start, data_interval_end, run_after, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, bundle_version, scheduled_by_job_id, context_carrier) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %s, %s, %s, %s, %s, %s)",
"statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, run_type, triggered_by, conf, data_interval_start, data_interval_end, run_after, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, bundle_version, scheduled_by_job_id, context_carrier, created_dag_version_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %s, %s, %s, %s, %s, %s, %s)",
"orig_error": "(1062, \"Duplicate entry 'test_dag_id-test_run_id' for key 'dag_run.dag_run_dag_id_run_id_key'\")",
},
),
HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "Unique constraint violation",
"statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, run_type, triggered_by, conf, data_interval_start, data_interval_end, run_after, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, bundle_version, scheduled_by_job_id, context_carrier) VALUES (%(dag_id)s, %(queued_at)s, %(logical_date)s, %(start_date)s, %(end_date)s, %(state)s, %(run_id)s, %(creating_job_id)s, %(run_type)s, %(triggered_by)s, %(conf)s, %(data_interval_start)s, %(data_interval_end)s, %(run_after)s, %(last_scheduling_decision)s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %(updated_at)s, %(clear_number)s, %(backfill_id)s, %(bundle_version)s, %(scheduled_by_job_id)s, %(context_carrier)s) RETURNING dag_run.id",
"statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, run_type, triggered_by, conf, data_interval_start, data_interval_end, run_after, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, bundle_version, scheduled_by_job_id, context_carrier, created_dag_version_id) VALUES (%(dag_id)s, %(queued_at)s, %(logical_date)s, %(start_date)s, %(end_date)s, %(state)s, %(run_id)s, %(creating_job_id)s, %(run_type)s, %(triggered_by)s, %(conf)s, %(data_interval_start)s, %(data_interval_end)s, %(run_after)s, %(last_scheduling_decision)s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %(updated_at)s, %(clear_number)s, %(backfill_id)s, %(bundle_version)s, %(scheduled_by_job_id)s, %(context_carrier)s, %(created_dag_version_id)s) RETURNING dag_run.id",
"orig_error": 'duplicate key value violates unique constraint "dag_run_dag_id_run_id_key"\nDETAIL: Key (dag_id, run_id)=(test_dag_id, test_run_id) already exists.\n',
},
),
Expand Down
Loading
Loading