diff --git a/airflow-core/docs/img/airflow_erd.sha256 b/airflow-core/docs/img/airflow_erd.sha256 index b77295d3f4325..c53cc36d331db 100644 --- a/airflow-core/docs/img/airflow_erd.sha256 +++ b/airflow-core/docs/img/airflow_erd.sha256 @@ -1 +1 @@ -72b7bb2d4e109d8f786d229e68c83d2b6f7442e48a6617ea3d47842f1bfa33eb \ No newline at end of file +e0de73aab81a28995b99be21dd25c8ca31c4e0f4a5a0a26df8aff412e5067fd5 \ No newline at end of file diff --git a/airflow-core/docs/img/airflow_erd.svg b/airflow-core/docs/img/airflow_erd.svg index 77f46bc34005c..5565970e5573f 100644 --- a/airflow-core/docs/img/airflow_erd.svg +++ b/airflow-core/docs/img/airflow_erd.svg @@ -776,6 +776,7 @@ dag_version_id [UUID] + NOT NULL duration @@ -1812,7 +1813,7 @@ dag_version--task_instance 0..N -{0,1} +1 diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index d6645ebdc4ad8..0a2d4ea6a89f4 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``ffdb0566c7c0`` (head) | ``66a7743fe20e`` | ``3.1.0`` | Add dag_favorite table. | +| ``5d3072c51bac`` (head) | ``ffdb0566c7c0`` | ``3.1.0`` | Make dag_version_id non-nullable in TaskInstance. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``ffdb0566c7c0`` | ``66a7743fe20e`` | ``3.1.0`` | Add dag_favorite table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``66a7743fe20e`` | ``583e80dfcef4`` | ``3.1.0`` | Add triggering user to dag_run. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py index 392fcc598df71..e2470119a1c80 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -44,6 +44,7 @@ class TaskInstanceResponse(BaseModel): id: str task_id: str dag_id: str + dag_version: DagVersionResponse run_id: str = Field(alias="dag_run_id") map_index: int logical_date: datetime | None @@ -76,7 +77,6 @@ class TaskInstanceResponse(BaseModel): ) trigger: TriggerResponse | None queued_by_job: JobResponse | None = Field(alias="triggerer_job") - dag_version: DagVersionResponse | None class TaskInstanceCollectionResponse(BaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 37e80e9a28aab..d00a383f5b875 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -10377,6 +10377,8 @@ components: dag_id: type: string title: Dag Id + dag_version: + $ref: '#/components/schemas/DagVersionResponse' dag_run_id: type: string title: Dag Run Id @@ -10504,15 +10506,12 @@ components: anyOf: - $ref: '#/components/schemas/JobResponse' - type: 'null' - dag_version: - anyOf: - - $ref: '#/components/schemas/DagVersionResponse' - - type: 'null' type: object required: - id - task_id - dag_id + - dag_version - dag_run_id - map_index - logical_date @@ -10541,7 +10540,6 @@ components: - rendered_map_index - trigger - triggerer_job - - dag_version title: TaskInstanceResponse description: TaskInstance serializer for responses. TaskInstanceState: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 5afc2e75a7193..54b6414fbf2c9 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -241,6 +241,7 @@ class TaskInstance(BaseModel): dag_id: str run_id: str try_number: int + dag_version_id: uuid.UUID map_index: int = -1 hostname: str | None = None context_carrier: dict | None = None diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 5462f10297495..ccef71f7bfb14 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -21,9 +21,11 @@ from airflow.api_fastapi.execution_api.versions.v2025_04_28 import AddRenderedMapIndexField from airflow.api_fastapi.execution_api.versions.v2025_05_20 import DowngradeUpstreamMapIndexes +from airflow.api_fastapi.execution_api.versions.v2025_08_10 import AddDagVersionIdField bundle = VersionBundle( HeadVersion(), + Version("2025-08-10", AddDagVersionIdField), Version("2025-05-20", DowngradeUpstreamMapIndexes), Version("2025-04-28", AddRenderedMapIndexField), Version("2025-04-11"), diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py new file mode 100644 index 0000000000000..dcea9a6a0e857 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import VersionChange, schema + +from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance + + +class AddDagVersionIdField(VersionChange): + """Add the `dag_version_id` field to the TaskInstance model.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = (schema(TaskInstance).field("dag_version_id").didnt_exist,) diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index 9eca6aeeae247..010f884ca655b 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -33,6 +33,7 @@ from airflow.exceptions import AirflowConfigException, DagRunNotFound, TaskInstanceNotFound from airflow.models import TaskInstance from airflow.models.dag import DAG as SchedulerDAG, _get_or_create_dagrun +from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun from airflow.sdk.definitions.dag import DAG, _run_task from airflow.sdk.definitions.param import ParamsDict @@ -200,7 +201,13 @@ def _get_ti( f"run_id or logical_date of {logical_date_or_run_id!r} not found" ) # TODO: Validate map_index is in range? - ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index) + dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) + if not dag_version: + # TODO: Remove this once DagVersion.get_latest_version is guaranteed to return a DagVersion/raise + raise ValueError( + f"Cannot create TaskInstance for {dag.dag_id} because the Dag is not serialized." + ) + ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index, dag_version_id=dag_version.id) if dag_run in session: session.add(ti) ti.dag_run = dag_run diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index bca3a777b3655..43a4aab1dbc47 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -55,7 +55,7 @@ class TaskInstance(BaseModel): """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK.""" id: uuid.UUID - + dag_version_id: uuid.UUID task_id: str dag_id: str run_id: str diff --git a/airflow-core/src/airflow/migrations/versions/0076_3_1_0_make_dag_version_id_non_nullable_in_.py b/airflow-core/src/airflow/migrations/versions/0076_3_1_0_make_dag_version_id_non_nullable_in_.py new file mode 100644 index 0000000000000..cbd183f7b8f43 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0076_3_1_0_make_dag_version_id_non_nullable_in_.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Make dag_version_id non-nullable in TaskInstance. + +Revision ID: 5d3072c51bac +Revises: ffdb0566c7c0 +Create Date: 2025-05-20 10:38:25.635779 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy_utils import UUIDType + +# revision identifiers, used by Alembic. +revision = "5d3072c51bac" +down_revision = "ffdb0566c7c0" +branch_labels = None +depends_on = None +airflow_version = "3.1.0" + + +def upgrade(): + """Apply make dag_version_id non-nullable in TaskInstance.""" + conn = op.get_bind() + if conn.dialect.name == "postgresql": + update_query = sa.text(""" + UPDATE task_instance + SET dag_version_id = latest_versions.id + FROM ( + SELECT DISTINCT ON (dag_id) dag_id, id + FROM dag_version + ORDER BY dag_id, created_at DESC + ) latest_versions + WHERE task_instance.dag_id = latest_versions.dag_id + AND task_instance.dag_version_id IS NULL + """) + else: + update_query = sa.text(""" + UPDATE task_instance + SET dag_version_id = ( + SELECT id FROM ( + SELECT id, dag_id, + ROW_NUMBER() OVER (PARTITION BY dag_id ORDER BY created_at DESC) as rn + FROM dag_version + ) ranked_versions + WHERE ranked_versions.dag_id = task_instance.dag_id + AND ranked_versions.rn = 1 + ) + WHERE task_instance.dag_version_id IS NULL + """) + + op.execute(update_query) + + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.alter_column("dag_version_id", existing_type=UUIDType(binary=False), nullable=False) + + +def downgrade(): + """Unapply make dag_version_id non-nullable in TaskInstance.""" + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.alter_column("dag_version_id", existing_type=UUIDType(binary=False), nullable=True) diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index 0332a2e825e24..aa73159b7b294 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -245,6 +245,9 @@ def _create_orm_dagrun( select(DagModel.bundle_version).where(DagModel.dag_id == dag.dag_id), ) dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) + if not dag_version: + raise AirflowException(f"Cannot create DagRun for DAG {dag.dag_id} because the dag is not serialized") + run = DagRun( dag_id=dag.dag_id, run_id=run_id, @@ -270,7 +273,7 @@ def _create_orm_dagrun( run.dag = dag # create the associated task instances # state is None at the moment of creation - run.verify_integrity(session=session, dag_version_id=dag_version.id if dag_version else None) + run.verify_integrity(session=session, dag_version_id=dag_version.id) return run diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 9415a3b0f7a34..d32146222fe37 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1452,7 +1452,11 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: # It's enough to revise map index once per task id, # checking the map index for each mapped task significantly slows down scheduling if schedulable.task.task_id not in revised_map_index_task_ids: - ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session)) + ready_tis.extend( + self._revise_map_indexes_if_mapped( + schedulable.task, dag_version_id=schedulable.dag_version_id, session=session + ) + ) revised_map_index_task_ids.add(schedulable.task.task_id) ready_tis.append(schedulable) @@ -1555,9 +1559,7 @@ def _emit_duration_stats_for_finished_state(self): Stats.timing(f"dagrun.duration.{self.state}", **timer_params) @provide_session - def verify_integrity( - self, *, session: Session = NEW_SESSION, dag_version_id: UUIDType | None = None - ) -> None: + def verify_integrity(self, *, session: Session = NEW_SESSION, dag_version_id: UUIDType) -> None: """ Verify the DagRun by checking for removed tasks or tasks that are not in the database yet. @@ -1687,7 +1689,7 @@ def _get_task_creator( created_counts: dict[str, int], ti_mutation_hook: Callable, hook_is_noop: Literal[True], - dag_version_id: UUIDType | None, + dag_version_id: UUIDType, ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: ... @overload @@ -1696,7 +1698,7 @@ def _get_task_creator( created_counts: dict[str, int], ti_mutation_hook: Callable, hook_is_noop: Literal[False], - dag_version_id: UUIDType | None, + dag_version_id: UUIDType, ) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: ... def _get_task_creator( @@ -1704,7 +1706,7 @@ def _get_task_creator( created_counts: dict[str, int], ti_mutation_hook: Callable, hook_is_noop: Literal[True, False], - dag_version_id: UUIDType | None, + dag_version_id: UUIDType, ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]: """ Get the task creator function. @@ -1815,7 +1817,9 @@ def _create_task_instances( # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() - def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]: + def _revise_map_indexes_if_mapped( + self, task: Operator, *, dag_version_id: UUIDType, session: Session + ) -> Iterator[TI]: """ Check if task increased or reduced in length and handle appropriately. @@ -1861,7 +1865,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> for index in range(total_length): if index in existing_indexes: continue - ti = TI(task, run_id=self.run_id, map_index=index, state=None) + ti = TI(task, run_id=self.run_id, map_index=index, state=None, dag_version_id=dag_version_id) self.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 9044817c851c0..aff5146eb4343 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -24,6 +24,7 @@ import math import operator import os +import uuid from collections import defaultdict from collections.abc import Collection, Generator, Iterable, Sequence from datetime import timedelta @@ -566,8 +567,7 @@ class TaskInstance(Base, LoggingMixin): _task_display_property_value = Column("task_display_name", String(2000), nullable=True) dag_version_id = Column( - UUIDType(binary=False), - ForeignKey("dag_version.id", ondelete="RESTRICT"), + UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="RESTRICT"), nullable=False ) dag_version = relationship("DagVersion", back_populates="task_instances") @@ -632,10 +632,10 @@ class TaskInstance(Base, LoggingMixin): def __init__( self, task: Operator, + dag_version_id: UUIDType | uuid.UUID, run_id: str | None = None, state: str | None = None, map_index: int = -1, - dag_version_id: UUIDType | None = None, ): super().__init__() self.dag_id = task.dag_id @@ -645,7 +645,6 @@ def __init__( self.refresh_from_task(task) if TYPE_CHECKING: assert self.task - # init_on_load will config the log self.init_on_load() @@ -675,7 +674,7 @@ def stats_tags(self) -> dict[str, str]: @staticmethod def insert_mapping( - run_id: str, task: Operator, map_index: int, dag_version_id: UUIDType | None + run_id: str, task: Operator, map_index: int, dag_version_id: UUIDType ) -> dict[str, Any]: """ Insert mapping. @@ -683,7 +682,7 @@ def insert_mapping( :meta private: """ priority_weight = task.weight_rule.get_weight( - TaskInstance(task=task, run_id=run_id, map_index=map_index) + TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id) ) return { @@ -738,6 +737,7 @@ def from_runtime_ti(cls, runtime_ti: RuntimeTaskInstanceProtocol) -> TaskInstanc run_id=runtime_ti.run_id, task=runtime_ti.task, # type: ignore[arg-type] map_index=runtime_ti.map_index, + dag_version_id=runtime_ti.dag_version_id, ) if TYPE_CHECKING: @@ -760,6 +760,7 @@ def to_runtime_ti(self, context_from_server) -> RuntimeTaskInstanceProtocol: hostname=self.hostname, _ti_context_from_server=context_from_server, start_date=self.start_date, + dag_version_id=self.dag_version_id, ) return runtime_ti diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 8efd3e111efb1..5f7890a4739f7 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -4641,6 +4641,9 @@ export const $TaskInstanceResponse = { type: 'string', title: 'Dag Id' }, + dag_version: { + '$ref': '#/components/schemas/DagVersionResponse' + }, dag_run_id: { type: 'string', title: 'Dag Run Id' @@ -4886,20 +4889,10 @@ export const $TaskInstanceResponse = { type: 'null' } ] - }, - dag_version: { - anyOf: [ - { - '$ref': '#/components/schemas/DagVersionResponse' - }, - { - type: 'null' - } - ] } }, type: 'object', - required: ['id', 'task_id', 'dag_id', 'dag_run_id', 'map_index', 'logical_date', 'run_after', 'start_date', 'end_date', 'duration', 'state', 'try_number', 'max_tries', 'task_display_name', 'dag_display_name', 'hostname', 'unixname', 'pool', 'pool_slots', 'queue', 'priority_weight', 'operator', 'queued_when', 'scheduled_when', 'pid', 'executor', 'executor_config', 'note', 'rendered_map_index', 'trigger', 'triggerer_job', 'dag_version'], + required: ['id', 'task_id', 'dag_id', 'dag_version', 'dag_run_id', 'map_index', 'logical_date', 'run_after', 'start_date', 'end_date', 'duration', 'state', 'try_number', 'max_tries', 'task_display_name', 'dag_display_name', 'hostname', 'unixname', 'pool', 'pool_slots', 'queue', 'priority_weight', 'operator', 'queued_when', 'scheduled_when', 'pid', 'executor', 'executor_config', 'note', 'rendered_map_index', 'trigger', 'triggerer_job'], title: 'TaskInstanceResponse', description: 'TaskInstance serializer for responses.' } as const; diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index e749ba3a6cb46..1fb875d7e4f42 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1260,6 +1260,7 @@ export type TaskInstanceResponse = { id: string; task_id: string; dag_id: string; + dag_version: DagVersionResponse; dag_run_id: string; map_index: number; logical_date: string | null; @@ -1291,7 +1292,6 @@ export type TaskInstanceResponse = { }; trigger: TriggerResponse | null; triggerer_job: JobResponse | null; - dag_version: DagVersionResponse | null; }; /** diff --git a/airflow-core/src/airflow/ui/src/hooks/useSelectedVersion.ts b/airflow-core/src/airflow/ui/src/hooks/useSelectedVersion.ts index fb33fca1719a0..2ed274ef4b54a 100644 --- a/airflow-core/src/airflow/ui/src/hooks/useSelectedVersion.ts +++ b/airflow-core/src/airflow/ui/src/hooks/useSelectedVersion.ts @@ -82,7 +82,7 @@ const useSelectedVersion = (): number | undefined => { const selectedVersionNumber = selectedVersionUrl ?? - mappedTaskInstanceData?.dag_version?.version_number ?? + (mappedTaskInstanceData ? mappedTaskInstanceData.dag_version.version_number : undefined) ?? (runData?.dag_versions ?? []).at(-1)?.version_number ?? dagData?.latest_dag_version?.version_number; diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 5685756262766..8666deac458d4 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -93,7 +93,7 @@ class MappedClassProtocol(Protocol): "2.10.3": "5f2621c13b39", "3.0.0": "29ce7909c52b", "3.0.3": "fe199e1abd77", - "3.1.0": "ffdb0566c7c0", + "3.1.0": "5d3072c51bac", } diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index 9dd472e9a53a2..16ee780f0764c 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -21,6 +21,7 @@ import itertools import os from datetime import timedelta +from typing import TYPE_CHECKING from unittest import mock import pendulum @@ -33,6 +34,7 @@ from airflow.listeners.listener import get_listener_manager from airflow.models import DagRun, TaskInstance from airflow.models.baseoperator import BaseOperator +from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DagBag from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.models.taskinstancehistory import TaskInstanceHistory @@ -110,7 +112,7 @@ def create_task_instances( run_id = "TEST_DAG_RUN_ID" logical_date = self.ti_init.pop("logical_date", self.default_time) dr = None - + dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) tis = [] for i in range(counter): if task_instances is None: @@ -135,7 +137,9 @@ def create_task_instances( ) session.add(dr) session.flush() - ti = TaskInstance(task=tasks[i], **self.ti_init) + if TYPE_CHECKING: + assert dag_version + ti = TaskInstance(task=tasks[i], **self.ti_init, dag_version_id=dag_version.id) session.add(ti) ti.dag_run = dr ti.note = "placeholder-note" @@ -178,7 +182,7 @@ def test_should_respond_200(self, test_client, session): assert response.status_code == 200 assert response.json() == { "dag_id": "example_python_operator", - "dag_version": None, + "dag_version": mock.ANY, "dag_display_name": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", @@ -310,7 +314,7 @@ def test_should_respond_200_with_task_state_in_deferred(self, test_client, sessi assert response.status_code == 200 assert data == { "dag_id": "example_python_operator", - "dag_version": None, + "dag_version": mock.ANY, "dag_display_name": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", @@ -362,7 +366,7 @@ def test_should_respond_200_with_task_state_in_removed(self, test_client, sessio assert response.status_code == 200 assert response.json() == { "dag_id": "example_python_operator", - "dag_version": None, + "dag_version": mock.ANY, "dag_display_name": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", @@ -409,7 +413,7 @@ def test_should_respond_200_task_instance_with_rendered(self, test_client, sessi assert response.json() == { "dag_id": "example_python_operator", - "dag_version": None, + "dag_version": mock.ANY, "dag_display_name": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", @@ -458,7 +462,9 @@ def test_raises_404_for_mapped_task_instance_with_multiple_indexes(self, test_cl old_ti = tis[0] for index in range(3): - ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=index) + ti = TaskInstance( + task=old_ti.task, run_id=old_ti.run_id, map_index=index, dag_version_id=old_ti.dag_version_id + ) for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: setattr(ti, attr, getattr(old_ti, attr)) session.add(ti) @@ -476,7 +482,9 @@ def test_raises_404_for_mapped_task_instance_with_one_index(self, test_client, s old_ti = tis[0] - ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=2) + ti = TaskInstance( + task=old_ti.task, run_id=old_ti.run_id, map_index=2, dag_version_id=old_ti.dag_version_id + ) for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: setattr(ti, attr, getattr(old_ti, attr)) session.add(ti) @@ -496,7 +504,9 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, se tis = self.create_task_instances(session) old_ti = tis[0] for idx in (1, 2): - ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti = TaskInstance( + task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=old_ti.dag_version_id + ) ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: setattr(ti, attr, getattr(old_ti, attr)) @@ -513,7 +523,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, se assert response.json() == { "dag_id": "example_python_operator", - "dag_version": None, + "dag_version": mock.ANY, "dag_display_name": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", @@ -603,7 +613,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): logical_date=DEFAULT_DATETIME_1, data_interval=(DEFAULT_DATETIME_1, DEFAULT_DATETIME_2), ) - + dag_version = DagVersion.get_latest_version(dag_id) session.add( TaskMap( dag_id=dr.dag_id, @@ -630,7 +640,9 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): itertools.repeat(TaskInstanceState.RUNNING, dag["running"]), ) ): - ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=state) + ti = TaskInstance( + mapped, run_id=dr.run_id, map_index=index, state=state, dag_version_id=dag_version.id + ) setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) @@ -1111,7 +1123,8 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): True, ("/dags/~/dagRuns/~/taskInstances"), {"version_number": [1, 2, 3]}, - 6, + 7, # apart from the TIs in the fixture, we also get one from + # the create_task_instances method id="test multiple version numbers filter", ), ], @@ -1125,7 +1138,9 @@ def test_should_respond_200( update_extras=update_extras, task_instances=task_instances, ) - response = test_client.get(url, params=params) + with mock.patch("airflow.api_fastapi.core_api.datamodels.dag_versions.DagBundlesManager"): + # Mock DagBundlesManager to avoid checking if dags-folder bundle is configured + response = test_client.get(url, params=params) if params == {"task_id_pattern": "task_match_id"}: import pprint @@ -1359,7 +1374,13 @@ def test_should_respond_dependencies_mapped(self, test_client, session): ) old_ti = tis[0] - ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=0, state=old_ti.state) + ti = TaskInstance( + task=old_ti.task, + run_id=old_ti.run_id, + map_index=0, + state=old_ti.state, + dag_version_id=old_ti.dag_version_id, + ) session.add(ti) session.commit() @@ -1724,7 +1745,7 @@ def test_should_respond_200(self, test_client, session): "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, } @pytest.mark.parametrize("try_number", [1, 2]) @@ -1760,7 +1781,7 @@ def test_should_respond_200_with_different_try_numbers(self, test_client, try_nu "try_number": try_number, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, } @pytest.mark.parametrize("try_number", [1, 2]) @@ -1770,7 +1791,9 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers( tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}]) old_ti = tis[0] for idx in (1, 2): - ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti = TaskInstance( + task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=old_ti.dag_version_id + ) ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) ti.try_number = 1 for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: @@ -1825,7 +1848,7 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers( "try_number": try_number, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, } def test_should_respond_200_with_task_state_in_deferred(self, test_client, session): @@ -1888,7 +1911,7 @@ def test_should_respond_200_with_task_state_in_deferred(self, test_client, sessi "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, } def test_should_respond_200_with_task_state_in_removed(self, test_client, session): @@ -1925,7 +1948,7 @@ def test_should_respond_200_with_task_state_in_removed(self, test_client, sessio "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, } def test_should_respond_401(self, unauthenticated_test_client): @@ -2471,7 +2494,7 @@ def test_should_respond_200_with_dag_run_id( { "dag_id": "example_python_operator", "dag_display_name": "example_python_operator", - "dag_version": None, + "dag_version": mock.ANY, "dag_run_id": "TEST_DAG_RUN_ID_0", "task_id": "print_the_context", "duration": mock.ANY, @@ -2841,7 +2864,7 @@ def test_should_respond_200(self, test_client, session): "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, }, { "dag_id": "example_python_operator", @@ -2868,7 +2891,7 @@ def test_should_respond_200(self, test_client, session): "try_number": 2, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, }, ], "total_entries": 2, @@ -2928,7 +2951,7 @@ def test_ti_in_retry_state_not_returned(self, test_client, session): "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, }, ], "total_entries": 1, @@ -2938,7 +2961,9 @@ def test_mapped_task_should_respond_200(self, test_client, session): tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}]) old_ti = tis[0] for idx in (1, 2): - ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti = TaskInstance( + task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=old_ti.dag_version_id + ) for attr in ["duration", "end_date", "pid", "start_date", "state", "queue"]: setattr(ti, attr, getattr(old_ti, attr)) ti.try_number = 1 @@ -2997,7 +3022,7 @@ def test_mapped_task_should_respond_200(self, test_client, session): "try_number": 1, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, }, { "dag_id": "example_python_operator", @@ -3024,7 +3049,7 @@ def test_mapped_task_should_respond_200(self, test_client, session): "try_number": 2, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "dag_version": None, + "dag_version": mock.ANY, }, ], "total_entries": 2, @@ -3160,7 +3185,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): { "dag_id": self.DAG_ID, "dag_display_name": self.DAG_DISPLAY_NAME, - "dag_version": None, + "dag_version": mock.ANY, "dag_run_id": self.RUN_ID, "logical_date": "2020-01-01T00:00:00Z", "task_id": self.TASK_ID, @@ -3227,7 +3252,9 @@ def test_should_update_task_instance_state(self, test_client, session): def test_should_update_mapped_task_instance_state(self, test_client, session): map_index = 1 tis = self.create_task_instances(session) - ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index) + ti = TaskInstance( + task=tis[0].task, run_id=tis[0].run_id, map_index=map_index, dag_version_id=tis[0].dag_version_id + ) ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) session.add(ti) session.commit() @@ -3248,7 +3275,12 @@ def test_should_update_mapped_task_instance_summary_state(self, test_client, ses tis = self.create_task_instances(session) for map_index in [1, 2, 3]: - ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index) + ti = TaskInstance( + task=tis[0].task, + run_id=tis[0].run_id, + map_index=map_index, + dag_version_id=tis[0].dag_version_id, + ) ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) session.add(ti) tis[0].map_index = 0 @@ -3404,7 +3436,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte { "dag_id": "example_python_operator", "dag_display_name": "example_python_operator", - "dag_version": None, + "dag_version": mock.ANY, "dag_run_id": "TEST_DAG_RUN_ID", "logical_date": "2020-01-01T00:00:00Z", "task_id": "print_the_context", @@ -3521,7 +3553,7 @@ def test_update_mask_set_note_should_respond_200( { "dag_id": self.DAG_ID, "dag_display_name": self.DAG_DISPLAY_NAME, - "dag_version": None, + "dag_version": mock.ANY, "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", "logical_date": "2020-01-01T00:00:00Z", @@ -3572,7 +3604,7 @@ def test_set_note_should_respond_200(self, test_client, session): { "dag_id": self.DAG_ID, "dag_display_name": self.DAG_DISPLAY_NAME, - "dag_version": None, + "dag_version": mock.ANY, "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", "logical_date": "2020-01-01T00:00:00Z", @@ -3617,7 +3649,9 @@ def test_set_note_should_respond_200_mapped_task_with_rtif(self, test_client, se tis = self.create_task_instances(session) old_ti = tis[0] for idx in (1, 2): - ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti = TaskInstance( + task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=old_ti.dag_version_id + ) ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: setattr(ti, attr, getattr(old_ti, attr)) @@ -3640,7 +3674,7 @@ def test_set_note_should_respond_200_mapped_task_with_rtif(self, test_client, se { "dag_id": self.DAG_ID, "dag_display_name": self.DAG_DISPLAY_NAME, - "dag_version": None, + "dag_version": mock.ANY, "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", "logical_date": "2020-01-01T00:00:00Z", @@ -3687,7 +3721,9 @@ def test_set_note_should_respond_200_mapped_task_summary_with_rtif(self, test_cl tis = self.create_task_instances(session) old_ti = tis[0] for idx in (1, 2): - ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti = TaskInstance( + task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=old_ti.dag_version_id + ) ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: setattr(ti, attr, getattr(old_ti, attr)) @@ -3709,7 +3745,7 @@ def test_set_note_should_respond_200_mapped_task_summary_with_rtif(self, test_cl assert response_ti == { "dag_id": self.DAG_ID, "dag_display_name": self.DAG_DISPLAY_NAME, - "dag_version": None, + "dag_version": mock.ANY, "duration": 10000.0, "end_date": "2020-01-03T00:00:00Z", "logical_date": "2020-01-01T00:00:00Z", @@ -3814,7 +3850,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): { "dag_id": self.DAG_ID, "dag_display_name": self.DAG_DISPLAY_NAME, - "dag_version": None, + "dag_version": mock.ANY, "dag_run_id": self.RUN_ID, "logical_date": "2020-01-01T00:00:00Z", "task_id": self.TASK_ID, @@ -3914,7 +3950,9 @@ def test_should_respond_403(self, unauthorized_test_client): def test_should_not_update_mapped_task_instance(self, test_client, session): map_index = 1 tis = self.create_task_instances(session) - ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index) + ti = TaskInstance( + task=tis[0].task, run_id=tis[0].run_id, map_index=map_index, dag_version_id=tis[0].dag_version_id + ) ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) session.add(ti) session.commit() @@ -3945,6 +3983,7 @@ def test_should_not_update_mapped_task_instance_summary(self, test_client, sessi run_id=tis[0].run_id, map_index=map_index, state="running", + dag_version_id=tis[0].dag_version_id, ) ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) session.add(ti) @@ -4087,7 +4126,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte { "dag_id": "example_python_operator", "dag_display_name": "example_python_operator", - "dag_version": None, + "dag_version": mock.ANY, "dag_run_id": "TEST_DAG_RUN_ID", "logical_date": "2020-01-01T00:00:00Z", "task_id": "print_the_context", diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py index 45527901be303..5594f89f6f014 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py @@ -21,9 +21,11 @@ import pytest +from airflow import DAG from airflow.api_fastapi.core_api.datamodels.xcom import XComCreateBody -from airflow.models.dag import DagModel +from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator @@ -381,8 +383,9 @@ def test_should_respond_200_with_xcom_key(self, key, expected_entries, test_clie @provide_session def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti=False, session=None): - dag = DagModel(dag_id=dag_id) - session.add(dag) + dag = DAG(dag_id=dag_id) + dag.sync_to_db(session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") dagrun = DagRun( dag_id=dag_id, run_id=run_id, @@ -391,13 +394,16 @@ def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti= run_type=DagRunType.MANUAL, ) session.add(dagrun) + dag_version = DagVersion.get_latest_version(dag.dag_id) if mapped_ti: for i in [0, 1]: - ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i) + ti = TaskInstance( + EmptyOperator(task_id=task_id), run_id=run_id, map_index=i, dag_version_id=dag_version.id + ) ti.dag_id = dag_id session.add(ti) else: - ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, dag_version_id=dag_version.id) ti.dag_id = dag_id session.add(ti) session.commit() diff --git a/airflow-core/tests/unit/callbacks/test_callback_requests.py b/airflow-core/tests/unit/callbacks/test_callback_requests.py index 37a7a3023d872..fab145ee8d7a9 100644 --- a/airflow-core/tests/unit/callbacks/test_callback_requests.py +++ b/airflow-core/tests/unit/callbacks/test_callback_requests.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import uuid from datetime import datetime import pytest @@ -65,6 +66,7 @@ def test_from_json(self, input, request_class): ), run_id="fake_run", state=State.RUNNING, + dag_version_id=uuid.uuid4(), ) ti.start_date = timezone.utcnow() diff --git a/airflow-core/tests/unit/cli/commands/test_task_command.py b/airflow-core/tests/unit/cli/commands/test_task_command.py index faad02503e407..5ce864667027d 100644 --- a/airflow-core/tests/unit/cli/commands/test_task_command.py +++ b/airflow-core/tests/unit/cli/commands/test_task_command.py @@ -38,6 +38,8 @@ from airflow.configuration import conf from airflow.exceptions import DagRunNotFound from airflow.models import DagBag, DagRun, TaskInstance +from airflow.models.dag_version import DagVersion +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.bash import BashOperator from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone @@ -275,6 +277,9 @@ def test_mapped_task_render(self): """ tasks render should render and displays templated fields for a given mapping task """ + dag = DagBag().get_dag("test_mapped_classic") + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") with redirect_stdout(io.StringIO()) as stdout: task_command.task_render( self.parser.parse_args( @@ -341,6 +346,8 @@ def test_task_state(self): def test_task_states_for_dag_run(self): dag2 = DagBag().dags["example_python_operator"] + + SerializedDagModel.write_dag(dag2, bundle_name="testing") task2 = dag2.get_task(task_id="print_the_context") dag2 = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag2)) @@ -357,7 +364,8 @@ def test_task_states_for_dag_run(self): run_type=DagRunType.MANUAL, triggered_by=DagRunTriggeredByType.CLI, ) - ti2 = TaskInstance(task2, run_id=dagrun.run_id) + dag_version = DagVersion.get_latest_version(dag2.dag_id) + ti2 = TaskInstance(task2, run_id=dagrun.run_id, dag_version_id=dag_version.id) ti2.set_state(State.SUCCESS) ti_start = ti2.start_date ti_end = ti2.end_date diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index 575201b2f4498..d99fde3abc8eb 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -57,7 +57,7 @@ def test_invalid_slotspool(): def test_get_task_log(): executor = BaseExecutor() - ti = TaskInstance(task=BaseOperator(task_id="dummy")) + ti = TaskInstance(task=BaseOperator(task_id="dummy"), dag_version_id=mock.MagicMock()) assert executor.get_task_log(ti=ti, try_number=1) == ([], []) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index eb4b9528d1c1a..e5b37bb268d3c 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -59,6 +59,7 @@ def _test_execute(self, mock_supervise, parallelism=1): success_tis = [ workloads.TaskInstance( id=uuid7(), + dag_version_id=uuid7(), task_id=f"success_{i}", dag_id="mydag", run_id="run1", diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 456bac3748821..81df9ec2549b9 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -573,8 +573,9 @@ def test_execute_task_instances_backfill_tasks_will_execute(self, dag_maker): session = settings.Session() dr1 = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB) + dag_version = DagVersion.get_latest_version(dr1.dag_id) - ti1 = TaskInstance(task1, run_id=dr1.run_id) + ti1 = TaskInstance(task1, run_id=dr1.run_id, dag_version_id=dag_version.id) ti1.refresh_from_db() ti1.state = State.SCHEDULED session.merge(ti1) @@ -5511,7 +5512,8 @@ def test_no_dagruns_would_stuck_in_running(self, dag_maker): scheduler_job = Job(executor=MockExecutor(do_update=False)) self.job_runner = SchedulerJobRunner(job=scheduler_job) - ti = TaskInstance(task=task1, run_id=dr1_running.run_id) + dag_version = DagVersion.get_latest_version(dag_id=dag.dag_id) + ti = TaskInstance(task=task1, run_id=dr1_running.run_id, dag_version_id=dag_version.id) ti.refresh_from_db() ti.state = State.SUCCESS session.merge(ti) @@ -6011,10 +6013,10 @@ def test_task_instance_heartbeat_timeout_message(self, session, create_dagrun): # We will provision 2 tasks so we can check we only find task instance heartbeat timeouts from this scheduler tasks_to_setup = ["branching", "run_this_first"] - + dag_version = DagVersion.get_latest_version(dag.dag_id) for task_id in tasks_to_setup: task = dag.get_task(task_id=task_id) - ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) + ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING, dag_version_id=dag_version.id) ti.queued_by_job_id = 999 session.add(ti) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index d859e9ed70bbc..0fdb1c5fb0523 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -43,6 +43,8 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG +from airflow.models.dag_version import DagVersion +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.variable import Variable from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator @@ -103,10 +105,12 @@ def create_trigger_in_db(session, trigger, operator=None): else: operator = BaseOperator(task_id="test_ti", dag=dag) session.add(dag_model) + SerializedDagModel.write_dag(dag, bundle_name="testing") session.add(run) session.add(trigger_orm) session.flush() - task_instance = TaskInstance(operator, run_id=run.run_id) + dag_version = DagVersion.get_latest_version(dag.dag_id) + task_instance = TaskInstance(operator, run_id=run.run_id, dag_version_id=dag_version.id) task_instance.trigger_id = trigger_orm.id session.add(task_instance) session.commit() @@ -396,17 +400,20 @@ async def test_trigger_create_race_condition_38599(session, supervisor_builder): trigger_orm = Trigger.from_object(trigger) session.add(trigger_orm) session.flush() - - dag = DagModel(dag_id="test-dag") + dag = DAG(dag_id="test-dag") + dm = DagModel(dag_id="test-dag") + session.add(dm) + SerializedDagModel.write_dag(dag, bundle_name="testing") dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none", run_after=timezone.utcnow()) + dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance( PythonOperator(task_id="dummy-task", python_callable=print), run_id=dag_run.run_id, state=TaskInstanceState.DEFERRED, + dag_version_id=dag_version.id, ) ti.dag_id = dag.dag_id ti.trigger_id = trigger_orm.id - session.add(dag) session.add(dag_run) session.add(ti) diff --git a/airflow-core/tests/unit/models/test_baseoperator.py b/airflow-core/tests/unit/models/test_baseoperator.py index b4f7197583ec5..501fe93c70b49 100644 --- a/airflow-core/tests/unit/models/test_baseoperator.py +++ b/airflow-core/tests/unit/models/test_baseoperator.py @@ -28,7 +28,9 @@ BaseOperator, ) from airflow.models.dag import DAG +from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.models.trigger import TriggerFailureReason from airflow.providers.common.sql.operators import sql @@ -228,17 +230,20 @@ def test_get_task_instances(session): test_dag = DAG(dag_id="test_dag", schedule=None, start_date=first_logical_date) task = BaseOperator(task_id="test_task", dag=test_dag) + test_dag.sync_to_db() + SerializedDagModel.write_dag(test_dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(test_dag.dag_id) common_dr_kwargs = { "dag_id": test_dag.dag_id, "run_type": DagRunType.MANUAL, } dr1 = DagRun(logical_date=first_logical_date, run_id="test_run_id_1", **common_dr_kwargs) - ti_1 = TaskInstance(run_id=dr1.run_id, task=task) + ti_1 = TaskInstance(run_id=dr1.run_id, task=task, dag_version_id=dag_version.id) dr2 = DagRun(logical_date=second_logical_date, run_id="test_run_id_2", **common_dr_kwargs) - ti_2 = TaskInstance(run_id=dr2.run_id, task=task) + ti_2 = TaskInstance(run_id=dr2.run_id, task=task, dag_version_id=dag_version.id) dr3 = DagRun(logical_date=third_logical_date, run_id="test_run_id_3", **common_dr_kwargs) - ti_3 = TaskInstance(run_id=dr3.run_id, task=task) + ti_3 = TaskInstance(run_id=dr3.run_id, task=task, dag_version_id=dag_version.id) session.add_all([dr1, dr2, dr3, ti_1, ti_2, ti_3]) session.commit() diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index 832cc141a29aa..801862e8ad865 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -44,10 +44,12 @@ class TestClearTasks: @pytest.fixture(autouse=True, scope="class") def clean(self): db.clear_db_runs() + db.clear_db_serialized_dags() yield db.clear_db_runs() + db.clear_db_serialized_dags() def test_clear_task_instances(self, dag_maker): # Explicitly needs catchup as True as test is creating history runs @@ -368,9 +370,6 @@ def test_clear_task_instances_without_dag_param(self, dag_maker, session): task0 = EmptyOperator(task_id="task0") task1 = EmptyOperator(task_id="task1", retries=2) - # Write DAG to the database so it can be found by clear_task_instances(). - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) - dr = dag_maker.create_dagrun( state=State.RUNNING, run_type=DagRunType.SCHEDULED, diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index a173d7ec50f17..055302e4f55c5 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -57,6 +57,7 @@ ExecutorLoader, get_asset_triggered_next_run_info, ) +from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance as TI @@ -148,6 +149,8 @@ def _create_dagrun( start_date: datetime.datetime | None = None, **kwargs, ) -> DagRun: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") logical_date = timezone.coerce_datetime(logical_date) if not isinstance(data_interval, DataInterval): data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) @@ -239,6 +242,10 @@ def test_get_num_task_instances(self): test_dag = DAG(dag_id=test_dag_id, schedule=None, start_date=DEFAULT_DATE) test_task = EmptyOperator(task_id=test_task_id, dag=test_dag) + test_dag.sync_to_db() + SerializedDagModel.write_dag(test_dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(test_dag_id) + dag_version_id = dag_version.id dr1 = _create_dagrun( test_dag, @@ -273,17 +280,16 @@ def test_get_num_task_instances(self): DEFAULT_DATE + datetime.timedelta(days=2), ), ) - - ti1 = TI(task=test_task, run_id=dr1.run_id) + ti1 = TI(task=test_task, run_id=dr1.run_id, dag_version_id=dag_version_id) ti1.refresh_from_db() ti1.state = None - ti2 = TI(task=test_task, run_id=dr2.run_id) + ti2 = TI(task=test_task, run_id=dr2.run_id, dag_version_id=dag_version_id) ti2.refresh_from_db() ti2.state = State.RUNNING - ti3 = TI(task=test_task, run_id=dr3.run_id) + ti3 = TI(task=test_task, run_id=dr3.run_id, dag_version_id=dag_version_id) ti3.refresh_from_db() ti3.state = State.QUEUED - ti4 = TI(task=test_task, run_id=dr4.run_id) + ti4 = TI(task=test_task, run_id=dr4.run_id, dag_version_id=dag_version_id) ti4.refresh_from_db() ti4.state = State.RUNNING session = settings.Session() @@ -333,6 +339,8 @@ def test_get_task_instances_before(self): test_dag = DAG(dag_id=test_dag_id, schedule=None, start_date=BASE_DATE) EmptyOperator(task_id=test_task_id, dag=test_dag) + test_dag.sync_to_db() + SerializedDagModel.write_dag(test_dag, bundle_name="testing") session = settings.Session() @@ -540,6 +548,8 @@ def test_create_dagrun_when_schedule_is_none_and_empty_start_date(self): # Check that we don't get an AttributeError 'start_date' for self.start_date when schedule is none dag = DAG("dag_with_none_schedule_and_empty_start_date", schedule=None) dag.add_task(BaseOperator(task_id="task_without_start_date")) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dagrun = dag.create_dagrun( run_id="test", state=State.RUNNING, @@ -738,10 +748,10 @@ def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state, catch dag.max_active_runs = 1 EmptyOperator(task_id="dummy", dag=dag, owner="airflow") - session = settings.Session() dag.clear() DAG.bulk_write_to_db("testing", None, [dag], session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") model = session.get(DagModel, dag.dag_id) @@ -1011,7 +1021,8 @@ def test_schedule_dag_no_previous_runs(self): dag_id = "test_schedule_dag_no_previous_runs" dag = DAG(dag_id=dag_id, schedule=None) dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) - + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dag_run = dag.create_dagrun( run_id="test", run_type=DagRunType.SCHEDULED, @@ -1048,6 +1059,8 @@ def test_dag_handle_callback_crash(self, mock_stats): ) when = TEST_DATE dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=when)) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") with create_session() as session: dag_run = dag.create_dagrun( @@ -1084,7 +1097,8 @@ def test_dag_handle_callback_with_removed_task(self, dag_maker, session): ) as dag: EmptyOperator(task_id="faketastic") task_removed = EmptyOperator(task_id="removed_task") - + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") with create_session() as session: dag_run = dag.create_dagrun( run_id="test", @@ -1314,6 +1328,8 @@ def test_description_from_timetable(self, timetable, expected_description): def test_create_dagrun_job_id_is_set(self): job_id = 42 dag = DAG(dag_id="test_create_dagrun_job_id_is_set", schedule=None) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dr = dag.create_dagrun( run_id="test_create_dagrun_job_id_is_set", logical_date=DEFAULT_DATE, @@ -1406,13 +1422,12 @@ def test_clear_set_dagrun_state(self, dag_run_state, dag_maker, session): assert dr.state == dag_run_state @pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING]) - @pytest.mark.need_serialized_dag - def test_clear_set_dagrun_state_for_mapped_task(self, dag_maker, dag_run_state): + def test_clear_set_dagrun_state_for_mapped_task(self, session, dag_run_state): dag_id = "test_clear_set_dagrun_state" task_id = "t1" - with dag_maker(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1) as dag: + with DAG(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1) as dag: @task_decorator def make_arg_lists(): @@ -1423,7 +1438,6 @@ def consumer(value): PythonOperator.partial(task_id=task_id, python_callable=consumer).expand(op_args=make_arg_lists()) - session = dag_maker.session dagrun_1 = _create_dagrun( dag, run_type=DagRunType.BACKFILL_JOB, @@ -1466,6 +1480,8 @@ def consumer(value): def test_dag_test_basic(self): dag = DAG(dag_id="test_local_testing_conn_file", schedule=None, start_date=DEFAULT_DATE) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") mock_object = mock.MagicMock() @task_decorator @@ -1481,6 +1497,8 @@ def check_task(): def test_dag_test_with_dependencies(self): dag = DAG(dag_id="test_local_testing_conn_file", schedule=None, start_date=DEFAULT_DATE) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") mock_object = mock.MagicMock() @task_decorator @@ -1520,6 +1538,8 @@ def handle_dag_failure(context): mock_task_object_1 = mock.MagicMock() mock_task_object_2 = mock.MagicMock() + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") @task_decorator def check_task(): @@ -1557,6 +1577,8 @@ def test_dag_connection_file(self, tmp_path): conn_type: postgres """ dag = DAG(dag_id="test_local_testing_conn_file", schedule=None, start_date=DEFAULT_DATE) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") @task_decorator def check_task(): @@ -1958,6 +1980,8 @@ def test_validate_executor_field(self): def test_validate_params_on_trigger_dag(self): dag = DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")}) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"): dag.create_dagrun( run_id="test_dagrun_missing_param", @@ -2464,6 +2488,8 @@ def teardown_method(self) -> None: @pytest.mark.parametrize("tasks_count", [3, 12]) def test_count_number_queries(self, tasks_count): dag = DAG("test_dagrun_query_count", schedule=None, start_date=DEFAULT_DATE) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") for i in range(tasks_count): EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag) with assert_queries_count(5): diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index ca0ca9225649e..dbcdf9382c2ce 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -116,6 +116,7 @@ def create_dag_run( else: run_type = DagRunType.MANUAL data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) + dag_run = dag.create_dagrun( run_id=dag.timetable.generate_run_id( run_type=run_type, @@ -455,7 +456,7 @@ def test_on_success_callback_when_task_skipped(self, session): _ = EmptyOperator(task_id="test_state_succeeded1", dag=dag) dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") initial_task_states = { "test_state_succeeded1": TaskInstanceState.SKIPPED, @@ -659,7 +660,7 @@ def on_success_callable(context): # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) dag.relative_fileloc = relative_fileloc - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") session.commit() dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) @@ -708,7 +709,7 @@ def on_failure_callable(context): # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) dag.relative_fileloc = relative_fileloc - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") session.commit() dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) @@ -968,11 +969,11 @@ def test_depends_on_past(self, dag_maker, session, prev_ti_state, is_ti_schedula run_type=DagRunType.SCHEDULED, ) - prev_ti = TI(task, run_id=dag_run_1.run_id) + prev_ti = TI(task, run_id=dag_run_1.run_id, dag_version_id=dag_run_1.created_dag_version_id) prev_ti.refresh_from_db(session=session) prev_ti.set_state(prev_ti_state, session=session) session.flush() - ti = TI(task, run_id=dag_run_2.run_id) + ti = TI(task, run_id=dag_run_2.run_id, dag_version_id=dag_run_1.created_dag_version_id) ti.refresh_from_db(session=session) decision = dag_run_2.task_instance_scheduling_decisions(session=session) @@ -1045,7 +1046,7 @@ def test_next_dagruns_to_examine_only_unpaused(self, session, state): ) session.add(orm_dag) session.flush() - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") dr = dag.create_dagrun( run_id=dag.timetable.generate_run_id( run_type=DagRunType.SCHEDULED, @@ -1071,7 +1072,8 @@ def test_next_dagruns_to_examine_only_unpaused(self, session, state): assert runs == [dr] orm_dag.is_paused = True - session.flush() + session.merge(orm_dag) + session.commit() runs = func(session).all() assert runs == [] @@ -1086,7 +1088,7 @@ def test_no_scheduling_delay_for_nonscheduled_runs(self, stats_mock, session): dag_task = EmptyOperator(task_id="dummy", dag=dag) dag.sync_to_db(session=session) - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") initial_task_states = { dag_task.task_id: TaskInstanceState.SUCCESS, @@ -1127,7 +1129,7 @@ def test_emit_scheduling_delay(self, session, schedule, expected): orm_dag = DagModel(**orm_dag_kwargs) session.add(orm_dag) session.flush() - SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") dag_run = dag.create_dagrun( run_id=dag.timetable.generate_run_id( run_type=DagRunType.SCHEDULED, @@ -1647,7 +1649,7 @@ def task_2(arg2): ... assert len(decision.schedulable_tis) == 2 # We insert a faulty record - session.add(TaskInstance(task=dag.get_task("task_2"), run_id=dr.run_id)) + session.add(TaskInstance(task=dag.get_task("task_2"), run_id=dr.run_id, dag_version_id=ti.dag_version_id)) session.flush() decision = dr.task_instance_scheduling_decisions() @@ -1946,9 +1948,22 @@ def test_schedule_tis_map_index(dag_maker, session): task = BaseOperator(task_id="task_1") dr = DagRun(dag_id="test", run_id="test", run_type=DagRunType.MANUAL) - ti0 = TI(task=task, run_id=dr.run_id, map_index=0, state=TaskInstanceState.SUCCESS) - ti1 = TI(task=task, run_id=dr.run_id, map_index=1, state=None) - ti2 = TI(task=task, run_id=dr.run_id, map_index=2, state=TaskInstanceState.SUCCESS) + dag_version = DagVersion.get_latest_version(dag_id=dr.dag_id) + ti0 = TI( + task=task, + run_id=dr.run_id, + map_index=0, + state=TaskInstanceState.SUCCESS, + dag_version_id=dag_version.id, + ) + ti1 = TI(task=task, run_id=dr.run_id, map_index=1, state=None, dag_version_id=dag_version.id) + ti2 = TI( + task=task, + run_id=dr.run_id, + map_index=2, + state=TaskInstanceState.SUCCESS, + dag_version_id=dag_version.id, + ) session.add_all((dr, ti0, ti1, ti2)) session.flush() @@ -1987,8 +2002,8 @@ def execute_complete(self): task = TestOperator(task_id="test_task") dr: DagRun = dag_maker.create_dagrun() - - ti = TI(task=task, run_id=dr.run_id, state=None) + dag_version = DagVersion.get_latest_version(dag_id=dr.dag_id) + ti = TI(task=task, run_id=dr.run_id, state=None, dag_version_id=dag_version.id) assert ti.state is None dr.schedule_tis((ti,), session=session) assert ti.state == TaskInstanceState.DEFERRED @@ -2149,7 +2164,7 @@ def do_something_else(i): ti.map_index = 0 task = ti.task for map_index in range(1, 5): - ti = TI(task, run_id=dr.run_id, map_index=map_index) + ti = TI(task, run_id=dr.run_id, map_index=map_index, dag_version_id=ti.dag_version_id) session.add(ti) ti.dag_run = dr session.flush() diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py b/airflow-core/tests/unit/models/test_mappedoperator.py index b116082111a47..2922ee1c23e50 100644 --- a/airflow-core/tests/unit/models/test_mappedoperator.py +++ b/airflow-core/tests/unit/models/test_mappedoperator.py @@ -28,6 +28,8 @@ from airflow.exceptions import AirflowSkipException from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG +from airflow.models.dag_version import DagVersion +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator @@ -67,6 +69,8 @@ def execute(self, context: Context): unrenderable_values = [UnrenderableClass(), UnrenderableClass()] mapped = CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values) task1 >> mapped + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dag.test() assert ( "Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'" @@ -124,9 +128,17 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec TaskInstance.run_id == dr.run_id, ).delete() + dag_version = DagVersion.get_latest_version(dr.dag_id) + for index in range(num_existing_tis): # Give the existing TIs a state to make sure we don't change them - ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + ti = TaskInstance( + mapped, + run_id=dr.run_id, + map_index=index, + state=TaskInstanceState.SUCCESS, + dag_version_id=dag_version.id, + ) session.add(ti) session.flush() @@ -165,10 +177,16 @@ def test_expand_mapped_task_failed_state_in_db(dag_maker, session): keys=None, ) ) - + dag_version = DagVersion.get_latest_version(dr.dag_id) for index in range(2): # Give the existing TIs a state to make sure we don't change them - ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + ti = TaskInstance( + mapped, + run_id=dr.run_id, + map_index=index, + state=TaskInstanceState.SUCCESS, + dag_version_id=dag_version.id, + ) session.add(ti) session.flush() @@ -260,10 +278,17 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis TaskInstance.task_id == mapped.task_id, TaskInstance.run_id == dr.run_id, ).delete() + dag_version = DagVersion.get_latest_version(dr.dag_id) for index in range(num_existing_tis): # Give the existing TIs a state to make sure we don't change them - ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + ti = TaskInstance( + mapped, + run_id=dr.run_id, + map_index=index, + state=TaskInstanceState.SUCCESS, + dag_version_id=dag_version.id, + ) session.add(ti) session.flush() diff --git a/airflow-core/tests/unit/models/test_pool.py b/airflow-core/tests/unit/models/test_pool.py index 587f0fd60cba1..ad0c935d618ca 100644 --- a/airflow-core/tests/unit/models/test_pool.py +++ b/airflow-core/tests/unit/models/test_pool.py @@ -21,6 +21,7 @@ from airflow import settings from airflow.exceptions import AirflowException, PoolNotFound +from airflow.models.dag_version import DagVersion from airflow.models.pool import Pool from airflow.models.taskinstance import TaskInstance as TI from airflow.providers.standard.operators.empty import EmptyOperator @@ -178,10 +179,10 @@ def test_infinite_slots(self, dag_maker): op2 = EmptyOperator(task_id="dummy2", pool="test_pool") dr = dag_maker.create_dagrun() - - ti1 = TI(task=op1, run_id=dr.run_id) + dag_version = DagVersion.get_latest_version(dr.dag_id) + ti1 = TI(task=op1, run_id=dr.run_id, dag_version_id=dag_version.id) ti1.refresh_from_db() - ti2 = TI(task=op2, run_id=dr.run_id) + ti2 = TI(task=op2, run_id=dr.run_id, dag_version_id=dag_version.id) ti2.refresh_from_db() ti1.state = State.RUNNING ti2.state = State.QUEUED @@ -228,10 +229,10 @@ def test_default_pool_open_slots(self, dag_maker): op3 = EmptyOperator(task_id="dummy3") dr = dag_maker.create_dagrun() - - ti1 = TI(task=op1, run_id=dr.run_id) - ti2 = TI(task=op2, run_id=dr.run_id) - ti3 = TI(task=op3, run_id=dr.run_id) + dag_version = DagVersion.get_latest_version(dr.dag_id) + ti1 = TI(task=op1, run_id=dr.run_id, dag_version_id=dag_version.id) + ti2 = TI(task=op2, run_id=dr.run_id, dag_version_id=dag_version.id) + ti3 = TI(task=op3, run_id=dr.run_id, dag_version_id=dag_version.id) ti1.refresh_from_db() ti1.state = State.RUNNING ti2.refresh_from_db() diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index c4a01490f940c..f7bf2d240e00d 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -44,6 +44,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG +from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun from airflow.models.pool import Pool from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -425,7 +426,7 @@ def test_ti_updates_with_task(self, create_task_instance, session): dag=dag, ) - ti2 = TI(task=task2, run_id=ti.run_id) + ti2 = TI(task=task2, run_id=ti.run_id, dag_version_id=ti.dag_version_id) session.add(ti2) session.flush() @@ -1225,7 +1226,9 @@ def do_something_else(i): base_task = ti.task for map_index in range(1, 5): - ti = TaskInstance(base_task, run_id=dr.run_id, map_index=map_index) + ti = TaskInstance( + base_task, run_id=dr.run_id, map_index=map_index, dag_version_id=ti.dag_version_id + ) session.add(ti) ti.dag_run = dr session.flush() @@ -1276,7 +1279,7 @@ def test_are_dependents_done( downstream_task = EmptyOperator(task_id="downstream_task", dag=dag) ti.task >> downstream_task - downstream_ti = TI(downstream_task, run_id=ti.run_id) + downstream_ti = TI(downstream_task, run_id=ti.run_id, dag_version_id=ti.dag_version_id) downstream_ti.set_state(downstream_ti_state, session) session.flush() @@ -1310,7 +1313,9 @@ def test_check_and_change_state_before_execution(self, create_task_instance, tes SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag - ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + ti_from_deserialized_task = TI( + task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id + ) assert ti_from_deserialized_task.try_number == 0 assert ti_from_deserialized_task.check_and_change_state_before_execution() @@ -1331,7 +1336,9 @@ def test_check_and_change_state_before_execution_provided_id_overrides( SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag - ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + ti_from_deserialized_task = TI( + task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id + ) assert ti_from_deserialized_task.try_number == 0 assert ti_from_deserialized_task.check_and_change_state_before_execution( @@ -1349,7 +1356,9 @@ def test_check_and_change_state_before_execution_with_exec_id(self, create_task_ SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag - ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + ti_from_deserialized_task = TI( + task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id + ) assert ti_from_deserialized_task.try_number == 0 assert ti_from_deserialized_task.check_and_change_state_before_execution( @@ -1369,7 +1378,9 @@ def test_check_and_change_state_before_execution_dep_not_met( SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag - ti2 = TI(task=serialized_dag.get_task(task2.task_id), run_id=ti.run_id) + ti2 = TI( + task=serialized_dag.get_task(task2.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id + ) assert not ti2.check_and_change_state_before_execution() def test_check_and_change_state_before_execution_dep_not_met_already_running( @@ -1383,7 +1394,9 @@ def test_check_and_change_state_before_execution_dep_not_met_already_running( SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag - ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + ti_from_deserialized_task = TI( + task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id + ) assert not ti_from_deserialized_task.check_and_change_state_before_execution() assert ti_from_deserialized_task.state == State.RUNNING @@ -1400,7 +1413,9 @@ def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( SerializedDagModel.write_dag(ti.task.dag, bundle_name="testing") serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag - ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + ti_from_deserialized_task = TI( + task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id + ) assert not ti_from_deserialized_task.check_and_change_state_before_execution() assert ti_from_deserialized_task.state == State.FAILED @@ -1554,7 +1569,7 @@ def test_overwrite_params_with_dag_run_conf_none(self, create_task_instance): def test_set_duration(self): task = EmptyOperator(task_id="op", email="test@test.test") - ti = TI(task=task) + ti = TI(task=task, dag_version_id=mock.MagicMock()) ti.start_date = datetime.datetime(2018, 10, 1, 1) ti.end_date = datetime.datetime(2018, 10, 1, 2) ti.set_duration() @@ -1562,7 +1577,7 @@ def test_set_duration(self): def test_set_duration_empty_dates(self): task = EmptyOperator(task_id="op", email="test@test.test") - ti = TI(task=task) + ti = TI(task=task, dag_version_id=mock.MagicMock()) ti.set_duration() assert ti.duration is None @@ -2363,7 +2378,7 @@ def execute(self, context): ... retries=1, dag=dag, ) - ti2 = TI(task=task2, run_id=dr.run_id) + ti2 = TI(task=task2, run_id=dr.run_id, dag_version_id=ti1.dag_version_id) ti2.state = State.FAILED session.add(ti2) session.flush() @@ -2391,7 +2406,7 @@ def execute(self, context): ... retries=1, dag=dag, ) - ti3 = TI(task=task3, run_id=dr.run_id) + ti3 = TI(task=task3, run_id=dr.run_id, dag_version_id=ti1.dag_version_id) session.add(ti3) session.flush() ti3.state = State.FAILED @@ -2497,13 +2512,13 @@ def execute(self, context): ... tasks = [] for i, state in enumerate(states): op = CustomOp(task_id=f"reg_Task{i}", dag=dag) - ti = TI(task=op, run_id=dr.run_id) + ti = TI(task=op, run_id=dr.run_id, dag_version_id=ti1.dag_version_id) ti.state = state session.add(ti) tasks.append(ti) fail_task = CustomOp(task_id="fail_Task", dag=dag) - ti_ff = TI(task=fail_task, run_id=dr.run_id) + ti_ff = TI(task=fail_task, run_id=dr.run_id, dag_version_id=ti1.dag_version_id) ti_ff.state = State.FAILED session.add(ti_ff) session.commit() @@ -2635,7 +2650,7 @@ def test_refresh_from_db(self, create_task_instance): mock_task.task_id = expected_values["task_id"] mock_task.dag_id = expected_values["dag_id"] - ti = TI(task=mock_task, run_id="test") + ti = TI(task=mock_task, run_id="test", dag_version_id=ti.dag_version_id) ti.refresh_from_db() for key, expected_value in expected_values.items(): assert hasattr(ti, key), f"Key {key} is missing in the TaskInstance." @@ -2656,7 +2671,7 @@ def test_operator_field_with_serialization(self, create_task_instance): deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) assert deserialized_op.task_type == "EmptyOperator" # Verify that ti.operator field renders correctly "with" Serialization - ser_ti = TI(task=deserialized_op, run_id=None) + ser_ti = TI(task=deserialized_op, run_id=None, dag_version_id=ti.dag_version_id) assert ser_ti.operator == "EmptyOperator" assert ser_ti.task.operator_name == "EmptyOperator" @@ -2796,7 +2811,7 @@ def mock_policy(task_instance: TaskInstance): retries=30, executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}, ) - ti = TI(task, run_id=None) + ti = TI(task, run_id=None, dag_version_id=mock.MagicMock()) ti.refresh_from_task(task, pool_override=pool_override) assert ti.queue == expected_queue @@ -3162,7 +3177,6 @@ def test_delete_dagversion_restricted_when_taskinstance_exists(dag_maker, sessio """ Ensure that deleting a DagVersion with existing TaskInstance references is restricted (ON DELETE RESTRICT). """ - from airflow.models.dag_version import DagVersion with dag_maker(dag_id="test_dag_restrict", session=session) as dag: EmptyOperator(task_id="task1") diff --git a/airflow-core/tests/unit/models/test_taskmap.py b/airflow-core/tests/unit/models/test_taskmap.py index 10fb4d99bba09..b1ef191c6c046 100644 --- a/airflow-core/tests/unit/models/test_taskmap.py +++ b/airflow-core/tests/unit/models/test_taskmap.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest from airflow.models.taskinstance import TaskInstance @@ -28,7 +30,7 @@ def test_task_map_from_task_instance_xcom(): task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id="test_run", map_index=0) + ti = TaskInstance(task=task, run_id="test_run", map_index=0, dag_version_id=mock.MagicMock()) ti.dag_id = "test_dag" value = {"key1": "value1", "key2": "value2"} @@ -49,7 +51,7 @@ def test_task_map_from_task_instance_xcom(): def test_task_map_with_invalid_task_instance(): task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id=None, map_index=0) + ti = TaskInstance(task=task, run_id=None, map_index=0, dag_version_id=mock.MagicMock()) ti.dag_id = "test_dag" # Define some arbitrary XCom-like value data diff --git a/airflow-core/tests/unit/models/test_trigger.py b/airflow-core/tests/unit/models/test_trigger.py index e4bd5f591f7e7..83e7886b79304 100644 --- a/airflow-core/tests/unit/models/test_trigger.py +++ b/airflow-core/tests/unit/models/test_trigger.py @@ -117,12 +117,16 @@ def test_clean_unused(session, create_task_instance): task_instance.trigger_id = trigger1.id session.add(task_instance) fake_task1 = EmptyOperator(task_id="fake2", dag=task_instance.task.dag) - task_instance1 = TaskInstance(task=fake_task1, run_id=task_instance.run_id) + task_instance1 = TaskInstance( + task=fake_task1, run_id=task_instance.run_id, dag_version_id=task_instance.dag_version_id + ) task_instance1.state = State.SUCCESS task_instance1.trigger_id = trigger2.id session.add(task_instance1) fake_task2 = EmptyOperator(task_id="fake3", dag=task_instance.task.dag) - task_instance2 = TaskInstance(task=fake_task2, run_id=task_instance.run_id) + task_instance2 = TaskInstance( + task=fake_task2, run_id=task_instance.run_id, dag_version_id=task_instance.dag_version_id + ) task_instance2.state = State.SUCCESS task_instance2.trigger_id = trigger4.id session.add(task_instance2) diff --git a/airflow-core/tests/unit/models/test_xcom.py b/airflow-core/tests/unit/models/test_xcom.py index a67bbb7343c85..914271f9386e3 100644 --- a/airflow-core/tests/unit/models/test_xcom.py +++ b/airflow-core/tests/unit/models/test_xcom.py @@ -24,8 +24,11 @@ import pytest +from airflow import DAG from airflow.configuration import conf +from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun, DagRunType +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator @@ -60,6 +63,9 @@ def reset_db(): @pytest.fixture def task_instance_factory(request, session: Session): def func(*, dag_id, task_id, logical_date, run_after=None): + dag = DAG(dag_id=dag_id) + dag.sync_to_db(session=session) + SerializedDagModel.write_dag(dag, bundle_name="testing") run_id = DagRun.generate_run_id( run_type=DagRunType.SCHEDULED, logical_date=logical_date, @@ -75,7 +81,9 @@ def func(*, dag_id, task_id, logical_date, run_after=None): run_after=run_after if run_after is not None else logical_date, ) session.add(run) - ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) + session.flush() + dag_version = DagVersion.get_latest_version(run.dag_id, session=session) + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, dag_version_id=dag_version.id) ti.dag_id = dag_id session.add(ti) session.commit() @@ -102,7 +110,11 @@ def task_instance(task_instance_factory): @pytest.fixture def task_instances(session, task_instance): - ti2 = TaskInstance(EmptyOperator(task_id="task_2"), run_id=task_instance.run_id) + ti2 = TaskInstance( + EmptyOperator(task_id="task_2"), + run_id=task_instance.run_id, + dag_version_id=task_instance.dag_version_id, + ) ti2.dag_id = task_instance.dag_id session.add(ti2) session.commit() @@ -317,6 +329,7 @@ def tis_for_xcom_get_many_from_prior_dates(self, task_instance_factory, push_sim def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_from_prior_dates): ti1, ti2 = tis_for_xcom_get_many_from_prior_dates + session.add(ti1) # for some reason, ti1 goes out of the session scope stored_xcoms = XComModel.get_many( run_id=ti2.run_id, key="xcom_1", diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 1fc9d0486fecd..e345dcc4fbc34 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -26,6 +26,7 @@ from dateutil import relativedelta from kubernetes.client import models as k8s from pendulum.tz.timezone import FixedTimezone, Timezone +from uuid6 import uuid7 from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.exceptions import ( @@ -175,12 +176,14 @@ def validate(self, obj): task=EmptyOperator(task_id="test-task"), run_id="fake_run", state=State.RUNNING, + dag_version_id=uuid7(), ) TI_WITH_START_DAY = TaskInstance( task=EmptyOperator(task_id="test-task"), run_id="fake_run", state=State.RUNNING, + dag_version_id=uuid7(), ) TI_WITH_START_DAY.start_date = timezone.utcnow() diff --git a/airflow-core/tests/unit/ti_deps/deps/test_dag_ti_slots_available_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_dag_ti_slots_available_dep.py index bfc89fc90c1a0..1acdeb6fb6dea 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_dag_ti_slots_available_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_dag_ti_slots_available_dep.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from unittest import mock from unittest.mock import Mock import pytest @@ -34,7 +35,7 @@ def test_concurrency_reached(self): """ dag = Mock(concurrency=1, get_concurrency_reached=Mock(return_value=True)) task = Mock(dag=dag, pool_slots=1) - ti = TaskInstance(task) + ti = TaskInstance(task, dag_version_id=mock.MagicMock()) assert not DagTISlotsAvailableDep().is_met(ti=ti) @@ -44,6 +45,6 @@ def test_all_conditions_met(self): """ dag = Mock(concurrency=1, get_concurrency_reached=Mock(return_value=False)) task = Mock(dag=dag, pool_slots=1) - ti = TaskInstance(task) + ti = TaskInstance(task, dag_version_id=mock.MagicMock()) assert DagTISlotsAvailableDep().is_met(ti=ti) diff --git a/airflow-core/tests/unit/ti_deps/deps/test_dag_unpaused_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_dag_unpaused_dep.py index 4fb1bd09a9304..bf1a8bdf8945e 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_dag_unpaused_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_dag_unpaused_dep.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from unittest import mock from unittest.mock import Mock import pytest @@ -34,7 +35,7 @@ def test_concurrency_reached(self): """ dag = Mock(**{"get_is_paused.return_value": True}) task = Mock(dag=dag) - ti = TaskInstance(task=task) + ti = TaskInstance(task=task, dag_version_id=mock.MagicMock()) assert not DagUnpausedDep().is_met(ti=ti) @@ -44,6 +45,6 @@ def test_all_conditions_met(self): """ dag = Mock(**{"get_is_paused.return_value": False}) task = Mock(dag=dag) - ti = TaskInstance(task=task) + ti = TaskInstance(task=task, dag_version_id=mock.MagicMock()) assert DagUnpausedDep().is_met(ti=ti) diff --git a/airflow-core/tests/unit/ti_deps/deps/test_not_in_retry_period_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_not_in_retry_period_dep.py index 44c241b71491d..3b8aec4763ecb 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_not_in_retry_period_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_not_in_retry_period_dep.py @@ -18,6 +18,7 @@ from __future__ import annotations from datetime import timedelta +from unittest import mock from unittest.mock import Mock import pytest @@ -34,7 +35,7 @@ class TestNotInRetryPeriodDep: def _get_task_instance(self, state, end_date=None, retry_delay=timedelta(minutes=15)): task = Mock(retry_delay=retry_delay, retry_exponential_backoff=False) - ti = TaskInstance(task=task, state=state) + ti = TaskInstance(task=task, state=state, dag_version_id=mock.MagicMock()) ti.end_date = end_date return ti diff --git a/airflow-core/tests/unit/ti_deps/deps/test_prev_dagrun_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_prev_dagrun_dep.py index 9fd99124d6ef2..182241232fa73 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_prev_dagrun_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_prev_dagrun_dep.py @@ -24,6 +24,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG +from airflow.models.serialized_dag import SerializedDagModel from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.utils.state import DagRunState, TaskInstanceState @@ -54,6 +55,8 @@ def test_first_task_run_of_new_task(self): start_date=START_DATE, wait_for_downstream=False, ) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") # Old DAG run will include only TaskInstance of old_task dag.create_dagrun( run_id="old_run", diff --git a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py index 133ba83c72b69..581fd4f22ea91 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py @@ -26,6 +26,7 @@ import pytest from airflow.models.baseoperator import BaseOperator +from airflow.models.dag_version import DagVersion from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import task, task_group @@ -120,7 +121,7 @@ def do_something(i): def do_something_else(i): return 1 - with dag_maker(dag_id="test_dag"): + with dag_maker(dag_id="test_dag") as dag: nums = do_something.expand(i=[i + 1 for i in range(5)]) do_something_else.expand(i=nums) if add_setup_tasks: @@ -136,8 +137,13 @@ def do_something_else(i): def _expand_tasks(task_instance: str, upstream: str) -> BaseOperator | None: ti = dr.get_task_instance(task_instance, session=session) ti.map_index = 0 + dag_version = DagVersion.get_latest_version(dag.dag_id) + if TYPE_CHECKING: + assert dag_version for map_index in range(1, 5): - ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) + ti = TaskInstance( + ti.task, run_id=dr.run_id, map_index=map_index, dag_version_id=dag_version.id + ) session.add(ti) ti.dag_run = dr session.flush() diff --git a/airflow-core/tests/unit/utils/test_db_cleanup.py b/airflow-core/tests/unit/utils/test_db_cleanup.py index 89cfe892f3817..7b4fc055405d5 100644 --- a/airflow-core/tests/unit/utils/test_db_cleanup.py +++ b/airflow-core/tests/unit/utils/test_db_cleanup.py @@ -30,8 +30,11 @@ from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.ext.declarative import DeclarativeMeta +from airflow import DAG from airflow.exceptions import AirflowException from airflow.models import DagModel, DagRun, TaskInstance +from airflow.models.dag_version import DagVersion +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.python import PythonOperator from airflow.utils import timezone from airflow.utils.db_cleanup import ( @@ -665,8 +668,12 @@ def test_drop_archived_tables(self, mock_input, confirm_mock, inspect_mock, capl def create_tis(base_date, num_tis, run_type=DagRunType.SCHEDULED): with create_session() as session: - dag = DagModel(dag_id=f"test-dag_{uuid4()}") - session.add(dag) + dag_id = f"test-dag_{uuid4()}" + dag = DAG(dag_id=dag_id) + dm = DagModel(dag_id=dag_id) + session.add(dm) + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) for num in range(num_tis): start_date = base_date.add(days=num) dag_run = DagRun( @@ -676,7 +683,9 @@ def create_tis(base_date, num_tis, run_type=DagRunType.SCHEDULED): start_date=start_date, ) ti = TaskInstance( - PythonOperator(task_id="dummy-task", python_callable=print), run_id=dag_run.run_id + PythonOperator(task_id="dummy-task", python_callable=print), + run_id=dag_run.run_id, + dag_version_id=dag_version.id, ) ti.dag_id = dag.dag_id ti.start_date = start_date diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py b/airflow-core/tests/unit/utils/test_log_handlers.py index 286eae5c84255..8367aed97ce67 100644 --- a/airflow-core/tests/unit/utils/test_log_handlers.py +++ b/airflow-core/tests/unit/utils/test_log_handlers.py @@ -40,6 +40,7 @@ from airflow.executors import executor_constants, executor_loader from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner +from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancehistory import TaskInstanceHistory @@ -118,7 +119,8 @@ def task_callable(ti): ) dagrun = dag_maker.create_dagrun() - ti = TaskInstance(task=task, run_id=dagrun.run_id) + dag_version = DagVersion.get_latest_version(dagrun.dag_id) + ti = TaskInstance(task=task, run_id=dagrun.run_id, dag_version_id=dag_version.id) logger = ti.log ti.log.disabled = False @@ -161,7 +163,8 @@ def task_callable(ti): python_callable=task_callable, ) dagrun = dag_maker.create_dagrun() - ti = TaskInstance(task=task, run_id=dagrun.run_id) + dag_version = DagVersion.get_latest_version(dagrun.dag_id) + ti = TaskInstance(task=task, run_id=dagrun.run_id, dag_version_id=dag_version.id) ti.try_number = 0 ti.state = State.SKIPPED @@ -329,7 +332,8 @@ def task_callable(ti): python_callable=task_callable, ) dagrun = dag_maker.create_dagrun() - ti = TaskInstance(task=task, run_id=dagrun.run_id) + dag_version = DagVersion.get_latest_version(dagrun.dag_id) + ti = TaskInstance(task=task, run_id=dagrun.run_id, dag_version_id=dag_version.id) ti.try_number = 2 ti.state = State.RUNNING @@ -380,7 +384,8 @@ def task_callable(ti): python_callable=task_callable, ) dagrun = dag_maker.create_dagrun() - ti = TaskInstance(task=task, run_id=dagrun.run_id) + dag_version = DagVersion.get_latest_version(dagrun.dag_id) + ti = TaskInstance(task=task, run_id=dagrun.run_id, dag_version_id=dag_version.id) ti.try_number = 1 ti.state = State.RUNNING diff --git a/airflow-core/tests/unit/utils/test_sqlalchemy.py b/airflow-core/tests/unit/utils/test_sqlalchemy.py index 8ddb21e0a1971..602a46a8ab179 100644 --- a/airflow-core/tests/unit/utils/test_sqlalchemy.py +++ b/airflow-core/tests/unit/utils/test_sqlalchemy.py @@ -30,6 +30,7 @@ from airflow import settings from airflow.models.dag import DAG +from airflow.models.serialized_dag import SerializedDagModel from airflow.serialization.enums import DagAttributeTypes, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.settings import Session @@ -73,7 +74,8 @@ def test_utc_transformations(self): dag = DAG(dag_id=dag_id, schedule=datetime.timedelta(days=1), start_date=start_date) dag.clear() - + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") run = dag.create_dagrun( run_id=iso_date, run_type=DagRunType.MANUAL, diff --git a/airflow-core/tests/unit/utils/test_state.py b/airflow-core/tests/unit/utils/test_state.py index d418fef0745f7..e1f5dd05cfbd1 100644 --- a/airflow-core/tests/unit/utils/test_state.py +++ b/airflow-core/tests/unit/utils/test_state.py @@ -22,6 +22,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.session import create_session from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -38,6 +39,8 @@ def test_dagrun_state_enum_escape(): """ with create_session() as session: dag = DAG(dag_id="test_dagrun_state_enum_escape", schedule=timedelta(days=1), start_date=DEFAULT_DATE) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dag.create_dagrun( run_id=dag.timetable.generate_run_id( run_type=DagRunType.SCHEDULED, diff --git a/airflow-core/tests/unit/utils/test_task_handler_with_custom_formatter.py b/airflow-core/tests/unit/utils/test_task_handler_with_custom_formatter.py index 118e442d0c15a..5e5519047980a 100644 --- a/airflow-core/tests/unit/utils/test_task_handler_with_custom_formatter.py +++ b/airflow-core/tests/unit/utils/test_task_handler_with_custom_formatter.py @@ -22,6 +22,7 @@ import pytest from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG +from airflow.models.dag_version import DagVersion from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils.log.logging_mixin import set_context @@ -74,7 +75,8 @@ def task_instance(dag_maker): data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE), **triggered_by_kwargs, ) - ti = TaskInstance(task=task, run_id=dagrun.run_id) + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=task, run_id=dagrun.run_id, dag_version_id=dag_version.id) ti.log.disabled = False yield ti clear_db_runs() diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index a4f2e00bbff79..c64907a938a47 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -1587,6 +1587,7 @@ class TaskInstanceResponse(BaseModel): id: Annotated[str, Field(title="Id")] task_id: Annotated[str, Field(title="Task Id")] dag_id: Annotated[str, Field(title="Dag Id")] + dag_version: DagVersionResponse dag_run_id: Annotated[str, Field(title="Dag Run Id")] map_index: Annotated[int, Field(title="Map Index")] logical_date: Annotated[datetime | None, Field(title="Logical Date")] = None @@ -1616,7 +1617,6 @@ class TaskInstanceResponse(BaseModel): rendered_fields: Annotated[dict[str, Any] | None, Field(title="Rendered Fields")] = None trigger: TriggerResponse | None = None triggerer_job: JobResponse | None = None - dag_version: DagVersionResponse | None = None class TaskResponse(BaseModel): diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 2c04feffa92df..7efb6de98a784 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2296,6 +2296,7 @@ def _create_task_instance( run_id=run_id, try_number=try_number, map_index=map_index, + dag_version_id=uuid7(), ), dag_rel_path="", bundle_info=BundleInfo(name="anything", version="any"), diff --git a/devel-common/src/tests_common/test_utils/mock_context.py b/devel-common/src/tests_common/test_utils/mock_context.py index 4391490dfb4a2..200d8a583ab0b 100644 --- a/devel-common/src/tests_common/test_utils/mock_context.py +++ b/devel-common/src/tests_common/test_utils/mock_context.py @@ -16,8 +16,10 @@ # under the License. from __future__ import annotations +import inspect from collections.abc import Iterable from typing import TYPE_CHECKING, Any +from unittest import mock from airflow.utils.context import Context @@ -41,7 +43,23 @@ def __init__( state: str | None = TaskInstanceState.RUNNING, map_index: int = -1, ): - super().__init__(task=task, run_id=run_id, state=state, map_index=map_index) + # Inspect the parameters of TaskInstance.__init__ + init_sig = inspect.signature(super().__init__) + if "dag_version_id" in init_sig.parameters: + super().__init__( + task=task, + run_id=run_id, + state=state, + map_index=map_index, + dag_version_id=mock.MagicMock(), + ) + else: + super().__init__( + task=task, + run_id=run_id, + state=state, + map_index=map_index, + ) # type: ignore[call-arg] self.values: dict[str, Any] = {} def xcom_pull( diff --git a/devel-common/src/tests_common/test_utils/system_tests.py b/devel-common/src/tests_common/test_utils/system_tests.py index 0950922308867..744b793d205a9 100644 --- a/devel-common/src/tests_common/test_utils/system_tests.py +++ b/devel-common/src/tests_common/test_utils/system_tests.py @@ -26,6 +26,8 @@ from airflow.utils.state import DagRunState +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + if TYPE_CHECKING: from airflow.models.dagrun import DagRun from airflow.sdk.definitions.context import Context @@ -65,6 +67,17 @@ def test_run(): dag.on_success_callback = add_callback(dag.on_success_callback, callback) # If the env variable ``_AIRFLOW__SYSTEM_TEST_USE_EXECUTOR`` is set, then use an executor to run the # DAG + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag import DagModel + from airflow.models.serialized_dag import SerializedDagModel + from airflow.settings import Session + + d = DagModel(dag_id=dag.dag_id) + s = Session() + s.add(d) + s.commit() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_run = dag.test( use_executor=os.environ.get("_AIRFLOW__SYSTEM_TEST_USE_EXECUTOR") == "1", **test_kwargs, diff --git a/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py index afbcf113932a6..d504cee468933 100644 --- a/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py @@ -76,7 +76,10 @@ def create_context(task) -> Context: logical_date=logical_date, ), ) - task_instance = TaskInstance(task=task) + if AIRFLOW_V_3_0_PLUS: + task_instance = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=mock.MagicMock()) + else: + task_instance = TaskInstance(task=task) task_instance.dag_run = dag_run task_instance.dag_id = dag.dag_id task_instance.try_number = 1 diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py index 30a66e46608d4..097ccc9eabbd2 100644 --- a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py @@ -33,6 +33,7 @@ from watchtower import CloudWatchLogHandler from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.log.cloudwatch_task_handler import ( CloudWatchRemoteLogIO, @@ -199,30 +200,38 @@ def setup(self, create_log_template, tmp_path_factory, session): f"arn:aws:logs:{self.region_name}:11111111:log-group:{self.remote_log_group}", ) - date = datetime(2020, 1, 1) - dag_id = "dag_for_testing_cloudwatch_task_handler" - task_id = "task_for_testing_cloudwatch_log_handler" - self.dag = DAG(dag_id=dag_id, schedule=None, start_date=date) - task = EmptyOperator(task_id=task_id, dag=self.dag) - if AIRFLOW_V_3_0_PLUS: - dag_run = DagRun( - dag_id=self.dag.dag_id, - logical_date=date, - run_id="test", - run_type="scheduled", - ) - else: - dag_run = DagRun( - dag_id=self.dag.dag_id, - execution_date=date, - run_id="test", - run_type="scheduled", - ) - session.add(dag_run) - session.commit() - session.refresh(dag_run) + date = datetime(2020, 1, 1) + dag_id = "dag_for_testing_cloudwatch_task_handler" + task_id = "task_for_testing_cloudwatch_log_handler" + self.dag = DAG(dag_id=dag_id, schedule=None, start_date=date) + task = EmptyOperator(task_id=task_id, dag=self.dag) + if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + dag_run = DagRun( + dag_id=self.dag.dag_id, + logical_date=date, + run_id="test", + run_type="scheduled", + ) + else: + dag_run = DagRun( + dag_id=self.dag.dag_id, + execution_date=date, + run_id="test", + run_type="scheduled", + ) + session.add(dag_run) + session.commit() + session.refresh(dag_run) - self.ti = TaskInstance(task=task, run_id=dag_run.run_id) + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(self.dag.dag_id, session=session) + self.ti = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) + else: + self.ti = TaskInstance(task=task, run_id=dag_run.run_id) self.ti.dag_run = dag_run self.ti.try_number = 1 self.ti.state = State.RUNNING diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py index b28ca64e12769..f254eeeec07a7 100644 --- a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py +++ b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py @@ -27,7 +27,8 @@ from botocore.exceptions import ClientError from moto import mock_aws -from airflow.models import DAG, DagRun, TaskInstance +from airflow.models import DAG, DagModel, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.log.s3_task_handler import S3TaskHandler from airflow.providers.standard.operators.empty import EmptyOperator @@ -59,28 +60,35 @@ def setup_tests(self, create_log_template, tmp_path_factory, session): self.subject = self.s3_task_handler.io assert self.subject.hook is not None - date = datetime(2016, 1, 1) - self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date) - task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag) - if AIRFLOW_V_3_0_PLUS: - dag_run = DagRun( - dag_id=self.dag.dag_id, - logical_date=date, - run_id="test", - run_type="manual", - ) - else: - dag_run = DagRun( - dag_id=self.dag.dag_id, - execution_date=date, - run_id="test", - run_type="manual", - ) - session.add(dag_run) - session.commit() - session.refresh(dag_run) - - self.ti = TaskInstance(task=task, run_id=dag_run.run_id) + date = datetime(2016, 1, 1) + self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date) + task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag) + if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + dag_run = DagRun( + dag_id=self.dag.dag_id, + logical_date=date, + run_id="test", + run_type="manual", + ) + else: + dag_run = DagRun( + dag_id=self.dag.dag_id, + execution_date=date, + run_id="test", + run_type="manual", + ) + session.add(dag_run) + session.commit() + session.refresh(dag_run) + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + self.ti = TaskInstance(task=task, dag_version_id=dag_version.id) + else: + self.ti = TaskInstance(task=task, run_id=dag_run.run_id) self.ti.dag_run = dag_run self.ti.try_number = 1 self.ti.state = State.RUNNING @@ -94,6 +102,7 @@ def setup_tests(self, create_log_template, tmp_path_factory, session): self.dag.clear() session.query(DagRun).delete() + session.query(DagModel).delete() if self.s3_task_handler.handler: with contextlib.suppress(Exception): os.remove(self.s3_task_handler.handler.baseFilename) @@ -178,44 +187,51 @@ def setup_tests(self, create_log_template, tmp_path_factory, session): # Verify the hook now with the config override assert self.s3_task_handler.io.hook is not None - date = datetime(2016, 1, 1) - self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date) - task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag) - if AIRFLOW_V_3_0_PLUS: - dag_run = DagRun( - dag_id=self.dag.dag_id, - logical_date=date, - run_id="test", - run_type="manual", - ) - else: - dag_run = DagRun( - dag_id=self.dag.dag_id, - execution_date=date, - run_id="test", - run_type="manual", - ) - session.add(dag_run) - session.commit() - session.refresh(dag_run) + date = datetime(2016, 1, 1) + self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date) + task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag) + if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + dag_run = DagRun( + dag_id=self.dag.dag_id, + logical_date=date, + run_id="test", + run_type="manual", + ) + else: + dag_run = DagRun( + dag_id=self.dag.dag_id, + execution_date=date, + run_id="test", + run_type="manual", + ) + session.add(dag_run) + session.commit() + session.refresh(dag_run) + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + self.ti = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) + else: self.ti = TaskInstance(task=task, run_id=dag_run.run_id) - self.ti.dag_run = dag_run - self.ti.try_number = 1 - self.ti.state = State.RUNNING - session.add(self.ti) - session.commit() - - self.conn = boto3.client("s3") - self.conn.create_bucket(Bucket="bucket") - yield - - self.dag.clear() - - session.query(DagRun).delete() - if self.s3_task_handler.handler: - with contextlib.suppress(Exception): - os.remove(self.s3_task_handler.handler.baseFilename) + self.ti.dag_run = dag_run + self.ti.try_number = 1 + self.ti.state = State.RUNNING + session.add(self.ti) + session.commit() + + self.conn = boto3.client("s3") + self.conn.create_bucket(Bucket="bucket") + yield + + self.dag.clear() + + session.query(DagRun).delete() + if self.s3_task_handler.handler: + with contextlib.suppress(Exception): + os.remove(self.s3_task_handler.handler.baseFilename) def test_set_context_raw(self): self.ti.raw = True diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py b/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py index be9ba5f406d50..8a343a0be686e 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.operators.athena import AthenaOperator from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger @@ -235,6 +236,12 @@ def test_return_value( ): """Test we return the right value -- that will get put in to XCom by the execution engine""" if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + ti = TaskInstance(task=self.athena, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=timezone.utcnow(), @@ -250,7 +257,7 @@ def test_return_value( run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.athena) + ti = TaskInstance(task=self.athena) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py b/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py index ab80cb6bc6ca2..50101edad0368 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook from airflow.providers.amazon.aws.links.datasync import DataSyncTaskLink from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator @@ -355,6 +356,11 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): self.set_up_operator() if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(self.dag.dag_id) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=timezone.utcnow(), @@ -362,6 +368,7 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) + ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) else: dag_run = DagRun( dag_id=self.dag.dag_id, @@ -370,7 +377,7 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.datasync) + ti = TaskInstance(task=self.datasync) ti.dag_run = dag_run session.add(ti) session.commit() @@ -569,6 +576,12 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): self.set_up_operator() if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=timezone.utcnow(), @@ -584,7 +597,7 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.datasync) + ti = TaskInstance(task=self.datasync) ti.dag_run = dag_run session.add(ti) session.commit() @@ -685,6 +698,12 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): self.set_up_operator() if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=timezone.utcnow(), @@ -700,7 +719,7 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.datasync) + ti = TaskInstance(task=self.datasync) ti.dag_run = dag_run session.add(ti) session.commit() @@ -894,6 +913,12 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): self.set_up_operator() if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=timezone.utcnow(), @@ -909,7 +934,7 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.datasync) + ti = TaskInstance(task=self.datasync) ti.dag_run = dag_run session.add(ti) session.commit() @@ -1006,6 +1031,12 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): self.set_up_operator() if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + ti = TaskInstance(task=self.datasync, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=timezone.utcnow(), @@ -1021,7 +1052,7 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.datasync) + ti = TaskInstance(task=self.datasync) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py index 5771483cd320f..3418019ee7751 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py @@ -25,6 +25,7 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.variable import Variable from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.providers.amazon.aws.operators.dms import ( @@ -51,6 +52,9 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields +if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + TASK_ARN = "test_arn" @@ -315,6 +319,10 @@ def test_describe_tasks_return_value(self, mock_conn, mock_describe_replication_ ) if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + ti = TaskInstance(task=describe_task, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.dag.dag_id, logical_date=timezone.utcnow(), @@ -330,7 +338,7 @@ def test_describe_tasks_return_value(self, mock_conn, mock_describe_replication_ run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=describe_task) + ti = TaskInstance(task=describe_task) ti.dag_run = dag_run session.add(ti) session.commit() @@ -512,6 +520,10 @@ def test_template_fields_native(self, mock_conn, session): ) if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=op, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=dag.dag_id, run_id="test", @@ -527,7 +539,7 @@ def test_template_fields_native(self, mock_conn, session): state=DagRunState.RUNNING, execution_date=logical_date, ) - ti = TaskInstance(task=op) + ti = TaskInstance(task=op) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py index e18f1ab3ac090..0cfa38014d1bb 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py @@ -27,6 +27,7 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.operators.emr import EmrAddStepsOperator from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger from airflow.utils import timezone @@ -101,6 +102,12 @@ def test_validate_mutually_exclusive_args(self, job_flow_id, job_flow_name): @pytest.mark.db_test def test_render_template(self, session, clean_dags_and_dagruns): if AIRFLOW_V_3_0_PLUS: + self.operator.dag.sync_to_db() + SerializedDagModel.write_dag(self.operator.dag, bundle_name="testing") + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(self.operator.dag.dag_id) + ti = TaskInstance(task=self.operator, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.operator.dag.dag_id, logical_date=DEFAULT_DATE, @@ -116,7 +123,7 @@ def test_render_template(self, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.operator) + ti = TaskInstance(task=self.operator) ti.dag_run = dag_run session.add(ti) session.commit() @@ -168,6 +175,12 @@ def test_render_template_from_file(self, mocked_hook_client, session, clean_dags do_xcom_push=False, ) if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=test_task, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=dag.dag_id, logical_date=timezone.utcnow(), @@ -183,7 +196,7 @@ def test_render_template_from_file(self, mocked_hook_client, session, clean_dags run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=test_task) + ti = TaskInstance(task=test_task) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py index 037fb4430360f..686f40c103762 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py @@ -28,6 +28,7 @@ from airflow.exceptions import TaskDeferred from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger from airflow.providers.amazon.aws.utils.waiter import WAITER_POLICY_NAME_MAPPING, WaitPolicy @@ -100,6 +101,12 @@ def test_init(self): def test_render_template(self, session, clean_dags_and_dagruns): self.operator.job_flow_overrides = self._config if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + self.operator.dag.sync_to_db() + SerializedDagModel.write_dag(self.operator.dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(self.operator.dag.dag_id) + ti = TaskInstance(task=self.operator, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.operator.dag_id, logical_date=DEFAULT_DATE, @@ -115,7 +122,7 @@ def test_render_template(self, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.operator) + ti = TaskInstance(task=self.operator) ti.dag_run = dag_run session.add(ti) session.commit() @@ -148,6 +155,12 @@ def test_render_template_from_file(self, mocked_hook_client, session, clean_dags self.operator.params = {"releaseLabel": "5.11.0"} if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + self.operator.dag.sync_to_db() + SerializedDagModel.write_dag(self.operator.dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(self.operator.dag.dag_id) + ti = TaskInstance(task=self.operator, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.operator.dag_id, logical_date=DEFAULT_DATE, @@ -163,7 +176,7 @@ def test_render_template_from_file(self, mocked_hook_client, session, clean_dags run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.operator) + ti = TaskInstance(task=self.operator) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py index 223943511ae77..c422de837607e 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py @@ -33,6 +33,7 @@ from airflow import DAG from airflow.exceptions import AirflowException from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.s3 import ( @@ -59,6 +60,7 @@ from airflow.utils.timezone import datetime, utcnow from airflow.utils.types import DagRunType +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from unit.amazon.aws.utils.test_template_fields import validate_template_fields BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-bucket") @@ -667,7 +669,15 @@ def test_dates_from_template(self, session): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=op) + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=op, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=op) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py index 6219001f03feb..9c9515d4f88ea 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py @@ -26,6 +26,7 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.operators.sagemaker import ( SageMakerBaseOperator, SageMakerCreateExperimentOperator, @@ -209,6 +210,12 @@ def test_create_experiment(self, conn_mock, session, clean_dags_and_dagruns): dag=dag, ) if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=op, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=dag.dag_id, logical_date=logical_date, @@ -224,7 +231,7 @@ def test_create_experiment(self, conn_mock, session, clean_dags_and_dagruns): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=op) + ti = TaskInstance(task=op) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py index 4ada09d092303..aecaf0b91a574 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py @@ -26,6 +26,7 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.variable import Variable from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor @@ -129,6 +130,11 @@ def test_parse_bucket_key_from_jinja(self, mock_head_object, session, clean_dags ) if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) dag_run = DagRun( dag_id=dag.dag_id, logical_date=logical_date, @@ -136,6 +142,7 @@ def test_parse_bucket_key_from_jinja(self, mock_head_object, session, clean_dags run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) + ti = TaskInstance(task=op, dag_version_id=dag_version.id) else: dag_run = DagRun( dag_id=dag.dag_id, @@ -144,7 +151,7 @@ def test_parse_bucket_key_from_jinja(self, mock_head_object, session, clean_dags run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=op) + ti = TaskInstance(task=op) ti.dag_run = dag_run session.add(ti) session.commit() @@ -179,7 +186,13 @@ def test_parse_list_of_bucket_keys_from_jinja(self, mock_head_object, session, c run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) + ti = TaskInstance(task=op) else: + from airflow.models.dag_version import DagVersion + + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) dag_run = DagRun( dag_id=dag.dag_id, logical_date=logical_date, @@ -187,7 +200,7 @@ def test_parse_list_of_bucket_keys_from_jinja(self, mock_head_object, session, c run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=op) + ti = TaskInstance(task=op, dag_version_id=dag_version.id) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_base.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_base.py index 235955bd6eb84..eb03cace7ea4a 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_base.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_base.py @@ -21,6 +21,7 @@ from airflow import DAG from airflow.models import DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.transfers.base import AwsToAwsBaseOperator from airflow.utils import timezone from airflow.utils.state import DagRunState @@ -44,8 +45,14 @@ def test_render_template(self, session, clean_dags_and_dagruns): source_aws_conn_id="{{ ds }}", dest_aws_conn_id="{{ ds }}", ) - ti = TaskInstance(operator, run_id="something") + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(self.dag.dag_id) + ti = TaskInstance(operator, run_id="something", dag_version_id=dag_version.id) ti.dag_run = DagRun( dag_id=self.dag.dag_id, run_id="something", @@ -54,6 +61,7 @@ def test_render_template(self, session, clean_dags_and_dagruns): state=DagRunState.RUNNING, ) else: + ti = TaskInstance(operator, run_id="something") ti.dag_run = DagRun( dag_id=self.dag.dag_id, run_id="something", diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_dynamodb_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_dynamodb_to_s3.py index dcd3b6058a660..34519383b6cec 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_dynamodb_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_dynamodb_to_s3.py @@ -26,6 +26,7 @@ from airflow import DAG from airflow.models import DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import ( DynamoDBToS3Operator, JSONEncoder, @@ -274,8 +275,14 @@ def test_render_template(self, session): source_aws_conn_id="{{ ds }}", dest_aws_conn_id="{{ ds }}", ) - ti = TaskInstance(operator, run_id="something") + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(operator, run_id="something", dag_version_id=dag_version.id) ti.dag_run = DagRun( dag_id=dag.dag_id, run_id="something", @@ -284,6 +291,7 @@ def test_render_template(self, session): state=DagRunState.RUNNING, ) else: + ti = TaskInstance(operator, run_id="something") ti.dag_run = DagRun( dag_id=dag.dag_id, run_id="something", diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py index 764faf6dadbfe..5401e557be759 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py @@ -22,6 +22,7 @@ import pytest from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.amazon.aws.transfers.mongo_to_s3 import MongoToS3Operator from airflow.utils import timezone from airflow.utils.state import DagRunState @@ -84,6 +85,12 @@ def test_template_field_overrides(self): @pytest.mark.db_test def test_render_template(self, session): if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(self.mock_operator.dag_id) + ti = TaskInstance(self.mock_operator, dag_version_id=dag_version.id) dag_run = DagRun( dag_id=self.mock_operator.dag_id, logical_date=DEFAULT_DATE, @@ -99,7 +106,7 @@ def test_render_template(self, session): run_type=DagRunType.MANUAL, state=DagRunState.RUNNING, ) - ti = TaskInstance(task=self.mock_operator) + ti = TaskInstance(task=self.mock_operator) ti.dag_run = dag_run session.add(ti) session.commit() diff --git a/providers/apache/kylin/tests/unit/apache/kylin/operators/test_kylin_cube.py b/providers/apache/kylin/tests/unit/apache/kylin/operators/test_kylin_cube.py index 47e53cbc5f373..1f56d597d05bd 100644 --- a/providers/apache/kylin/tests/unit/apache/kylin/operators/test_kylin_cube.py +++ b/providers/apache/kylin/tests/unit/apache/kylin/operators/test_kylin_cube.py @@ -170,8 +170,15 @@ def test_render_template(self, session): "end_time": "1483286400000", }, ) - ti = TaskInstance(operator, run_id="kylin_test") + if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + from airflow.models.dag_version import DagVersion + from airflow.models.serialized_dag import SerializedDagModel + + SerializedDagModel.write_dag(dag=self.dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(operator.dag_id) + ti = TaskInstance(operator, run_id="kylin_test", dag_version_id=dag_version.id) ti.dag_run = DagRun( dag_id=self.dag.dag_id, run_id="kylin_test", @@ -182,6 +189,7 @@ def test_render_template(self, session): state=state.DagRunState.RUNNING, ) else: + ti = TaskInstance(operator, run_id="kylin_test") ti.dag_run = DagRun( dag_id=self.dag.dag_id, run_id="kylin_test", diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 264606748b167..3e05f89abadc9 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -197,8 +197,15 @@ def test_spark_submit_cmd_connection_overrides(self): def test_render_template(self, session): # Given operator = SparkSubmitOperator(task_id="spark_submit_job", dag=self.dag, **self._config) - ti = TaskInstance(operator, run_id="spark_test") + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + from airflow.models.serialized_dag import SerializedDagModel + + self.dag.sync_to_db() + SerializedDagModel.write_dag(dag=self.dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(operator.dag_id) + ti = TaskInstance(operator, run_id="spark_test", dag_version_id=dag_version.id) ti.dag_run = DagRun( dag_id=self.dag.dag_id, run_id="spark_test", @@ -209,6 +216,7 @@ def test_render_template(self, session): state="running", ) else: + ti = TaskInstance(operator, run_id="spark_test") ti.dag_run = DagRun( dag_id=self.dag.dag_id, run_id="spark_test", diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py b/providers/celery/tests/integration/celery/test_celery_executor.py index 08575811cc94e..874fbce3cbc5d 100644 --- a/providers/celery/tests/integration/celery/test_celery_executor.py +++ b/providers/celery/tests/integration/celery/test_celery_executor.py @@ -30,6 +30,7 @@ # leave this it is used by the test worker import celery.contrib.testing.tasks # noqa: F401 import pytest +import uuid6 from celery import Celery from celery.backends.base import BaseBackend, BaseKeyValueStoreBackend from celery.backends.database import DatabaseBackend @@ -220,7 +221,10 @@ def fake_task(): dag=DAG(dag_id="dag_id"), start_date=datetime.now(), ) - ti = TaskInstance(task=task, run_id="abc") + if AIRFLOW_V_3_0_PLUS: + ti = TaskInstance(task=task, run_id="abc", dag_version_id=uuid6.uuid7()) + else: + ti = TaskInstance(task=task, run_id="abc") workload = workloads.ExecuteTask.model_construct( ti=workloads.TaskInstance.model_validate(ti, from_attributes=True), ) @@ -256,7 +260,10 @@ def test_retry_on_error_sending_task(self, caplog): dag=DAG(dag_id="id"), start_date=datetime.now(), ) - ti = TaskInstance(task=task, run_id="abc") + if AIRFLOW_V_3_0_PLUS: + ti = TaskInstance(task=task, run_id="abc", dag_version_id=uuid6.uuid7()) + else: + ti = TaskInstance(task=task, run_id="abc") workload = workloads.ExecuteTask.model_construct( ti=workloads.TaskInstance.model_validate(ti, from_attributes=True), ) diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index d71fa9a64bb56..c56f16ddf3a3c 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -36,6 +36,7 @@ from airflow.configuration import conf from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.providers.celery.executors import celery_executor, celery_executor_utils, default_celery from airflow.providers.celery.executors.celery_executor import CeleryExecutor @@ -46,6 +47,9 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + pytestmark = pytest.mark.db_test @@ -192,10 +196,16 @@ def test_command_validation(self, command, raise_exception): def test_try_adopt_task_instances_none(self): start_date = timezone.utcnow() - timedelta(days=2) - with DAG("test_try_adopt_task_instances_none", schedule=None): + with DAG("test_try_adopt_task_instances_none", schedule=None) as dag: task_1 = BaseOperator(task_id="task_1", start_date=start_date) - key1 = TaskInstance(task=task_1, run_id=None) + if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + key1 = TaskInstance(task=task_1, run_id=None, dag_version_id=dag_version.id) + else: + key1 = TaskInstance(task=task_1, run_id=None) tis = [key1] executor = celery_executor.CeleryExecutor() @@ -211,10 +221,17 @@ def test_try_adopt_task_instances(self): task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) - ti1 = TaskInstance(task=task_1, run_id=None) + if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti1 = TaskInstance(task=task_1, run_id=None, dag_version_id=dag_version.id) + ti2 = TaskInstance(task=task_2, run_id=None, dag_version_id=dag_version.id) + else: + ti1 = TaskInstance(task=task_1, run_id=None) + ti2 = TaskInstance(task=task_2, run_id=None) ti1.external_executor_id = "231" ti1.state = State.QUEUED - ti2 = TaskInstance(task=task_2, run_id=None) ti2.external_executor_id = "232" ti2.state = State.QUEUED @@ -244,10 +261,16 @@ def mock_celery_revoke(self): def test_cleanup_stuck_queued_tasks(self, mock_fail): start_date = timezone.utcnow() - timedelta(days=2) - with DAG("test_cleanup_stuck_queued_tasks_failed", schedule=None): + with DAG("test_cleanup_stuck_queued_tasks_failed", schedule=None) as dag: task = BaseOperator(task_id="task_1", start_date=start_date) - ti = TaskInstance(task=task, run_id=None) + if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(task.dag.dag_id) + ti = TaskInstance(task=task, run_id=None, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task, run_id=None) ti.external_executor_id = "231" ti.state = State.QUEUED ti.queued_dttm = timezone.utcnow() - timedelta(minutes=30) @@ -273,10 +296,16 @@ def test_cleanup_stuck_queued_tasks(self, mock_fail): def test_revoke_task(self, mock_fail): start_date = timezone.utcnow() - timedelta(days=2) - with DAG("test_revoke_task", schedule=None): + with DAG("test_revoke_task", schedule=None) as dag: task = BaseOperator(task_id="task_1", start_date=start_date) - ti = TaskInstance(task=task, run_id=None) + if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(task.dag.dag_id) + ti = TaskInstance(task=task, run_id=None, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task, run_id=None) ti.external_executor_id = "231" ti.state = State.QUEUED ti.queued_dttm = timezone.utcnow() - timedelta(minutes=30) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py index 46e41e348f145..41737d713eec2 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py @@ -66,7 +66,12 @@ def generate_pod_yaml(args): kube_config = KubeConfig() for task in dag.tasks: - ti = TaskInstance(task, run_id=dr.run_id) + if AIRFLOW_V_3_0_PLUS: + from uuid6 import uuid7 + + ti = TaskInstance(task, run_id=dr.run_id, dag_version_id=uuid7()) + else: + ti = TaskInstance(task, run_id=dr.run_id) ti.dag_run = dr ti.dag_model = dm diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py index e467bcda87633..231b62f85a867 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/log_handlers/test_log_handlers.py @@ -31,6 +31,7 @@ from airflow.executors import executor_loader from airflow.models.dag import DAG from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.utils.log.file_task_handler import ( FileTaskHandler, @@ -132,6 +133,8 @@ def task_callable(ti): "run_after": DEFAULT_DATE, "triggered_by": DagRunTriggeredByType.TEST, } + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") else: dagrun_kwargs = {"execution_date": DEFAULT_DATE} dagrun = dag.create_dagrun( @@ -141,7 +144,10 @@ def task_callable(ti): data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE), **dagrun_kwargs, ) - ti = TaskInstance(task=task, run_id=dagrun.run_id) + if AIRFLOW_V_3_0_PLUS: + ti = TaskInstance(task=task, run_id=dagrun.run_id, dag_version_id=dagrun.created_dag_version_id) + else: + ti = TaskInstance(task=task, run_id=dagrun.run_id) ti.try_number = 3 ti.executor = "KubernetesExecutor" diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py index 691a549696116..130a87cbbe3b1 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py @@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagModel, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.cncf.kubernetes.operators.job import ( KubernetesDeleteJobOperator, KubernetesJobOperator, @@ -60,6 +61,8 @@ def create_context(task, persist_to_db=False, map_index=None): dag = DAG(dag_id="dag", schedule=None, start_date=pendulum.now()) dag.add_task(task) if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dag_run = DagRun( run_id=DagRun.generate_run_id( run_type=DagRunType.MANUAL, logical_date=DEFAULT_DATE, run_after=DEFAULT_DATE @@ -73,7 +76,13 @@ def create_context(task, persist_to_db=False, map_index=None): run_type=DagRunType.MANUAL, dag_id=dag.dag_id, ) - task_instance = TaskInstance(task=task, run_id=dag_run.run_id) + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(dag.dag_id) + task_instance = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) + else: + task_instance = TaskInstance(task=task, run_id=dag_run.run_id) task_instance.dag_run = dag_run if map_index is not None: task_instance.map_index = map_index diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py index 37af22925d1b0..db6f966ee6fff 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py @@ -37,6 +37,7 @@ TaskDeferred, ) from airflow.models import DAG, DagModel, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.cncf.kubernetes import pod_generator from airflow.providers.cncf.kubernetes.operators.pod import ( KubernetesPodOperator, @@ -105,6 +106,8 @@ def create_context(task, persist_to_db=False, map_index=None): dag.add_task(task) now = timezone.utcnow() if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dag_run = DagRun( run_id=DagRun.generate_run_id( run_type=DagRunType.MANUAL, logical_date=DEFAULT_DATE, run_after=DEFAULT_DATE @@ -122,13 +125,20 @@ def create_context(task, persist_to_db=False, map_index=None): dag_id=dag.dag_id, execution_date=now, ) - task_instance = TaskInstance(task=task, run_id=dag_run.run_id) + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag_version = DagVersion.get_latest_version(dag.dag_id) + task_instance = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) + else: + task_instance = TaskInstance(task=task, run_id=dag_run.run_id) task_instance.dag_run = dag_run if map_index is not None: task_instance.map_index = map_index if persist_to_db: with create_session() as session: - session.add(DagModel(dag_id=dag.dag_id)) + if not AIRFLOW_V_3_0_PLUS: + session.add(DagModel(dag_id=dag.dag_id)) session.add(dag_run) session.add(task_instance) session.commit() diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py index fa91a6686368a..e16ae58a643e3 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -211,7 +211,12 @@ def create_context(task): execution_date=logical_date, run_id=DagRun.generate_run_id(DagRunType.MANUAL, logical_date), ) - task_instance = TaskInstance(task=task) + if AIRFLOW_V_3_0_PLUS: + from uuid6 import uuid7 + + task_instance = TaskInstance(task=task, dag_version_id=uuid7()) + else: + task_instance = TaskInstance(task=task) task_instance.dag_run = dag_run task_instance.dag_id = dag.dag_id task_instance.xcom_push = mock.Mock() diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index 058ccf2b55b25..489e228dd8055 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -36,6 +36,7 @@ from airflow import DAG from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection, DagRun, TaskInstance as TI +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.common.sql.operators.sql import ( BaseSQLOperator, @@ -1103,6 +1104,9 @@ def setup_method(self): self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag) self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None + if AIRFLOW_V_3_0_PLUS: + self.dag.sync_to_db() + SerializedDagModel.write_dag(self.dag, bundle_name="testing") def get_ti(self, task_id, dr=None): if dr is None: diff --git a/providers/docker/tests/unit/docker/decorators/test_docker.py b/providers/docker/tests/unit/docker/decorators/test_docker.py index c2264a6c76340..94939f11ad20d 100644 --- a/providers/docker/tests/unit/docker/decorators/test_docker.py +++ b/providers/docker/tests/unit/docker/decorators/test_docker.py @@ -97,7 +97,10 @@ def f(): ret = f() dr = dag_maker.create_dagrun() - ti = TaskInstance(task=ret.operator, run_id=dr.run_id) + if AIRFLOW_V_3_0_PLUS: + ti = TaskInstance(task=ret.operator, run_id=dr.run_id, dag_version_id=dr.created_dag_version_id) + else: + ti = TaskInstance(task=ret.operator, run_id=dr.run_id) rendered = ti.render_templates() assert rendered.container_name == f"python_{dr.dag_id}" assert rendered.mounts[0]["Target"] == f"/{ti.run_id}" diff --git a/providers/docker/tests/unit/docker/operators/test_docker.py b/providers/docker/tests/unit/docker/operators/test_docker.py index a8ab46e89af87..1e540c6404583 100644 --- a/providers/docker/tests/unit/docker/operators/test_docker.py +++ b/providers/docker/tests/unit/docker/operators/test_docker.py @@ -27,7 +27,6 @@ from docker.types import DeviceRequest, LogConfig, Mount, Ulimit from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.models import TaskInstance from airflow.providers.docker.exceptions import DockerContainerFailedException from airflow.providers.docker.operators.docker import DockerOperator, fetch_logs @@ -810,7 +809,7 @@ def test_basic_docker_operator_with_template_fields(self, dag_maker): operator.execute({}) dr = dag_maker.create_dagrun() - ti = TaskInstance(task=operator, run_id=dr.run_id) + ti = dr.task_instances[0] rendered = ti.render_templates() assert rendered.container_name == f"python_{dr.dag_id}" assert rendered.mounts[0]["Target"] == f"/{ti.run_id}" diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index 5133d92ed37b9..8671bf590bfca 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -63,6 +63,7 @@ "dag_id": "mock", "run_id": "mock", "try_number": 1, + "dag_version_id": "01234567-89ab-cdef-0123-456789abcdef", "pool_slots": 1, "queue": "default", "priority_weight": 1, diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py index 02ab7582a54cd..947264ea39503 100644 --- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py +++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py @@ -328,6 +328,7 @@ def test_queue_workload(self): queue="default", priority_weight=1, start_date=timezone.utcnow(), + dag_version_id="4d828a62-a417-4936-a7a6-2b3fabacecab", ), dag_rel_path="mock.py", log_path="mock.log", diff --git a/providers/google/tests/unit/google/cloud/utils/airflow_util.py b/providers/google/tests/unit/google/cloud/utils/airflow_util.py index 3e0b14cb0a567..15a191715075e 100644 --- a/providers/google/tests/unit/google/cloud/utils/airflow_util.py +++ b/providers/google/tests/unit/google/cloud/utils/airflow_util.py @@ -20,6 +20,7 @@ from unittest import mock import pendulum +from uuid6 import uuid7 from airflow.models import DAG, Connection from airflow.models.dagrun import DagRun @@ -27,6 +28,8 @@ from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + if TYPE_CHECKING: from airflow.providers.google.version_compat import BaseOperator @@ -39,7 +42,9 @@ def get_dag_run(dag_id: str = "test_dag_id", run_id: str = "test_dag_id") -> Dag def get_task_instance(task: BaseOperator) -> TaskInstance: - return TaskInstance(task, timezone.datetime(2022, 1, 1)) + if AIRFLOW_V_3_0_PLUS: + return TaskInstance(task=task, run_id=None, dag_version_id=uuid7()) + return TaskInstance(task=task, run_id=None) # type: ignore def get_conn() -> Connection: diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py index 23d7d18d4a407..714bb3a4f384b 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py @@ -319,6 +319,8 @@ def get_dag_run(self, dag_id: str = "test_dag_id", run_id: str = "test_dag_id") return dag_run def get_task_instance(self, task: BaseOperator) -> TaskInstance: + if AIRFLOW_V_3_0_PLUS: + return TaskInstance(task, run_id=timezone.datetime(2022, 1, 1), dag_version_id=mock.MagicMock()) return TaskInstance(task, timezone.datetime(2022, 1, 1)) def get_conn( @@ -348,8 +350,10 @@ def create_context(self, task, dag=None): execution_date=logical_date, run_id=DagRun.generate_run_id(DagRunType.MANUAL, logical_date), ) - - task_instance = TaskInstance(task=task) + if AIRFLOW_V_3_0_PLUS: + task_instance = TaskInstance(task=task, dag_version_id=mock.MagicMock()) + else: + task_instance = TaskInstance(task=task) task_instance.dag_run = dag_run task_instance.xcom_push = mock.Mock() date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_wasb.py b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_wasb.py index c93db4ea05b58..c1344b32adc86 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_wasb.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_wasb.py @@ -34,6 +34,8 @@ from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator @@ -96,6 +98,8 @@ def get_dag_run(self, dag_id: str = "test_dag_id", run_id: str = "test_dag_id") return dag_run def get_task_instance(self, task: BaseOperator) -> TaskInstance: + if AIRFLOW_V_3_0_PLUS: + return TaskInstance(task, run_id=timezone.datetime(2022, 1, 1), dag_version_id=mock.MagicMock()) return TaskInstance(task, timezone.datetime(2022, 1, 1)) def get_conn(self) -> Connection: @@ -123,8 +127,10 @@ def create_context(self, task, dag=None): run_type=DagRunType.MANUAL, logical_date=logical_date, run_after=logical_date ), ) - - task_instance = TaskInstance(task=task) + if AIRFLOW_V_3_0_PLUS: + task_instance = TaskInstance(task=task, dag_version_id=mock.MagicMock()) + else: + task_instance = TaskInstance(task=task) task_instance.dag_run = dag_run task_instance.xcom_push = mock.Mock() date_key = "execution_date" if hasattr(DagRun, "execution_date") else "logical_date" @@ -241,6 +247,8 @@ def get_dag_run(self, dag_id: str = "test_dag_id", run_id: str = "test_dag_id") return dag_run def get_task_instance(self, task: BaseOperator) -> TaskInstance: + if AIRFLOW_V_3_0_PLUS: + return TaskInstance(task, run_id=timezone.datetime(2022, 1, 1), dag_version_id=mock.MagicMock()) return TaskInstance(task, timezone.datetime(2022, 1, 1)) def get_conn(self) -> Connection: @@ -268,8 +276,10 @@ def create_context(self, task, dag=None): run_type=DagRunType.MANUAL, logical_date=logical_date, run_after=logical_date ), ) - - task_instance = TaskInstance(task=task) + if AIRFLOW_V_3_0_PLUS: + task_instance = TaskInstance(task=task, dag_version_id=mock.MagicMock()) + else: + task_instance = TaskInstance(task=task) task_instance.dag_run = dag_run task_instance.xcom_push = mock.Mock() date_key = "execution_date" if hasattr(DagRun, "execution_date") else "logical_date" diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py b/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py index d0401c8a4f0ae..e79e79489f177 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py @@ -869,11 +869,24 @@ def test_emit_dag_complete_event( dag_run._state = DagRunState.SUCCESS dag_run.end_date = event_time - mocked_fetch_tis.return_value = [ - TaskInstance(task=task_0, run_id=run_id, state=TaskInstanceState.SUCCESS), - TaskInstance(task=task_1, run_id=run_id, state=TaskInstanceState.SKIPPED), - TaskInstance(task=task_2, run_id=run_id, state=TaskInstanceState.FAILED), - ] + if AIRFLOW_V_3_0_PLUS: + mocked_fetch_tis.return_value = [ + TaskInstance( + task=task_0, run_id=run_id, state=TaskInstanceState.SUCCESS, dag_version_id=mock.MagicMock() + ), + TaskInstance( + task=task_1, run_id=run_id, state=TaskInstanceState.SKIPPED, dag_version_id=mock.MagicMock() + ), + TaskInstance( + task=task_2, run_id=run_id, state=TaskInstanceState.FAILED, dag_version_id=mock.MagicMock() + ), + ] + else: + mocked_fetch_tis.return_value = [ + TaskInstance(task=task_0, run_id=run_id, state=TaskInstanceState.SUCCESS), + TaskInstance(task=task_1, run_id=run_id, state=TaskInstanceState.SKIPPED), + TaskInstance(task=task_2, run_id=run_id, state=TaskInstanceState.FAILED), + ] generate_static_uuid.return_value = random_uuid adapter.dag_success( @@ -1007,11 +1020,24 @@ def test_emit_dag_failed_event( ) dag_run._state = DagRunState.FAILED dag_run.end_date = event_time - mocked_fetch_tis.return_value = [ - TaskInstance(task=task_0, run_id=run_id, state=TaskInstanceState.SUCCESS), - TaskInstance(task=task_1, run_id=run_id, state=TaskInstanceState.SKIPPED), - TaskInstance(task=task_2, run_id=run_id, state=TaskInstanceState.FAILED), - ] + if AIRFLOW_V_3_0_PLUS: + mocked_fetch_tis.return_value = [ + TaskInstance( + task=task_0, run_id=run_id, state=TaskInstanceState.SUCCESS, dag_version_id=mock.MagicMock() + ), + TaskInstance( + task=task_1, run_id=run_id, state=TaskInstanceState.SKIPPED, dag_version_id=mock.MagicMock() + ), + TaskInstance( + task=task_2, run_id=run_id, state=TaskInstanceState.FAILED, dag_version_id=mock.MagicMock() + ), + ] + else: + mocked_fetch_tis.return_value = [ + TaskInstance(task=task_0, run_id=run_id, state=TaskInstanceState.SUCCESS), + TaskInstance(task=task_1, run_id=run_id, state=TaskInstanceState.SKIPPED), + TaskInstance(task=task_2, run_id=run_id, state=TaskInstanceState.FAILED), + ] generate_static_uuid.return_value = random_uuid adapter.dag_failed( diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index 563a16e1363fd..a33abd926fbcf 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -203,7 +203,10 @@ def sample_callable(**kwargs): state=DagRunState.RUNNING, execution_date=date, # type: ignore ) - task_instance = TaskInstance(t, run_id=run_id) + if AIRFLOW_V_3_0_PLUS: + task_instance = TaskInstance(t, run_id=run_id, dag_version_id=dagrun.created_dag_version_id) + else: + task_instance = TaskInstance(t, run_id=run_id) # type: ignore return dagrun, task_instance def _create_listener_and_task_instance(self) -> tuple[OpenLineageListener, TaskInstance]: @@ -241,8 +244,10 @@ def mock_task_id(dag_id, task_id, try_number, logical_date, map_index): adapter.fail_task = mock.Mock() adapter.complete_task = mock.Mock() listener.adapter = adapter - - task_instance = TaskInstance(task=mock.Mock()) + if AIRFLOW_V_3_0_PLUS: + task_instance = TaskInstance(task=mock.Mock(), dag_version_id=mock.MagicMock()) + else: + task_instance = TaskInstance(task=mock.Mock()) # type: ignore task_instance.dag_run = DagRun() task_instance.dag_run.run_id = "dag_run_run_id" task_instance.dag_run.data_interval_start = None @@ -957,6 +962,7 @@ def test_listener_does_not_change_task_instance(self, render_mock, mock_supervis run_id=run_id, try_number=1, map_index=-1, + dag_version_id=uuid7(), ) runtime_ti = RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), task=task) @@ -1044,7 +1050,7 @@ def sample_callable(**kwargs): state=DagRunState.QUEUED, **dagrun_kwargs, # type: ignore ) - task_instance = TaskInstance(t, run_id=run_id) + task_instance = TaskInstance(t, run_id=run_id) # type: ignore task_instance.dag_run = dagrun return dagrun, task_instance @@ -1068,7 +1074,7 @@ def _create_listener_and_task_instance( if not runtime_ti: # TaskInstance is used when on API server (when listener gets called about manual state change) - task_instance = TaskInstance(task=MagicMock()) # type: ignore + task_instance = TaskInstance(task=MagicMock(), dag_version_id=uuid7()) # type: ignore task_instance.dag_run = DagRun() task_instance.dag_run.dag_id = "dag_id_from_dagrun_and_not_ti" task_instance.dag_run.run_id = "dag_run_run_id" @@ -1109,6 +1115,7 @@ def _create_listener_and_task_instance( run_id="dag_run_run_id", try_number=1, map_index=-1, + dag_version_id=uuid7(), ) task_instance = RuntimeTaskInstance.model_construct( # type: ignore **sdk_task_instance.model_dump(exclude_unset=True), diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 04ff2bdb5c720..ca8d78ce7c6b3 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -19,6 +19,7 @@ import datetime import pathlib +from unittest import mock from unittest.mock import MagicMock, PropertyMock, patch import pendulum @@ -832,12 +833,23 @@ def test_get_task_groups_details_no_task_groups(): @patch("airflow.providers.openlineage.conf.custom_run_facets", return_value=set()) def test_get_user_provided_run_facets_with_no_function_definition(mock_custom_facet_funcs): - sample_ti = TaskInstance( - task=EmptyOperator( - task_id="test-task", dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)) - ), - state="running", - ) + if AIRFLOW_V_3_0_PLUS: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + dag_version_id=mock.MagicMock(), + ) + else: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + ) result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING) assert result == {} @@ -847,12 +859,23 @@ def test_get_user_provided_run_facets_with_no_function_definition(mock_custom_fa return_value={"unit.openlineage.utils.custom_facet_fixture.get_additional_test_facet"}, ) def test_get_user_provided_run_facets_with_function_definition(mock_custom_facet_funcs): - sample_ti = TaskInstance( - task=EmptyOperator( - task_id="test-task", dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)) - ), - state="running", - ) + if AIRFLOW_V_3_0_PLUS: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + dag_version_id=mock.MagicMock(), + ) + else: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + ) result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING) assert len(result) == 1 assert result["additional_run_facet"].name == f"test-lineage-namespace-{TaskInstanceState.RUNNING}" @@ -866,14 +889,25 @@ def test_get_user_provided_run_facets_with_function_definition(mock_custom_facet }, ) def test_get_user_provided_run_facets_with_return_value_as_none(mock_custom_facet_funcs): - sample_ti = TaskInstance( - task=BashOperator( - task_id="test-task", - bash_command="exit 0;", - dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), - ), - state="running", - ) + if AIRFLOW_V_3_0_PLUS: + sample_ti = TaskInstance( + task=BashOperator( + task_id="test-task", + bash_command="exit 0;", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + dag_version_id=mock.MagicMock(), + ) + else: + sample_ti = TaskInstance( + task=BashOperator( + task_id="test-task", + bash_command="exit 0;", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + ) result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING) assert result == {} @@ -888,12 +922,23 @@ def test_get_user_provided_run_facets_with_return_value_as_none(mock_custom_face }, ) def test_get_user_provided_run_facets_with_multiple_function_definition(mock_custom_facet_funcs): - sample_ti = TaskInstance( - task=EmptyOperator( - task_id="test-task", dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)) - ), - state="running", - ) + if AIRFLOW_V_3_0_PLUS: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + dag_version_id=mock.MagicMock(), + ) + else: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + ) result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING) assert len(result) == 2 assert result["additional_run_facet"].name == f"test-lineage-namespace-{TaskInstanceState.RUNNING}" @@ -909,12 +954,23 @@ def test_get_user_provided_run_facets_with_multiple_function_definition(mock_cus }, ) def test_get_user_provided_run_facets_with_duplicate_facet_keys(mock_custom_facet_funcs): - sample_ti = TaskInstance( - task=EmptyOperator( - task_id="test-task", dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)) - ), - state="running", - ) + if AIRFLOW_V_3_0_PLUS: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + dag_version_id=mock.MagicMock(), + ) + else: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + ) result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING) assert len(result) == 1 assert result["additional_run_facet"].name == f"test-lineage-namespace-{TaskInstanceState.RUNNING}" @@ -926,12 +982,23 @@ def test_get_user_provided_run_facets_with_duplicate_facet_keys(mock_custom_face return_value={"invalid_function"}, ) def test_get_user_provided_run_facets_with_invalid_function_definition(mock_custom_facet_funcs): - sample_ti = TaskInstance( - task=EmptyOperator( - task_id="test-task", dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)) - ), - state="running", - ) + if AIRFLOW_V_3_0_PLUS: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + dag_version_id=mock.MagicMock(), + ) + else: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + ) result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING) assert result == {} @@ -941,12 +1008,23 @@ def test_get_user_provided_run_facets_with_invalid_function_definition(mock_cust return_value={"providers.unit.openlineage.utils.custom_facet_fixture.return_type_is_not_dict"}, ) def test_get_user_provided_run_facets_with_wrong_return_type_function(mock_custom_facet_funcs): - sample_ti = TaskInstance( - task=EmptyOperator( - task_id="test-task", dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)) - ), - state="running", - ) + if AIRFLOW_V_3_0_PLUS: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + dag_version_id=mock.MagicMock(), + ) + else: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + ) result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING) assert result == {} @@ -956,12 +1034,23 @@ def test_get_user_provided_run_facets_with_wrong_return_type_function(mock_custo return_value={"providers.unit.openlineage.utils.custom_facet_fixture.get_custom_facet_throws_exception"}, ) def test_get_user_provided_run_facets_with_exception(mock_custom_facet_funcs): - sample_ti = TaskInstance( - task=EmptyOperator( - task_id="test-task", dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)) - ), - state="running", - ) + if AIRFLOW_V_3_0_PLUS: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + dag_version_id=mock.MagicMock(), + ) + else: + sample_ti = TaskInstance( + task=EmptyOperator( + task_id="test-task", + dag=DAG("test-dag", schedule=None, start_date=datetime.datetime(2024, 7, 1)), + ), + state="running", + ) result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING) assert result == {} @@ -1658,6 +1747,7 @@ def test_taskinstance_info_af3(): run_id="test_run", try_number=1, map_index=2, + dag_version_id=ti_id, ) start_date = timezone.datetime(2025, 1, 1) diff --git a/providers/redis/tests/unit/redis/log/test_redis_task_handler.py b/providers/redis/tests/unit/redis/log/test_redis_task_handler.py index fb845ba48dd45..99bb497c56318 100644 --- a/providers/redis/tests/unit/redis/log/test_redis_task_handler.py +++ b/providers/redis/tests/unit/redis/log/test_redis_task_handler.py @@ -23,6 +23,7 @@ import pytest from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.redis.log.redis_task_handler import RedisTaskHandler from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils.session import create_session @@ -62,7 +63,15 @@ def ti(self): session.commit() session.refresh(dag_run) - ti = TaskInstance(task=task, run_id=dag_run.run_id) + if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task, run_id=dag_run.run_id) ti.dag_run = dag_run ti.try_number = 1 ti.state = State.RUNNING diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py index 9f52f80a4eb41..721b72e57811e 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py @@ -177,6 +177,13 @@ def create_context(task, dag=None): tzinfo = pendulum.timezone("UTC") logical_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo) if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion + from airflow.models.serialized_dag import SerializedDagModel + + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + task_instance = TaskInstance(task=task, run_id="test_run_id", dag_version_id=dag_version.id) dag_run = DagRun( dag_id=dag.dag_id, logical_date=logical_date, @@ -191,7 +198,7 @@ def create_context(task, dag=None): run_id=DagRun.generate_run_id(DagRunType.MANUAL, logical_date), ) - task_instance = TaskInstance(task=task) + task_instance = TaskInstance(task=task) task_instance.dag_run = dag_run task_instance.xcom_push = mock.Mock() date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" @@ -208,6 +215,7 @@ def create_context(task, dag=None): } +@pytest.mark.db_test class TestSnowflakeSqlApiOperator: @pytest.fixture def mock_execute_query(self): diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh.py b/providers/ssh/tests/unit/ssh/operators/test_ssh.py index b100fc13d31a4..2222db83cc74a 100644 --- a/providers/ssh/tests/unit/ssh/operators/test_ssh.py +++ b/providers/ssh/tests/unit/ssh/operators/test_ssh.py @@ -27,12 +27,17 @@ from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout from airflow.models import TaskInstance +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.operators.ssh import SSHOperator from airflow.utils.timezone import datetime from airflow.utils.types import NOTSET from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion pytestmark = pytest.mark.db_test @@ -263,10 +268,16 @@ def test_push_ssh_exit_to_xcom(self, request, dag_maker): ssh_exit_code = random.randrange(1, 100) self.exec_ssh_client_command.return_value = (ssh_exit_code, b"", b"ssh output") - with dag_maker(dag_id=f"dag_{request.node.name}"): + with dag_maker(dag_id=f"dag_{request.node.name}") as dag: task = SSHOperator(task_id="push_xcom", ssh_hook=self.hook, command=command) dr = dag_maker.create_dagrun(run_id="push_xcom") - ti = TaskInstance(task=task, run_id=dr.run_id) + if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=task, run_id=dr.run_id, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task, run_id=dr.run_id) with pytest.raises(AirflowException, match=f"SSH operator error: exit status = {ssh_exit_code}"): dag_maker.run_ti("push_xcom", dr) assert ti.xcom_pull(task_ids=task.task_id, key="ssh_exit") == ssh_exit_code diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 221aaf9b06fe8..e617fa8c88be7 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -382,7 +382,10 @@ def arg_task(*args): ret = arg_task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple) dr = self.create_dag_run() - ti = TaskInstance(task=ret.operator, run_id=dr.run_id) + if AIRFLOW_V_3_0_PLUS: + ti = TaskInstance(task=ret.operator, run_id=dr.run_id, dag_version_id=dr.created_dag_version_id) + else: + ti = TaskInstance(task=ret.operator, run_id=dr.run_id) rendered_op_args = ti.render_templates().op_args assert len(rendered_op_args) == 4 assert rendered_op_args[0] == 4 @@ -403,7 +406,10 @@ def kwargs_task(an_int, a_date, a_templated_string): ) dr = self.create_dag_run() - ti = TaskInstance(task=ret.operator, run_id=dr.run_id) + if AIRFLOW_V_3_0_PLUS: + ti = TaskInstance(task=ret.operator, run_id=dr.run_id, dag_version_id=dr.created_dag_version_id) + else: + ti = TaskInstance(task=ret.operator, run_id=dr.run_id) rendered_op_kwargs = ti.render_templates().op_kwargs assert rendered_op_kwargs["an_int"] == 4 assert rendered_op_kwargs["a_date"] == date(2019, 1, 1) diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index abbccf97cd9be..daff66dcb13f8 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -75,6 +75,7 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: + from airflow.models.dag_version import DagVersion from airflow.sdk import task as task_deco from airflow.utils.types import DagRunTriggeredByType else: @@ -121,10 +122,17 @@ def add_fake_task_group(self, target_states=None): with TaskGroup(group_id=TEST_TASK_GROUP_ID) as task_group: _ = [EmptyOperator(task_id=f"task{i}") for i in range(len(target_states))] dag.sync_to_db() + if AIRFLOW_V_3_0_PLUS: + SerializedDagModel.write_dag(dag, bundle_name="testing") + else: SerializedDagModel.write_dag(dag) for idx, task in enumerate(task_group): - ti = TaskInstance(task=task, run_id=self.dag_run_id) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(task_group[idx].dag_id) + ti = TaskInstance(task=task, run_id=self.dag_run_id, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task, run_id=self.dag_run_id) ti.run(ignore_ti_state=True, mark_success=True) ti.set_state(target_states[idx]) @@ -144,16 +152,27 @@ def fake_mapped_task(x: int): fake_task() fake_mapped_task.expand(x=list(map_indexes)) dag.sync_to_db() - SerializedDagModel.write_dag(dag) + if AIRFLOW_V_3_0_PLUS: + SerializedDagModel.write_dag(dag, bundle_name="testing") + else: + SerializedDagModel.write_dag(dag) for task in task_group: if task.task_id == "fake_mapped_task": for map_index in map_indexes: - ti = TaskInstance(task=task, run_id=self.dag_run_id, map_index=map_index) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=task, run_id=self.dag_run_id, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task, run_id=self.dag_run_id, map_index=map_index) ti.run(ignore_ti_state=True, mark_success=True) ti.set_state(target_state) else: - ti = TaskInstance(task=task, run_id=self.dag_run_id) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti = TaskInstance(task=task, run_id=self.dag_run_id, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task, run_id=self.dag_run_id) ti.run(ignore_ti_state=True, mark_success=True) ti.set_state(target_state) @@ -1847,6 +1866,8 @@ def _factory(depth: int) -> DagBag: for dag in dags: if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dag_bag.bag_dag(dag=dag) else: dag_bag.bag_dag(dag=dag, root_dag=dag) # type: ignore[call-arg] @@ -2004,6 +2025,9 @@ def dag_bag_head_tail(session): @provide_session def test_clear_overlapping_external_task_marker(dag_bag_head_tail, session): dag: DAG = dag_bag_head_tail.get_dag("head_tail") + dag.sync_to_db() + if AIRFLOW_V_3_0_PLUS: + SerializedDagModel.write_dag(dag, bundle_name="testing") # "Run" 10 times. for delta in range(10): @@ -2021,7 +2045,12 @@ def test_clear_overlapping_external_task_marker(dag_bag_head_tail, session): dagrun.execution_date = logical_date session.add(dagrun) for task in dag.tasks: - ti = TaskInstance(task=task) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(task.dag_id, session=session) + ti = TaskInstance(task=task, dag_version_id=dag_version.id) + + else: + ti = TaskInstance(task=task) dagrun.task_instances.append(ti) ti.state = TaskInstanceState.SUCCESS session.flush() @@ -2029,9 +2058,11 @@ def test_clear_overlapping_external_task_marker(dag_bag_head_tail, session): assert dag.clear(start_date=DEFAULT_DATE, dag_bag=dag_bag_head_tail, session=session) == 30 -@provide_session def test_clear_overlapping_external_task_marker_with_end_date(dag_bag_head_tail, session): dag: DAG = dag_bag_head_tail.get_dag("head_tail") + dag.sync_to_db() + if AIRFLOW_V_3_0_PLUS: + SerializedDagModel.write_dag(dag=dag, bundle_name="testing") # "Run" 10 times. for delta in range(10): @@ -2048,8 +2079,13 @@ def test_clear_overlapping_external_task_marker_with_end_date(dag_bag_head_tail, else: dagrun.execution_date = logical_date session.add(dagrun) + for task in dag.tasks: - ti = TaskInstance(task=task) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) + ti = TaskInstance(task=task, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task) dagrun.task_instances.append(ti) ti.state = TaskInstanceState.SUCCESS session.flush() @@ -2108,6 +2144,7 @@ def dummy_task(x: int): head >> body >> tail if AIRFLOW_V_3_0_PLUS: + dag.sync_to_db() dag_bag.bag_dag(dag=dag) else: dag_bag.bag_dag(dag=dag, root_dag=dag) @@ -2115,10 +2152,11 @@ def dummy_task(x: int): return dag_bag -@provide_session def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_mapped_tasks, session): dag: DAG = dag_bag_head_tail_mapped_tasks.get_dag("head_tail") - + dag.sync_to_db() + if AIRFLOW_V_3_0_PLUS: + SerializedDagModel.write_dag(dag=dag, bundle_name="testing") # "Run" 10 times. for delta in range(10): logical_date = DEFAULT_DATE + timedelta(days=delta) @@ -2137,11 +2175,24 @@ def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m for task in dag.tasks: if task.task_id == "dummy_task": for map_index in range(5): - ti = TaskInstance(task=task, run_id=dagrun.run_id, map_index=map_index) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) + ti = TaskInstance( + task=task, + run_id=dagrun.run_id, + map_index=map_index, + dag_version_id=dag_version.id, + ) + else: + ti = TaskInstance(task=task, run_id=dagrun.run_id, map_index=map_index) ti.state = TaskInstanceState.SUCCESS dagrun.task_instances.append(ti) else: - ti = TaskInstance(task=task, run_id=dagrun.run_id) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) + ti = TaskInstance(task=task, run_id=dagrun.run_id, dag_version_id=dag_version.id) + else: + ti = TaskInstance(task=task, run_id=dagrun.run_id) ti.state = TaskInstanceState.SUCCESS dagrun.task_instances.append(ti) session.flush() diff --git a/providers/standard/tests/unit/standard/sensors/test_time_delta.py b/providers/standard/tests/unit/standard/sensors/test_time_delta.py index bc70aca8a9b2b..3ba3d14d0dec5 100644 --- a/providers/standard/tests/unit/standard/sensors/test_time_delta.py +++ b/providers/standard/tests/unit/standard/sensors/test_time_delta.py @@ -88,18 +88,18 @@ def test_timedelta_sensor_run_after_vs_interval(run_after, interval_end, dag_mak with dag_maker() as dag: op = TimeDeltaSensor(task_id="wait_sensor_check", delta=delta, dag=dag, mode="reschedule") - kwargs = {} - if AIRFLOW_V_3_0_PLUS: - from airflow.utils.types import DagRunTriggeredByType - - kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after) - dr = dag.create_dagrun( - run_id="abcrhroceuh", - run_type=DagRunType.MANUAL, - state=None, - session=session, - **kwargs, - ) + kwargs = {} + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.types import DagRunTriggeredByType + + kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after) + dr = dag.create_dagrun( + run_id="abcrhroceuh", + run_type=DagRunType.MANUAL, + state=None, + session=session, + **kwargs, + ) ti = dr.task_instances[0] context.update(dag_run=dr, ti=ti) expected = interval_end or run_after @@ -138,19 +138,19 @@ def test_timedelta_sensor_deferrable_run_after_vs_interval(run_after, interval_e deferrable=True, # <-- the feature under test ) - dr = dag.create_dagrun( - run_id="abcrhroceuh", - run_type=DagRunType.MANUAL, - state=None, - **kwargs, - ) - context.update(dag_run=dr) + dr = dag.create_dagrun( + run_id="abcrhroceuh", + run_type=DagRunType.MANUAL, + state=None, + **kwargs, + ) + context.update(dag_run=dr) - expected_base = interval_end or run_after - expected_fire_time = expected_base + delta + expected_base = interval_end or run_after + expected_fire_time = expected_base + delta - with pytest.raises(TaskDeferred) as td: - sensor.execute(context) + with pytest.raises(TaskDeferred) as td: + sensor.execute(context) # The sensor should defer once with a DateTimeTrigger trigger = td.value.trigger @@ -218,25 +218,26 @@ def test_timedelta_sensor_async_run_after_vs_interval(self, run_after, interval_ if interval_end: context["data_interval_end"] = interval_end with dag_maker() as dag: - kwargs = {} - if AIRFLOW_V_3_0_PLUS: - from airflow.utils.types import DagRunTriggeredByType - - kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after) - - dr = dag.create_dagrun( - run_id="abcrhroceuh", - run_type=DagRunType.MANUAL, - state=None, - **kwargs, - ) - context.update(dag_run=dr) - delta = timedelta(seconds=1) - with pytest.warns(AirflowProviderDeprecationWarning): - op = TimeDeltaSensorAsync(task_id="wait_sensor_check", delta=delta, dag=dag) - base_time = interval_end or run_after - expected_time = base_time + delta - with pytest.raises(TaskDeferred) as caught: - op.execute(context) - - assert caught.value.trigger.moment == expected_time + ... + kwargs = {} + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.types import DagRunTriggeredByType + + kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after) + + dr = dag.create_dagrun( + run_id="abcrhroceuh", + run_type=DagRunType.MANUAL, + state=None, + **kwargs, + ) + context.update(dag_run=dr) + delta = timedelta(seconds=1) + with pytest.warns(AirflowProviderDeprecationWarning): + op = TimeDeltaSensorAsync(task_id="wait_sensor_check", delta=delta, dag=dag) + base_time = interval_end or run_after + expected_time = base_time + delta + with pytest.raises(TaskDeferred) as caught: + op.execute(context) + + assert caught.value.trigger.moment == expected_time diff --git a/providers/standard/tests/unit/standard/sensors/test_weekday.py b/providers/standard/tests/unit/standard/sensors/test_weekday.py index 7bf3dd812c48a..b86349a032e93 100644 --- a/providers/standard/tests/unit/standard/sensors/test_weekday.py +++ b/providers/standard/tests/unit/standard/sensors/test_weekday.py @@ -143,10 +143,10 @@ def test_weekday_sensor_should_use_run_after_when_logical_date_is_not_provided(s use_task_logical_date=True, dag=dag, ) - dr = dag_maker.create_dagrun( - run_id="manual_run", - start_date=DEFAULT_DATE, - logical_date=None, - **{"run_after": timezone.utcnow()}, - ) - assert op.poke(context={"logical_date": None, "dag_run": dr}) is True + dr = dag_maker.create_dagrun( + run_id="manual_run", + start_date=DEFAULT_DATE, + logical_date=None, + **{"run_after": timezone.utcnow()}, + ) + assert op.poke(context={"logical_date": None, "dag_run": dr}) is True diff --git a/providers/standard/tests/unit/standard/utils/test_skipmixin.py b/providers/standard/tests/unit/standard/utils/test_skipmixin.py index 360a591af10fb..30b26b2624f16 100644 --- a/providers/standard/tests/unit/standard/utils/test_skipmixin.py +++ b/providers/standard/tests/unit/standard/utils/test_skipmixin.py @@ -37,6 +37,7 @@ if AIRFLOW_V_3_0_PLUS: from airflow.exceptions import DownstreamTasksSkipped + from airflow.models.dag_version import DagVersion from airflow.providers.standard.utils.skipmixin import SkipMixin from airflow.sdk import task, task_group else: @@ -140,15 +141,18 @@ def test_skip_all_except__branch_task_ids_none( with dag_maker( "dag_test_skip_all_except", serialized=True, - ): + ) as dag: task1 = EmptyOperator(task_id="task1") task2 = EmptyOperator(task_id="task2") task3 = EmptyOperator(task_id="task3") task1 >> [task2, task3] dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) - - ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(dag.dag_id) + ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id) + else: + ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID) if AIRFLOW_V_3_0_PLUS: with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -191,8 +195,11 @@ def test_skip_all_except__skip_task3(self, dag_maker, branch_task_ids, expected_ task1 >> [task2, task3] dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) - - ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(task1.dag_id) + ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id) + else: + ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID) if AIRFLOW_V_3_0_PLUS: with pytest.raises(DownstreamTasksSkipped) as exc_info: @@ -234,12 +241,55 @@ def task_group_op(k): task_group_op.expand(k=[0, 1]) dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) - branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), run_id=DEFAULT_DAG_RUN_ID, map_index=0) - branch_op_ti_1 = TI(dag.get_task("task_group_op.branch_op"), run_id=DEFAULT_DAG_RUN_ID, map_index=1) - branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"), run_id=DEFAULT_DAG_RUN_ID, map_index=0) - branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"), run_id=DEFAULT_DAG_RUN_ID, map_index=1) - branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"), run_id=DEFAULT_DAG_RUN_ID, map_index=0) - branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"), run_id=DEFAULT_DAG_RUN_ID, map_index=1) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(dag.dag_id) + branch_op_ti_0 = TI( + dag.get_task("task_group_op.branch_op"), + run_id=DEFAULT_DAG_RUN_ID, + map_index=0, + dag_version_id=dag_version.id, + ) + branch_op_ti_1 = TI( + dag.get_task("task_group_op.branch_op"), + run_id=DEFAULT_DAG_RUN_ID, + map_index=1, + dag_version_id=dag_version.id, + ) + branch_a_ti_0 = TI( + dag.get_task("task_group_op.branch_a"), + run_id=DEFAULT_DAG_RUN_ID, + map_index=0, + dag_version_id=dag_version.id, + ) + branch_a_ti_1 = TI( + dag.get_task("task_group_op.branch_a"), + run_id=DEFAULT_DAG_RUN_ID, + map_index=1, + dag_version_id=dag_version.id, + ) + branch_b_ti_0 = TI( + dag.get_task("task_group_op.branch_b"), + run_id=DEFAULT_DAG_RUN_ID, + map_index=0, + dag_version_id=dag_version.id, + ) + branch_b_ti_1 = TI( + dag.get_task("task_group_op.branch_b"), + run_id=DEFAULT_DAG_RUN_ID, + map_index=1, + dag_version_id=dag_version.id, + ) + else: + branch_op_ti_0 = TI( + dag.get_task("task_group_op.branch_op"), run_id=DEFAULT_DAG_RUN_ID, map_index=0 + ) + branch_op_ti_1 = TI( + dag.get_task("task_group_op.branch_op"), run_id=DEFAULT_DAG_RUN_ID, map_index=1 + ) + branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"), run_id=DEFAULT_DAG_RUN_ID, map_index=0) + branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"), run_id=DEFAULT_DAG_RUN_ID, map_index=1) + branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"), run_id=DEFAULT_DAG_RUN_ID, map_index=0) + branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"), run_id=DEFAULT_DAG_RUN_ID, map_index=1) SkipMixin().skip_all_except(ti=branch_op_ti_0, branch_task_ids="task_group_op.branch_a") SkipMixin().skip_all_except(ti=branch_op_ti_1, branch_task_ids="task_group_op.branch_b") @@ -257,7 +307,11 @@ def test_raise_exception_on_not_accepted_branch_task_ids_type(self, dag_maker): with dag_maker("dag_test_skip_all_except_wrong_type"): task = EmptyOperator(task_id="task") dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) - ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(task.dag_id) + ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id) + else: + ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID) error_message = ( r"'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, but got 'int'\." ) @@ -268,7 +322,11 @@ def test_raise_exception_on_not_accepted_iterable_branch_task_ids_type(self, dag with dag_maker("dag_test_skip_all_except_wrong_type"): task = EmptyOperator(task_id="task") dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) - ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(task.dag_id) + ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id) + else: + ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID) error_message = ( r"'branch_task_ids' expected all task IDs are strings. " r"Invalid tasks found: \{\(42, 'int'\)\}\." @@ -292,8 +350,11 @@ def test_raise_exception_on_not_valid_branch_task_ids(self, dag_maker, branch_ta task1 >> [task2, task3] dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) - - ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID) + if AIRFLOW_V_3_0_PLUS: + dag_version = DagVersion.get_latest_version(task1.dag_id) + ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id) + else: + ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID) error_message = r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: .*" with pytest.raises(AirflowException, match=error_message): diff --git a/providers/yandex/tests/unit/yandex/links/test_yq.py b/providers/yandex/tests/unit/yandex/links/test_yq.py index bb6ef4bf3608d..b9e8050f8dc8c 100644 --- a/providers/yandex/tests/unit/yandex/links/test_yq.py +++ b/providers/yandex/tests/unit/yandex/links/test_yq.py @@ -52,7 +52,10 @@ def test_default_link(): link = YQLink() op = MockOperator(task_id="test_task_id") - ti = TaskInstance(task=op, run_id="run_id1") + if AIRFLOW_V_3_0_PLUS: + ti = TaskInstance(task=op, run_id="run_id1", dag_version_id=mock.MagicMock()) + else: + ti = TaskInstance(task=op, run_id="run_id1") assert link.get_link(op, ti_key=ti.key) == "https://yq.cloud.yandex.ru" @@ -62,5 +65,8 @@ def test_link(): link = YQLink() op = MockOperator(task_id="test_task_id") - ti = TaskInstance(task=op, run_id="run_id1") + if AIRFLOW_V_3_0_PLUS: + ti = TaskInstance(task=op, run_id="run_id1", dag_version_id=mock.MagicMock()) + else: + ti = TaskInstance(task=op, run_id="run_id1") assert link.get_link(op, ti_key=ti.key) == "https://g.com" diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index ac1e51d5e55c9..674249f2d49d6 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel -API_VERSION: Final[str] = "2025-05-20" +API_VERSION: Final[str] = "2025-08-10" class AssetAliasReferenceAssetEventDagRun(BaseModel): @@ -398,6 +398,7 @@ class TaskInstance(BaseModel): dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] try_number: Annotated[int, Field(title="Try Number")] + dag_version_id: Annotated[UUID, Field(title="Dag Version Id")] map_index: Annotated[int | None, Field(title="Map Index")] = -1 hostname: Annotated[str | None, Field(title="Hostname")] = None context_carrier: Annotated[dict[str, Any] | None, Field(title="Context Carrier")] = None diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 4ba12597e6145..6de5878465414 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1241,6 +1241,7 @@ def _run_task(*, ti, run_triggerer=False): run_id=ti.run_id, try_number=ti.try_number, map_index=ti.map_index, + dag_version_id=ti.dag_version_id, ), task=ti.task, ) diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 7e48b24c86f55..51d9d207f4b6d 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -55,6 +55,7 @@ class RuntimeTaskInstanceProtocol(Protocol): """Minimal interface for a task instance available during the execution.""" id: uuid.UUID + dag_version_id: uuid.UUID task: BaseOperator task_id: str dag_id: str diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 5adaa2562abc7..b2e7f5e71c0fa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -43,6 +43,7 @@ def test_recv_StartupDetails(self): "try_number": 1, "run_id": "b", "dag_id": "c", + "dag_version_id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"), }, "ti_context": { "dag_run": { diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 742dba98c4d38..67f565ff9f4c1 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -190,6 +190,7 @@ def subprocess_main(): dag_id="c", run_id="d", try_number=1, + dag_version_id=uuid7(), ), client=client_with_ti_start, target=subprocess_main, @@ -262,6 +263,7 @@ def subprocess_main(): dag_id="c", run_id="d", try_number=1, + dag_version_id=uuid7(), ), client=client_with_ti_start, target=subprocess_main, @@ -296,6 +298,7 @@ def subprocess_main(): dag_id="c", run_id="d", try_number=1, + dag_version_id=uuid7(), ), client=client_with_ti_start, target=subprocess_main, @@ -314,7 +317,9 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance( + id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + ), client=MagicMock(spec=sdk_client.Client), target=subprocess_main, ) @@ -347,7 +352,9 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance( + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + ), client=sdk_client.Client(base_url="", dry_run=True, token=""), target=subprocess_main, ) @@ -382,7 +389,9 @@ def _on_child_started(self, *args, **kwargs): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance( + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + ), client=sdk_client.Client(base_url="", dry_run=True, token=""), target=subprocess_main, ) @@ -402,6 +411,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker dag_id="super_basic_run", run_id="c", try_number=1, + dag_version_id=uuid7(), ) bundle_info = BundleInfo(name="my-bundle", version=None) @@ -443,6 +453,7 @@ def test_supervise_handles_deferred_task( dag_id="super_basic_deferred_run", run_id="d", try_number=1, + dag_version_id=uuid7(), ) # Create a mock client to assert calls to the client @@ -497,7 +508,9 @@ def test_supervise_handles_deferred_task( def test_supervisor_handles_already_running_task(self): """Test that Supervisor prevents starting a Task Instance that is already running.""" - ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1) + ti = TaskInstance( + id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + ) # Mock API Server response indicating the TI is already running # The API Server would return a 409 Conflict status code if the TI is not @@ -576,7 +589,9 @@ def handle_request(request: httpx.Request) -> httpx.Response: proc = ActivitySubprocess.start( dag_rel_path=os.devnull, - what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance( + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + ), client=make_client(transport=httpx.MockTransport(handle_request)), target=subprocess_main, bundle_info=FAKE_BUNDLE, @@ -803,6 +818,7 @@ def subprocess_main(): dag_id="c", run_id="d", try_number=1, + dag_version_id=uuid7(), ), client=client_with_ti_start, target=subprocess_main, @@ -959,7 +975,9 @@ def _handler(sig, frame): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance( + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + ), client=client_with_ti_start, target=subprocess_main, ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index f2ff89670818b..a76c2c689a3d6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -145,6 +145,7 @@ def test_parse(test_dags_dir: Path, make_ti_context): dag_id="super_basic", run_id="c", try_number=1, + dag_version_id=uuid7(), ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -200,6 +201,7 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, dag_id=dag_id, run_id="c", try_number=1, + dag_version_id=uuid7(), ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -253,6 +255,7 @@ def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): dag_id="dag_name", run_id="c", try_number=1, + dag_version_id=uuid7(), ), dag_rel_path="path_test.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -504,6 +507,7 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm dag_id="basic_templated_dag", run_id="c", try_number=1, + dag_version_id=uuid7(), ), bundle_info=FAKE_BUNDLE, dag_rel_path="", @@ -617,6 +621,7 @@ def execute(self, context): dag_id="basic_dag", run_id="c", try_number=1, + dag_version_id=uuid7(), ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -663,6 +668,7 @@ def execute(self, context): dag_id="basic_dag", run_id="c", try_number=1, + dag_version_id=uuid7(), ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -709,6 +715,7 @@ def execute(self, context): dag_id="basic_dag", run_id="c", try_number=1, + dag_version_id=uuid7(), ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -842,7 +849,9 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch task_id = "conditional_task" what = StartupDetails( - ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), + ti=TaskInstance( + id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1, dag_version_id=uuid7() + ), dag_rel_path="dag_parsing_context.py", bundle_info=BundleInfo(name="my-bundle", version=None), ti_context=make_ti_context(dag_id=dag_id, run_id="c"), @@ -1085,6 +1094,7 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ dag_id=dag_id, run_id="test_run", try_number=1, + dag_version_id=uuid7(), ) start_date = timezone.datetime(2025, 1, 1) @@ -2196,6 +2206,7 @@ def execute(self, context): dag_id="basic_dag", run_id="c", try_number=1, + dag_version_id=uuid7(), ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -2234,6 +2245,7 @@ def execute(self, context): dag_id=dag.dag_id, run_id="test_run", try_number=1, + dag_version_id=uuid7(), ) runtime_ti = RuntimeTaskInstance.model_construct( @@ -2272,6 +2284,7 @@ def execute(self, context): dag_id=dag.dag_id, run_id="test_run", try_number=1, + dag_version_id=uuid7(), ) runtime_ti = RuntimeTaskInstance.model_construct( diff --git a/task-sdk/tests/task_sdk/log/test_log.py b/task-sdk/tests/task_sdk/log/test_log.py index 15db5a801e9e0..60305494beaf6 100644 --- a/task-sdk/tests/task_sdk/log/test_log.py +++ b/task-sdk/tests/task_sdk/log/test_log.py @@ -50,13 +50,15 @@ def test_json_rendering(captured_logs): task_id="test_task", run_id="test_run", try_number=1, + dag_version_id=UUID("ffec3c8e-2898-46f8-b7d5-3cc571577368"), ), ) assert captured_logs assert isinstance(captured_logs[0], bytes) assert json.loads(captured_logs[0]) == { "event": "A test message with a Pydantic class", - "pydantic_class": "TaskInstance(id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), task_id='test_task', dag_id='test_dag', run_id='test_run', try_number=1, map_index=-1, hostname=None, context_carrier=None)", + "pydantic_class": "TaskInstance(id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), task_id='test_task', dag_id='test_dag', run_id='test_run', " + "try_number=1, dag_version_id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), map_index=-1, hostname=None, context_carrier=None)", "timestamp": unittest.mock.ANY, "level": "info", } @@ -83,6 +85,7 @@ def test_jwt_token_is_redacted(captured_logs): task_id="test_task", run_id="test_run", try_number=1, + dag_version_id=UUID("ffec3c8e-2898-46f8-b7d5-3cc571577368"), ), ) assert captured_logs @@ -92,7 +95,7 @@ def test_jwt_token_is_redacted(captured_logs): "level": "info", "pydantic_class": "TaskInstance(id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), " "task_id='test_task', dag_id='test_dag', run_id='test_run', " - "try_number=1, map_index=-1, hostname=None, context_carrier=None)", + "try_number=1, dag_version_id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), map_index=-1, hostname=None, context_carrier=None)", "timestamp": unittest.mock.ANY, "token": "eyJ***", } @@ -121,6 +124,7 @@ def test_logs_are_masked(captured_logs): try_number=1, map_index=-1, hostname=None, + dag_version_id=UUID("ffec3c8e-2898-46f8-b7d5-3cc571577368"), ), "timestamp": "2025-03-25T05:13:27.073918Z", }, @@ -133,6 +137,7 @@ def test_logs_are_masked(captured_logs): task_id="test_task", run_id="test_run", try_number=1, + dag_version_id=UUID("ffec3c8e-2898-46f8-b7d5-3cc571577368"), ), ) assert captured_logs @@ -144,7 +149,7 @@ def test_logs_are_masked(captured_logs): "level": "info", "pydantic_class": "TaskInstance(id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), " "task_id='test_task', dag_id='test_dag', run_id='test_run', " - "try_number=1, map_index=-1, hostname=None, context_carrier=None)", + "try_number=1, dag_version_id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), map_index=-1, hostname=None, context_carrier=None)", "timestamp": "2025-03-25T05:13:27.073918Z", }