Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def _trigger_dag(
:return: list of triggered dags
"""
dag = dag_bag.get_dag(dag_id, session=session) # prefetch dag if it is stored serialized

if dag is None or dag_id not in dag_bag.dags:
dag_info = dag_id, None # Using None since we don't have bundle_version information
if dag is None or dag_info not in dag_bag.dags:
raise DagNotFound(f"Dag id {dag_id} not found")

run_after = run_after or timezone.coerce_datetime(timezone.utcnow())
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def dag_report(args) -> None:
if bundle.name not in bundles_to_reserialize:
continue
bundle.initialize()
dagbag = DagBag(bundle.path, include_examples=False)
dagbag = DagBag(bundle.path, include_examples=False, bundle_version=bundle.version)
all_dagbag_stats.extend(dagbag.dagbag_stats)

AirflowConsole().print_as(
Expand Down Expand Up @@ -629,7 +629,7 @@ def _parse_and_get_dag(dag_id: str) -> DAG | None:
bag = DagBag(
dag_folder=dag_absolute_path, include_examples=False, safe_mode=False, load_op_links=False
)
return bag.dags.get(dag_id)
return bag.dags.get((dag_id, bundle.version))


@cli_utils.action_cli
Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess:
id=id,
path=dag_file.absolute_path,
bundle_path=cast("Path", dag_file.bundle_path),
bundle_version=dag_file.bundle_version,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of this? I thought that with dag file processing, we are always tracking the latest version and we don't specify it here intentionally -- and i'm not even sure where this argument goes to in the interface

Copy link
Contributor Author

@ephraimbuddy ephraimbuddy Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed so that we specifically store dag_id and the version in dagbag.dags.

If you have two dagruns running, with different number of task instances, you will notice that when the Scheduler does dag_run.update_state, some task will be marked removed because the retrieved dag from dagbag has fewer tasks

callbacks=callback_to_execute_for_file,
selector=self.selector,
logger=logger,
Expand Down
9 changes: 7 additions & 2 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileP
bag = DagBag(
dag_folder=msg.file,
bundle_path=msg.bundle_path,
bundle_version=msg.bundle_version,
include_examples=False,
safe_mode=True,
load_op_links=False,
Expand Down Expand Up @@ -149,7 +150,7 @@ def _execute_callbacks(


def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: FilteringBoundLogger) -> None:
dag = dagbag.dags[request.dag_id]
dag = dagbag.dags[(request.dag_id, request.bundle_version)]

callbacks = dag.on_failure_callback if request.is_failure_callback else dag.on_success_callback
if not callbacks:
Expand Down Expand Up @@ -189,6 +190,7 @@ class DagFileParseRequest(BaseModel):
requests_fd: int
callback_requests: list[CallbackRequest] = Field(default_factory=list)
type: Literal["DagFileParseRequest"] = "DagFileParseRequest"
bundle_version: str | None = None


class DagFileParsingResult(BaseModel):
Expand Down Expand Up @@ -237,21 +239,24 @@ def start( # type: ignore[override]
bundle_path: Path,
callbacks: list[CallbackRequest],
target: Callable[[], None] = _parse_file_entrypoint,
bundle_version: str | None = None,
**kwargs,
) -> Self:
proc: Self = super().start(target=target, **kwargs)
proc._on_child_started(callbacks, path, bundle_path)
proc._on_child_started(callbacks, path, bundle_path, bundle_version=bundle_version)
return proc

def _on_child_started(
self,
callbacks: list[CallbackRequest],
path: str | os.PathLike[str],
bundle_path: Path,
bundle_version: str | None = None,
) -> None:
msg = DagFileParseRequest(
file=os.fspath(path),
bundle_path=bundle_path,
bundle_version=bundle_version,
requests_fd=self._requests_fd,
callback_requests=callbacks,
)
Expand Down
39 changes: 25 additions & 14 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from datetime import date, timedelta
from functools import lru_cache, partial
from itertools import groupby
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Protocol

from sqlalchemy import and_, delete, exists, func, select, text, tuple_, update
from sqlalchemy.exc import OperationalError
Expand Down Expand Up @@ -141,6 +141,12 @@ def _is_parent_process() -> bool:
return multiprocessing.current_process().name == "MainProcess"


class GetDag(Protocol):
"""Typing for cached_get_dag."""

def __call__(self, dag_id: str, bundle_version: str | None = None) -> DAG | None: ...


class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
"""
SchedulerJobRunner runs for a specific time interval and schedules jobs that are ready to run.
Expand Down Expand Up @@ -489,7 +495,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
if task_instance.dag_model.has_task_concurrency_limits:
# Many dags don't have a task_concurrency, so where we can avoid loading the full
# serialized DAG the better.
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
serialized_dag = self.dagbag.get_dag(
dag_id, session=session, bundle_version=task_instance.dag_run.bundle_version
)
# If the dag is missing, fail the task and continue to the next task.
if not serialized_dag:
self.log.error(
Expand Down Expand Up @@ -773,6 +781,7 @@ def process_executor_events(
.where(filter_for_tis)
.options(selectinload(TI.dag_model))
.options(joinedload(TI.dag_version))
.options(joinedload(TI.dag_run))
)
# row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have
# multi-schedulers
Expand Down Expand Up @@ -867,7 +876,7 @@ def process_executor_events(

# Get task from the Serialized DAG
try:
dag = dag_bag.get_dag(ti.dag_id)
dag = dag_bag.get_dag(ti.dag_id, bundle_version=ti.dag_run.bundle_version)
task = dag.get_task(ti.task_id)
except Exception:
cls.logger().exception("Marking task instance %s as %s", ti, state)
Expand Down Expand Up @@ -985,7 +994,9 @@ def _update_dag_run_state_for_paused_dags(self, session: Session = NEW_SESSION)
.group_by(DagRun)
)
for dag_run in paused_runs:
dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag = self.dagbag.get_dag(
dag_run.dag_id, session=session, bundle_version=dag_run.bundle_version
)
if dag is not None:
dag_run.dag = dag
_, callback_to_run = dag_run.update_state(execute_callbacks=False, session=session)
Expand Down Expand Up @@ -1336,11 +1347,9 @@ def _do_scheduling(self, session: Session) -> int:

# Send the callbacks after we commit to ensure the context is up to date when it gets run
# cache saves time during scheduling of many dag_runs for same dag
cached_get_dag: Callable[[str], DAG | None] = lru_cache()(
partial(self.dagbag.get_dag, session=session)
)
cached_get_dag: GetDag = lru_cache()(partial(self.dagbag.get_dag, session=session))
for dag_run, callback_to_run in callback_tuples:
dag = cached_get_dag(dag_run.dag_id)
dag = cached_get_dag(dag_run.dag_id, bundle_version=dag_run.bundle_version)
if dag:
# Sending callbacks to the database, so it must be done outside of prohibit_commit.
self._send_dag_callbacks_to_processor(dag, callback_to_run)
Expand Down Expand Up @@ -1677,17 +1686,15 @@ def _update_state(dag: DAG, dag_run: DagRun):
)

# cache saves time during scheduling of many dag_runs for same dag
cached_get_dag: Callable[[str], DAG | None] = lru_cache()(
partial(self.dagbag.get_dag, session=session)
)
cached_get_dag: GetDag = lru_cache()(partial(self.dagbag.get_dag, session=session))

span = Trace.get_current_span()
for dag_run in dag_runs:
dag_id = dag_run.dag_id
run_id = dag_run.run_id
backfill_id = dag_run.backfill_id
backfill = dag_run.backfill
dag = dag_run.dag = cached_get_dag(dag_id)
dag = dag_run.dag = cached_get_dag(dag_id, bundle_version=dag_run.bundle_version)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
Expand Down Expand Up @@ -1768,7 +1775,9 @@ def _schedule_dag_run(
)
callback: DagCallbackRequest | None = None

dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag = dag_run.dag = self.dagbag.get_dag(
dag_run.dag_id, session=session, bundle_version=dag_run.bundle_version
)
dag_model = DM.get_dagmodel(dag_run.dag_id, session)

if not dag or not dag_model:
Expand Down Expand Up @@ -1882,7 +1891,9 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) ->
self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
return True
# Refresh the DAG
dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session)
dag_run.dag = self.dagbag.get_dag(
dag_id=dag_run.dag_id, session=session, bundle_version=dag_run.bundle_version
)
if not dag_run.dag:
return False
# Select all TIs in State.unfinished and update the dag_version_id
Expand Down
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,23 @@ def get_version(
dag_id: str,
version_number: int | None = None,
*,
bundle_version: str | None = None,
session: Session = NEW_SESSION,
) -> DagVersion | None:
"""
Get the version of the DAG.

:param dag_id: The DAG ID.
:param version_number: The version number.
:param bundle_version: The bundle version.
:param session: The database session.
:return: The version of the DAG or None if not found.
"""
version_select_obj = select(cls).where(cls.dag_id == dag_id)
if version_number:
version_select_obj = version_select_obj.where(cls.version_number == version_number)
if bundle_version:
version_select_obj = version_select_obj.where(cls.bundle_version == bundle_version)

return session.scalar(version_select_obj.order_by(cls.id.desc()).limit(1))

Expand Down
Loading