diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 1c021015ed590..ef51da9827943 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -197,7 +197,7 @@ def get_extra_schemas() -> dict[str, dict]: """Get all the extra schemas that are not part of the main FastAPI app.""" from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance from airflow.executors.workloads import BundleInfo - from airflow.utils.state import TerminalTIState + from airflow.utils.state import TaskInstanceState, TerminalTIState return { "TaskInstance": TaskInstance.model_json_schema(), @@ -205,6 +205,7 @@ def get_extra_schemas() -> dict[str, dict]: # Include the combined state enum too. In the datamodels we separate out SUCCESS from the other states # as that has different payload requirements "TerminalTIState": {"type": "string", "enum": list(TerminalTIState)}, + "TaskInstanceState": {"type": "string", "enum": list(TaskInstanceState)}, } diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 50ac601cd1a4c..1e8dbb8ff4bdd 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -48,7 +48,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import Context - from airflow.sdk.api.datamodels._generated import IntermediateTIState, TerminalTIState + from airflow.sdk.api.datamodels._generated import TaskInstanceState as TIState from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance @@ -2020,7 +2020,7 @@ class RunTaskCallable(Protocol): """Protocol for better type hints for the fixture `run_task`.""" @property - def state(self) -> IntermediateTIState | TerminalTIState: ... + def state(self) -> TIState: ... @property def msg(self) -> ToSupervisor | None: ... @@ -2052,7 +2052,7 @@ def __call__( ti_id: UUID | None = None, max_tries: int | None = None, context_update: dict[str, Any] | None = None, - ) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]: ... + ) -> tuple[TIState, ToSupervisor | None, BaseException | None]: ... @pytest.fixture @@ -2197,7 +2197,7 @@ def execute(self, context): task = MyTaskOperator(task_id="test_task") run_task(task) - assert run_task.state == TerminalTIState.SUCCESS + assert run_task.state == TaskInstanceState.SUCCESS assert run_task.error is None """ import structlog @@ -2309,7 +2309,7 @@ def __init__(self, create_runtime_ti): self._context = None @property - def state(self) -> IntermediateTIState | TerminalTIState: + def state(self) -> TIState: """Get the task state.""" return self._state @@ -2348,7 +2348,7 @@ def __call__( ti_id: UUID | None = None, max_tries: int | None = None, context_update: dict[str, Any] | None = None, - ) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]: + ) -> tuple[TIState, ToSupervisor | None, BaseException | None]: now = timezone.utcnow() if logical_date is None: logical_date = now diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index c64f721b9ae61..9c9c6897b1190 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -41,9 +41,9 @@ DagRunStateResponse, DagRunType, PrevSuccessfulDagRunResponse, + TaskInstanceState, TaskStatesResponse, TerminalStateNonSuccess, - TerminalTIState, TIDeferredStatePayload, TIEnterRunningPayload, TIHeartbeatInfo, @@ -148,7 +148,7 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime): """Tell the API server that this TI has reached a terminal state.""" - if state == TerminalTIState.SUCCESS: + if state == TaskInstanceState.SUCCESS: raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead") # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state)) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index d9c5982e53995..7c7647635ea05 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -384,6 +384,21 @@ class TerminalTIState(str, Enum): REMOVED = "removed" +class TaskInstanceState(str, Enum): + REMOVED = "removed" + SCHEDULED = "scheduled" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + RESTARTING = "restarting" + FAILED = "failed" + UP_FOR_RETRY = "up_for_retry" + UP_FOR_RESCHEDULE = "up_for_reschedule" + UPSTREAM_FAILED = "upstream_failed" + SKIPPED = "skipped" + DEFERRED = "deferred" + + class AssetEventDagRunReference(BaseModel): """ Schema for AssetEvent model used in DagRun. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index b4d68086b0c4d..af9e612fe83c5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -63,8 +63,8 @@ DagRunStateResponse, PrevSuccessfulDagRunResponse, TaskInstance, + TaskInstanceState, TaskStatesResponse, - TerminalTIState, TIDeferredStatePayload, TIRescheduleStatePayload, TIRetryStatePayload, @@ -370,9 +370,9 @@ class TaskState(BaseModel): """ state: Literal[ - TerminalTIState.FAILED, - TerminalTIState.SKIPPED, - TerminalTIState.REMOVED, + TaskInstanceState.FAILED, + TaskInstanceState.SKIPPED, + TaskInstanceState.REMOVED, ] end_date: datetime | None = None type: Literal["TaskState"] = "TaskState" diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index c90d6ea5a0241..053379ce947d4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -54,10 +54,9 @@ from airflow.sdk.api.datamodels._generated import ( AssetResponse, ConnectionResponse, - IntermediateTIState, TaskInstance, + TaskInstanceState, TaskStatesResponse, - TerminalTIState, VariableResponse, ) from airflow.sdk.exceptions import ErrorType @@ -128,10 +127,10 @@ # "Directly" here means that the PATCH API calls to transition into these states are # made from _handle_request() itself and don't have to come all the way to wait(). STATES_SENT_DIRECTLY = [ - IntermediateTIState.DEFERRED, - IntermediateTIState.UP_FOR_RESCHEDULE, - IntermediateTIState.UP_FOR_RETRY, - TerminalTIState.SUCCESS, + TaskInstanceState.DEFERRED, + TaskInstanceState.UP_FOR_RESCHEDULE, + TaskInstanceState.UP_FOR_RETRY, + TaskInstanceState.SUCCESS, SERVER_TERMINATED, ] @@ -975,10 +974,10 @@ def final_state(self): Not valid before the process has finished. """ if self._exit_code == 0: - return self._terminal_state or TerminalTIState.SUCCESS + return self._terminal_state or TaskInstanceState.SUCCESS if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED: return SERVER_TERMINATED - return TerminalTIState.FAILED + return TaskInstanceState.FAILED def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): log.debug("Received message from task runner", msg=msg) @@ -1035,10 +1034,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) resp = XComCountResponse(len=len) elif isinstance(msg, DeferTask): - self._terminal_state = IntermediateTIState.DEFERRED + self._terminal_state = TaskInstanceState.DEFERRED self.client.task_instances.defer(self.id, msg) elif isinstance(msg, RescheduleTask): - self._terminal_state = IntermediateTIState.UP_FOR_RESCHEDULE + self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE self.client.task_instances.reschedule(self.id, msg) elif isinstance(msg, SkipDownstreamTasks): self.client.task_instances.skip_downstream_tasks(self.id, msg) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 00c09528c19a1..6e1e5f885f5e1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -42,9 +42,8 @@ from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.datamodels._generated import ( AssetProfile, - IntermediateTIState, TaskInstance, - TerminalTIState, + TaskInstanceState, TIRunContext, ) from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard @@ -94,7 +93,6 @@ ) from airflow.sdk.execution_time.xcom import XCom from airflow.utils.net import get_hostname -from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import coerce_datetime if TYPE_CHECKING: @@ -131,6 +129,8 @@ class RuntimeTaskInstance(TaskInstance): end_date: AwareDatetime | None = None + state: TaskInstanceState | None = None + is_mapped: bool | None = None """True if the original task was mapped.""" @@ -600,6 +600,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: _ti_context_from_server=what.ti_context, max_tries=what.ti_context.max_tries, start_date=what.start_date, + state=TaskInstanceState.RUNNING, ) @@ -765,7 +766,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv def _defer_task( defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger -) -> tuple[ToSupervisor, IntermediateTIState]: +) -> tuple[ToSupervisor, TaskInstanceState]: # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) @@ -778,7 +779,7 @@ def _defer_task( next_method=defer.method_name, next_kwargs=defer.kwargs or {}, ) - state = IntermediateTIState.DEFERRED + state = TaskInstanceState.DEFERRED return msg, state @@ -787,7 +788,7 @@ def run( ti: RuntimeTaskInstance, context: Context, log: Logger, -) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]: +) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]: """Run the task in this process.""" from airflow.exceptions import ( AirflowException, @@ -807,7 +808,7 @@ def run( assert isinstance(ti.task, BaseOperator) msg: ToSupervisor | None = None - state: IntermediateTIState | TerminalTIState + state: TaskInstanceState error: BaseException | None = None try: @@ -827,7 +828,7 @@ def run( # catch it and handle it like a normal task failure if early_exit := _prepare(ti, log, context): msg = early_exit - state = TerminalTIState.FAILED + ti.state = state = TaskInstanceState.FAILED return state, msg, error result = _execute_task(context, ti, log) @@ -848,16 +849,16 @@ def run( if e.args: log.info("Skipping task.", reason=e.args[0]) msg = TaskState( - state=TerminalTIState.SKIPPED, + state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc), ) - state = TerminalTIState.SKIPPED + state = TaskInstanceState.SKIPPED except AirflowRescheduleException as reschedule: log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") msg = RescheduleTask( reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc) ) - state = IntermediateTIState.UP_FOR_RESCHEDULE + state = TaskInstanceState.UP_FOR_RESCHEDULE except (AirflowFailException, AirflowSensorTimeout) as e: # If AirflowFailException is raised, task should not retry. # If a sensor in reschedule mode reaches timeout, task should not retry. @@ -865,10 +866,10 @@ def run( # TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951 # TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952 msg = TaskState( - state=TerminalTIState.FAILED, + state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), ) - state = TerminalTIState.FAILED + state = TaskInstanceState.FAILED error = e except (AirflowTaskTimeout, AirflowException) as e: # We should allow retries if the task has defined it. @@ -881,10 +882,10 @@ def run( # If these are thrown, we should mark the TI state as failed. log.exception("Task failed with exception") msg = TaskState( - state=TerminalTIState.FAILED, + state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), ) - state = TerminalTIState.FAILED + state = TaskInstanceState.FAILED error = e except SystemExit as e: # SystemExit needs to be retried if they are eligible. @@ -898,14 +899,16 @@ def run( finally: if msg: SUPERVISOR_COMMS.send_request(msg=msg, log=log) + # Return the message to make unit tests easier too + ti.state = state return state, msg, error def _handle_current_task_success( context: Context, ti: RuntimeTaskInstance, -) -> tuple[SucceedTask, TerminalTIState]: +) -> tuple[SucceedTask, TaskInstanceState]: task_outlets = list(_build_asset_profiles(ti.task.outlets)) outlet_events = list(_serialize_outlet_events(context["outlet_events"])) msg = SucceedTask( @@ -913,21 +916,21 @@ def _handle_current_task_success( task_outlets=task_outlets, outlet_events=outlet_events, ) - return msg, TerminalTIState.SUCCESS + return msg, TaskInstanceState.SUCCESS def _handle_current_task_failed( ti: RuntimeTaskInstance, -) -> tuple[RetryTask, IntermediateTIState] | tuple[TaskState, TerminalTIState]: +) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]: end_date = datetime.now(tz=timezone.utc) if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: - return RetryTask(end_date=end_date), IntermediateTIState.UP_FOR_RETRY - return TaskState(state=TerminalTIState.FAILED, end_date=end_date), TerminalTIState.FAILED + return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY + return TaskState(state=TaskInstanceState.FAILED, end_date=end_date), TaskInstanceState.FAILED def _handle_trigger_dag_run( drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger -) -> tuple[ToSupervisor, IntermediateTIState | TerminalTIState]: +) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) SUPERVISOR_COMMS.send_request( @@ -948,12 +951,12 @@ def _handle_trigger_dag_run( "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", dag_id=drte.trigger_dag_id, ) - msg = TaskState(state=TerminalTIState.SKIPPED, end_date=datetime.now(tz=timezone.utc)) - state = TerminalTIState.SKIPPED + msg = TaskState(state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc)) + state = TaskInstanceState.SKIPPED else: log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) - msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) - state = TerminalTIState.FAILED + msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc)) + state = TaskInstanceState.FAILED return msg, state @@ -998,8 +1001,8 @@ def _handle_trigger_dag_run( log.error( "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state ) - msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) - state = TerminalTIState.FAILED + msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc)) + state = TaskInstanceState.FAILED return msg, state if comms_msg.state in drte.allowed_states: log.info( @@ -1151,7 +1154,7 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): def finalize( ti: RuntimeTaskInstance, - state: IntermediateTIState | TerminalTIState, + state: TaskInstanceState, context: Context, log: Logger, error: BaseException | None = None, @@ -1172,7 +1175,7 @@ def finalize( ) log.debug("Running finalizers", ti=ti) - if state == TerminalTIState.SUCCESS: + if state == TaskInstanceState.SUCCESS: _run_task_state_change_callbacks(task, "on_success_callback", context, log) try: get_listener_manager().hook.on_task_instance_success( @@ -1180,9 +1183,9 @@ def finalize( ) except Exception: log.exception("error calling listener") - elif state == TerminalTIState.SKIPPED: + elif state == TaskInstanceState.SKIPPED: _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) - elif state == IntermediateTIState.UP_FOR_RETRY: + elif state == TaskInstanceState.UP_FOR_RETRY: _run_task_state_change_callbacks(task, "on_retry_callback", context, log) try: get_listener_manager().hook.on_task_instance_failed( @@ -1192,7 +1195,7 @@ def finalize( log.exception("error calling listener") if error and task.email_on_retry and task.email: _send_task_error_email(task.email, ti, error) - elif state == TerminalTIState.FAILED: + elif state == TaskInstanceState.FAILED: _run_task_state_change_callbacks(task, "on_failure_callback", context, log) try: get_listener_manager().hook.on_task_instance_failed( diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index 5373202ab9f17..d6b8ca8da4e96 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -24,7 +24,7 @@ import pendulum import pytest -from airflow.sdk.api.datamodels._generated import TerminalTIState +from airflow.sdk.api.datamodels._generated import TaskInstanceState from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator @@ -405,7 +405,7 @@ def g2(y): g1.expand(x=t()) -RunTI = Callable[[DAG, str, int], TerminalTIState] +RunTI = Callable[[DAG, str, int], TaskInstanceState] def test_map_cross_product(run_ti: RunTI, mock_supervisor_comms): @@ -439,7 +439,7 @@ def xcom_get(): mock_supervisor_comms.get_message.side_effect = xcom_get states = [run_ti(dag, "show", map_index) for map_index in range(6)] - assert states == [TerminalTIState.SUCCESS] * 6 + assert states == [TaskInstanceState.SUCCESS] * 6 assert outputs == [ (1, ("a", "x")), (1, ("b", "y")), @@ -479,7 +479,7 @@ def xcom_get(): mock_supervisor_comms.get_message.side_effect = xcom_get states = [run_ti(dag, "show", map_index) for map_index in range(4)] - assert states == [TerminalTIState.SUCCESS] * 4 + assert states == [TaskInstanceState.SUCCESS] * 4 assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index 9a16be08a352c..cc968194132d6 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -25,13 +25,13 @@ from pytest_unordered import unordered from airflow.exceptions import AirflowSkipException -from airflow.sdk.api.datamodels._generated import TerminalTIState +from airflow.sdk.api.datamodels._generated import TaskInstanceState from airflow.sdk.definitions.dag import DAG from airflow.sdk.execution_time.comms import GetXCom, XComResult log = structlog.get_logger(__name__) -RunTI = Callable[[DAG, str, int], TerminalTIState] +RunTI = Callable[[DAG, str, int], TaskInstanceState] def test_xcom_map(run_ti: RunTI, mock_supervisor_comms): @@ -55,7 +55,7 @@ def pull(value): mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) for map_index in range(3): - assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS assert results == {"aa", "bb", "cc"} @@ -85,7 +85,7 @@ def c_to_none(v): # Run "pull". This should automatically convert "c" to None. for map_index in range(3): - assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS assert results == {"a", "b", None} @@ -115,13 +115,13 @@ def c_to_none(v): # The first two "pull" tis should succeed. for map_index in range(2): - assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS # Clear captured logs from the above captured_logs[:] = [] # But the third one fails because the map() result cannot be used as kwargs. - assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED + assert run_ti(dag, "pull", 2) == TaskInstanceState.FAILED assert captured_logs == unordered( [ @@ -165,7 +165,7 @@ def does_not_work_with_c(v): # Mock xcom result from push task mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # The third one (for "c") will fail. - assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED + assert run_ti(dag, "pull", 2) == TaskInstanceState.FAILED assert captured_logs == unordered( [ @@ -209,7 +209,7 @@ def pull(value): # Now "pull" should apply the mapping functions in order. for map_index in range(3): - assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS assert results == {"aa", "bb", "cc"} @@ -256,7 +256,7 @@ def xcom_get(): # Run "pull". for map_index in range(4): - assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS assert results == {"aa", "bbbb", "cccccc", "dddddddd"} @@ -287,7 +287,7 @@ def skip_c(v): # Run "forward". This should automatically skip "c". states = [run_ti(dag, "forward", map_index) for map_index in range(3)] - assert states == [TerminalTIState.SUCCESS, TerminalTIState.SUCCESS, TerminalTIState.SKIPPED] + assert states == [TaskInstanceState.SUCCESS, TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED] assert result == ["a", "b"] @@ -353,9 +353,9 @@ def xcom_get(): mock_supervisor_comms.get_message.side_effect = xcom_get # Run "pull_one" and "pull_all". - assert run_ti(dag, "pull_all", None) == TerminalTIState.SUCCESS + assert run_ti(dag, "pull_all", None) == TaskInstanceState.SUCCESS assert all_results == ["a", "b", "c", 1, 2] states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)] - assert states == [TerminalTIState.SUCCESS] * 5 + assert states == [TaskInstanceState.SUCCESS] * 5 assert agg_results == {"a", "b", "c", 1, 2} diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4aff50a5fd486..8709966f98046 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -48,7 +48,7 @@ AssetResponse, DagRunState, TaskInstance, - TerminalTIState, + TaskInstanceState, ) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.comms import ( @@ -311,7 +311,7 @@ def subprocess_main(): @spy_agency.spy_for(ActivitySubprocess._on_child_started) def _on_child_started(self, *args, **kwargs): # Set it up so we are in overtime straight away - self._terminal_state = TerminalTIState.SUCCESS + self._terminal_state = TaskInstanceState.SUCCESS ActivitySubprocess._on_child_started.call_original(self, *args, **kwargs) heartbeat_spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat) @@ -649,11 +649,11 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t False, id="no_terminal_state", ), - pytest.param(TerminalTIState.SUCCESS, 15.0, 10, False, id="below_threshold"), - pytest.param(TerminalTIState.SUCCESS, 9.0, 10, True, id="above_threshold"), - pytest.param(TerminalTIState.FAILED, 9.0, 10, True, id="above_threshold_failed_state"), - pytest.param(TerminalTIState.SKIPPED, 9.0, 10, True, id="above_threshold_skipped_state"), - pytest.param(TerminalTIState.SUCCESS, None, 20, False, id="task_end_datetime_none"), + pytest.param(TaskInstanceState.SUCCESS, 15.0, 10, False, id="below_threshold"), + pytest.param(TaskInstanceState.SUCCESS, 9.0, 10, True, id="above_threshold"), + pytest.param(TaskInstanceState.FAILED, 9.0, 10, True, id="above_threshold_failed_state"), + pytest.param(TaskInstanceState.SKIPPED, 9.0, 10, True, id="above_threshold_skipped_state"), + pytest.param(TaskInstanceState.SUCCESS, None, 20, False, id="task_end_datetime_none"), ], ) def test_overtime_handling( @@ -1180,10 +1180,10 @@ def watched_subprocess(self, mocker): OKResponse(ok=True), id="delete_xcom", ), - # we aren't adding all states under TerminalTIState here, because this test's scope is only to check + # we aren't adding all states under TaskInstanceState here, because this test's scope is only to check # if it can handle TaskState message pytest.param( - TaskState(state=TerminalTIState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), + TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), b"", "", (), diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 044e7864a58b4..f9ad46f357e7c 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -52,9 +52,8 @@ AssetProfile, AssetResponse, DagRunState, - IntermediateTIState, TaskInstance, - TerminalTIState, + TaskInstanceState, ) from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION @@ -113,7 +112,6 @@ ) from airflow.sdk.execution_time.xcom import XCom from airflow.utils import timezone -from airflow.utils.state import TaskInstanceState from airflow.utils.types import NOTSET, ArgNotSet from tests_common.test_utils.mock_operators import AirflowLink @@ -306,7 +304,9 @@ def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_com # Run the task ti = create_runtime_ti(dag_id="basic_deferred_run", task=task) - state, msg, err = run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + + assert ti.state == TaskInstanceState.DEFERRED # send_request will only be called when the TaskDeferred exception is raised mock_supervisor_comms.send_request.assert_any_call(msg=expected_defer_task, log=mock.ANY) @@ -328,7 +328,7 @@ def execute(self, context): context = ti.get_template_context() log = mock.MagicMock() run(ti, context=context, log=log) - finalize(ti, context=context, log=mock.MagicMock(), state=TerminalTIState.SUCCESS) + finalize(ti, context=context, log=mock.MagicMock(), state=TaskInstanceState.SUCCESS) assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] log.info.assert_called_with("Skipping downstream tasks.") @@ -360,6 +360,7 @@ def test_resume_from_deferred(time_machine, create_runtime_ti, mock_supervisor_c state, msg, err = run(ti, context=ti.get_template_context(), log=mock.MagicMock()) assert err is None assert state == TaskInstanceState.SUCCESS + assert ti.state == TaskInstanceState.SUCCESS spy_agency.assert_spy_called_with(spy, mock.ANY, event=instant) @@ -381,8 +382,10 @@ def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comm run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + assert ti.state == TaskInstanceState.SKIPPED + mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant), log=mock.ANY + msg=TaskState(state=TaskInstanceState.SKIPPED, end_date=instant), log=mock.ANY ) @@ -401,9 +404,11 @@ def test_run_raises_base_exception(time_machine, create_runtime_ti, mock_supervi run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + assert ti.state == TaskInstanceState.FAILED + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( - state=TerminalTIState.FAILED, + state=TaskInstanceState.FAILED, end_date=instant, ), log=mock.ANY, @@ -426,9 +431,11 @@ def test_run_raises_system_exit(time_machine, create_runtime_ti, mock_supervisor log = mock.MagicMock() run(ti, context=ti.get_template_context(), log=log) + assert ti.state == TaskInstanceState.FAILED + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( - state=TerminalTIState.FAILED, + state=TaskInstanceState.FAILED, end_date=instant, ), log=mock.ANY, @@ -455,9 +462,11 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + assert ti.state == TaskInstanceState.FAILED + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( - state=TerminalTIState.FAILED, + state=TaskInstanceState.FAILED, end_date=instant, ), log=mock.ANY, @@ -481,10 +490,12 @@ def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + assert ti.state == TaskInstanceState.FAILED + # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( - state=TerminalTIState.FAILED, + state=TaskInstanceState.FAILED, end_date=instant, ), log=mock.ANY, @@ -526,6 +537,7 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm spy_agency.assert_spy_called(task.prepare_for_execution) assert ti.task._lock_for_execution assert ti.task is not task, "ti.task should be a copy of the original task" + assert ti.state == TaskInstanceState.SUCCESS mock_supervisor_comms.send_request.assert_any_call( msg=SetRenderedFields( @@ -644,7 +656,7 @@ def execute(self, context): mock.call.send_request( msg=SucceedTask( end_date=instant, - state=TerminalTIState.SUCCESS, + state=TaskInstanceState.SUCCESS, task_outlets=[], outlet_events=[], ), @@ -699,9 +711,11 @@ def execute(self, context): run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + assert ti.state == TaskInstanceState.SUCCESS + # Ensure the task is Successful mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask(state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), + msg=SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), log=mock.ANY, ) @@ -747,8 +761,10 @@ def execute(self, context): run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + assert ti.state == TaskInstanceState.FAILED + mock_supervisor_comms.send_request.assert_called_once_with( - msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY + msg=TaskState(state=TaskInstanceState.FAILED, end_date=instant), log=mock.ANY ) @@ -1445,7 +1461,7 @@ def execute(self, context): mock_supervisor_comms.send_request.assert_called_once_with( msg=SucceedTask( - state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] ), log=mock.ANY, ) @@ -1454,7 +1470,7 @@ def execute(self, context): finalize( runtime_ti, log=mock.MagicMock(), - state=TerminalTIState.SUCCESS, + state=TaskInstanceState.SUCCESS, context=runtime_ti.get_template_context(), ) mock_xcom_set.assert_called_once_with( @@ -1491,7 +1507,7 @@ def __init__(self, bash_command, *args, **kwargs): finalize( runtime_ti, - state=TerminalTIState.SUCCESS, + state=TaskInstanceState.SUCCESS, context=runtime_ti.get_template_context(), log=mock.MagicMock(), ) @@ -1862,7 +1878,7 @@ def execute(self, context): mock_supervisor_comms.send_request.assert_called_once_with( msg=SucceedTask( - state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] ), log=mock.ANY, ) @@ -1891,7 +1907,7 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( msg=SucceedTask( - state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] ), log=mock.ANY, ) @@ -1918,7 +1934,7 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( msg=SucceedTask( - state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] ), log=mock.ANY, ) @@ -1953,7 +1969,7 @@ def execute(self, context): mock_supervisor_comms.send_request.assert_called_once_with( msg=SucceedTask( - state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] ), log=mock.ANY, ) @@ -1992,7 +2008,7 @@ def execute(self, context): mock_supervisor_comms.send_request.assert_called_once_with( msg=SucceedTask( - state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] ), log=mock.ANY, ) @@ -2023,7 +2039,7 @@ def return_num(num): mock_supervisor_comms.send_request.assert_any_call( msg=SucceedTask( - state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] ), log=mock.ANY, ) @@ -2195,28 +2211,28 @@ def _execute_failure(self, context): pytest.param( _execute_success, False, - TerminalTIState.SUCCESS, + TaskInstanceState.SUCCESS, ["on-execute callback", "execute success", "on-success callback"], id="success", ), pytest.param( _execute_skipped, False, - TerminalTIState.SKIPPED, + TaskInstanceState.SKIPPED, ["on-execute callback", "execute skipped", "on-skipped callback"], id="skipped", ), pytest.param( _execute_failure, False, - TerminalTIState.FAILED, + TaskInstanceState.FAILED, ["on-execute callback", "execute failure", "on-failure callback"], id="failure", ), pytest.param( _execute_failure, True, - IntermediateTIState.UP_FOR_RETRY, + TaskInstanceState.UP_FOR_RETRY, ["on-execute callback", "execute failure", "on-retry callback"], id="retry", ), @@ -2263,7 +2279,7 @@ class CustomOperator(BaseOperator): "on_success_callback", _execute_success, False, - TerminalTIState.SUCCESS, + TaskInstanceState.SUCCESS, ["on-execute 1", "on-execute 3", "execute success", "on-success 1", "on-success 3"], [], id="success", @@ -2272,7 +2288,7 @@ class CustomOperator(BaseOperator): "on_skipped_callback", _execute_skipped, False, - TerminalTIState.SKIPPED, + TaskInstanceState.SKIPPED, ["on-execute 1", "on-execute 3", "execute skipped", "on-skipped 1", "on-skipped 3"], [], id="skipped", @@ -2281,7 +2297,7 @@ class CustomOperator(BaseOperator): "on_failure_callback", _execute_failure, False, - TerminalTIState.FAILED, + TaskInstanceState.FAILED, ["on-execute 1", "on-execute 3", "execute failure", "on-failure 1", "on-failure 3"], [(1, mock.call("Task failed with exception"))], id="failure", @@ -2290,7 +2306,7 @@ class CustomOperator(BaseOperator): "on_retry_callback", _execute_failure, True, - IntermediateTIState.UP_FOR_RETRY, + TaskInstanceState.UP_FOR_RETRY, ["on-execute 1", "on-execute 3", "execute failure", "on-retry 1", "on-retry 3"], [(1, mock.call("Task failed with exception"))], id="retry", @@ -2555,7 +2571,7 @@ def test_handle_trigger_dag_run_wait_for_completion( @pytest.mark.parametrize( ["allowed_states", "failed_states", "intermediate_state"], [ - ([DagRunState.SUCCESS], None, IntermediateTIState.DEFERRED), + ([DagRunState.SUCCESS], None, TaskInstanceState.DEFERRED), ], ) def test_handle_trigger_dag_run_deferred(