Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP 72: Handling deferrable tasks in execution API as well as TASK SDK #44241

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import uuid
from typing import Annotated, Literal, Union
from typing import Annotated, Any, Literal, Union

from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema

Expand Down Expand Up @@ -61,6 +61,29 @@ class TITargetStatePayload(BaseModel):
state: IntermediateTIState


class TIDeferredStatePayload(BaseModel):
"""Schema for updating TaskInstance to a deferred state."""

state: Annotated[
Literal[IntermediateTIState.DEFERRED],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
WithJsonSchema(
{
"type": "string",
"enum": [IntermediateTIState.DEFERRED],
"default": IntermediateTIState.DEFERRED,
}
),
]

classpath: str
kwargs: dict[str, Any]
created_date: UtcDateTime
next_method: str
# need to serialise to datetime.timedelta
timeout: str | None


def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
"""
Determine the discriminator key for TaskInstance state transitions.
Expand All @@ -78,6 +101,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
return str(state)
elif state in set(TerminalTIState):
return "_terminal_"
elif state == "deferred":
return "deferred"
return "_other_"


Expand All @@ -88,6 +113,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
Annotated[TIEnterRunningPayload, Tag("running")],
Annotated[TITerminalStatePayload, Tag("_terminal_")],
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
],
Discriminator(ti_state_discriminator),
]
Expand Down
54 changes: 52 additions & 2 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from __future__ import annotations

import logging
from typing import Annotated
from datetime import timedelta
from typing import TYPE_CHECKING, Annotated
from uuid import UUID

from fastapi import Body, Depends, HTTPException, status
Expand All @@ -30,14 +31,16 @@
from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIStateUpdate,
TITerminalStatePayload,
)
from airflow.models import Trigger
from airflow.models.taskinstance import TaskInstance as TI
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState

# TODO: Add dependency on JWT token
router = AirflowRouter()
Expand Down Expand Up @@ -122,6 +125,53 @@ def ti_update_state(
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
trigger_row = Trigger(
classpath=ti_patch_payload.classpath,
kwargs=ti_patch_payload.kwargs,
created_date=ti_patch_payload.created_date,
)
session.add(trigger_row)
session.flush()

ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()

if not ti:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"TaskInstance with id {ti_id_str} not found.",
},
)

ti.state = TaskInstanceState.DEFERRED
ti.trigger_id = trigger_row.id
ti.next_method = ti_patch_payload.next_method
ti.next_kwargs = ti_patch_payload.kwargs or {}
# handle properly based on client
timeout = ti_patch_payload.timeout
# Calculate timeout too if it was passed
if timeout is not None:
ti.trigger_timeout = timezone.utcnow() + timedelta(days=int(timeout))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Timeout shouldn't be days!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to handle it better based on what I can send from the client, this isn't right yet

else:
ti.trigger_timeout = None

# If an execution_timeout is set, set the timeout to the minimum of
# it and the trigger timeout
if ti.task:
execution_timeout = ti.task.execution_timeout
if execution_timeout:
if TYPE_CHECKING:
assert ti.start_date
if ti.trigger_timeout:
ti.trigger_timeout = min(ti.start_date + execution_timeout, ti.trigger_timeout)
else:
ti.trigger_timeout = ti.start_date + execution_timeout

session.commit()

log.info("TI %s state updated to: deferred", ti_id_str)
return

# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
Expand Down
47 changes: 46 additions & 1 deletion tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

from __future__ import annotations

from datetime import datetime
from unittest import mock

import pytest
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError

from airflow.models import Trigger
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState

from tests_common.test_utils.db import clear_db_runs

Expand Down Expand Up @@ -193,6 +195,49 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta
assert response.status_code == 500
assert response.json()["detail"] == "Database error occurred"

def test_ti_update_state_to_deferred(self, client, session, create_task_instance):
"""
Test that tests if the transition to deferred state is handled correctly.
"""

ti = create_task_instance(
task_id="test_ti_update_state_to_deferred",
state=State.RUNNING,
session=session,
)
session.commit()

payload = {
"state": "deferred",
"classpath": "my-class-path",
"kwargs": {},
"created_date": "2024-10-31T12:00:00Z",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why is this being sent in the payload? Should it be the time the server received the request instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is handled this way:

self.created_date = created_date or timezone.utcnow()

In case we do not pass it, it will take the UTC now. Should we allow the option to override it or remove it entirely?

"next_method": "execute_callback",
"timeout": None,
}

response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

t = session.query(Trigger).all()
assert len(t) == 1
assert t[0].created_date == datetime(2024, 10, 31, 12, 0, tzinfo=timezone.utc)
assert t[0].classpath == "my-class-path"
assert t[0].kwargs == {}

tis = session.query(TaskInstance).all()
assert len(tis) == 1

assert tis[0].state == TaskInstanceState.DEFERRED
assert tis[0].trigger_id == t[0].id
assert tis[0].next_method == "execute_callback"
assert tis[0].next_kwargs == {}
assert tis[0].trigger_timeout is None


class TestTIHealthEndpoint:
def setup_method(self):
Expand Down