Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate public endpoint Set Task Instances State to FastAPI #44246

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) ->
)


@mark_fastapi_migration_done
@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE)
@action_logging
@provide_session
Expand Down
42 changes: 42 additions & 0 deletions airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ConfigDict,
Field,
NonNegativeInt,
field_validator,
)

from airflow.api_fastapi.core_api.datamodels.job import JobResponse
Expand Down Expand Up @@ -150,3 +151,44 @@ class TaskInstanceHistoryCollectionResponse(BaseModel):

task_instances: list[TaskInstanceHistoryResponse]
total_entries: int


class SetTaskInstancesStateBody(BaseModel):
"""Request body for Set Task Instances State endpoint."""

dry_run: bool = True
task_id: str
dag_run_id: str
include_upstream: bool
include_downstream: bool
include_future: bool
include_past: bool
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved
new_state: str
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved

@field_validator("new_state", mode="before")
@classmethod
def validate_new_state(cls, ns: str) -> str:
"""Validate new_state."""
valid_states = [
vs.name.lower()
for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED)
]
ns = ns.lower()
if ns not in valid_states:
raise ValueError(f"'{ns}' is not one of {valid_states}")
return ns


class TaskInstanceReferenceResponse(BaseModel):
"""Task Instance Reference serializer for responses."""

task_id: str
dag_run_id: str = Field(validation_alias="run_id")
dag_id: str
logical_date: datetime
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved


class TaskInstanceReferenceCollectionResponse(BaseModel):
"""Task Instance Reference collection serializer for responses."""

task_instances: list[TaskInstanceReferenceResponse]
130 changes: 130 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4181,6 +4181,63 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/updateTaskInstancesState:
put:
tags:
- Task Instance
summary: Set Task Instances State
description: Set a state of task instances.
operationId: set_task_instances_state
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/SetTaskInstancesStateBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/TaskInstanceReferenceCollectionResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/tasks:
get:
tags:
Expand Down Expand Up @@ -6709,6 +6766,44 @@ components:
- latest_scheduler_heartbeat
title: SchedulerInfoSchema
description: Schema for Scheduler info.
SetTaskInstancesStateBody:
properties:
dry_run:
type: boolean
title: Dry Run
default: true
task_id:
type: string
title: Task Id
dag_run_id:
type: string
title: Dag Run Id
include_upstream:
type: boolean
title: Include Upstream
include_downstream:
type: boolean
title: Include Downstream
include_future:
type: boolean
title: Include Future
include_past:
type: boolean
title: Include Past
new_state:
type: string
title: New State
type: object
required:
- task_id
- dag_run_id
- include_upstream
- include_downstream
- include_future
- include_past
- new_state
title: SetTaskInstancesStateBody
description: Request body for Set Task Instances State endpoint.
TaskCollectionResponse:
properties:
tasks:
Expand Down Expand Up @@ -6887,6 +6982,41 @@ components:
- executor_config
title: TaskInstanceHistoryResponse
description: TaskInstanceHistory serializer for responses.
TaskInstanceReferenceCollectionResponse:
properties:
task_instances:
items:
$ref: '#/components/schemas/TaskInstanceReferenceResponse'
type: array
title: Task Instances
type: object
required:
- task_instances
title: TaskInstanceReferenceCollectionResponse
description: Task Instance Reference collection serializer for responses.
TaskInstanceReferenceResponse:
properties:
task_id:
type: string
title: Task Id
dag_run_id:
type: string
title: Dag Run Id
dag_id:
type: string
title: Dag Id
logical_date:
type: string
format: date-time
title: Logical Date
type: object
required:
- task_id
- dag_run_id
- dag_id
- logical_date
title: TaskInstanceReferenceResponse
description: Task Instance Reference serializer for responses.
TaskInstanceResponse:
properties:
id:
Expand Down
81 changes: 69 additions & 12 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.task_instances import (
SetTaskInstancesStateBody,
TaskDependencyCollectionResponse,
TaskInstanceCollectionResponse,
TaskInstanceHistoryResponse,
TaskInstanceReferenceCollectionResponse,
TaskInstanceReferenceResponse,
TaskInstanceResponse,
TaskInstancesBatchBody,
)
Expand All @@ -64,13 +67,12 @@
from airflow.utils.db import get_query_count
from airflow.utils.state import TaskInstanceState

task_instances_router = AirflowRouter(
tags=["Task Instance"], prefix="/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances"
)
task_instances_router = AirflowRouter(tags=["Task Instance"], prefix="/dags/{dag_id}")
task_instances_prefix = "/dagRuns/{dag_run_id}/taskInstances"
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved


@task_instances_router.get(
"/{task_id}",
task_instances_prefix + "/{task_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_task_instance(
Expand Down Expand Up @@ -99,7 +101,7 @@ def get_task_instance(


@task_instances_router.get(
"/{task_id}/listMapped",
task_instances_prefix + "/{task_id}/listMapped",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_mapped_task_instances(
Expand Down Expand Up @@ -182,11 +184,11 @@ def get_mapped_task_instances(


@task_instances_router.get(
"/{task_id}/dependencies",
task_instances_prefix + "/{task_id}/dependencies",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
@task_instances_router.get(
"/{task_id}/{map_index}/dependencies",
task_instances_prefix + "/{task_id}/{map_index}/dependencies",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_task_instance_dependencies(
Expand Down Expand Up @@ -236,7 +238,7 @@ def get_task_instance_dependencies(


@task_instances_router.get(
"/{task_id}/{map_index}",
task_instances_prefix + "/{task_id}/{map_index}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_mapped_task_instance(
Expand Down Expand Up @@ -265,7 +267,7 @@ def get_mapped_task_instance(


@task_instances_router.get(
"",
task_instances_prefix,
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_task_instances(
Expand Down Expand Up @@ -348,7 +350,7 @@ def get_task_instances(


@task_instances_router.post(
"/list",
task_instances_prefix + "/list",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_task_instances_batch(
Expand Down Expand Up @@ -428,7 +430,7 @@ def get_task_instances_batch(


@task_instances_router.get(
"/{task_id}/tries/{task_try_number}",
task_instances_prefix + "/{task_id}/tries/{task_try_number}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_task_instance_try_details(
Expand Down Expand Up @@ -463,7 +465,7 @@ def _query(orm_object: Base) -> TI | TIH | None:


@task_instances_router.get(
"/{task_id}/{map_index}/tries/{task_try_number}",
task_instances_prefix + "/{task_id}/{map_index}/tries/{task_try_number}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_mapped_task_instance_try_details(
Expand All @@ -482,3 +484,58 @@ def get_mapped_task_instance_try_details(
map_index=map_index,
session=session,
)


@task_instances_router.put(
"/updateTaskInstancesState",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]),
)
def set_task_instances_state(
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved
dag_id: str,
request: Request,
body: SetTaskInstancesStateBody,
session: Annotated[Session, Depends(get_session)],
) -> TaskInstanceReferenceCollectionResponse:
"""Set a state of task instances."""
error_message = f"Dag ID {dag_id} not found"
dag = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)

task_id = body.task_id
task = dag.task_dict.get(task_id)
if not task:
error_message = f"Task ID {task_id} not found"
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)

run_id = body.dag_run_id
error_message = f"Task instance not found for task '{task_id}' on DAG run with ID '{run_id}'"
if not run_id:
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)

select_stmt = select(TI).where(
TI.dag_id == dag_id, TI.task_id == task_id, TI.run_id == run_id, TI.map_index == -1
)
if run_id and not session.scalars(select_stmt).one_or_none():
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)

tis = dag.set_task_instance_state(
task_id=task_id,
run_id=run_id,
state=body.new_state,
upstream=body.include_upstream,
downstream=body.include_downstream,
future=body.include_future,
past=body.include_past,
commit=not body.dry_run,
session=session,
)
return TaskInstanceReferenceCollectionResponse(
task_instances=[
TaskInstanceReferenceResponse.model_validate(
ti,
from_attributes=True,
)
for ti in tis
]
)
3 changes: 3 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,9 @@ export type BackfillServiceUnpauseBackfillMutationResult = Awaited<
export type BackfillServiceCancelBackfillMutationResult = Awaited<
ReturnType<typeof BackfillService.cancelBackfill>
>;
export type TaskInstanceServiceSetTaskInstancesStateMutationResult = Awaited<
ReturnType<typeof TaskInstanceService.setTaskInstancesState>
>;
export type ConnectionServicePatchConnectionMutationResult = Awaited<
ReturnType<typeof ConnectionService.patchConnection>
>;
Expand Down
Loading