Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TITerminalStatePayload(StrictBaseModel):

end_date: UtcDateTime
"""When the task completed executing"""
rendered_map_index: str | None = None


class TISuccessStatePayload(StrictBaseModel):
Expand All @@ -97,6 +98,7 @@ class TISuccessStatePayload(StrictBaseModel):

task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)]
outlet_events: Annotated[list[dict[str, Any]], Field(default_factory=list)]
rendered_map_index: str | None = None


class TITargetStatePayload(StrictBaseModel):
Expand Down Expand Up @@ -136,6 +138,7 @@ class TIDeferredStatePayload(StrictBaseModel):

Both forms will be passed along to the TaskSDK upon resume, the server will not handle either.
"""
rendered_map_index: str | None = None


class TIRescheduleStatePayload(StrictBaseModel):
Expand Down Expand Up @@ -171,6 +174,7 @@ class TIRetryStatePayload(StrictBaseModel):
),
]
end_date: UtcDateTime
rendered_map_index: str | None = None


class TISkippedDownstreamTasksStatePayload(StrictBaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

from cadwyn import HeadVersion, Version, VersionBundle

from airflow.api_fastapi.execution_api.versions.v2025_04_28 import AddRenderedMapIndexField

bundle = VersionBundle(
HeadVersion(),
Version("2025-04-28", AddRenderedMapIndexField),
Version("2025-04-11"),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from cadwyn import VersionChange, schema

from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TIDeferredStatePayload,
TIRetryStatePayload,
TISuccessStatePayload,
TITerminalStatePayload,
)


class AddRenderedMapIndexField(VersionChange):
"""Add the `rendered_map_index` field to payload models."""

description = __doc__

instructions_to_migrate_to_previous_version = (
schema(TITerminalStatePayload).field("rendered_map_index").didnt_exist,
schema(TISuccessStatePayload).field("rendered_map_index").didnt_exist,
schema(TIDeferredStatePayload).field("rendered_map_index").didnt_exist,
schema(TIRetryStatePayload).field("rendered_map_index").didnt_exist,
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

def test_custom_openapi_includes_extra_schemas(client):
"""Test to ensure that extra schemas are correctly included in the OpenAPI schema."""
response = client.get("/execution/openapi.json?version=2025-04-11")
response = client.get("/execution/openapi.json?version=2025-04-28")
assert response.status_code == 200

openapi_schema = response.json()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

DEFAULT_START_DATE = timezone.parse("2024-10-31T11:00:00Z")
DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z")
DEFAULT_RENDERED_MAP_INDEX = "test rendered map index"


def _create_asset_aliases(session, num: int = 2) -> None:
Expand Down Expand Up @@ -465,6 +466,39 @@ def test_ti_update_state_to_terminal(
assert ti.state == expected_state
assert ti.end_date == end_date

@pytest.mark.parametrize(
("state", "end_date", "expected_state", "rendered_map_index"),
[
(State.SUCCESS, DEFAULT_END_DATE, State.SUCCESS, DEFAULT_RENDERED_MAP_INDEX),
(State.FAILED, DEFAULT_END_DATE, State.FAILED, DEFAULT_RENDERED_MAP_INDEX),
(State.SKIPPED, DEFAULT_END_DATE, State.SKIPPED, DEFAULT_RENDERED_MAP_INDEX),
],
)
def test_ti_update_state_to_terminal_with_rendered_map_index(
self, client, session, create_task_instance, state, end_date, expected_state, rendered_map_index
):
ti = create_task_instance(
task_id="test_ti_update_state_to_terminal_with_rendered_map_index",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={"state": state, "end_date": end_date.isoformat(), "rendered_map_index": rendered_map_index},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == expected_state
assert ti.end_date == end_date
assert ti.rendered_map_index == rendered_map_index

@pytest.mark.parametrize(
"task_outlets",
[
Expand Down
19 changes: 13 additions & 6 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,29 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
return TIRunContext.model_validate_json(resp.read())

def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime):
def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index):
"""Tell the API server that this TI has reached a terminal state."""
if state == TaskInstanceState.SUCCESS:
raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead")
# TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state))
body = TITerminalStatePayload(
end_date=when, state=TerminalStateNonSuccess(state), rendered_map_index=rendered_map_index
)
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def retry(self, id: uuid.UUID, end_date: datetime):
def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index):
"""Tell the API server that this TI has failed and reached a up_for_retry state."""
body = TIRetryStatePayload(end_date=end_date)
body = TIRetryStatePayload(end_date=end_date, rendered_map_index=rendered_map_index)
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events):
def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, rendered_map_index):
"""Tell the API server that this TI has succeeded."""
body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets, outlet_events=outlet_events)
body = TISuccessStatePayload(
end_date=when,
task_outlets=task_outlets,
outlet_events=outlet_events,
rendered_map_index=rendered_map_index,
)
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
Expand Down
6 changes: 5 additions & 1 deletion task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue

API_VERSION: Final[str] = "2025-04-11"
API_VERSION: Final[str] = "2025-04-28"


class AssetAliasReferenceAssetEventDagRun(BaseModel):
Expand Down Expand Up @@ -193,6 +193,7 @@ class TIDeferredStatePayload(BaseModel):
trigger_timeout: Annotated[timedelta | None, Field(title="Trigger Timeout")] = None
next_method: Annotated[str, Field(title="Next Method")]
next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None
rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None


class TIEnterRunningPayload(BaseModel):
Expand Down Expand Up @@ -245,6 +246,7 @@ class TIRetryStatePayload(BaseModel):
)
state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] = "up_for_retry"
end_date: Annotated[AwareDatetime, Field(title="End Date")]
rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None


class TISkippedDownstreamTasksStatePayload(BaseModel):
Expand All @@ -270,6 +272,7 @@ class TISuccessStatePayload(BaseModel):
end_date: Annotated[AwareDatetime, Field(title="End Date")]
task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task Outlets")] = None
outlet_events: Annotated[list[dict[str, Any]] | None, Field(title="Outlet Events")] = None
rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None


class TITargetStatePayload(BaseModel):
Expand Down Expand Up @@ -494,3 +497,4 @@ class TITerminalStatePayload(BaseModel):
)
state: TerminalStateNonSuccess
end_date: Annotated[AwareDatetime, Field(title="End Date")]
rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ class TaskState(BaseModel):
]
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"
rendered_map_index: str | None = None


class SucceedTask(TISuccessStatePayload):
Expand Down
12 changes: 11 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ class ActivitySubprocess(WatchedSubprocess):
# TODO: This should come from airflow.cfg: [core] task_success_overtime
TASK_OVERTIME_THRESHOLD: ClassVar[float] = 20.0
_task_end_time_monotonic: float | None = attrs.field(default=None, init=False)
_rendered_map_index: str | None = attrs.field(default=None, init=False)

decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor)

Expand Down Expand Up @@ -842,7 +843,10 @@ def wait(self) -> int:
# by the subprocess in the `handle_requests` method.
if self.final_state not in STATES_SENT_DIRECTLY:
self.client.task_instances.finish(
id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc)
id=self.id,
state=self.final_state,
when=datetime.now(tz=timezone.utc),
rendered_map_index=self._rendered_map_index,
)

# Now at the last possible moment, when all logs and comms with the subprocess has finished, lets
Expand Down Expand Up @@ -988,21 +992,26 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.succeed(
id=self.id,
when=msg.end_date,
task_outlets=msg.task_outlets,
outlet_events=msg.outlet_events,
rendered_map_index=self._rendered_map_index,
)
elif isinstance(msg, RetryTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.retry(
id=self.id,
end_date=msg.end_date,
rendered_map_index=self._rendered_map_index,
)
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
Expand Down Expand Up @@ -1045,6 +1054,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
resp = xcom
elif isinstance(msg, DeferTask):
self._terminal_state = TaskInstanceState.DEFERRED
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.defer(self.id, msg)
elif isinstance(msg, RescheduleTask):
self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
Expand Down
Loading