Skip to content

Commit

Permalink
AIP 72: Handling "deferrable" tasks in execution_api and task SDK (#4…
Browse files Browse the repository at this point in the history
…4241)

closes: #44137

Co-authored-by: Kaxil Naik <kaxilnaik@gmail.com>
  • Loading branch information
amoghrajesh and kaxil authored Nov 27, 2024
1 parent 6f0d731 commit 761cedd
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 18 deletions.
28 changes: 26 additions & 2 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from __future__ import annotations

import uuid
from typing import Annotated, Literal, Union
from datetime import timedelta
from typing import Annotated, Any, Literal, Union

from pydantic import Discriminator, Tag, WithJsonSchema
from pydantic import Discriminator, Field, Tag, WithJsonSchema

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
Expand Down Expand Up @@ -60,6 +61,26 @@ class TITargetStatePayload(BaseModel):
state: IntermediateTIState


class TIDeferredStatePayload(BaseModel):
"""Schema for updating TaskInstance to a deferred state."""

state: Annotated[
Literal[IntermediateTIState.DEFERRED],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
WithJsonSchema(
{
"type": "string",
"enum": [IntermediateTIState.DEFERRED],
"default": IntermediateTIState.DEFERRED,
}
),
]
classpath: str
trigger_kwargs: Annotated[dict[str, Any], Field(default_factory=dict)]
next_method: str
trigger_timeout: timedelta | None = None


def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
"""
Determine the discriminator key for TaskInstance state transitions.
Expand All @@ -77,6 +98,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
return str(state)
elif state in set(TerminalTIState):
return "_terminal_"
elif state == TIState.DEFERRED:
return "deferred"
return "_other_"


Expand All @@ -87,6 +110,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
Annotated[TIEnterRunningPayload, Tag("running")],
Annotated[TITerminalStatePayload, Tag("_terminal_")],
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
],
Discriminator(ti_state_discriminator),
]
Expand Down
25 changes: 25 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIStateUpdate,
TITerminalStatePayload,
)
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.trigger import Trigger
from airflow.utils import timezone
from airflow.utils.state import State

Expand Down Expand Up @@ -122,6 +124,29 @@ def ti_update_state(
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
if ti_patch_payload.trigger_timeout is not None:
timeout = timezone.utcnow() + ti_patch_payload.trigger_timeout

trigger_row = Trigger(
classpath=ti_patch_payload.classpath,
kwargs=ti_patch_payload.trigger_kwargs,
)
session.add(trigger_row)

# TODO: HANDLE execution timeout later as it requires a call to the DB
# either get it from the serialised DAG or get it from the API

query = update(TI).where(TI.id == ti_id_str)
query = query.values(
state=State.DEFERRED,
trigger_id=trigger_row.id,
next_method=ti_patch_payload.next_method,
next_kwargs=ti_patch_payload.trigger_kwargs,
trigger_timeout=timeout,
)

# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
TerminalTIState,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TITerminalStatePayload,
Expand Down Expand Up @@ -116,6 +117,7 @@ def start(self, id: uuid.UUID, pid: int, when: datetime):

def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
# TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))

self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
Expand All @@ -124,6 +126,13 @@ def heartbeat(self, id: uuid.UUID, pid: int):
body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())

def defer(self, id: uuid.UUID, msg):
"""Tell the API server that this TI has been deferred."""
body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True))

# Create a deferred state payload from msg
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())


class ConnectionOperations:
__slots__ = ("client",)
Expand Down
14 changes: 13 additions & 1 deletion task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from __future__ import annotations

from datetime import datetime
from datetime import datetime, timedelta
from enum import Enum
from typing import Annotated, Any, Literal
from uuid import UUID
Expand Down Expand Up @@ -58,6 +58,18 @@ class IntermediateTIState(str, Enum):
DEFERRED = "deferred"


class TIDeferredStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a deferred state.
"""

state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred"
classpath: Annotated[str, Field(title="Classpath")]
trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger Kwargs")] = None
next_method: Annotated[str, Field(title="Next Method")]
trigger_timeout: Annotated[timedelta | None, Field(title="Trigger Timeout")] = None


class TIEnterRunningPayload(BaseModel):
"""
Schema for updating TaskInstance to 'RUNNING' state with minimal required fields.
Expand Down
9 changes: 8 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ConnectionResponse,
TaskInstance,
TerminalTIState,
TIDeferredStatePayload,
VariableResponse,
XComResponse,
)
Expand Down Expand Up @@ -103,6 +104,12 @@ class TaskState(BaseModel):
type: Literal["TaskState"] = "TaskState"


class DeferTask(TIDeferredStatePayload):
"""Update a task instance state to deferred."""

type: Literal["DeferTask"] = "DeferTask"


class GetXCom(BaseModel):
key: str
dag_id: str
Expand All @@ -123,6 +130,6 @@ class GetVariable(BaseModel):


ToSupervisor = Annotated[
Union[TaskState, GetXCom, GetConnection, GetVariable],
Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask],
Field(discriminator="type"),
]
28 changes: 23 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@
from pydantic import TypeAdapter

from airflow.sdk.api.client import Client
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.api.datamodels._generated import IntermediateTIState, TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
DeferTask,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -263,6 +264,7 @@ class WatchedSubprocess:
_process: psutil.Process
_exit_code: int | None = None
_terminal_state: str | None = None
_final_state: str | None = None

_last_heartbeat: float = 0

Expand Down Expand Up @@ -398,9 +400,10 @@ def wait(self) -> int:
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1

self.client.task_instances.finish(
id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc)
)
if self.final_state in TerminalTIState:
self.client.task_instances.finish(
id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc)
)
return self._exit_code

def _monitor_subprocess(self):
Expand Down Expand Up @@ -472,10 +475,20 @@ def final_state(self):
Not valid before the process has finished.
"""
if self._final_state:
return self._final_state
if self._exit_code == 0:
return self._terminal_state or TerminalTIState.SUCCESS
return TerminalTIState.FAILED

@final_state.setter
def final_state(self, value):
"""Setter for final_state for certain task instance stated present in IntermediateTIState."""
# TODO: Remove the setter and manage using the final_state property
# to be taken in a follow up
if value not in TerminalTIState:
self._final_state = value

def __rich_repr__(self):
yield "ti_id", self.ti_id
yield "pid", self.pid
Expand Down Expand Up @@ -518,11 +531,16 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
resp = xcom.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, DeferTask):
self.final_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.ti_id, msg)
resp = None
else:
log.error("Unhandled request", msg=msg)
continue

self.stdin.write(resp + b"\n")
if resp:
self.stdin.write(resp + b"\n")


# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read
Expand Down
17 changes: 14 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor, ToTask
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, ToSupervisor, ToTask

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand Down Expand Up @@ -159,8 +159,19 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined]
except TaskDeferred:
...
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
next_method = defer.method_name
timeout = defer.timeout
msg = DeferTask(
state="deferred",
classpath=classpath,
trigger_kwargs=trigger_kwargs,
next_method=next_method,
trigger_timeout=timeout,
)
global SUPERVISOR_COMMS
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
except AirflowSkipException:
...
except AirflowRescheduleException:
Expand Down
37 changes: 37 additions & 0 deletions task_sdk/tests/dags/super_basic_deferred_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import datetime

from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync
from airflow.sdk.definitions.dag import dag
from airflow.utils import timezone


@dag()
def super_basic_deferred_run():
DateTimeSensorAsync(
task_id="async",
target_time=str(timezone.utcnow() + datetime.timedelta(seconds=3)),
poke_interval=60,
timeout=600,
)


super_basic_deferred_run()
Loading

0 comments on commit 761cedd

Please sign in to comment.