Skip to content

Commit

Permalink
Add set note functionality, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
omkar-foss committed Nov 21, 2024
1 parent 4795c73 commit addcfe1
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 104 deletions.
35 changes: 22 additions & 13 deletions airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
from __future__ import annotations

from datetime import datetime
from typing import Annotated
from typing import Annotated, Any

from pydantic import (
AliasChoices,
AliasPath,
AwareDatetime,
BaseModel,
Expand All @@ -29,6 +28,7 @@
Field,
NonNegativeInt,
field_validator,
model_validator,
)

from airflow.api_fastapi.core_api.datamodels.job import JobResponse
Expand Down Expand Up @@ -154,25 +154,34 @@ class TaskInstanceHistoryCollectionResponse(BaseModel):
total_entries: int


class TaskInstanceReferenceResponse(BaseModel):
"""Task Instance Reference serializer for responses."""

task_id: str
dag_run_id: str = Field(validation_alias=AliasChoices("run_id"))
dag_id: str
logical_date: datetime


class PatchTaskInstanceBody(BaseModel):
"""Request body for Clear Task Instances endpoint."""

dry_run: bool = True
new_state: str
new_state: str | None = None
note: str | None = None

@model_validator(mode="before")
@classmethod
def validate_model(cls, data: Any) -> Any:
if data.get("note") is None and data.get("new_state") is None:
raise ValueError("new_state is required.")
return data

@field_validator("note", mode="before")
@classmethod
def validate_note(cls, note: str | None) -> str | None:
"""Validate note."""
if note is None:
return None
if len(note) > 1000:
raise ValueError("Note length should not exceed 1000 characters.")
return note

@field_validator("new_state", mode="before")
@classmethod
def validate_new_state(cls, ns: str) -> str:
"""Convert timezone attribute to string representation."""
"""Validate new_state."""
valid_states = [
vs.name.lower()
for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED)
Expand Down
50 changes: 22 additions & 28 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3434,7 +3434,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/TaskInstanceReferenceResponse'
$ref: '#/components/schemas/TaskInstanceResponse'
'401':
content:
application/json:
Expand All @@ -3447,6 +3447,12 @@ paths:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
Expand Down Expand Up @@ -3892,7 +3898,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/TaskInstanceReferenceResponse'
$ref: '#/components/schemas/TaskInstanceResponse'
'401':
content:
application/json:
Expand All @@ -3905,6 +3911,12 @@ paths:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
Expand Down Expand Up @@ -6553,11 +6565,16 @@ components:
title: Dry Run
default: true
new_state:
type: string
anyOf:
- type: string
- type: 'null'
title: New State
note:
anyOf:
- type: string
- type: 'null'
title: Note
type: object
required:
- new_state
title: PatchTaskInstanceBody
description: Request body for Clear Task Instances endpoint.
PluginCollectionResponse:
Expand Down Expand Up @@ -7038,29 +7055,6 @@ components:
- executor_config
title: TaskInstanceHistoryResponse
description: TaskInstanceHistory serializer for responses.
TaskInstanceReferenceResponse:
properties:
task_id:
type: string
title: Task Id
dag_run_id:
type: string
title: Dag Run Id
dag_id:
type: string
title: Dag Id
logical_date:
type: string
format: date-time
title: Logical Date
type: object
required:
- task_id
- dag_run_id
- dag_id
- logical_date
title: TaskInstanceReferenceResponse
description: Task Instance Reference serializer for responses.
TaskInstanceResponse:
properties:
id:
Expand Down
50 changes: 38 additions & 12 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from typing import Annotated, Literal

from fastapi import Depends, HTTPException, Request, status
from sqlalchemy import or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.sql import select

from airflow.api_fastapi.common.db.common import get_session, paginated_select
from airflow.api_fastapi.common.parameters import (
Expand Down Expand Up @@ -52,7 +53,6 @@
TaskDependencyCollectionResponse,
TaskInstanceCollectionResponse,
TaskInstanceHistoryResponse,
TaskInstanceReferenceResponse,
TaskInstanceResponse,
TaskInstancesBatchBody,
)
Expand Down Expand Up @@ -488,11 +488,11 @@ def get_mapped_task_instance_try_details(

@task_instances_router.patch(
"/{task_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]),
)
@task_instances_router.patch(
"/{task_id}/{map_index}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]),
)
def patch_task_instance(
dag_id: str,
Expand All @@ -502,7 +502,7 @@ def patch_task_instance(
body: PatchTaskInstanceBody,
session: Annotated[Session, Depends(get_session)],
map_index: int = -1,
) -> TaskInstanceReferenceResponse:
) -> TaskInstanceResponse:
"""Update the state of a task instance."""
dag = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
Expand All @@ -511,13 +511,27 @@ def patch_task_instance(
if not dag.has_task(task_id):
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found in DAG '{dag_id}'")

ti: TI | None = session.scalars(
select(TI).where(
TI.dag_id == dag_id, TI.task_id == task_id, TI.run_id == dag_run_id, TI.map_index == map_index
query = (
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
.join(TI.dag_run)
.options(joinedload(TI.rendered_task_instance_fields))
)
if map_index == -1:
query = query.where(or_(TI.map_index == -1, TI.map_index is None))
else:
query = query.where(TI.map_index == map_index)

try:
ti = session.scalar(query)
except MultipleResultsFound:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Multiple task instances found. As the TI is mapped, add the map_index value to the URL",
)
).one_or_none()
err_msg_404 = f"Task instance not found for task '{task_id}' on DAG run with ID '{dag_run_id}'"
if not ti:

err_msg_404 = f"Task Instance not found for dag_id={dag_id}, run_id={dag_run_id}, task_id={task_id}"
if ti is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)

if not body.dry_run:
Expand All @@ -532,4 +546,16 @@ def patch_task_instance(
if not ti:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)
ti = tis[0] if isinstance(tis, list) else tis
return TaskInstanceReferenceResponse.model_validate(ti, from_attributes=True)

# Set new note to the task instance if available in body.
if body.note is not None:
# @TODO: replace None passed for user_id with actual user id when
# permissions and auth is in place.
if ti.task_instance_note is None:
ti.note = (body.note, None)
else:
ti.task_instance_note.content = body.note
ti.task_instance_note.user_id = None
session.commit()

return TaskInstanceResponse.model_validate(ti, from_attributes=True)
4 changes: 2 additions & 2 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2851,7 +2851,7 @@ export const usePoolServicePatchPool = <
* @param data.taskId
* @param data.requestBody
* @param data.mapIndex
* @returns TaskInstanceReferenceResponse Successful Response
* @returns TaskInstanceResponse Successful Response
* @throws ApiError
*/
export const useTaskInstanceServicePatchTaskInstance = <
Expand Down Expand Up @@ -2906,7 +2906,7 @@ export const useTaskInstanceServicePatchTaskInstance = <
* @param data.taskId
* @param data.mapIndex
* @param data.requestBody
* @returns TaskInstanceReferenceResponse Successful Response
* @returns TaskInstanceResponse Successful Response
* @throws ApiError
*/
export const useTaskInstanceServicePatchTaskInstance1 = <
Expand Down
47 changes: 19 additions & 28 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2660,12 +2660,29 @@ export const $PatchTaskInstanceBody = {
default: true,
},
new_state: {
type: "string",
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "New State",
},
note: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Note",
},
},
type: "object",
required: ["new_state"],
title: "PatchTaskInstanceBody",
description: "Request body for Clear Task Instances endpoint.",
} as const;
Expand Down Expand Up @@ -3365,32 +3382,6 @@ export const $TaskInstanceHistoryResponse = {
description: "TaskInstanceHistory serializer for responses.",
} as const;

export const $TaskInstanceReferenceResponse = {
properties: {
task_id: {
type: "string",
title: "Task Id",
},
dag_run_id: {
type: "string",
title: "Dag Run Id",
},
dag_id: {
type: "string",
title: "Dag Id",
},
logical_date: {
type: "string",
format: "date-time",
title: "Logical Date",
},
},
type: "object",
required: ["task_id", "dag_run_id", "dag_id", "logical_date"],
title: "TaskInstanceReferenceResponse",
description: "Task Instance Reference serializer for responses.",
} as const;

export const $TaskInstanceResponse = {
properties: {
id: {
Expand Down
6 changes: 4 additions & 2 deletions airflow/ui/openapi-gen/requests/services.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1869,7 +1869,7 @@ export class TaskInstanceService {
* @param data.taskId
* @param data.requestBody
* @param data.mapIndex
* @returns TaskInstanceReferenceResponse Successful Response
* @returns TaskInstanceResponse Successful Response
* @throws ApiError
*/
public static patchTaskInstance(
Expand All @@ -1889,6 +1889,7 @@ export class TaskInstanceService {
body: data.requestBody,
mediaType: "application/json",
errors: {
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Not Found",
Expand Down Expand Up @@ -2070,7 +2071,7 @@ export class TaskInstanceService {
* @param data.taskId
* @param data.mapIndex
* @param data.requestBody
* @returns TaskInstanceReferenceResponse Successful Response
* @returns TaskInstanceResponse Successful Response
* @throws ApiError
*/
public static patchTaskInstance1(
Expand All @@ -2088,6 +2089,7 @@ export class TaskInstanceService {
body: data.requestBody,
mediaType: "application/json",
errors: {
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Not Found",
Expand Down
Loading

0 comments on commit addcfe1

Please sign in to comment.