diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index bbc557d012463..92a1e933dc93a 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -25,7 +25,10 @@ from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.core_api.base import BaseModel +from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse +from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState +from airflow.utils.types import DagRunType class TIEnterRunningPayload(BaseModel): @@ -94,9 +97,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: state = v.get("state") else: state = getattr(v, "state", None) - if state == TIState.RUNNING: - return str(state) - elif state in set(TerminalTIState): + if state in set(TerminalTIState): return "_terminal_" elif state == TIState.DEFERRED: return "deferred" @@ -107,7 +108,6 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: # and "_other_" is a catch-all for all other states that are not covered by the other schemas. TIStateUpdate = Annotated[ Union[ - Annotated[TIEnterRunningPayload, Tag("running")], Annotated[TITerminalStatePayload, Tag("_terminal_")], Annotated[TITargetStatePayload, Tag("_other_")], Annotated[TIDeferredStatePayload, Tag("deferred")], @@ -135,3 +135,34 @@ class TaskInstance(BaseModel): run_id: str try_number: int map_index: int | None = None + + +class DagRun(BaseModel): + """Schema for DagRun model with minimal required fields needed for Runtime.""" + + # TODO: `dag_id` and `run_id` are duplicated from TaskInstance + # See if we can avoid sending these fields from API server and instead + # use the TaskInstance data to get the DAG run information in the client (Task Execution Interface). + dag_id: str + run_id: str + + logical_date: UtcDateTime + data_interval_start: UtcDateTime | None + data_interval_end: UtcDateTime | None + start_date: UtcDateTime + end_date: UtcDateTime | None + run_type: DagRunType + conf: Annotated[dict[str, Any], Field(default_factory=dict)] + + +class TIRunContext(BaseModel): + """Response schema for TaskInstance run context.""" + + dag_run: DagRun + """DAG run information for the task instance.""" + + variables: Annotated[list[VariableResponse], Field(default_factory=list)] + """Variables that can be accessed by the task instance.""" + + connections: Annotated[list[ConnectionResponse], Field(default_factory=list)] + """Connections that can be accessed by the task instance.""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index e06798209c5da..3a1545283e81b 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -30,12 +30,15 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( + DagRun, TIDeferredStatePayload, TIEnterRunningPayload, TIHeartbeatInfo, + TIRunContext, TIStateUpdate, TITerminalStatePayload, ) +from airflow.models.dagrun import DagRun as DR from airflow.models.taskinstance import TaskInstance as TI, _update_rtif from airflow.models.trigger import Trigger from airflow.utils import timezone @@ -48,6 +51,110 @@ log = logging.getLogger(__name__) +@router.patch( + "/{task_instance_id}/run", + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, + status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, + status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, + }, +) +def ti_run( + task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep +) -> TIRunContext: + """ + Run a TaskInstance. + + This endpoint is used to start a TaskInstance that is in the QUEUED state. + """ + # We only use UUID above for validation purposes + ti_id_str = str(task_instance_id) + + old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update() + try: + (previous_state, dag_id, run_id) = session.execute(old).one() + except NoResultFound: + log.error("Task Instance %s not found", ti_id_str) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": "Task Instance not found", + }, + ) + + # We exclude_unset to avoid updating fields that are not set in the payload + data = ti_run_payload.model_dump(exclude_unset=True) + + query = update(TI).where(TI.id == ti_id_str).values(data) + + # TODO: We will need to change this for other states like: + # reschedule, retry, defer etc. + if previous_state != State.QUEUED: + log.warning( + "Can not start Task Instance ('%s') in invalid state: %s", + ti_id_str, + previous_state, + ) + + # TODO: Pass a RFC 9457 compliant error message in "detail" field + # https://datatracker.ietf.org/doc/html/rfc9457 + # to provide more information about the error + # FastAPI will automatically convert this to a JSON response + # This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370 + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "invalid_state", + "message": "TI was not in a state where it could be marked as running", + "previous_state": previous_state, + }, + ) + log.info("Task with %s state started on %s ", previous_state, ti_run_payload.hostname) + # Ensure there is no end date set. + query = query.values( + end_date=None, + hostname=ti_run_payload.hostname, + unixname=ti_run_payload.unixname, + pid=ti_run_payload.pid, + state=State.RUNNING, + ) + + try: + result = session.execute(query) + log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount) + + dr = session.execute( + select( + DR.run_id, + DR.dag_id, + DR.data_interval_start, + DR.data_interval_end, + DR.start_date, + DR.end_date, + DR.run_type, + DR.conf, + DR.logical_date, + ).filter_by(dag_id=dag_id, run_id=run_id) + ).one_or_none() + + if not dr: + raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.") + + return TIRunContext( + dag_run=DagRun.model_validate(dr, from_attributes=True), + # TODO: Add variables and connections that are needed (and has perms) for the task + variables=[], + connections=[], + ) + except SQLAlchemyError as e: + log.error("Error marking Task Instance state as running: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred" + ) + + @router.patch( "/{task_instance_id}/state", status_code=status.HTTP_204_NO_CONTENT, @@ -92,37 +199,7 @@ def ti_update_state( query = update(TI).where(TI.id == ti_id_str).values(data) - if isinstance(ti_patch_payload, TIEnterRunningPayload): - if previous_state != State.QUEUED: - log.warning( - "Can not start Task Instance ('%s') in invalid state: %s", - ti_id_str, - previous_state, - ) - - # TODO: Pass a RFC 9457 compliant error message in "detail" field - # https://datatracker.ietf.org/doc/html/rfc9457 - # to provide more information about the error - # FastAPI will automatically convert this to a JSON response - # This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370 - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail={ - "reason": "invalid_state", - "message": "TI was not in a state where it could be marked as running", - "previous_state": previous_state, - }, - ) - log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname) - # Ensure there is no end date set. - query = query.values( - end_date=None, - hostname=ti_patch_payload.hostname, - unixname=ti_patch_payload.unixname, - pid=ti_patch_payload.pid, - state=State.RUNNING, - ) - elif isinstance(ti_patch_payload, TITerminalStatePayload): + if isinstance(ti_patch_payload, TITerminalStatePayload): query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 568eb3c90bd86..5f08f2a6242c4 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -30,10 +30,12 @@ from airflow.sdk import __version__ from airflow.sdk.api.datamodels._generated import ( ConnectionResponse, + DagRunType, TerminalTIState, TIDeferredStatePayload, TIEnterRunningPayload, TIHeartbeatInfo, + TIRunContext, TITerminalStatePayload, ValidationError as RemoteValidationError, VariablePostBody, @@ -110,11 +112,12 @@ class TaskInstanceOperations: def __init__(self, client: Client): self.client = client - def start(self, id: uuid.UUID, pid: int, when: datetime): + def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: """Tell the API server that this TI has started running.""" body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when) - self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) + resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) + return TIRunContext.model_validate_json(resp.read()) def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime): """Tell the API server that this TI has reached a terminal state.""" @@ -218,7 +221,23 @@ def auth_flow(self, request: httpx.Request): # This exists as a aid for debugging or local running via the `dry_run` argument to Client. It doesn't make # sense for returning connections etc. def noop_handler(request: httpx.Request) -> httpx.Response: - log.debug("Dry-run request", method=request.method, path=request.url.path) + path = request.url.path + log.debug("Dry-run request", method=request.method, path=path) + + if path.startswith("/task-instances/") and path.endswith("/run"): + # Return a fake context + return httpx.Response( + 200, + json={ + "dag_run": { + "dag_id": "test_dag", + "run_id": "test_run", + "logical_date": "2021-01-01T00:00:00Z", + "start_date": "2021-01-01T00:00:00Z", + "run_type": DagRunType.MANUAL, + }, + }, + ) return httpx.Response(200, json={"text": "Hello, world!"}) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 37659ffcc1be6..5a103e78fc0ff 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -44,6 +44,17 @@ class ConnectionResponse(BaseModel): extra: Annotated[str | None, Field(title="Extra")] = None +class DagRunType(str, Enum): + """ + Class with DagRun types. + """ + + BACKFILL = "backfill" + SCHEDULED = "scheduled" + MANUAL = "manual" + ASSET_TRIGGERED = "asset_triggered" + + class IntermediateTIState(str, Enum): """ States that a Task Instance can be in that indicate it is not yet in a terminal or running state. @@ -159,10 +170,36 @@ class TaskInstance(BaseModel): map_index: Annotated[int | None, Field(title="Map Index")] = None +class DagRun(BaseModel): + """ + Schema for DagRun model with minimal required fields needed for Runtime. + """ + + dag_id: Annotated[str, Field(title="Dag Id")] + run_id: Annotated[str, Field(title="Run Id")] + logical_date: Annotated[datetime, Field(title="Logical Date")] + data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None + data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None + start_date: Annotated[datetime, Field(title="Start Date")] + end_date: Annotated[datetime | None, Field(title="End Date")] = None + run_type: DagRunType + conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None + + class HTTPValidationError(BaseModel): detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None +class TIRunContext(BaseModel): + """ + Response schema for TaskInstance run context. + """ + + dag_run: DagRun + variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None + connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None + + class TITerminalStatePayload(BaseModel): """ Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED). diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 9e6093a092da0..03f92c549fd75 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -54,6 +54,7 @@ TaskInstance, TerminalTIState, TIDeferredStatePayload, + TIRunContext, VariableResponse, XComResponse, ) @@ -70,6 +71,7 @@ class StartupDetails(BaseModel): Responses will come back on stdin """ + ti_context: TIRunContext type: Literal["StartupDetails"] = "StartupDetails" diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 677030b7bdce2..589cae56434c4 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -397,7 +397,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ # We've forked, but the task won't start doing anything until we send it the StartupDetails # message. But before we do that, we need to tell the server it's started (so it has the chance to # tell us "no, stop!" for any reason) - self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc)) + ti_context = self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc)) self._last_successful_heartbeat = time.monotonic() except Exception: # On any error kill that subprocess! @@ -408,6 +408,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ ti=ti, file=os.fspath(path), requests_fd=requests_fd, + ti_context=ti_context, ) # Send the message to tell the process what it needs to execute diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 5aca25f590e5e..92f400d46e2bb 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -23,13 +23,13 @@ import sys from datetime import datetime, timezone from io import FileIO -from typing import TYPE_CHECKING, Any, Generic, TextIO, TypeVar +from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar import attrs import structlog -from pydantic import BaseModel, ConfigDict, JsonValue, TypeAdapter +from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter -from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState +from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( DeferTask, @@ -48,9 +48,13 @@ class RuntimeTaskInstance(TaskInstance): model_config = ConfigDict(arbitrary_types_allowed=True) task: BaseOperator + _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None + """The Task Instance context from the API server, if any.""" def get_template_context(self): + # TODO: Assess if we need to it through airflow.utils.timezone.coerce_datetime() context: dict[str, Any] = { + # From the Task Execution interface "dag": self.task.dag, "inlets": self.task.inlets, "map_index_template": self.task.map_index_template, @@ -59,15 +63,9 @@ def get_template_context(self): "task": self.task, "task_instance": self, "ti": self, - # "dag_run": dag_run, - # "data_interval_end": timezone.coerce_datetime(data_interval.end), - # "data_interval_start": timezone.coerce_datetime(data_interval.start), # "outlet_events": OutletEventAccessors(), - # "ds": ds, - # "ds_nodash": ds_nodash, # "expanded_ti_count": expanded_ti_count, # "inlet_events": InletEventsAccessors(task.inlets, session=session), - # "logical_date": logical_date, # "macros": macros, # "params": validated_params, # "prev_data_interval_start_success": get_prev_data_interval_start_success(), @@ -77,15 +75,36 @@ def get_template_context(self): # "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}", # "test_mode": task_instance.test_mode, # "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events), - # "ts": ts, - # "ts_nodash": ts_nodash, - # "ts_nodash_with_tz": ts_nodash_with_tz, # "var": { # "json": VariableAccessor(deserialize_json=True), # "value": VariableAccessor(deserialize_json=False), # }, # "conn": ConnectionAccessor(), } + if self._ti_context_from_server: + dag_run = self._ti_context_from_server.dag_run + + logical_date = dag_run.logical_date + ds = logical_date.strftime("%Y-%m-%d") + ds_nodash = ds.replace("-", "") + ts = logical_date.isoformat() + ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") + ts_nodash_with_tz = ts.replace("-", "").replace(":", "") + + context_from_server = { + # TODO: Assess if we need to pass these through timezone.coerce_datetime + "dag_run": dag_run, + "data_interval_end": dag_run.data_interval_end, + "data_interval_start": dag_run.data_interval_start, + "logical_date": logical_date, + "ds": ds, + "ds_nodash": ds_nodash, + "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", + "ts": ts, + "ts_nodash": ts_nodash, + "ts_nodash_with_tz": ts_nodash_with_tz, + } + context.update(context_from_server) return context @@ -113,7 +132,11 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance: if not isinstance(task, BaseOperator): raise TypeError(f"task is of the wrong type, got {type(task)}, wanted {BaseOperator}") - return RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=task) + return RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), + task=task, + _ti_context_from_server=what.ti_context, + ) SendMsgType = TypeVar("SendMsgType", bound=BaseModel) diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index d10531ba1bb4f..346c3adfcc137 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -94,23 +94,31 @@ class TestTaskInstanceOperations: response parsing. """ - def test_task_instance_start(self): + def test_task_instance_start(self, make_ti_context): # Simulate a successful response from the server that starts a task ti_id = uuid6.uuid7() + start_date = "2024-10-31T12:00:00Z" + ti_context = make_ti_context( + start_date=start_date, + logical_date="2024-10-31T12:00:00Z", + run_type="manual", + ) def handle_request(request: httpx.Request) -> httpx.Response: - if request.url.path == f"/task-instances/{ti_id}/state": + if request.url.path == f"/task-instances/{ti_id}/run": actual_body = json.loads(request.read()) assert actual_body["pid"] == 100 - assert actual_body["start_date"] == "2024-10-31T12:00:00Z" + assert actual_body["start_date"] == start_date assert actual_body["state"] == "running" return httpx.Response( - status_code=204, + status_code=200, + json=ti_context.model_dump(mode="json"), ) return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.start(ti_id, 100, "2024-10-31T12:00:00Z") + resp = client.task_instances.start(ti_id, 100, start_date) + assert resp == ti_context @pytest.mark.parametrize("state", [state for state in TerminalTIState]) def test_task_instance_finish(self, state): diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index 04e94008842d3..25d0a1b0061b6 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -19,7 +19,7 @@ import logging import os from pathlib import Path -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, Any, NoReturn, Protocol import pytest @@ -29,8 +29,12 @@ os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" if TYPE_CHECKING: + from datetime import datetime + from structlog.typing import EventDict, WrappedLogger + from airflow.sdk.api.datamodels._generated import TIRunContext + @pytest.hookimpl() def pytest_addhooks(pluginmanager: pytest.PytestPluginManager): @@ -116,7 +120,7 @@ def _disable_ol_plugin(): # The OpenLineage plugin imports setproctitle, and that now causes (C) level thread calls, which on Py # 3.12+ issues a warning when os.fork happens. So for this plugin we disable it - # And we load plugins when setting the priorty_weight field + # And we load plugins when setting the priority_weight field import airflow.plugins_manager old = airflow.plugins_manager.plugins @@ -128,3 +132,85 @@ def _disable_ol_plugin(): yield airflow.plugins_manager.plugins = None + + +class MakeTIContextCallable(Protocol): + def __call__( + self, + dag_id: str = ..., + run_id: str = ..., + logical_date: str | datetime = ..., + data_interval_start: str | datetime = ..., + data_interval_end: str | datetime = ..., + start_date: str | datetime = ..., + run_type: str = ..., + ) -> TIRunContext: ... + + +class MakeTIContextDictCallable(Protocol): + def __call__( + self, + dag_id: str = ..., + run_id: str = ..., + logical_date: str = ..., + data_interval_start: str | datetime = ..., + data_interval_end: str | datetime = ..., + start_date: str | datetime = ..., + run_type: str = ..., + ) -> dict[str, Any]: ... + + +@pytest.fixture +def make_ti_context() -> MakeTIContextCallable: + """Factory for creating TIRunContext objects.""" + from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + + def _make_context( + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: str | datetime = "2024-12-01T01:00:00Z", + data_interval_start: str | datetime = "2024-12-01T00:00:00Z", + data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + start_date: str | datetime = "2024-12-01T01:00:00Z", + run_type: str = "manual", + ) -> TIRunContext: + return TIRunContext( + dag_run=DagRun( + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, # type: ignore + data_interval_start=data_interval_start, # type: ignore + data_interval_end=data_interval_end, # type: ignore + start_date=start_date, # type: ignore + run_type=run_type, # type: ignore + ) + ) + + return _make_context + + +@pytest.fixture +def make_ti_context_dict(make_ti_context: MakeTIContextCallable) -> MakeTIContextDictCallable: + """Factory for creating context dictionaries suited for API Server response.""" + + def _make_context_dict( + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: str | datetime = "2024-12-01T00:00:00Z", + data_interval_start: str | datetime = "2024-12-01T00:00:00Z", + data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + start_date: str | datetime = "2024-12-01T00:00:00Z", + run_type: str = "manual", + ) -> dict[str, Any]: + context = make_ti_context( + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + start_date=start_date, + run_type=run_type, + ) + return context.model_dump(exclude_unset=True, mode="json") + + return _make_context_dict diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 406b2ee26996f..70f9e26486408 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -254,7 +254,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): try_number=1, ) # Assert Exit Code is 0 - assert supervise(ti=ti, dag_path=dagfile_path, token="", server="", dry_run=True) == 0 + assert supervise(ti=ti, dag_path=dagfile_path, token="", server="", dry_run=True) == 0, captured_logs # We should have a log from the task! assert { @@ -265,7 +265,9 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): "timestamp": "2024-11-07T12:34:56.078901Z", } in captured_logs - def test_supervise_handles_deferred_task(self, test_dags_dir, captured_logs, time_machine, mocker): + def test_supervise_handles_deferred_task( + self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context + ): """ Test that the supervisor handles a deferred task correctly. @@ -281,12 +283,13 @@ def test_supervise_handles_deferred_task(self, test_dags_dir, captured_logs, tim # Create a mock client to assert calls to the client # We assume the implementation of the client is correct and only need to check the calls mock_client = mocker.Mock(spec=sdk_client.Client) + mock_client.task_instances.start.return_value = make_ti_context() instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0) time_machine.move_to(instant, tick=False) # Assert supervisor runs the task successfully - assert supervise(ti=ti, dag_path=dagfile_path, token="", client=mock_client) == 0 + assert supervise(ti=ti, dag_path=dagfile_path, token="", client=mock_client) == 0, captured_logs # Validate calls to the client mock_client.task_instances.start.assert_called_once_with(ti.id, mocker.ANY, mocker.ANY) @@ -320,7 +323,7 @@ def test_supervisor_handles_already_running_task(self): # The API Server would return a 409 Conflict status code if the TI is not # in a "queued" state. def handle_request(request: httpx.Request) -> httpx.Response: - if request.url.path == f"/task-instances/{ti.id}/state": + if request.url.path == f"/task-instances/{ti.id}/run": return httpx.Response( 409, json={ @@ -345,7 +348,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: } @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"]) - def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker): + def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, make_ti_context_dict): """ Test that ensures that the Supervisor does not cause the task to fail if the Task Instance is no longer in the running state. Instead, it logs the error and terminates the task process if it @@ -383,7 +386,9 @@ def handle_request(request: httpx.Request) -> httpx.Response: "current_state": "success", }, ) - # Return a 204 for all other requests like the initial call to mark the task as running + elif request.url.path == f"/task-instances/{ti_id}/run": + return httpx.Response(200, json=make_ti_context_dict()) + # Return a 204 for all other requests return httpx.Response(status_code=204) proc = WatchedSubprocess.start( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index c9755c252bbe6..2b812c92a7338 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -94,7 +94,12 @@ def test_recv_StartupDetails(self): w.makefile("wb").write( b'{"type":"StartupDetails", "ti": {' b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", "dag_id": "c" }, ' - b'"file": "/dev/null", "requests_fd": ' + str(w2.fileno()).encode("ascii") + b"}\n" + b'"ti_context":{"dag_run":{"dag_id":"c","run_id":"b","logical_date":"2024-12-01T01:00:00Z",' + b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' + b'"start_date":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},' + b'"variables":null,"connections":null},"file": "/dev/null", "requests_fd": ' + + str(w2.fileno()).encode("ascii") + + b"}\n" ) decoder = CommsDecoder(input=r.makefile("r")) @@ -112,12 +117,13 @@ def test_recv_StartupDetails(self): assert decoder.request_socket.fileno() == w2.fileno() -def test_parse(test_dags_dir: Path): +def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1), file=str(test_dags_dir / "super_basic.py"), requests_fd=0, + ti_context=make_ti_context(), ) ti = parse(what) @@ -128,12 +134,13 @@ def test_parse(test_dags_dir: Path): assert isinstance(ti.task.dag, DAG) -def test_run_basic(time_machine, mocked_parse): +def test_run_basic(time_machine, mocked_parse, make_ti_context): """Test running a basic task.""" what = StartupDetails( ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), file="", requests_fd=0, + ti_context=make_ti_context(), ) instant = timezone.datetime(2024, 12, 3, 10, 0) @@ -150,7 +157,7 @@ def test_run_basic(time_machine, mocked_parse): ) -def test_run_deferred_basic(time_machine, mocked_parse): +def test_run_deferred_basic(time_machine, mocked_parse, make_ti_context): """Test that a task can transition to a deferred state.""" import datetime @@ -169,6 +176,7 @@ def test_run_deferred_basic(time_machine, mocked_parse): ti=TaskInstance(id=uuid7(), task_id="async", dag_id="basic_deferred_run", run_id="c", try_number=1), file="", requests_fd=0, + ti_context=make_ti_context(), ) # Expected DeferTask @@ -194,7 +202,7 @@ def test_run_deferred_basic(time_machine, mocked_parse): mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY) -def test_run_basic_skipped(time_machine, mocked_parse): +def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context): """Test running a basic task that marks itself skipped.""" from airflow.providers.standard.operators.python import PythonOperator @@ -209,6 +217,7 @@ def test_run_basic_skipped(time_machine, mocked_parse): ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="basic_skipped", run_id="c", try_number=1), file="", requests_fd=0, + ti_context=make_ti_context(), ) ti = mocked_parse(what, "basic_skipped", task) @@ -226,7 +235,7 @@ def test_run_basic_skipped(time_machine, mocked_parse): ) -def test_startup_basic_templated_dag(mocked_parse): +def test_startup_basic_templated_dag(mocked_parse, make_ti_context): """Test running a DAG with templated task.""" from airflow.providers.standard.operators.bash import BashOperator @@ -241,6 +250,7 @@ def test_startup_basic_templated_dag(mocked_parse): ), file="", requests_fd=0, + ti_context=make_ti_context(), ) mocked_parse(what, "basic_templated_dag", task) @@ -288,7 +298,9 @@ def test_startup_basic_templated_dag(mocked_parse): ), ], ) -def test_startup_dag_with_templated_fields(mocked_parse, task_params, expected_rendered_fields): +def test_startup_dag_with_templated_fields( + mocked_parse, task_params, expected_rendered_fields, make_ti_context +): """Test startup of a DAG with various templated fields.""" class CustomOperator(BaseOperator): @@ -305,6 +317,7 @@ def __init__(self, *args, **kwargs): ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1), file="", requests_fd=0, + ti_context=make_ti_context(), ) mocked_parse(what, "basic_dag", task) @@ -318,3 +331,73 @@ def __init__(self, *args, **kwargs): msg=SetRenderedFields(rendered_fields=expected_rendered_fields), log=mock.ANY, ) + + +class TestRuntimeTaskInstance: + def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): + """Test get_template_context without ti_context_from_server.""" + + task = BaseOperator(task_id="hello") + + ti_id = uuid7() + ti = TaskInstance( + id=ti_id, task_id=task.task_id, dag_id="basic_task", run_id="test_run", try_number=1 + ) + + what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context()) + runtime_ti = mocked_parse(what, ti.dag_id, task) + context = runtime_ti.get_template_context() + + # Verify the context keys and values + assert context == { + "dag": runtime_ti.task.dag, + "inlets": task.inlets, + "map_index_template": task.map_index_template, + "outlets": task.outlets, + "run_id": "test_run", + "task": task, + "task_instance": runtime_ti, + "ti": runtime_ti, + } + + def test_get_context_with_ti_context_from_server(self, mocked_parse, make_ti_context): + """Test the context keys are added when sent from API server (mocked)""" + from airflow.utils import timezone + + ti = TaskInstance(id=uuid7(), task_id="hello", dag_id="basic_task", run_id="test_run", try_number=1) + + task = BaseOperator(task_id=ti.task_id) + + ti_context = make_ti_context(dag_id=ti.dag_id, run_id=ti.run_id) + what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=ti_context) + + runtime_ti = mocked_parse(what, ti.dag_id, task) + + # Assume the context is sent from the API server + # `task_sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received + # from the API server + runtime_ti._ti_context_from_server = ti_context + dr = ti_context.dag_run + + context = runtime_ti.get_template_context() + + assert context == { + "dag": runtime_ti.task.dag, + "inlets": task.inlets, + "map_index_template": task.map_index_template, + "outlets": task.outlets, + "run_id": "test_run", + "task": task, + "task_instance": runtime_ti, + "ti": runtime_ti, + "dag_run": dr, + "data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0), + "data_interval_start": timezone.datetime(2024, 12, 1, 0, 0, 0), + "logical_date": timezone.datetime(2024, 12, 1, 1, 0, 0), + "ds": "2024-12-01", + "ds_nodash": "20241201", + "task_instance_key_str": "basic_task__hello__20241201", + "ts": "2024-12-01T01:00:00+00:00", + "ts_nodash": "20241201T010000", + "ts_nodash_with_tz": "20241201T010000+0000", + } diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 15e56bbc58710..e67d82a718cd6 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -39,40 +39,58 @@ DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z") -class TestTIUpdateState: +class TestTIRunState: def setup_method(self): clear_db_runs() def teardown_method(self): clear_db_runs() - def test_ti_update_state_to_running(self, client, session, create_task_instance): + def test_ti_run_state_to_running(self, client, session, create_task_instance, time_machine): """ Test that the Task Instance state is updated to running when the Task Instance is in a state where it can be marked as running. """ + instant_str = "2024-09-30T12:00:00Z" + instant = timezone.parse(instant_str) + time_machine.move_to(instant, tick=False) ti = create_task_instance( - task_id="test_ti_update_state_to_running", + task_id="test_ti_run_state_to_running", state=State.QUEUED, session=session, + start_date=instant, ) session.commit() response = client.patch( - f"/execution/task-instances/{ti.id}/state", + f"/execution/task-instances/{ti.id}/run", json={ "state": "running", "hostname": "random-hostname", "unixname": "random-unixname", "pid": 100, - "start_date": "2024-10-31T12:00:00Z", + "start_date": instant_str, }, ) - assert response.status_code == 204 - assert response.text == "" + assert response.status_code == 200 + assert response.json() == { + "dag_run": { + "dag_id": "dag", + "run_id": "test", + "logical_date": instant_str, + "data_interval_start": instant.subtract(days=1).to_iso8601_string(), + "data_interval_end": instant_str, + "start_date": instant_str, + "end_date": None, + "run_type": "manual", + "conf": {}, + }, + "variables": [], + "connections": [], + } # Refresh the Task Instance from the database so that we can check the updated values session.refresh(ti) @@ -80,10 +98,10 @@ def test_ti_update_state_to_running(self, client, session, create_task_instance) assert ti.hostname == "random-hostname" assert ti.unixname == "random-unixname" assert ti.pid == 100 - assert ti.start_date.isoformat() == "2024-10-31T12:00:00+00:00" + assert ti.start_date == instant @pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState if s != State.QUEUED]) - def test_ti_update_state_conflict_if_not_queued( + def test_ti_run_state_conflict_if_not_queued( self, client, session, create_task_instance, initial_ti_state ): """ @@ -91,13 +109,13 @@ def test_ti_update_state_conflict_if_not_queued( running. In this case, the Task Instance is first in NONE state so it cannot be marked as running. """ ti = create_task_instance( - task_id="test_ti_update_state_conflict_if_not_queued", + task_id="test_ti_run_state_conflict_if_not_queued", state=initial_ti_state, ) session.commit() response = client.patch( - f"/execution/task-instances/{ti.id}/state", + f"/execution/task-instances/{ti.id}/run", json={ "state": "running", "hostname": "random-hostname", @@ -118,6 +136,14 @@ def test_ti_update_state_conflict_if_not_queued( assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state + +class TestTIUpdateState: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + @pytest.mark.parametrize( ("state", "end_date", "expected_state"), [ @@ -160,7 +186,7 @@ def test_ti_update_state_not_found(self, client, session): task_instance_id = "0182e924-0f1e-77e6-ab50-e977118bc139" # Pre-condition: the Task Instance does not exist - assert session.scalar(select(TaskInstance.id).where(TaskInstance.id == task_instance_id)) is None + assert session.get(TaskInstance, task_instance_id) is None payload = {"state": "success", "end_date": "2024-10-31T12:30:00Z"} @@ -171,6 +197,26 @@ def test_ti_update_state_not_found(self, client, session): "message": "Task Instance not found", } + def test_ti_update_state_running_errors(self, client, session, create_task_instance, time_machine): + """ + Test that a 422 error is returned when the Task Instance state is RUNNING in the payload. + + Task should be set to Running state via the /execution/task-instances/{task_instance_id}/run endpoint. + """ + + ti = create_task_instance( + task_id="test_ti_update_state_running_errors", + state=State.QUEUED, + session=session, + start_date=DEFAULT_START_DATE, + ) + + session.commit() + + response = client.patch(f"/execution/task-instances/{ti.id}/state", json={"state": "running"}) + + assert response.status_code == 422 + def test_ti_update_state_database_error(self, client, session, create_task_instance): """ Test that a database error is handled correctly when updating the Task Instance state. @@ -181,17 +227,14 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta ) session.commit() payload = { - "state": "running", - "hostname": "random-hostname", - "unixname": "random-unixname", - "pid": 100, - "start_date": "2024-10-31T12:00:00Z", + "state": "success", + "end_date": "2024-10-31T12:00:00Z", } with mock.patch( "airflow.api_fastapi.common.db.common.Session.execute", side_effect=[ - mock.Mock(one=lambda: ("queued",)), # First call returns "queued" + mock.Mock(one=lambda: ("running",)), # First call returns "queued" SQLAlchemyError("Database error"), # Second call raises an error ], ): @@ -334,7 +377,7 @@ def test_ti_heartbeat_non_existent_task(self, client, session, create_task_insta task_instance_id = "0182e924-0f1e-77e6-ab50-e977118bc139" # Pre-condition: the Task Instance does not exist - assert session.scalar(select(TaskInstance.id).where(TaskInstance.id == task_instance_id)) is None + assert session.get(TaskInstance, task_instance_id) is None response = client.put( f"/execution/task-instances/{task_instance_id}/heartbeat",