diff --git a/airflow-core/src/airflow/api/common/trigger_dag.py b/airflow-core/src/airflow/api/common/trigger_dag.py index 8378876f6760c..971d2665796e8 100644 --- a/airflow-core/src/airflow/api/common/trigger_dag.py +++ b/airflow-core/src/airflow/api/common/trigger_dag.py @@ -63,8 +63,8 @@ def _trigger_dag( :return: list of triggered dags """ dag = dag_bag.get_dag(dag_id, session=session) # prefetch dag if it is stored serialized - - if dag is None or dag_id not in dag_bag.dags: + dag_info = dag_id, None # Using None since we don't have bundle_version information + if dag is None or dag_info not in dag_bag.dags: raise DagNotFound(f"Dag id {dag_id} not found") run_after = run_after or timezone.coerce_datetime(timezone.utcnow()) diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index 44e14352ad185..8b9c8c85d5636 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -523,7 +523,7 @@ def dag_report(args) -> None: if bundle.name not in bundles_to_reserialize: continue bundle.initialize() - dagbag = DagBag(bundle.path, include_examples=False) + dagbag = DagBag(bundle.path, include_examples=False, bundle_version=bundle.version) all_dagbag_stats.extend(dagbag.dagbag_stats) AirflowConsole().print_as( @@ -629,7 +629,7 @@ def _parse_and_get_dag(dag_id: str) -> DAG | None: bag = DagBag( dag_folder=dag_absolute_path, include_examples=False, safe_mode=False, load_op_links=False ) - return bag.dags.get(dag_id) + return bag.dags.get((dag_id, bundle.version)) @cli_utils.action_cli diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 0f5c521df129f..cc37a4b95905a 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -871,6 +871,7 @@ def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: id=id, path=dag_file.absolute_path, bundle_path=cast("Path", dag_file.bundle_path), + bundle_version=dag_file.bundle_version, callbacks=callback_to_execute_for_file, selector=self.selector, logger=logger, diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 489b916240eb4..5bd069d00c9cf 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -94,6 +94,7 @@ def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileP bag = DagBag( dag_folder=msg.file, bundle_path=msg.bundle_path, + bundle_version=msg.bundle_version, include_examples=False, safe_mode=True, load_op_links=False, @@ -149,7 +150,7 @@ def _execute_callbacks( def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: FilteringBoundLogger) -> None: - dag = dagbag.dags[request.dag_id] + dag = dagbag.dags[(request.dag_id, request.bundle_version)] callbacks = dag.on_failure_callback if request.is_failure_callback else dag.on_success_callback if not callbacks: @@ -189,6 +190,7 @@ class DagFileParseRequest(BaseModel): requests_fd: int callback_requests: list[CallbackRequest] = Field(default_factory=list) type: Literal["DagFileParseRequest"] = "DagFileParseRequest" + bundle_version: str | None = None class DagFileParsingResult(BaseModel): @@ -237,10 +239,11 @@ def start( # type: ignore[override] bundle_path: Path, callbacks: list[CallbackRequest], target: Callable[[], None] = _parse_file_entrypoint, + bundle_version: str | None = None, **kwargs, ) -> Self: proc: Self = super().start(target=target, **kwargs) - proc._on_child_started(callbacks, path, bundle_path) + proc._on_child_started(callbacks, path, bundle_path, bundle_version=bundle_version) return proc def _on_child_started( @@ -248,10 +251,12 @@ def _on_child_started( callbacks: list[CallbackRequest], path: str | os.PathLike[str], bundle_path: Path, + bundle_version: str | None = None, ) -> None: msg = DagFileParseRequest( file=os.fspath(path), bundle_path=bundle_path, + bundle_version=bundle_version, requests_fd=self._requests_fd, callback_requests=callbacks, ) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 43fc54386c524..dae606eb3f55d 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -30,7 +30,7 @@ from datetime import date, timedelta from functools import lru_cache, partial from itertools import groupby -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Protocol from sqlalchemy import and_, delete, exists, func, select, text, tuple_, update from sqlalchemy.exc import OperationalError @@ -141,6 +141,12 @@ def _is_parent_process() -> bool: return multiprocessing.current_process().name == "MainProcess" +class GetDag(Protocol): + """Typing for cached_get_dag.""" + + def __call__(self, dag_id: str, bundle_version: str | None = None) -> DAG | None: ... + + class SchedulerJobRunner(BaseJobRunner, LoggingMixin): """ SchedulerJobRunner runs for a specific time interval and schedules jobs that are ready to run. @@ -489,7 +495,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - if task_instance.dag_model.has_task_concurrency_limits: # Many dags don't have a task_concurrency, so where we can avoid loading the full # serialized DAG the better. - serialized_dag = self.dagbag.get_dag(dag_id, session=session) + serialized_dag = self.dagbag.get_dag( + dag_id, session=session, bundle_version=task_instance.dag_run.bundle_version + ) # If the dag is missing, fail the task and continue to the next task. if not serialized_dag: self.log.error( @@ -773,6 +781,7 @@ def process_executor_events( .where(filter_for_tis) .options(selectinload(TI.dag_model)) .options(joinedload(TI.dag_version)) + .options(joinedload(TI.dag_run)) ) # row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have # multi-schedulers @@ -867,7 +876,7 @@ def process_executor_events( # Get task from the Serialized DAG try: - dag = dag_bag.get_dag(ti.dag_id) + dag = dag_bag.get_dag(ti.dag_id, bundle_version=ti.dag_run.bundle_version) task = dag.get_task(ti.task_id) except Exception: cls.logger().exception("Marking task instance %s as %s", ti, state) @@ -985,7 +994,9 @@ def _update_dag_run_state_for_paused_dags(self, session: Session = NEW_SESSION) .group_by(DagRun) ) for dag_run in paused_runs: - dag = self.dagbag.get_dag(dag_run.dag_id, session=session) + dag = self.dagbag.get_dag( + dag_run.dag_id, session=session, bundle_version=dag_run.bundle_version + ) if dag is not None: dag_run.dag = dag _, callback_to_run = dag_run.update_state(execute_callbacks=False, session=session) @@ -1336,11 +1347,9 @@ def _do_scheduling(self, session: Session) -> int: # Send the callbacks after we commit to ensure the context is up to date when it gets run # cache saves time during scheduling of many dag_runs for same dag - cached_get_dag: Callable[[str], DAG | None] = lru_cache()( - partial(self.dagbag.get_dag, session=session) - ) + cached_get_dag: GetDag = lru_cache()(partial(self.dagbag.get_dag, session=session)) for dag_run, callback_to_run in callback_tuples: - dag = cached_get_dag(dag_run.dag_id) + dag = cached_get_dag(dag_run.dag_id, bundle_version=dag_run.bundle_version) if dag: # Sending callbacks to the database, so it must be done outside of prohibit_commit. self._send_dag_callbacks_to_processor(dag, callback_to_run) @@ -1677,9 +1686,7 @@ def _update_state(dag: DAG, dag_run: DagRun): ) # cache saves time during scheduling of many dag_runs for same dag - cached_get_dag: Callable[[str], DAG | None] = lru_cache()( - partial(self.dagbag.get_dag, session=session) - ) + cached_get_dag: GetDag = lru_cache()(partial(self.dagbag.get_dag, session=session)) span = Trace.get_current_span() for dag_run in dag_runs: @@ -1687,7 +1694,7 @@ def _update_state(dag: DAG, dag_run: DagRun): run_id = dag_run.run_id backfill_id = dag_run.backfill_id backfill = dag_run.backfill - dag = dag_run.dag = cached_get_dag(dag_id) + dag = dag_run.dag = cached_get_dag(dag_id, bundle_version=dag_run.bundle_version) if not dag: self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id) continue @@ -1768,7 +1775,9 @@ def _schedule_dag_run( ) callback: DagCallbackRequest | None = None - dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) + dag = dag_run.dag = self.dagbag.get_dag( + dag_run.dag_id, session=session, bundle_version=dag_run.bundle_version + ) dag_model = DM.get_dagmodel(dag_run.dag_id, session) if not dag or not dag_model: @@ -1882,7 +1891,9 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) -> self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id) return True # Refresh the DAG - dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session) + dag_run.dag = self.dagbag.get_dag( + dag_id=dag_run.dag_id, session=session, bundle_version=dag_run.bundle_version + ) if not dag_run.dag: return False # Select all TIs in State.unfinished and update the dag_version_id diff --git a/airflow-core/src/airflow/models/dag_version.py b/airflow-core/src/airflow/models/dag_version.py index 2b6483c9d8375..b796adf7d35a9 100644 --- a/airflow-core/src/airflow/models/dag_version.py +++ b/airflow-core/src/airflow/models/dag_version.py @@ -140,6 +140,7 @@ def get_version( dag_id: str, version_number: int | None = None, *, + bundle_version: str | None = None, session: Session = NEW_SESSION, ) -> DagVersion | None: """ @@ -147,12 +148,15 @@ def get_version( :param dag_id: The DAG ID. :param version_number: The version number. + :param bundle_version: The bundle version. :param session: The database session. :return: The version of the DAG or None if not found. """ version_select_obj = select(cls).where(cls.dag_id == dag_id) if version_number: version_select_obj = version_select_obj.where(cls.version_number == version_number) + if bundle_version: + version_select_obj = version_select_obj.where(cls.bundle_version == bundle_version) return session.scalar(version_select_obj.order_by(cls.id.desc()).limit(1)) diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index 23b87a9cf2e05..95e2aaac82506 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -69,6 +69,7 @@ from airflow.models.dag import DAG from airflow.models.dagwarning import DagWarning + from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.types import ArgNotSet @@ -127,9 +128,11 @@ def __init__( collect_dags: bool = True, known_pools: set[str] | None = None, bundle_path: Path | None = None, + bundle_version: str | None = None, ): super().__init__() self.bundle_path: Path | None = bundle_path + self.bundle_version: str | None = bundle_version include_examples = ( include_examples if isinstance(include_examples, bool) @@ -141,7 +144,7 @@ def __init__( dag_folder = dag_folder or settings.DAGS_FOLDER self.dag_folder = dag_folder - self.dags: dict[str, DAG] = {} + self.dags: dict[tuple[str, str | None], DAG] = {} # the file's last modified timestamp when we last read it self.file_last_changed: dict[str, datetime] = {} self.import_errors: dict[str, str] = {} @@ -149,7 +152,7 @@ def __init__( self.has_logged = False self.read_dags_from_db = read_dags_from_db # Only used by read_dags_from_db=True - self.dags_last_fetched: dict[str, datetime] = {} + self.dags_last_fetched: dict[tuple[str, str | None], datetime] = {} # Only used by SchedulerJob to compare the dag_hash to identify change in DAGs self.dags_hash: dict[str, str] = {} @@ -178,10 +181,10 @@ def dag_ids(self) -> list[str]: :return: a list of DAG IDs in this bag """ - return list(self.dags) + return list([key[0] for key in self.dags]) @provide_session - def get_dag(self, dag_id, session: Session = None): + def get_dag(self, dag_id, session: Session = None, bundle_version: str | None = None): """ Get the DAG out of the dictionary, and refreshes it if expired. @@ -190,14 +193,16 @@ def get_dag(self, dag_id, session: Session = None): # Avoid circular import from airflow.models.dag import DagModel + dag_info = (dag_id, bundle_version or self.bundle_version) + if self.read_dags_from_db: # Import here so that serialized dag is only imported when serialization is enabled from airflow.models.serialized_dag import SerializedDagModel - if dag_id not in self.dags: + if dag_info not in self.dags: # Load from DB if not (yet) in the bag - self._add_dag_from_db(dag_id=dag_id, session=session) - return self.dags.get(dag_id) + self._add_dag_from_db(dag_id=dag_id, session=session, bundle_version=bundle_version) + return self.dags.get(dag_info) # If DAG is in the DagBag, check the following # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs) @@ -208,8 +213,8 @@ def get_dag(self, dag_id, session: Session = None): # if it exists and return None. min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL) if ( - dag_id in self.dags_last_fetched - and timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs + dag_info in self.dags_last_fetched + and timezone.utcnow() > self.dags_last_fetched[dag_info] + min_serialized_dag_fetch_secs ): sd_latest_version_and_updated_datetime = ( SerializedDagModel.get_latest_version_hash_and_updated_datetime( @@ -218,34 +223,34 @@ def get_dag(self, dag_id, session: Session = None): ) if not sd_latest_version_and_updated_datetime: self.log.warning("Serialized DAG %s no longer exists", dag_id) - del self.dags[dag_id] - del self.dags_last_fetched[dag_id] + del self.dags[dag_info] + del self.dags_last_fetched[dag_info] del self.dags_hash[dag_id] return None sd_latest_version, sd_last_updated_datetime = sd_latest_version_and_updated_datetime if ( - sd_last_updated_datetime > self.dags_last_fetched[dag_id] + sd_last_updated_datetime > self.dags_last_fetched[dag_info] or sd_latest_version != self.dags_hash[dag_id] ): - self._add_dag_from_db(dag_id=dag_id, session=session) + self._add_dag_from_db(dag_id=dag_id, session=session, bundle_version=bundle_version) - return self.dags.get(dag_id) + return self.dags.get(dag_info) # If asking for a known subdag, we want to refresh the parent dag = None root_dag_id = dag_id - if dag_id in self.dags: - dag = self.dags[dag_id] + if dag_info in self.dags: + dag = self.dags[dag_info] # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized? orm_dag = DagModel.get_current(root_dag_id, session=session) if not orm_dag: - return self.dags.get(dag_id) + return self.dags.get(dag_info) # If the dag corresponding to root_dag_id is absent or expired - is_missing = root_dag_id not in self.dags + is_missing = (root_dag_id, bundle_version) not in self.dags is_expired = ( orm_dag.last_expired and dag and dag.last_loaded and dag.last_loaded < orm_dag.last_expired ) @@ -260,23 +265,29 @@ def get_dag(self, dag_id, session: Session = None): # If the source file no longer exports `dag_id`, delete it from self.dags if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]: - return self.dags[dag_id] - elif dag_id in self.dags: - del self.dags[dag_id] - return self.dags.get(dag_id) + return self.dags[dag_info] + elif dag_info in self.dags: + del self.dags[dag_info] + return self.dags.get(dag_info) - def _add_dag_from_db(self, dag_id: str, session: Session): + def _add_dag_from_db(self, dag_id: str, session: Session, bundle_version: str | None = None): """Add DAG to DagBag from DB.""" - from airflow.models.serialized_dag import SerializedDagModel + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_version( + dag_id=dag_id, + bundle_version=bundle_version, + session=session, + ) - row: SerializedDagModel | None = SerializedDagModel.get(dag_id, session) + row: SerializedDagModel | None = dag_version.serialized_dag if dag_version else None if not row: return None row.load_op_links = self.load_op_links dag = row.dag - self.dags[dag.dag_id] = dag - self.dags_last_fetched[dag.dag_id] = timezone.utcnow() + self.dags[(dag.dag_id, bundle_version)] = dag + self.dags_last_fetched[(dag.dag_id, bundle_version)] = timezone.utcnow() self.dags_hash[dag.dag_id] = row.dag_hash def process_file(self, filepath, only_if_updated=True, safe_mode=True): @@ -536,18 +547,20 @@ def bag_dag(self, dag: DAG): raise AirflowClusterPolicyError(e) try: - prev_dag = self.dags.get(dag.dag_id) + prev_dag = self.dags.get((dag.dag_id, self.bundle_version)) if prev_dag and prev_dag.fileloc != dag.fileloc: raise AirflowDagDuplicatedIdException( dag_id=dag.dag_id, incoming=dag.fileloc, - existing=self.dags[dag.dag_id].fileloc, + existing=self.dags[(dag.dag_id, self.bundle_version)].fileloc, ) - self.dags[dag.dag_id] = dag - self.log.debug("Loaded DAG %s", dag) + self.dags[(dag.dag_id, self.bundle_version)] = dag + self.log.debug("Loaded DAG %s, bundle_version %s", dag, self.bundle_version) except (AirflowDagCycleException, AirflowDagDuplicatedIdException): # There was an error in bagging the dag. Remove it from the list of dags - self.log.exception("Exception bagging dag: %s", dag.dag_id) + self.log.exception( + "Exception bagging dag: %s, bundle_version %s", dag.dag_id, self.bundle_version + ) raise def collect_dags( diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index f78d9492830e0..88fe27ca5239d 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -28,7 +28,7 @@ import sqlalchemy_jsonfield import uuid6 from sqlalchemy import Column, ForeignKey, LargeBinary, String, exc, select, tuple_ -from sqlalchemy.orm import backref, foreign, relationship +from sqlalchemy.orm import backref, foreign, joinedload, relationship from sqlalchemy.sql.expression import func, literal from sqlalchemy_utils import UUIDType @@ -479,7 +479,7 @@ def get_latest_serialized_dags( @classmethod @provide_session - def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDAG]: + def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[tuple[str, str | None], SerializedDAG]: """ Read all DAGs in serialized_dag table. @@ -492,11 +492,13 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA .subquery() ) serialized_dags = session.scalars( - select(cls).join( + select(cls) + .join( latest_serialized_dag_subquery, (cls.dag_id == latest_serialized_dag_subquery.c.dag_id) and (cls.created_at == latest_serialized_dag_subquery.c.max_created), ) + .options(joinedload(cls.dag_version)) ) dags = {} @@ -506,7 +508,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA # Coherence check if dag.dag_id == row.dag_id: - dags[row.dag_id] = dag + dags[(row.dag_id, row.dag_version.bundle_version)] = dag else: log.warning( "dag_id Mismatch in DB: Row with dag_id '%s' has Serialised DAG with '%s' dag_id", diff --git a/airflow-core/src/airflow/utils/cli.py b/airflow-core/src/airflow/utils/cli.py index be27cd47677c4..5408018fe09b6 100644 --- a/airflow-core/src/airflow/utils/cli.py +++ b/airflow-core/src/airflow/utils/cli.py @@ -234,7 +234,7 @@ def get_dag_by_file_location(dag_id: str): f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse." ) dagbag = DagBag(dag_folder=dag_model.fileloc) - return dagbag.dags[dag_id] + return dagbag.dags[(dag_id, None)] def _search_for_dag_file(val: str | None) -> str | None: @@ -275,7 +275,7 @@ def get_dag(subdir: str | None, dag_id: str, from_db: bool = False) -> DAG: else: first_path = process_subdir(subdir) dagbag = DagBag(first_path) - dag = dagbag.dags.get(dag_id) # avoids db calls made in get_dag + dag = dagbag.dags.get((dag_id, None)) # avoids db calls made in get_dag # Create a SchedulerDAG since some of the CLI commands rely on DB access dag = DAG.from_sdk_dag(dag) if not dag: diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py index 24a7de4556c2b..83d6c36c40b36 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py @@ -92,7 +92,7 @@ def setup(self, test_client, dag_maker, request, session) -> None: DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) - dag_bag.dags = {self.dag.dag_id: self.dag} + dag_bag.dags = {(self.dag.dag_id, None): self.dag} test_client.app.state.dag_bag = dag_bag dag_bag.sync_to_db("dags-folder", None) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py index 11f3846cacd78..c4bc9b010e67c 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py @@ -259,7 +259,7 @@ def test_get_logs_of_removed_task(self, request_url, expected_filename, extra_qu # Recreate DAG without tasks dagbag = self.app.state.dag_bag dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.parse(self.default_time)) - del dagbag.dags[self.DAG_ID] + del dagbag.dags[(self.DAG_ID, None)] dagbag.bag_dag(dag=dag) key = self.app.state.secret_key diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py index 7052aeaedb631..237f6ef61fcff 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py @@ -66,9 +66,9 @@ def create_dags(self, test_client): task4 >> task5 dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = { - dag.dag_id: dag, - mapped_dag.dag_id: mapped_dag, - unscheduled_dag.dag_id: unscheduled_dag, + (dag.dag_id, None): dag, + (mapped_dag.dag_id, None): mapped_dag, + (unscheduled_dag.dag_id, None): unscheduled_dag, } test_client.app.state.dag_bag = dag_bag diff --git a/airflow-core/tests/unit/cli/commands/test_task_command.py b/airflow-core/tests/unit/cli/commands/test_task_command.py index b3dd69912cedb..3cc168674cd2d 100644 --- a/airflow-core/tests/unit/cli/commands/test_task_command.py +++ b/airflow-core/tests/unit/cli/commands/test_task_command.py @@ -111,8 +111,8 @@ def teardown_class(cls) -> None: @pytest.mark.execution_timeout(120) def test_cli_list_tasks(self): - for dag_id in self.dagbag.dags: - args = self.parser.parse_args(["tasks", "list", dag_id]) + for dag_info in self.dagbag.dags: + args = self.parser.parse_args(["tasks", "list", dag_info[0]]) task_command.task_list(args) def test_test(self): @@ -338,7 +338,7 @@ def test_task_state(self): ) def test_task_states_for_dag_run(self): - dag2 = DagBag().dags["example_python_operator"] + dag2 = DagBag().dags[("example_python_operator", None)] task2 = dag2.get_task(task_id="print_the_context") dag2 = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag2)) diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index 387b45b619f6a..6c59a1b27b9e4 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -559,7 +559,8 @@ def test_kill_timed_out_processors_no_kill(self): b'"bundle_path":"/opt/airflow/dags",' b'"requests_fd":123,' b'"callback_requests":[],' - b'"type":"DagFileParseRequest"' + b'"type":"DagFileParseRequest",' + b'"bundle_version":null' b"}\n", ), pytest.param( @@ -590,7 +591,8 @@ def test_kill_timed_out_processors_no_kill(self): b'"type":"DagCallbackRequest"' b"}" b"]," - b'"type":"DagFileParseRequest"' + b'"type":"DagFileParseRequest",' + b'"bundle_version":null' b"}\n", ), ], @@ -883,6 +885,7 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle): id=mock.ANY, path=Path(dag2_path.bundle_path, dag2_path.rel_path), bundle_path=dag2_path.bundle_path, + bundle_version=None, callbacks=[dag2_req1], selector=mock.ANY, logger=mock_logger, @@ -892,6 +895,7 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle): id=mock.ANY, path=Path(dag1_path.bundle_path, dag1_path.rel_path), bundle_path=dag1_path.bundle_path, + bundle_version=None, callbacks=[dag1_req1, dag1_req2], selector=mock.ANY, logger=mock_logger, diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index ca9670f81ae4a..0a20b07c1f2d2 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -305,7 +305,7 @@ def on_failure(context): dag = DAG(dag_id="a", on_failure_callback=on_failure) def fake_collect_dags(self, *args, **kwargs): - self.dags[dag.dag_id] = dag + self.dags[(dag.dag_id, None)] = dag spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) @@ -341,7 +341,7 @@ def on_failure(context): BaseOperator(task_id="b", on_failure_callback=on_failure) def fake_collect_dags(self, *args, **kwargs): - self.dags[dag.dag_id] = dag + self.dags[(dag.dag_id, None)] = dag spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 1c6488cb93676..e61426a52fdf9 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -3342,8 +3342,8 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): dag_version_1 = DagVersion.get_latest_version(dr.dag_id, session=session) assert dr.dag_versions[-1].id == dag_version_1.id - assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_changed": dag} - assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 1 + assert self.job_runner.dagbag.dags == {("test_verify_integrity_if_dag_changed", None): dag} + assert len(self.job_runner.dagbag.dags.get(("test_verify_integrity_if_dag_changed", None)).tasks) == 1 # Now let's say the DAG got updated (new task got added) BashOperator(task_id="bash_task_1", dag=dag, bash_command="echo hi") @@ -3359,8 +3359,8 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): assert len(drs) == 1 dr = drs[0] assert dr.dag_versions[-1].id == dag_version_2.id - assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_changed": dag} - assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 2 + assert self.job_runner.dagbag.dags == {("test_verify_integrity_if_dag_changed", None): dag} + assert len(self.job_runner.dagbag.dags.get(("test_verify_integrity_if_dag_changed", None)).tasks) == 2 tis_count = ( session.query(func.count(TaskInstance.task_id)) diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index 2e08e8748d787..d8ca271adc82a 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -485,15 +485,15 @@ def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, tmp_path): EmptyOperator(task_id="task_1") dag_maker.create_dagrun() dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False, read_dags_from_db=True) - dagbag.dags = {dag.dag_id: SerializedDAG.from_dict(SerializedDAG.to_dict(dag))} - dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() - timedelta(minutes=2))} + dagbag.dags = {(dag.dag_id, None): SerializedDAG.from_dict(SerializedDAG.to_dict(dag))} + dagbag.dags_last_fetched = {(dag.dag_id, None): (tz.utcnow() - timedelta(minutes=2))} dagbag.dags_hash = {dag.dag_id: mock.ANY} assert SerializedDagModel.has_dag(dag.dag_id) is False assert dagbag.get_dag(dag.dag_id) is None - assert dag.dag_id not in dagbag.dags - assert dag.dag_id not in dagbag.dags_last_fetched + assert (dag.dag_id, None) not in dagbag.dags + assert (dag.dag_id, None) not in dagbag.dags_last_fetched assert dag.dag_id not in dagbag.dags_hash def process_dag(self, create_dag, tmp_path): @@ -512,12 +512,13 @@ def process_dag(self, create_dag, tmp_path): def validate_dags(self, expected_dag, actual_found_dags, actual_dagbag, should_be_found=True): actual_found_dag_ids = [dag.dag_id for dag in actual_found_dags] dag_id = expected_dag.dag_id + dag_info = dag_id, None actual_dagbag.log.info("validating %s", dag_id) - assert (dag_id in actual_found_dag_ids) == should_be_found, ( + assert (dag_info in actual_found_dag_ids) == should_be_found, ( f'dag "{dag_id}" should {"" if should_be_found else "not "}' f'have been found after processing dag "{expected_dag.dag_id}"' ) - assert (dag_id in actual_dagbag.dags) == should_be_found, ( + assert (dag_info in actual_dagbag.dags) == should_be_found, ( f'dag "{dag_id}" should {"" if should_be_found else "not "}' f'be in dagbag.dags after processing dag "{expected_dag.dag_id}"' ) @@ -571,7 +572,7 @@ def test_deactivate_unknown_dags(self): """ dagbag = DagBag(include_examples=True) dag_id = "test_deactivate_unknown_dags" - expected_active_dags = dagbag.dags.keys() + expected_active_dags = [k[0] for k in dagbag.dags.keys()] model_before = DagModel(dag_id=dag_id, is_active=True) with create_session() as session: @@ -680,13 +681,13 @@ def test_get_dag_with_dag_serialization(self): Serialized DAG table after 'min_serialized_dag_fetch_interval' seconds are passed. """ with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 0)), tick=False): - example_bash_op_dag = DagBag(include_examples=True).dags.get("example_bash_operator") + example_bash_op_dag = DagBag(include_examples=True).dags.get(("example_bash_operator", None)) DAG.from_sdk_dag(example_bash_op_dag).sync_to_db() SerializedDagModel.write_dag(dag=example_bash_op_dag, bundle_name="testing") dag_bag = DagBag(read_dags_from_db=True) ser_dag_1 = dag_bag.get_dag("example_bash_operator") - ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"] + ser_dag_1_update_time = dag_bag.dags_last_fetched[("example_bash_operator", None)] assert example_bash_op_dag.tags == ser_dag_1.tags assert ser_dag_1_update_time == tz.datetime(2020, 1, 5, 0, 0, 0) @@ -705,9 +706,9 @@ def test_get_dag_with_dag_serialization(self): # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag' # fetches the Serialized DAG from DB with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 8)), tick=False): - with assert_queries_count(2): + with assert_queries_count(3): updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator") - updated_ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"] + updated_ser_dag_1_update_time = dag_bag.dags_last_fetched[("example_bash_operator", None)] assert set(updated_ser_dag_1.tags) == {"example", "example2", "new_tag"} assert updated_ser_dag_1_update_time > ser_dag_1_update_time @@ -722,7 +723,7 @@ def test_get_dag_refresh_race_condition(self, session, testing_dag_bundle): db_clean_up() # serialize the initial version of the DAG with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 0)), tick=False): - example_bash_op_dag = DagBag(include_examples=True).dags.get("example_bash_operator") + example_bash_op_dag = DagBag(include_examples=True).dags.get(("example_bash_operator", None)) DAG.from_sdk_dag(example_bash_op_dag).sync_to_db() SerializedDagModel.write_dag(dag=example_bash_op_dag, bundle_name="testing") @@ -730,10 +731,10 @@ def test_get_dag_refresh_race_condition(self, session, testing_dag_bundle): with time_machine.travel((tz.datetime(2020, 1, 5, 1, 0, 10)), tick=False): dag_bag = DagBag(read_dags_from_db=True) - with assert_queries_count(2): + with assert_queries_count(3): ser_dag = dag_bag.get_dag("example_bash_operator") - ser_dag_update_time = dag_bag.dags_last_fetched["example_bash_operator"] + ser_dag_update_time = dag_bag.dags_last_fetched[("example_bash_operator", None)] assert ser_dag.tags == {"example", "example2"} assert ser_dag_update_time == tz.datetime(2020, 1, 5, 1, 0, 10) @@ -755,9 +756,9 @@ def test_get_dag_refresh_race_condition(self, session, testing_dag_bundle): # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag' # fetches the Serialized DAG from DB with time_machine.travel((tz.datetime(2020, 1, 5, 1, 0, 30)), tick=False): - with assert_queries_count(2): + with assert_queries_count(3): updated_ser_dag = dag_bag.get_dag("example_bash_operator") - updated_ser_dag_update_time = dag_bag.dags_last_fetched["example_bash_operator"] + updated_ser_dag_update_time = dag_bag.dags_last_fetched[("example_bash_operator", None)] assert set(updated_ser_dag.tags) == {"example", "example2", "new_tag"} assert updated_ser_dag_update_time > ser_dag_update_time @@ -777,8 +778,8 @@ def test_collect_dags_from_db(self, testing_dag_bundle): new_dagbag.collect_dags_from_db() new_dags = new_dagbag.dags assert len(example_dags) == len(new_dags) - for dag_id, dag in example_dags.items(): - serialized_dag = new_dags[dag_id] + for dag_info, dag in example_dags.items(): + serialized_dag = new_dags[dag_info] assert serialized_dag.dag_id == dag.dag_id assert set(serialized_dag.task_dict) == set(dag.task_dict) diff --git a/airflow-core/tests/unit/models/test_dagcode.py b/airflow-core/tests/unit/models/test_dagcode.py index 818fb4915fd78..59b1979e60a26 100644 --- a/airflow-core/tests/unit/models/test_dagcode.py +++ b/airflow-core/tests/unit/models/test_dagcode.py @@ -73,13 +73,13 @@ def teardown_method(self): def _write_two_example_dags(self, session): example_dags = make_example_dags(example_dags_module) - bash_dag = example_dags["example_bash_operator"] + bash_dag = example_dags[("example_bash_operator", None)] SDM.write_dag(bash_dag, bundle_name="testing") dag_version = DagVersion.get_latest_version("example_bash_operator") x = DagCode(dag_version, bash_dag.fileloc) session.add(x) session.commit() - xcom_dag = example_dags["example_xcom"] + xcom_dag = example_dags[("example_xcom", None)] SDM.write_dag(xcom_dag, bundle_name="testing") dag_version = DagVersion.get_latest_version("example_xcom") x = DagCode(dag_version, xcom_dag.fileloc) @@ -129,7 +129,7 @@ def test_code_can_be_read_when_no_access_to_file(self, testing_dag_bundle): Test that code can be retrieved from DB when you do not have access to Code file. Source Code should at least exist in one of DB or File. """ - example_dag = make_example_dags(example_dags_module).get("example_bash_operator") + example_dag = make_example_dags(example_dags_module).get(("example_bash_operator", None)) SDM.write_dag(example_dag, bundle_name="testing") # Mock that there is no access to the Dag File @@ -142,7 +142,7 @@ def test_code_can_be_read_when_no_access_to_file(self, testing_dag_bundle): def test_db_code_created_on_serdag_change(self, session, testing_dag_bundle): """Test new DagCode is created in DB when ser dag is changed""" - example_dag = make_example_dags(example_dags_module).get("example_bash_operator") + example_dag = make_example_dags(example_dags_module).get(("example_bash_operator", None)) SDM.write_dag(example_dag, bundle_name="testing") dag = DAG.from_sdk_dag(example_dag) diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index 189904105ab05..6134407de173c 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -141,7 +141,7 @@ def my_callable2(): def test_serialized_dag_is_updated_if_dag_is_changed(self, testing_dag_bundle): """Test Serialized DAG is updated if DAG is changed""" example_dags = make_example_dags(example_dags_module) - example_bash_op_dag = example_dags.get("example_bash_operator") + example_bash_op_dag = example_dags.get(("example_bash_operator", None)) dag_updated = SDM.write_dag(dag=example_bash_op_dag, bundle_name="testing") assert dag_updated is True @@ -194,7 +194,7 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session): serialized_dags = SDM.read_all_dags() assert len(example_dags) == len(serialized_dags) - dag = example_dags.get("example_bash_operator") + dag = example_dags.get(("example_bash_operator", None)) # DAGs are serialized and deserialized to access create_dagrun object sdag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag=dag)) @@ -219,7 +219,7 @@ def test_order_of_dag_params_is_stable(self): the serialized DAG JSON. """ example_dags = make_example_dags(example_dags_module) - example_params_trigger_ui = example_dags.get("example_params_trigger_ui") + example_params_trigger_ui = example_dags.get(("example_params_trigger_ui", None)) before = list(example_params_trigger_ui.params.keys()) SDM.write_dag(example_params_trigger_ui, bundle_name="testing") @@ -285,7 +285,7 @@ def get_hash_set(): example_dags = self._write_example_dags() ordered_example_dags = dict(sorted(example_dags.items())) hashes = set() - for dag_id in ordered_example_dags.keys(): + for dag_id, _ in ordered_example_dags.keys(): smd = session.execute(select(SDM.dag_hash).where(SDM.dag_id == dag_id)).one() hashes.add(smd.dag_hash) return hashes diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index fe52ea522a5c8..1b748e2b313d3 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -339,7 +339,7 @@ def compute_last_dagrun(dag: DAG): bash_command='echo "{{ last_dagrun(dag) }}"', dag=dag, ) - return {dag.dag_id: dag} + return {(dag.dag_id, None): dag} def get_excluded_patterns() -> Generator[str, None, None]: @@ -357,7 +357,7 @@ def collect_dags(dag_folder=None): """Collects DAGs to test.""" dags = {} import_errors = {} - dags.update({"simple_dag": make_simple_dag()}) + dags.update({("simple_dag", None): make_simple_dag()}) dags.update(make_user_defined_macro_filter_dag()) if dag_folder is None: @@ -593,7 +593,7 @@ def test_deserialization_across_process(self): break dag = SerializedDAG.from_json(v) assert isinstance(dag, DAG) - stringified_dags[dag.dag_id] = dag + stringified_dags[(dag.dag_id, None)] = dag dags, _ = collect_dags("airflow/example_dags") assert set(stringified_dags.keys()) == set(dags.keys()) @@ -1613,7 +1613,7 @@ def test_basic_mapped_dag(self, dag_maker): "airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py", include_examples=False ) assert not dagbag.import_errors - dag = dagbag.dags["example_dynamic_task_mapping"] + dag = dagbag.dags[("example_dynamic_task_mapping", None)] ser_dag = SerializedDAG.to_dict(dag) # We should not include `_is_sensor` most of the time (as it would be wasteful). Check we don't assert "_is_sensor" not in ser_dag["dag"]["tasks"][0]["__var"] diff --git a/devel-common/src/tests_common/test_utils/db.py b/devel-common/src/tests_common/test_utils/db.py index 7eb5a0c7fc7cb..5622d8c7486a2 100644 --- a/devel-common/src/tests_common/test_utils/db.py +++ b/devel-common/src/tests_common/test_utils/db.py @@ -82,7 +82,8 @@ def _bootstrap_dagbag(): dagbag.sync_to_db(session=session) # Deactivate the unknown ones - DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session) + dag_ids = [key[0] for key in dagbag.dags.keys()] + DAG.deactivate_unknown_dags(dag_ids, session=session) def initial_db_init(): diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler_system.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler_system.py index e88c6fbeb3d14..50ca9b84b93b4 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler_system.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler_system.py @@ -79,7 +79,7 @@ def test_should_read_logs(self, session): assert subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait() == 0 ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first() - dag = DagBag(dag_folder=example_complex.__file__).dags["example_complex"] + dag = DagBag(dag_folder=example_complex.__file__).dags[("example_complex", None)] task = dag.task_dict["create_entry_group"] ti.task = task self.assert_remote_logs("INFO - Task exited with return code 0", ti) diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_execution.py b/providers/openlineage/tests/unit/openlineage/plugins/test_execution.py index 9f12994e72842..9371fbd2a731b 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_execution.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_execution.py @@ -96,7 +96,10 @@ def setup_job(self, task_name, run_id): dag_folder=TEST_DAG_FOLDER, include_examples=False, ) - dag = dagbag.dags.get("test_openlineage_execution") + key = "test_openlineage_execution" + if AIRFLOW_V_3_0_PLUS: + key = "test_openlineage_execution", None + dag = dagbag.dags.get(key) task = dag.get_task(task_name) if AIRFLOW_V_3_0_PLUS: @@ -206,7 +209,10 @@ def test_success_overtime_kills_tasks(self): dag_folder=TEST_DAG_FOLDER, include_examples=False, ) - dag = dagbag.dags.get("test_openlineage_execution") + key = "test_openlineage_execution" + if AIRFLOW_V_3_0_PLUS: + key = "test_openlineage_execution", None + dag = dagbag.dags.get(key) task = dag.get_task("execute_long_stall") if AIRFLOW_V_3_0_PLUS: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 9cecdc5547d44..e58a61622ce40 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -518,11 +518,12 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance: include_examples=False, safe_mode=False, load_op_links=False, + bundle_version=bundle_info.version, ) if TYPE_CHECKING: assert what.ti.dag_id - dag = bag.dags[what.ti.dag_id] + dag = bag.dags[(what.ti.dag_id, bundle_info.version)] # install_loader()