Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,15 @@ def get_extra_schemas() -> dict[str, dict]:
"""Get all the extra schemas that are not part of the main FastAPI app."""
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance
from airflow.executors.workloads import BundleInfo
from airflow.utils.state import TerminalTIState
from airflow.utils.state import TaskInstanceState, TerminalTIState

return {
"TaskInstance": TaskInstance.model_json_schema(),
"BundleInfo": BundleInfo.model_json_schema(),
# Include the combined state enum too. In the datamodels we separate out SUCCESS from the other states
# as that has different payload requirements
"TerminalTIState": {"type": "string", "enum": list(TerminalTIState)},
"TaskInstanceState": {"type": "string", "enum": list(TaskInstanceState)},
}


Expand Down
12 changes: 6 additions & 6 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import Context
from airflow.sdk.api.datamodels._generated import IntermediateTIState, TerminalTIState
from airflow.sdk.api.datamodels._generated import TaskInstanceState as TIState
from airflow.sdk.bases.operator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
Expand Down Expand Up @@ -2020,7 +2020,7 @@ class RunTaskCallable(Protocol):
"""Protocol for better type hints for the fixture `run_task`."""

@property
def state(self) -> IntermediateTIState | TerminalTIState: ...
def state(self) -> TIState: ...

@property
def msg(self) -> ToSupervisor | None: ...
Expand Down Expand Up @@ -2052,7 +2052,7 @@ def __call__(
ti_id: UUID | None = None,
max_tries: int | None = None,
context_update: dict[str, Any] | None = None,
) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]: ...
) -> tuple[TIState, ToSupervisor | None, BaseException | None]: ...


@pytest.fixture
Expand Down Expand Up @@ -2197,7 +2197,7 @@ def execute(self, context):

task = MyTaskOperator(task_id="test_task")
run_task(task)
assert run_task.state == TerminalTIState.SUCCESS
assert run_task.state == TaskInstanceState.SUCCESS
assert run_task.error is None
"""
import structlog
Expand Down Expand Up @@ -2309,7 +2309,7 @@ def __init__(self, create_runtime_ti):
self._context = None

@property
def state(self) -> IntermediateTIState | TerminalTIState:
def state(self) -> TIState:
"""Get the task state."""
return self._state

Expand Down Expand Up @@ -2348,7 +2348,7 @@ def __call__(
ti_id: UUID | None = None,
max_tries: int | None = None,
context_update: dict[str, Any] | None = None,
) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]:
) -> tuple[TIState, ToSupervisor | None, BaseException | None]:
now = timezone.utcnow()
if logical_date is None:
logical_date = now
Expand Down
4 changes: 2 additions & 2 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
DagRunStateResponse,
DagRunType,
PrevSuccessfulDagRunResponse,
TaskInstanceState,
TaskStatesResponse,
TerminalStateNonSuccess,
TerminalTIState,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
Expand Down Expand Up @@ -148,7 +148,7 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:

def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
if state == TerminalTIState.SUCCESS:
if state == TaskInstanceState.SUCCESS:
raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead")
# TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state))
Expand Down
15 changes: 15 additions & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,21 @@ class TerminalTIState(str, Enum):
REMOVED = "removed"


class TaskInstanceState(str, Enum):
REMOVED = "removed"
SCHEDULED = "scheduled"
QUEUED = "queued"
RUNNING = "running"
SUCCESS = "success"
RESTARTING = "restarting"
FAILED = "failed"
UP_FOR_RETRY = "up_for_retry"
UP_FOR_RESCHEDULE = "up_for_reschedule"
UPSTREAM_FAILED = "upstream_failed"
SKIPPED = "skipped"
DEFERRED = "deferred"


class AssetEventDagRunReference(BaseModel):
"""
Schema for AssetEvent model used in DagRun.
Expand Down
8 changes: 4 additions & 4 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@
DagRunStateResponse,
PrevSuccessfulDagRunResponse,
TaskInstance,
TaskInstanceState,
TaskStatesResponse,
TerminalTIState,
TIDeferredStatePayload,
TIRescheduleStatePayload,
TIRetryStatePayload,
Expand Down Expand Up @@ -370,9 +370,9 @@ class TaskState(BaseModel):
"""

state: Literal[
TerminalTIState.FAILED,
TerminalTIState.SKIPPED,
TerminalTIState.REMOVED,
TaskInstanceState.FAILED,
TaskInstanceState.SKIPPED,
TaskInstanceState.REMOVED,
]
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"
Expand Down
19 changes: 9 additions & 10 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@
from airflow.sdk.api.datamodels._generated import (
AssetResponse,
ConnectionResponse,
IntermediateTIState,
TaskInstance,
TaskInstanceState,
TaskStatesResponse,
TerminalTIState,
VariableResponse,
)
from airflow.sdk.exceptions import ErrorType
Expand Down Expand Up @@ -128,10 +127,10 @@
# "Directly" here means that the PATCH API calls to transition into these states are
# made from _handle_request() itself and don't have to come all the way to wait().
STATES_SENT_DIRECTLY = [
IntermediateTIState.DEFERRED,
IntermediateTIState.UP_FOR_RESCHEDULE,
IntermediateTIState.UP_FOR_RETRY,
TerminalTIState.SUCCESS,
TaskInstanceState.DEFERRED,
TaskInstanceState.UP_FOR_RESCHEDULE,
TaskInstanceState.UP_FOR_RETRY,
TaskInstanceState.SUCCESS,
SERVER_TERMINATED,
]

Expand Down Expand Up @@ -975,10 +974,10 @@ def final_state(self):
Not valid before the process has finished.
"""
if self._exit_code == 0:
return self._terminal_state or TerminalTIState.SUCCESS
return self._terminal_state or TaskInstanceState.SUCCESS
if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED:
return SERVER_TERMINATED
return TerminalTIState.FAILED
return TaskInstanceState.FAILED

def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
log.debug("Received message from task runner", msg=msg)
Expand Down Expand Up @@ -1035,10 +1034,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key)
resp = XComCountResponse(len=len)
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self._terminal_state = TaskInstanceState.DEFERRED
self.client.task_instances.defer(self.id, msg)
elif isinstance(msg, RescheduleTask):
self._terminal_state = IntermediateTIState.UP_FOR_RESCHEDULE
self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
self.client.task_instances.reschedule(self.id, msg)
elif isinstance(msg, SkipDownstreamTasks):
self.client.task_instances.skip_downstream_tasks(self.id, msg)
Expand Down
Loading