diff --git a/airflow-core/docs/img/airflow_erd.sha256 b/airflow-core/docs/img/airflow_erd.sha256 index c53cc36d331db..e935d2a08ab18 100644 --- a/airflow-core/docs/img/airflow_erd.sha256 +++ b/airflow-core/docs/img/airflow_erd.sha256 @@ -1 +1 @@ -e0de73aab81a28995b99be21dd25c8ca31c4e0f4a5a0a26df8aff412e5067fd5 \ No newline at end of file +2e49ab99fe1076b0f3f22a52b9ee37eeb7fc20a5a043ea504cc26022f4315277 \ No newline at end of file diff --git a/airflow-core/docs/img/airflow_erd.svg b/airflow-core/docs/img/airflow_erd.svg index 5565970e5573f..2f9f9b4becc5e 100644 --- a/airflow-core/docs/img/airflow_erd.svg +++ b/airflow-core/docs/img/airflow_erd.svg @@ -922,467 +922,529 @@ task_map - -task_map - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -keys - - [JSONB] - -length - - [INTEGER] - NOT NULL + +task_map + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +keys + + [JSONB] + +length + + [INTEGER] + NOT NULL task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_reschedule - -task_reschedule - -id - - [INTEGER] - NOT NULL - -duration - - [INTEGER] - NOT NULL - -end_date - - [TIMESTAMP] - NOT NULL - -reschedule_date - - [TIMESTAMP] - NOT NULL - -start_date - - [TIMESTAMP] - NOT NULL - -ti_id - - [UUID] - NOT NULL + +task_reschedule + +id + + [INTEGER] + NOT NULL + +duration + + [INTEGER] + NOT NULL + +end_date + + [TIMESTAMP] + NOT NULL + +reschedule_date + + [TIMESTAMP] + NOT NULL + +start_date + + [TIMESTAMP] + NOT NULL + +ti_id + + [UUID] + NOT NULL task_instance--task_reschedule - -0..N -1 + +0..N +1 xcom - -xcom - -dag_run_id - - [INTEGER] - NOT NULL - -key - - [VARCHAR(512)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -timestamp - - [TIMESTAMP] - NOT NULL - -value - - [JSONB] + +xcom + +dag_run_id + + [INTEGER] + NOT NULL + +key + + [VARCHAR(512)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +timestamp + + [TIMESTAMP] + NOT NULL + +value + + [JSONB] task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance_note - -task_instance_note - -ti_id - - [UUID] - NOT NULL - -content - - [VARCHAR(1000)] - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [VARCHAR(128)] + +task_instance_note + +ti_id + + [UUID] + NOT NULL + +content + + [VARCHAR(1000)] + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +user_id + + [VARCHAR(128)] task_instance--task_instance_note - -1 -1 + +1 +1 task_instance_history - -task_instance_history - -task_instance_id - - [UUID] - NOT NULL - -context_carrier - - [JSONB] - -custom_operator_name - - [VARCHAR(1000)] - -dag_id - - [VARCHAR(250)] - NOT NULL - -dag_version_id - - [UUID] - -duration - - [DOUBLE_PRECISION] - -end_date - - [TIMESTAMP] - -executor - - [VARCHAR(1000)] - -executor_config - - [BYTEA] - -external_executor_id - - [VARCHAR(250)] - -hostname - - [VARCHAR(1000)] - -map_index - - [INTEGER] - NOT NULL - -max_tries - - [INTEGER] - -next_kwargs - - [JSONB] - -next_method - - [VARCHAR(1000)] - -operator - - [VARCHAR(1000)] - -pid - - [INTEGER] - -pool - - [VARCHAR(256)] - NOT NULL - -pool_slots - - [INTEGER] - NOT NULL - -priority_weight - - [INTEGER] - -queue - - [VARCHAR(256)] - -queued_by_job_id - - [INTEGER] - -queued_dttm - - [TIMESTAMP] - -rendered_map_index - - [VARCHAR(250)] - -run_id - - [VARCHAR(250)] - NOT NULL - -scheduled_dttm - - [TIMESTAMP] - -span_status - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(20)] - -task_display_name - - [VARCHAR(2000)] - -task_id - - [VARCHAR(250)] - NOT NULL - -trigger_id - - [INTEGER] - -trigger_timeout - - [TIMESTAMP] - -try_number - - [INTEGER] - NOT NULL - -unixname - - [VARCHAR(1000)] - -updated_at - - [TIMESTAMP] + +task_instance_history + +task_instance_id + + [UUID] + NOT NULL + +context_carrier + + [JSONB] + +custom_operator_name + + [VARCHAR(1000)] + +dag_id + + [VARCHAR(250)] + NOT NULL + +dag_version_id + + [UUID] + +duration + + [DOUBLE_PRECISION] + +end_date + + [TIMESTAMP] + +executor + + [VARCHAR(1000)] + +executor_config + + [BYTEA] + +external_executor_id + + [VARCHAR(250)] + +hostname + + [VARCHAR(1000)] + +map_index + + [INTEGER] + NOT NULL + +max_tries + + [INTEGER] + +next_kwargs + + [JSONB] + +next_method + + [VARCHAR(1000)] + +operator + + [VARCHAR(1000)] + +pid + + [INTEGER] + +pool + + [VARCHAR(256)] + NOT NULL + +pool_slots + + [INTEGER] + NOT NULL + +priority_weight + + [INTEGER] + +queue + + [VARCHAR(256)] + +queued_by_job_id + + [INTEGER] + +queued_dttm + + [TIMESTAMP] + +rendered_map_index + + [VARCHAR(250)] + +run_id + + [VARCHAR(250)] + NOT NULL + +scheduled_dttm + + [TIMESTAMP] + +span_status + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(20)] + +task_display_name + + [VARCHAR(2000)] + +task_id + + [VARCHAR(250)] + NOT NULL + +trigger_id + + [INTEGER] + +trigger_timeout + + [TIMESTAMP] + +try_number + + [INTEGER] + NOT NULL + +unixname + + [VARCHAR(1000)] + +updated_at + + [TIMESTAMP] task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 - + -rendered_task_instance_fields - -rendered_task_instance_fields - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -k8s_pod_yaml - - [JSON] - -rendered_fields - - [JSON] - NOT NULL - - +hitl_detail + +hitl_detail + +ti_id + + [UUID] + NOT NULL + +body + + [TEXT] + +chosen_options + + [JSON] + +defaults + + [JSON] + +multiple + + [BOOLEAN] + +options + + [JSON] + NOT NULL + +params + + [JSON] + NOT NULL + +params_input + + [JSON] + NOT NULL + +response_at + + [TIMESTAMP] + +subject + + [TEXT] + NOT NULL + +user_id + + [VARCHAR(128)] + + -task_instance--rendered_task_instance_fields - -0..N -1 +task_instance--hitl_detail + +1 +1 + + + +rendered_task_instance_fields + +rendered_task_instance_fields + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +k8s_pod_yaml + + [JSON] + +rendered_fields + + [JSON] + NOT NULL task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 + + + +task_instance--rendered_task_instance_fields + +0..N +1 @@ -2259,7 +2321,7 @@ 1 - + alembic_version alembic_version diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index 0a2d4ea6a89f4..0e18989fbc8a7 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``5d3072c51bac`` (head) | ``ffdb0566c7c0`` | ``3.1.0`` | Make dag_version_id non-nullable in TaskInstance. | +| ``40f7c30a228b`` (head) | ``5d3072c51bac`` | ``3.1.0`` | Add Human In the Loop Detail table. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``5d3072c51bac`` | ``ffdb0566c7c0`` | ``3.1.0`` | Make dag_version_id non-nullable in TaskInstance. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``ffdb0566c7c0`` | ``66a7743fe20e`` | ``3.1.0`` | Add dag_favorite table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py new file mode 100644 index 0000000000000..88ad702316423 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Mapping +from datetime import datetime +from typing import Any + +from pydantic import Field, field_validator + +from airflow.api_fastapi.core_api.base import BaseModel +from airflow.sdk import Param + + +class UpdateHITLDetailPayload(BaseModel): + """Schema for updating the content of a Human-in-the-loop detail.""" + + chosen_options: list[str] + params_input: Mapping = Field(default_factory=dict) + + +class HITLDetailResponse(BaseModel): + """Response of updating a Human-in-the-loop detail.""" + + user_id: str + response_at: datetime + chosen_options: list[str] + params_input: Mapping = Field(default_factory=dict) + + +class HITLDetail(BaseModel): + """Schema for Human-in-the-loop detail.""" + + ti_id: str + + # User Request Detail + options: list[str] + subject: str + body: str | None = None + defaults: list[str] | None = None + multiple: bool = False + params: dict[str, Any] = Field(default_factory=dict) + + # Response Content Detail + user_id: str | None = None + response_at: datetime | None = None + chosen_options: list[str] | None = None + params_input: dict[str, Any] = Field(default_factory=dict) + + response_received: bool = False + + @field_validator("params", mode="before") + @classmethod + def get_params(cls, params: dict[str, Any]) -> dict[str, Any]: + """Convert params attribute to dict representation.""" + return {k: v.dump() if isinstance(v, Param) else v for k, v in params.items()} + + +class HITLDetailCollection(BaseModel): + """Schema for a collection of Human-in-the-loop details.""" + + hitl_details: list[HITLDetail] + total_entries: int diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 073e2c9c8a38e..ae9a645fc2e0f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -7122,6 +7122,304 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /api/v2/hitl-details/{dag_id}/{dag_run_id}/{task_id}: + patch: + tags: + - HumanInTheLoop + summary: Update Hitl Detail + description: Update a Human-in-the-loop detail. + operationId: update_hitl_detail + security: + - OAuth2PasswordBearer: [] + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateHITLDetailPayload' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/HITLDetailResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + get: + tags: + - HumanInTheLoop + summary: Get Hitl Detail + description: Get a Human-in-the-loop detail of a specific task instance. + operationId: get_hitl_detail + security: + - OAuth2PasswordBearer: [] + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/HITLDetail' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/v2/hitl-details/{dag_id}/{dag_run_id}/{task_id}/{map_index}: + patch: + tags: + - HumanInTheLoop + summary: Update Mapped Ti Hitl Detail + description: Update a Human-in-the-loop detail. + operationId: update_mapped_ti_hitl_detail + security: + - OAuth2PasswordBearer: [] + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: map_index + in: path + required: true + schema: + type: integer + title: Map Index + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateHITLDetailPayload' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/HITLDetailResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + get: + tags: + - HumanInTheLoop + summary: Get Mapped Ti Hitl Detail + description: Get a Human-in-the-loop detail of a specific task instance. + operationId: get_mapped_ti_hitl_detail + security: + - OAuth2PasswordBearer: [] + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: map_index + in: path + required: true + schema: + type: integer + title: Map Index + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/HITLDetail' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/v2/hitl-details/: + get: + tags: + - HumanInTheLoop + summary: Get Hitl Details + description: Get Human-in-the-loop details. + operationId: get_hitl_details + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/HITLDetailCollection' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + security: + - OAuth2PasswordBearer: [] /api/v2/monitor/health: get: tags: @@ -9591,6 +9889,113 @@ components: - name title: FastAPIRootMiddlewareResponse description: Serializer for Plugin FastAPI root middleware responses. + HITLDetail: + properties: + ti_id: + type: string + title: Ti Id + options: + items: + type: string + type: array + title: Options + subject: + type: string + title: Subject + body: + anyOf: + - type: string + - type: 'null' + title: Body + defaults: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Defaults + multiple: + type: boolean + title: Multiple + default: false + params: + additionalProperties: true + type: object + title: Params + user_id: + anyOf: + - type: string + - type: 'null' + title: User Id + response_at: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Response At + chosen_options: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Chosen Options + params_input: + additionalProperties: true + type: object + title: Params Input + response_received: + type: boolean + title: Response Received + default: false + type: object + required: + - ti_id + - options + - subject + title: HITLDetail + description: Schema for Human-in-the-loop detail. + HITLDetailCollection: + properties: + hitl_details: + items: + $ref: '#/components/schemas/HITLDetail' + type: array + title: Hitl Details + total_entries: + type: integer + title: Total Entries + type: object + required: + - hitl_details + - total_entries + title: HITLDetailCollection + description: Schema for a collection of Human-in-the-loop details. + HITLDetailResponse: + properties: + user_id: + type: string + title: User Id + response_at: + type: string + format: date-time + title: Response At + chosen_options: + items: + type: string + type: array + title: Chosen Options + params_input: + additionalProperties: true + type: object + title: Params Input + type: object + required: + - user_id + - response_at + - chosen_options + title: HITLDetailResponse + description: Response of updating a Human-in-the-loop detail. HTTPExceptionResponse: properties: detail: @@ -11120,6 +11525,22 @@ components: - latest_triggerer_heartbeat title: TriggererInfoResponse description: Triggerer info serializer for responses. + UpdateHITLDetailPayload: + properties: + chosen_options: + items: + type: string + type: array + title: Chosen Options + params_input: + additionalProperties: true + type: object + title: Params Input + type: object + required: + - chosen_options + title: UpdateHITLDetailPayload + description: Schema for updating the content of a Human-in-the-loop detail. ValidationError: properties: loc: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/__init__.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/__init__.py index fbbfb46dfa8d0..6db86ce2327a6 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/__init__.py @@ -37,6 +37,7 @@ from airflow.api_fastapi.core_api.routes.public.dags import dags_router from airflow.api_fastapi.core_api.routes.public.event_logs import event_logs_router from airflow.api_fastapi.core_api.routes.public.extra_links import extra_links_router +from airflow.api_fastapi.core_api.routes.public.hitl import hitl_router from airflow.api_fastapi.core_api.routes.public.import_error import import_error_router from airflow.api_fastapi.core_api.routes.public.job import job_router from airflow.api_fastapi.core_api.routes.public.log import task_instances_log_router @@ -83,6 +84,7 @@ authenticated_router.include_router(dag_parsing_router) authenticated_router.include_router(dag_tags_router) authenticated_router.include_router(dag_versions_router) +authenticated_router.include_router(hitl_router) # Include authenticated router in public router diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py new file mode 100644 index 0000000000000..78c7604b51677 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import structlog +from fastapi import Depends, HTTPException, status +from sqlalchemy import select + +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity +from airflow.api_fastapi.common.db.common import SessionDep, paginated_select +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.core_api.datamodels.hitl import ( + HITLDetail, + HITLDetailCollection, + HITLDetailResponse, + UpdateHITLDetailPayload, +) +from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import GetUserDep, ReadableTIFilterDep, requires_access_dag +from airflow.models.hitl import HITLDetail as HITLDetailModel +from airflow.models.taskinstance import TaskInstance as TI +from airflow.utils import timezone + +hitl_router = AirflowRouter(tags=["HumanInTheLoop"], prefix="/hitl-details") + +log = structlog.get_logger(__name__) + + +def _get_task_instance( + dag_id: str, + dag_run_id: str, + task_id: str, + session: SessionDep, + map_index: int | None = None, +) -> TI: + query = select(TI).where( + TI.dag_id == dag_id, + TI.run_id == dag_run_id, + TI.task_id == task_id, + ) + + if map_index is not None: + query = query.where(TI.map_index == map_index) + + task_instance = session.scalar(query) + if task_instance is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}` and map_index: `{map_index}` was not found", + ) + if map_index is None and task_instance.map_index != -1: + raise HTTPException( + status.HTTP_404_NOT_FOUND, "Task instance is mapped, add the map_index value to the URL" + ) + + return task_instance + + +def _update_hitl_detail( + dag_id: str, + dag_run_id: str, + task_id: str, + update_hitl_detail_payload: UpdateHITLDetailPayload, + user: GetUserDep, + session: SessionDep, + map_index: int | None = None, +) -> HITLDetailResponse: + task_instance = _get_task_instance( + dag_id=dag_id, + dag_run_id=dag_run_id, + task_id=task_id, + session=session, + map_index=map_index, + ) + ti_id_str = str(task_instance.id) + hitl_detail_model = session.scalar(select(HITLDetailModel).where(HITLDetailModel.ti_id == ti_id_str)) + if not hitl_detail_model: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"Human-in-the-loop detail does not exist for Task Instance with id {ti_id_str}", + ) + + if hitl_detail_model.response_received: + raise HTTPException( + status.HTTP_409_CONFLICT, + f"Human-in-the-loop detail has already been updated for Task Instance with id {ti_id_str} " + "and is not allowed to write again.", + ) + + hitl_detail_model.user_id = user.get_id() + hitl_detail_model.response_at = timezone.utcnow() + hitl_detail_model.chosen_options = update_hitl_detail_payload.chosen_options + hitl_detail_model.params_input = update_hitl_detail_payload.params_input + session.add(hitl_detail_model) + session.commit() + return HITLDetailResponse.model_validate(hitl_detail_model) + + +def _get_hitl_detail( + dag_id: str, + dag_run_id: str, + task_id: str, + session: SessionDep, + map_index: int | None = None, +) -> HITLDetail: + """Get a Human-in-the-loop detail of a specific task instance.""" + task_instance = _get_task_instance( + dag_id=dag_id, + dag_run_id=dag_run_id, + task_id=task_id, + session=session, + map_index=map_index, + ) + if task_instance is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}` and map_index: `{map_index}` was not found", + ) + + ti_id_str = str(task_instance.id) + hitl_detail_model = session.scalar(select(HITLDetailModel).where(HITLDetailModel.ti_id == ti_id_str)) + if not hitl_detail_model: + log.error("Human-in-the-loop detail not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": "Human-in-the-loop detail not found", + }, + ) + return HITLDetail.model_validate(hitl_detail_model) + + +@hitl_router.patch( + "/{dag_id}/{dag_run_id}/{task_id}", + responses=create_openapi_http_exception_doc( + [ + status.HTTP_404_NOT_FOUND, + status.HTTP_409_CONFLICT, + ] + ), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], +) +def update_hitl_detail( + dag_id: str, + dag_run_id: str, + task_id: str, + update_hitl_detail_payload: UpdateHITLDetailPayload, + user: GetUserDep, + session: SessionDep, +) -> HITLDetailResponse: + """Update a Human-in-the-loop detail.""" + return _update_hitl_detail( + dag_id=dag_id, + dag_run_id=dag_run_id, + task_id=task_id, + session=session, + update_hitl_detail_payload=update_hitl_detail_payload, + user=user, + map_index=None, + ) + + +@hitl_router.patch( + "/{dag_id}/{dag_run_id}/{task_id}/{map_index}", + responses=create_openapi_http_exception_doc( + [ + status.HTTP_404_NOT_FOUND, + status.HTTP_409_CONFLICT, + ] + ), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], +) +def update_mapped_ti_hitl_detail( + dag_id: str, + dag_run_id: str, + task_id: str, + update_hitl_detail_payload: UpdateHITLDetailPayload, + user: GetUserDep, + session: SessionDep, + map_index: int, +) -> HITLDetailResponse: + """Update a Human-in-the-loop detail.""" + return _update_hitl_detail( + dag_id=dag_id, + dag_run_id=dag_run_id, + task_id=task_id, + session=session, + update_hitl_detail_payload=update_hitl_detail_payload, + user=user, + map_index=map_index, + ) + + +@hitl_router.get( + "/{dag_id}/{dag_run_id}/{task_id}", + status_code=status.HTTP_200_OK, + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], +) +def get_hitl_detail( + dag_id: str, + dag_run_id: str, + task_id: str, + session: SessionDep, +) -> HITLDetail: + """Get a Human-in-the-loop detail of a specific task instance.""" + return _get_hitl_detail( + dag_id=dag_id, + dag_run_id=dag_run_id, + task_id=task_id, + session=session, + map_index=None, + ) + + +@hitl_router.get( + "/{dag_id}/{dag_run_id}/{task_id}/{map_index}", + status_code=status.HTTP_200_OK, + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], +) +def get_mapped_ti_hitl_detail( + dag_id: str, + dag_run_id: str, + task_id: str, + session: SessionDep, + map_index: int, +) -> HITLDetail: + """Get a Human-in-the-loop detail of a specific task instance.""" + return _get_hitl_detail( + dag_id=dag_id, + dag_run_id=dag_run_id, + task_id=task_id, + session=session, + map_index=map_index, + ) + + +@hitl_router.get( + "/", + status_code=status.HTTP_200_OK, + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], +) +def get_hitl_details( + readable_ti_filter: ReadableTIFilterDep, + session: SessionDep, +) -> HITLDetailCollection: + """Get Human-in-the-loop details.""" + query = select(HITLDetailModel).join(TI, HITLDetailModel.ti_id == TI.id) + hitl_detail_select, total_entries = paginated_select( + statement=query, + filters=[readable_ti_filter], + session=session, + ) + hitl_details = session.scalars(hitl_detail_select) + return HITLDetailCollection( + hitl_details=hitl_details, + total_entries=total_entries, + ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/hitl.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/hitl.py new file mode 100644 index 0000000000000..c75ca8c14f2ee --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/hitl.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import Field + +from airflow.api_fastapi.core_api.base import BaseModel +from airflow.models.hitl import HITLDetail + + +class HITLDetailRequest(BaseModel): + """Schema for the request part of a Human-in-the-loop detail for a specific task instance.""" + + ti_id: UUID + options: list[str] + subject: str + body: str | None = None + defaults: list[str] | None = None + multiple: bool = False + params: dict[str, Any] = Field(default_factory=dict) + + +class GetHITLDetailResponsePayload(BaseModel): + """Schema for getting the response part of a Human-in-the-loop detail for a specific task instance.""" + + ti_id: UUID + + +class UpdateHITLDetailPayload(BaseModel): + """Schema for writing the response part of a Human-in-the-loop detail for a specific task instance.""" + + ti_id: UUID + chosen_options: list[str] + params_input: dict[str, Any] = Field(default_factory=dict) + + +class HITLDetailResponse(BaseModel): + """Schema for the response part of a Human-in-the-loop detail for a specific task instance.""" + + response_received: bool + user_id: str | None + response_at: datetime | None + chosen_options: list[str] | None + params_input: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_hitl_detail_orm(cls, hitl_detail: HITLDetail) -> HITLDetailResponse: + return HITLDetailResponse( + response_received=hitl_detail.response_received, + response_at=hitl_detail.response_at, + user_id=hitl_detail.user_id, + chosen_options=hitl_detail.chosen_options, + params_input=hitl_detail.params_input or {}, + ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 164c3f0942d1f..ab163f0bac569 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -26,6 +26,7 @@ connections, dag_runs, health, + hitl, task_instances, task_reschedules, variables, @@ -48,5 +49,6 @@ ) authenticated_router.include_router(variables.router, prefix="/variables", tags=["Variables"]) authenticated_router.include_router(xcoms.router, prefix="/xcoms", tags=["XComs"]) +authenticated_router.include_router(hitl.router, prefix="/hitl-details", tags=["Human in the Loop"]) execution_api_router.include_router(authenticated_router) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py new file mode 100644 index 0000000000000..a82e496a8a7a2 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime, timezone +from uuid import UUID + +import structlog +from fastapi import APIRouter, HTTPException, status +from sqlalchemy import select + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.hitl import ( + HITLDetailRequest, + HITLDetailResponse, +) +from airflow.models.hitl import HITLDetail +from airflow.sdk.execution_time.comms import CreateHITLDetailPayload, UpdateHITLDetail + +router = APIRouter() + +log = structlog.get_logger(__name__) + + +@router.post( + "/{task_instance_id}", + status_code=status.HTTP_201_CREATED, +) +def add_hitl_detail( + task_instance_id: UUID, + payload: CreateHITLDetailPayload, + session: SessionDep, +) -> HITLDetailRequest: + """Get Human-in-the-loop detail for a specific Task Instance.""" + ti_id_str = str(task_instance_id) + hitl_detail_model = session.scalar(select(HITLDetail).where(HITLDetail.ti_id == ti_id_str)) + if hitl_detail_model: + raise HTTPException( + status.HTTP_409_CONFLICT, + f"Human-in-the-loop detail for Task Instance with id {ti_id_str} already exists.", + ) + + hitl_detail = HITLDetail( + ti_id=ti_id_str, + options=payload.options, + subject=payload.subject, + body=payload.body, + defaults=payload.defaults, + multiple=payload.multiple, + params=payload.params, + ) + session.add(hitl_detail) + session.commit() + return HITLDetailRequest.model_validate(hitl_detail) + + +@router.patch("/{task_instance_id}") +def update_hitl_detail( + task_instance_id: UUID, + payload: UpdateHITLDetail, + session: SessionDep, +) -> HITLDetailResponse: + """Update the response part of a Human-in-the-loop detail for a specific Task Instance.""" + ti_id_str = str(task_instance_id) + hitl_detail_model = session.execute(select(HITLDetail).where(HITLDetail.ti_id == ti_id_str)).scalar() + if hitl_detail_model.response_received: + raise HTTPException( + status.HTTP_409_CONFLICT, + f"Human-in-the-loop detail for Task Instance with id {ti_id_str} already exists.", + ) + + hitl_detail_model.user_id = "Fallback to defaults" + hitl_detail_model.response_at = datetime.now(timezone.utc) + hitl_detail_model.chosen_options = payload.chosen_options + hitl_detail_model.params_input = payload.params_input + session.add(hitl_detail_model) + session.commit() + return HITLDetailResponse.from_hitl_detail_orm(hitl_detail_model) + + +@router.get( + "/{task_instance_id}", + status_code=status.HTTP_200_OK, +) +def get_hitl_detail( + task_instance_id: UUID, + session: SessionDep, +) -> HITLDetailResponse: + """Get Human-in-the-loop detail for a specific Task Instance.""" + ti_id_str = str(task_instance_id) + hitl_detail_model = session.execute( + select(HITLDetail).where(HITLDetail.ti_id == ti_id_str), + ).scalar() + return HITLDetailResponse.from_hitl_detail_orm(hitl_detail_model) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index fd02baecb7176..07966655cfe99 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -43,6 +43,7 @@ from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import perform_heartbeat from airflow.models.trigger import Trigger +from airflow.sdk.api.datamodels._generated import HITLDetailResponse from airflow.sdk.execution_time.comms import ( CommsDecoder, ConnectionResult, @@ -52,12 +53,14 @@ GetConnection, GetDagRunState, GetDRCount, + GetHITLDetailResponse, GetTaskStates, GetTICount, GetVariable, GetXCom, TaskStatesResult, TICount, + UpdateHITLDetail, VariableResult, XComResult, _RequestFrame, @@ -209,6 +212,23 @@ class TriggerStateSync(BaseModel): to_cancel: set[int] +class HITLDetailResponseResult(HITLDetailResponse): + """Response to GetHITLDetailResponse request.""" + + type: Literal["HITLDetailResponseResult"] = "HITLDetailResponseResult" + + @classmethod + def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseResult: + """ + Create result class from API Response. + + API Response is autogenerated from the API schema, so we need to convert it to Result + for communication between the Supervisor and the task process since it needs a + discriminator field. + """ + return cls(**response.model_dump(exclude_defaults=True), type="HITLDetailResponseResult") + + ToTriggerRunner = Annotated[ messages.StartTriggerer | messages.TriggerStateSync @@ -219,6 +239,7 @@ class TriggerStateSync(BaseModel): | DRCount | TICount | TaskStatesResult + | HITLDetailResponseResult | ErrorResponse, Field(discriminator="type"), ] @@ -236,7 +257,9 @@ class TriggerStateSync(BaseModel): | GetTICount | GetTaskStates | GetDagRunState - | GetDRCount, + | GetDRCount + | GetHITLDetailResponse + | UpdateHITLDetail, Field(discriminator="type"), ] """ @@ -448,6 +471,16 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = TaskStatesResult.from_api_response(run_id_task_state_map) else: resp = run_id_task_state_map + elif isinstance(msg, UpdateHITLDetail): + api_resp = self.client.hitl.update_response( + ti_id=msg.ti_id, + chosen_options=msg.chosen_options, + params_input=msg.params_input, + ) + resp = HITLDetailResponseResult.from_api_response(response=api_resp) + elif isinstance(msg, GetHITLDetailResponse): + api_resp = self.client.hitl.get_detail_response(ti_id=msg.ti_id) + resp = HITLDetailResponseResult.from_api_response(response=api_resp) else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/airflow-core/src/airflow/migrations/versions/0077_3_1_0_add_human_in_the_loop_response.py b/airflow-core/src/airflow/migrations/versions/0077_3_1_0_add_human_in_the_loop_response.py new file mode 100644 index 0000000000000..61f950f5d120e --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0077_3_1_0_add_human_in_the_loop_response.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add Human In the Loop Detail table. + +Revision ID: 40f7c30a228b +Revises: 5d3072c51bac +Create Date: 2025-07-04 15:05:19.459197 + +""" + +from __future__ import annotations + +import sqlalchemy_jsonfield +from alembic import op +from sqlalchemy import Boolean, Column, ForeignKeyConstraint, String, Text +from sqlalchemy.dialects import postgresql + +from airflow.settings import json +from airflow.utils.sqlalchemy import UtcDateTime + +# revision identifiers, used by Alembic. +revision = "40f7c30a228b" +down_revision = "5d3072c51bac" +branch_labels = None +depends_on = None +airflow_version = "3.1.0" + + +def upgrade(): + """Add Human In the Loop Detail table.""" + op.create_table( + "hitl_detail", + Column( + "ti_id", + String(length=36).with_variant(postgresql.UUID(), "postgresql"), + primary_key=True, + nullable=False, + ), + Column("options", sqlalchemy_jsonfield.JSONField(json=json), nullable=False), + Column("subject", Text, nullable=False), + Column("body", Text, nullable=True), + Column("defaults", sqlalchemy_jsonfield.JSONField(json=json), nullable=True), + Column("multiple", Boolean, unique=False, default=False), + Column("params", sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}), + Column("response_at", UtcDateTime, nullable=True), + Column("user_id", String(128), nullable=True), + Column("chosen_options", sqlalchemy_jsonfield.JSONField(json=json), nullable=True), + Column("params_input", sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}), + ForeignKeyConstraint( + ["ti_id"], + ["task_instance.id"], + name="hitl_detail_ti_fkey", + ondelete="CASCADE", + onupdate="CASCADE", + ), + ) + + +def downgrade(): + """Response Human In the Loop Detail table.""" + op.drop_table("hitl_detail") diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py index 3e472b70b4bca..ac6c2a76e3274 100644 --- a/airflow-core/src/airflow/models/__init__.py +++ b/airflow-core/src/airflow/models/__init__.py @@ -103,6 +103,7 @@ def __getattr__(name): "DbCallbackRequest": "airflow.models.db_callback_request", "Deadline": "airflow.models.deadline", "Log": "airflow.models.log", + "HITLDetail": "airflow.models.hitl", "MappedOperator": "airflow.models.mappedoperator", "Operator": "airflow.models.operator", "Param": "airflow.sdk.definitions.param", diff --git a/airflow-core/src/airflow/models/hitl.py b/airflow-core/src/airflow/models/hitl.py new file mode 100644 index 0000000000000..9d060ba1c19d7 --- /dev/null +++ b/airflow-core/src/airflow/models/hitl.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import sqlalchemy_jsonfield +from sqlalchemy import Boolean, Column, ForeignKeyConstraint, String, Text +from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.hybrid import hybrid_property + +from airflow.models.base import Base +from airflow.settings import json +from airflow.utils.sqlalchemy import UtcDateTime + + +class HITLDetail(Base): + """Human-in-the-loop request and corresponding response.""" + + __tablename__ = "hitl_detail" + ti_id = Column( + String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), + primary_key=True, + nullable=False, + ) + + # User Request Detail + options = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) + subject = Column(Text, nullable=False) + body = Column(Text, nullable=True) + defaults = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + multiple = Column(Boolean, unique=False, default=False) + params = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + + # Response Content Detail + response_at = Column(UtcDateTime, nullable=True) + user_id = Column(String(128), nullable=True) + chosen_options = Column( + sqlalchemy_jsonfield.JSONField(json=json), + nullable=True, + default=None, + ) + params_input = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + + __table_args__ = ( + ForeignKeyConstraint( + (ti_id,), + ["task_instance.id"], + name="hitl_detail_ti_fkey", + ondelete="CASCADE", + onupdate="CASCADE", + ), + ) + + @hybrid_property + def response_received(self) -> bool: + return self.response_at is not None diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts index 143ec4c76550c..91fd42bd07909 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts @@ -1,7 +1,7 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.2 import { UseQueryResult } from "@tanstack/react-query"; -import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; +import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, HumanInTheLoopService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; import { DagRunState, DagWarningType } from "../requests/types.gen"; export type AssetServiceGetAssetsDefaultResponse = Awaited>; export type AssetServiceGetAssetsQueryResult = UseQueryResult; @@ -620,6 +620,27 @@ export const UseDagVersionServiceGetDagVersionsKeyFn = ({ bundleName, bundleVers orderBy?: string; versionNumber?: number; }, queryKey?: Array) => [useDagVersionServiceGetDagVersionsKey, ...(queryKey ?? [{ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }])]; +export type HumanInTheLoopServiceGetHitlDetailDefaultResponse = Awaited>; +export type HumanInTheLoopServiceGetHitlDetailQueryResult = UseQueryResult; +export const useHumanInTheLoopServiceGetHitlDetailKey = "HumanInTheLoopServiceGetHitlDetail"; +export const UseHumanInTheLoopServiceGetHitlDetailKeyFn = ({ dagId, dagRunId, taskId }: { + dagId: string; + dagRunId: string; + taskId: string; +}, queryKey?: Array) => [useHumanInTheLoopServiceGetHitlDetailKey, ...(queryKey ?? [{ dagId, dagRunId, taskId }])]; +export type HumanInTheLoopServiceGetMappedTiHitlDetailDefaultResponse = Awaited>; +export type HumanInTheLoopServiceGetMappedTiHitlDetailQueryResult = UseQueryResult; +export const useHumanInTheLoopServiceGetMappedTiHitlDetailKey = "HumanInTheLoopServiceGetMappedTiHitlDetail"; +export const UseHumanInTheLoopServiceGetMappedTiHitlDetailKeyFn = ({ dagId, dagRunId, mapIndex, taskId }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; +}, queryKey?: Array) => [useHumanInTheLoopServiceGetMappedTiHitlDetailKey, ...(queryKey ?? [{ dagId, dagRunId, mapIndex, taskId }])]; +export type HumanInTheLoopServiceGetHitlDetailsDefaultResponse = Awaited>; +export type HumanInTheLoopServiceGetHitlDetailsQueryResult = UseQueryResult; +export const useHumanInTheLoopServiceGetHitlDetailsKey = "HumanInTheLoopServiceGetHitlDetails"; +export const UseHumanInTheLoopServiceGetHitlDetailsKeyFn = (queryKey?: Array) => [useHumanInTheLoopServiceGetHitlDetailsKey, ...(queryKey ?? [])]; export type MonitorServiceGetHealthDefaultResponse = Awaited>; export type MonitorServiceGetHealthQueryResult = UseQueryResult; export const useMonitorServiceGetHealthKey = "MonitorServiceGetHealth"; @@ -752,6 +773,8 @@ export type PoolServiceBulkPoolsMutationResult = Awaited>; export type VariableServicePatchVariableMutationResult = Awaited>; export type VariableServiceBulkVariablesMutationResult = Awaited>; +export type HumanInTheLoopServiceUpdateHitlDetailMutationResult = Awaited>; +export type HumanInTheLoopServiceUpdateMappedTiHitlDetailMutationResult = Awaited>; export type AssetServiceDeleteAssetQueuedEventsMutationResult = Awaited>; export type AssetServiceDeleteDagAssetQueuedEventsMutationResult = Awaited>; export type AssetServiceDeleteDagAssetQueuedEventMutationResult = Awaited>; diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts index 1c0fc86697c2a..d10b539687bf5 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts @@ -1,7 +1,7 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.2 import { type QueryClient } from "@tanstack/react-query"; -import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; +import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, HumanInTheLoopService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; import { DagRunState, DagWarningType } from "../requests/types.gen"; import * as Common from "./common"; /** @@ -1172,6 +1172,45 @@ export const ensureUseDagVersionServiceGetDagVersionsData = (queryClient: QueryC versionNumber?: number; }) => queryClient.ensureQueryData({ queryKey: Common.UseDagVersionServiceGetDagVersionsKeyFn({ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }), queryFn: () => DagVersionService.getDagVersions({ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }) }); /** +* Get Hitl Detail +* Get a Human-in-the-loop detail of a specific task instance. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @returns HITLDetail Successful Response +* @throws ApiError +*/ +export const ensureUseHumanInTheLoopServiceGetHitlDetailData = (queryClient: QueryClient, { dagId, dagRunId, taskId }: { + dagId: string; + dagRunId: string; + taskId: string; +}) => queryClient.ensureQueryData({ queryKey: Common.UseHumanInTheLoopServiceGetHitlDetailKeyFn({ dagId, dagRunId, taskId }), queryFn: () => HumanInTheLoopService.getHitlDetail({ dagId, dagRunId, taskId }) }); +/** +* Get Mapped Ti Hitl Detail +* Get a Human-in-the-loop detail of a specific task instance. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @param data.mapIndex +* @returns HITLDetail Successful Response +* @throws ApiError +*/ +export const ensureUseHumanInTheLoopServiceGetMappedTiHitlDetailData = (queryClient: QueryClient, { dagId, dagRunId, mapIndex, taskId }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; +}) => queryClient.ensureQueryData({ queryKey: Common.UseHumanInTheLoopServiceGetMappedTiHitlDetailKeyFn({ dagId, dagRunId, mapIndex, taskId }), queryFn: () => HumanInTheLoopService.getMappedTiHitlDetail({ dagId, dagRunId, mapIndex, taskId }) }); +/** +* Get Hitl Details +* Get Human-in-the-loop details. +* @returns HITLDetailCollection Successful Response +* @throws ApiError +*/ +export const ensureUseHumanInTheLoopServiceGetHitlDetailsData = (queryClient: QueryClient) => queryClient.ensureQueryData({ queryKey: Common.UseHumanInTheLoopServiceGetHitlDetailsKeyFn(), queryFn: () => HumanInTheLoopService.getHitlDetails() }); +/** * Get Health * @returns HealthInfoResponse Successful Response * @throws ApiError diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts index d220cf4d19589..2eab4c35b3ef4 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts @@ -1,7 +1,7 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.2 import { type QueryClient } from "@tanstack/react-query"; -import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; +import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, HumanInTheLoopService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; import { DagRunState, DagWarningType } from "../requests/types.gen"; import * as Common from "./common"; /** @@ -1172,6 +1172,45 @@ export const prefetchUseDagVersionServiceGetDagVersions = (queryClient: QueryCli versionNumber?: number; }) => queryClient.prefetchQuery({ queryKey: Common.UseDagVersionServiceGetDagVersionsKeyFn({ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }), queryFn: () => DagVersionService.getDagVersions({ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }) }); /** +* Get Hitl Detail +* Get a Human-in-the-loop detail of a specific task instance. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @returns HITLDetail Successful Response +* @throws ApiError +*/ +export const prefetchUseHumanInTheLoopServiceGetHitlDetail = (queryClient: QueryClient, { dagId, dagRunId, taskId }: { + dagId: string; + dagRunId: string; + taskId: string; +}) => queryClient.prefetchQuery({ queryKey: Common.UseHumanInTheLoopServiceGetHitlDetailKeyFn({ dagId, dagRunId, taskId }), queryFn: () => HumanInTheLoopService.getHitlDetail({ dagId, dagRunId, taskId }) }); +/** +* Get Mapped Ti Hitl Detail +* Get a Human-in-the-loop detail of a specific task instance. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @param data.mapIndex +* @returns HITLDetail Successful Response +* @throws ApiError +*/ +export const prefetchUseHumanInTheLoopServiceGetMappedTiHitlDetail = (queryClient: QueryClient, { dagId, dagRunId, mapIndex, taskId }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; +}) => queryClient.prefetchQuery({ queryKey: Common.UseHumanInTheLoopServiceGetMappedTiHitlDetailKeyFn({ dagId, dagRunId, mapIndex, taskId }), queryFn: () => HumanInTheLoopService.getMappedTiHitlDetail({ dagId, dagRunId, mapIndex, taskId }) }); +/** +* Get Hitl Details +* Get Human-in-the-loop details. +* @returns HITLDetailCollection Successful Response +* @throws ApiError +*/ +export const prefetchUseHumanInTheLoopServiceGetHitlDetails = (queryClient: QueryClient) => queryClient.prefetchQuery({ queryKey: Common.UseHumanInTheLoopServiceGetHitlDetailsKeyFn(), queryFn: () => HumanInTheLoopService.getHitlDetails() }); +/** * Get Health * @returns HealthInfoResponse Successful Response * @throws ApiError diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 8a7ffc0525158..f47a175614451 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -1,8 +1,8 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.2 import { UseMutationOptions, UseQueryOptions, useMutation, useQuery } from "@tanstack/react-query"; -import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; -import { BackfillPostBody, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TriggerDAGRunPostBody, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; +import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, HumanInTheLoopService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; +import { BackfillPostBody, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; import * as Common from "./common"; /** * Get Assets @@ -1172,6 +1172,45 @@ export const useDagVersionServiceGetDagVersions = , "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseDagVersionServiceGetDagVersionsKeyFn({ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }, queryKey), queryFn: () => DagVersionService.getDagVersions({ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }) as TData, ...options }); /** +* Get Hitl Detail +* Get a Human-in-the-loop detail of a specific task instance. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @returns HITLDetail Successful Response +* @throws ApiError +*/ +export const useHumanInTheLoopServiceGetHitlDetail = = unknown[]>({ dagId, dagRunId, taskId }: { + dagId: string; + dagRunId: string; + taskId: string; +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseHumanInTheLoopServiceGetHitlDetailKeyFn({ dagId, dagRunId, taskId }, queryKey), queryFn: () => HumanInTheLoopService.getHitlDetail({ dagId, dagRunId, taskId }) as TData, ...options }); +/** +* Get Mapped Ti Hitl Detail +* Get a Human-in-the-loop detail of a specific task instance. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @param data.mapIndex +* @returns HITLDetail Successful Response +* @throws ApiError +*/ +export const useHumanInTheLoopServiceGetMappedTiHitlDetail = = unknown[]>({ dagId, dagRunId, mapIndex, taskId }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseHumanInTheLoopServiceGetMappedTiHitlDetailKeyFn({ dagId, dagRunId, mapIndex, taskId }, queryKey), queryFn: () => HumanInTheLoopService.getMappedTiHitlDetail({ dagId, dagRunId, mapIndex, taskId }) as TData, ...options }); +/** +* Get Hitl Details +* Get Human-in-the-loop details. +* @returns HITLDetailCollection Successful Response +* @throws ApiError +*/ +export const useHumanInTheLoopServiceGetHitlDetails = = unknown[]>(queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseHumanInTheLoopServiceGetHitlDetailsKeyFn(queryKey), queryFn: () => HumanInTheLoopService.getHitlDetails() as TData, ...options }); +/** * Get Health * @returns HealthInfoResponse Successful Response * @throws ApiError @@ -1989,6 +2028,53 @@ export const useVariableServiceBulkVariables = ({ mutationFn: ({ requestBody }) => VariableService.bulkVariables({ requestBody }) as unknown as Promise, ...options }); /** +* Update Hitl Detail +* Update a Human-in-the-loop detail. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @param data.requestBody +* @returns HITLDetailResponse Successful Response +* @throws ApiError +*/ +export const useHumanInTheLoopServiceUpdateHitlDetail = (options?: Omit, "mutationFn">) => useMutation({ mutationFn: ({ dagId, dagRunId, requestBody, taskId }) => HumanInTheLoopService.updateHitlDetail({ dagId, dagRunId, requestBody, taskId }) as unknown as Promise, ...options }); +/** +* Update Mapped Ti Hitl Detail +* Update a Human-in-the-loop detail. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @param data.mapIndex +* @param data.requestBody +* @returns HITLDetailResponse Successful Response +* @throws ApiError +*/ +export const useHumanInTheLoopServiceUpdateMappedTiHitlDetail = (options?: Omit, "mutationFn">) => useMutation({ mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId }) => HumanInTheLoopService.updateMappedTiHitlDetail({ dagId, dagRunId, mapIndex, requestBody, taskId }) as unknown as Promise, ...options }); +/** * Delete Asset Queued Events * Delete queued asset events for an asset. * @param data The data for the request. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts index 57f12caea7517..2f9e37e78d6c2 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts @@ -1,7 +1,7 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.2 import { UseQueryOptions, useSuspenseQuery } from "@tanstack/react-query"; -import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; +import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, HumanInTheLoopService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen"; import { DagRunState, DagWarningType } from "../requests/types.gen"; import * as Common from "./common"; /** @@ -1172,6 +1172,45 @@ export const useDagVersionServiceGetDagVersionsSuspense = , "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseDagVersionServiceGetDagVersionsKeyFn({ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }, queryKey), queryFn: () => DagVersionService.getDagVersions({ bundleName, bundleVersion, dagId, limit, offset, orderBy, versionNumber }) as TData, ...options }); /** +* Get Hitl Detail +* Get a Human-in-the-loop detail of a specific task instance. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @returns HITLDetail Successful Response +* @throws ApiError +*/ +export const useHumanInTheLoopServiceGetHitlDetailSuspense = = unknown[]>({ dagId, dagRunId, taskId }: { + dagId: string; + dagRunId: string; + taskId: string; +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseHumanInTheLoopServiceGetHitlDetailKeyFn({ dagId, dagRunId, taskId }, queryKey), queryFn: () => HumanInTheLoopService.getHitlDetail({ dagId, dagRunId, taskId }) as TData, ...options }); +/** +* Get Mapped Ti Hitl Detail +* Get a Human-in-the-loop detail of a specific task instance. +* @param data The data for the request. +* @param data.dagId +* @param data.dagRunId +* @param data.taskId +* @param data.mapIndex +* @returns HITLDetail Successful Response +* @throws ApiError +*/ +export const useHumanInTheLoopServiceGetMappedTiHitlDetailSuspense = = unknown[]>({ dagId, dagRunId, mapIndex, taskId }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseHumanInTheLoopServiceGetMappedTiHitlDetailKeyFn({ dagId, dagRunId, mapIndex, taskId }, queryKey), queryFn: () => HumanInTheLoopService.getMappedTiHitlDetail({ dagId, dagRunId, mapIndex, taskId }) as TData, ...options }); +/** +* Get Hitl Details +* Get Human-in-the-loop details. +* @returns HITLDetailCollection Successful Response +* @throws ApiError +*/ +export const useHumanInTheLoopServiceGetHitlDetailsSuspense = = unknown[]>(queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseHumanInTheLoopServiceGetHitlDetailsKeyFn(queryKey), queryFn: () => HumanInTheLoopService.getHitlDetails() as TData, ...options }); +/** * Get Health * @returns HealthInfoResponse Successful Response * @throws ApiError diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 5a46ac9e847bd..2e31d61df5ae1 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -3408,6 +3408,162 @@ export const $FastAPIRootMiddlewareResponse = { description: 'Serializer for Plugin FastAPI root middleware responses.' } as const; +export const $HITLDetail = { + properties: { + ti_id: { + type: 'string', + title: 'Ti Id' + }, + options: { + items: { + type: 'string' + }, + type: 'array', + title: 'Options' + }, + subject: { + type: 'string', + title: 'Subject' + }, + body: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Body' + }, + defaults: { + anyOf: [ + { + items: { + type: 'string' + }, + type: 'array' + }, + { + type: 'null' + } + ], + title: 'Defaults' + }, + multiple: { + type: 'boolean', + title: 'Multiple', + default: false + }, + params: { + additionalProperties: true, + type: 'object', + title: 'Params' + }, + user_id: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'User Id' + }, + response_at: { + anyOf: [ + { + type: 'string', + format: 'date-time' + }, + { + type: 'null' + } + ], + title: 'Response At' + }, + chosen_options: { + anyOf: [ + { + items: { + type: 'string' + }, + type: 'array' + }, + { + type: 'null' + } + ], + title: 'Chosen Options' + }, + params_input: { + additionalProperties: true, + type: 'object', + title: 'Params Input' + }, + response_received: { + type: 'boolean', + title: 'Response Received', + default: false + } + }, + type: 'object', + required: ['ti_id', 'options', 'subject'], + title: 'HITLDetail', + description: 'Schema for Human-in-the-loop detail.' +} as const; + +export const $HITLDetailCollection = { + properties: { + hitl_details: { + items: { + '$ref': '#/components/schemas/HITLDetail' + }, + type: 'array', + title: 'Hitl Details' + }, + total_entries: { + type: 'integer', + title: 'Total Entries' + } + }, + type: 'object', + required: ['hitl_details', 'total_entries'], + title: 'HITLDetailCollection', + description: 'Schema for a collection of Human-in-the-loop details.' +} as const; + +export const $HITLDetailResponse = { + properties: { + user_id: { + type: 'string', + title: 'User Id' + }, + response_at: { + type: 'string', + format: 'date-time', + title: 'Response At' + }, + chosen_options: { + items: { + type: 'string' + }, + type: 'array', + title: 'Chosen Options' + }, + params_input: { + additionalProperties: true, + type: 'object', + title: 'Params Input' + } + }, + type: 'object', + required: ['user_id', 'response_at', 'chosen_options'], + title: 'HITLDetailResponse', + description: 'Response of updating a Human-in-the-loop detail.' +} as const; + export const $HTTPExceptionResponse = { properties: { detail: { @@ -5696,6 +5852,27 @@ export const $TriggererInfoResponse = { description: 'Triggerer info serializer for responses.' } as const; +export const $UpdateHITLDetailPayload = { + properties: { + chosen_options: { + items: { + type: 'string' + }, + type: 'array', + title: 'Chosen Options' + }, + params_input: { + additionalProperties: true, + type: 'object', + title: 'Params Input' + } + }, + type: 'object', + required: ['chosen_options'], + title: 'UpdateHITLDetailPayload', + description: 'Schema for updating the content of a Human-in-the-loop detail.' +} as const; + export const $ValidationError = { properties: { loc: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index 321a708f6b37f..b935e2042366d 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetDagReportsData, GetDagReportsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutData, LogoutResponse, GetAuthMenusResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesData, GetGridTiSummariesResponse, GetLatestRunData, GetLatestRunResponse, GetCalendarData, GetCalendarResponse } from './types.gen'; +import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetDagReportsData, GetDagReportsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, UpdateMappedTiHitlDetailData, UpdateMappedTiHitlDetailResponse, GetMappedTiHitlDetailData, GetMappedTiHitlDetailResponse, GetHitlDetailsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutData, LogoutResponse, GetAuthMenusResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesData, GetGridTiSummariesResponse, GetLatestRunData, GetLatestRunResponse, GetCalendarData, GetCalendarResponse } from './types.gen'; export class AssetService { /** @@ -3360,6 +3360,150 @@ export class DagVersionService { } +export class HumanInTheLoopService { + /** + * Update Hitl Detail + * Update a Human-in-the-loop detail. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.requestBody + * @returns HITLDetailResponse Successful Response + * @throws ApiError + */ + public static updateHitlDetail(data: UpdateHitlDetailData): CancelablePromise { + return __request(OpenAPI, { + method: 'PATCH', + url: '/api/v2/hitl-details/{dag_id}/{dag_run_id}/{task_id}', + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId + }, + body: data.requestBody, + mediaType: 'application/json', + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 409: 'Conflict', + 422: 'Validation Error' + } + }); + } + + /** + * Get Hitl Detail + * Get a Human-in-the-loop detail of a specific task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @returns HITLDetail Successful Response + * @throws ApiError + */ + public static getHitlDetail(data: GetHitlDetailData): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v2/hitl-details/{dag_id}/{dag_run_id}/{task_id}', + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId + }, + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 422: 'Validation Error' + } + }); + } + + /** + * Update Mapped Ti Hitl Detail + * Update a Human-in-the-loop detail. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @param data.requestBody + * @returns HITLDetailResponse Successful Response + * @throws ApiError + */ + public static updateMappedTiHitlDetail(data: UpdateMappedTiHitlDetailData): CancelablePromise { + return __request(OpenAPI, { + method: 'PATCH', + url: '/api/v2/hitl-details/{dag_id}/{dag_run_id}/{task_id}/{map_index}', + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + map_index: data.mapIndex + }, + body: data.requestBody, + mediaType: 'application/json', + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 409: 'Conflict', + 422: 'Validation Error' + } + }); + } + + /** + * Get Mapped Ti Hitl Detail + * Get a Human-in-the-loop detail of a specific task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @returns HITLDetail Successful Response + * @throws ApiError + */ + public static getMappedTiHitlDetail(data: GetMappedTiHitlDetailData): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v2/hitl-details/{dag_id}/{dag_run_id}/{task_id}/{map_index}', + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + map_index: data.mapIndex + }, + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 422: 'Validation Error' + } + }); + } + + /** + * Get Hitl Details + * Get Human-in-the-loop details. + * @returns HITLDetailCollection Successful Response + * @throws ApiError + */ + public static getHitlDetails(): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v2/hitl-details/', + errors: { + 401: 'Unauthorized', + 403: 'Forbidden' + } + }); + } + +} + export class MonitorService { /** * Get Health diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index ed4e32db57f84..591ce7884373c 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -917,6 +917,48 @@ export type FastAPIRootMiddlewareResponse = { [key: string]: unknown | string; }; +/** + * Schema for Human-in-the-loop detail. + */ +export type HITLDetail = { + ti_id: string; + options: Array<(string)>; + subject: string; + body?: string | null; + defaults?: Array<(string)> | null; + multiple?: boolean; + params?: { + [key: string]: unknown; + }; + user_id?: string | null; + response_at?: string | null; + chosen_options?: Array<(string)> | null; + params_input?: { + [key: string]: unknown; + }; + response_received?: boolean; +}; + +/** + * Schema for a collection of Human-in-the-loop details. + */ +export type HITLDetailCollection = { + hitl_details: Array; + total_entries: number; +}; + +/** + * Response of updating a Human-in-the-loop detail. + */ +export type HITLDetailResponse = { + user_id: string; + response_at: string; + chosen_options: Array<(string)>; + params_input?: { + [key: string]: unknown; + }; +}; + /** * HTTPException Model used for error response. */ @@ -1429,6 +1471,16 @@ export type TriggererInfoResponse = { latest_triggerer_heartbeat: string | null; }; +/** + * Schema for updating the content of a Human-in-the-loop detail. + */ +export type UpdateHITLDetailPayload = { + chosen_options: Array<(string)>; + params_input?: { + [key: string]: unknown; + }; +}; + export type ValidationError = { loc: Array<(string | number)>; msg: string; @@ -2847,6 +2899,44 @@ export type GetDagVersionsData = { export type GetDagVersionsResponse = DAGVersionCollectionResponse; +export type UpdateHitlDetailData = { + dagId: string; + dagRunId: string; + requestBody: UpdateHITLDetailPayload; + taskId: string; +}; + +export type UpdateHitlDetailResponse = HITLDetailResponse; + +export type GetHitlDetailData = { + dagId: string; + dagRunId: string; + taskId: string; +}; + +export type GetHitlDetailResponse = HITLDetail; + +export type UpdateMappedTiHitlDetailData = { + dagId: string; + dagRunId: string; + mapIndex: number; + requestBody: UpdateHITLDetailPayload; + taskId: string; +}; + +export type UpdateMappedTiHitlDetailResponse = HITLDetailResponse; + +export type GetMappedTiHitlDetailData = { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; +}; + +export type GetMappedTiHitlDetailResponse = HITLDetail; + +export type GetHitlDetailsResponse = HITLDetailCollection; + export type GetHealthResponse = HealthInfoResponse; export type GetVersionResponse = VersionInfo; @@ -5793,6 +5883,136 @@ export type $OpenApiTs = { }; }; }; + '/api/v2/hitl-details/{dag_id}/{dag_run_id}/{task_id}': { + patch: { + req: UpdateHitlDetailData; + res: { + /** + * Successful Response + */ + 200: HITLDetailResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + get: { + req: GetHitlDetailData; + res: { + /** + * Successful Response + */ + 200: HITLDetail; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; + '/api/v2/hitl-details/{dag_id}/{dag_run_id}/{task_id}/{map_index}': { + patch: { + req: UpdateMappedTiHitlDetailData; + res: { + /** + * Successful Response + */ + 200: HITLDetailResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + get: { + req: GetMappedTiHitlDetailData; + res: { + /** + * Successful Response + */ + 200: HITLDetail; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; + '/api/v2/hitl-details/': { + get: { + res: { + /** + * Successful Response + */ + 200: HITLDetailCollection; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + }; + }; + }; '/api/v2/monitor/health': { get: { res: { diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 8666deac458d4..c99972a8b1e96 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -93,7 +93,7 @@ class MappedClassProtocol(Protocol): "2.10.3": "5f2621c13b39", "3.0.0": "29ce7909c52b", "3.0.3": "fe199e1abd77", - "3.1.0": "5d3072c51bac", + "3.1.0": "40f7c30a228b", } diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py new file mode 100644 index 0000000000000..3fa34a5779a37 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py @@ -0,0 +1,391 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from sqlalchemy.orm import Session + +from tests_common.test_utils.db import AIRFLOW_V_3_1_PLUS + +if not AIRFLOW_V_3_1_PLUS: + pytest.skip("Human in the loop public API compatible with Airflow >= 3.0.1", allow_module_level=True) + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import time_machine + +from airflow.models.hitl import HITLDetail + +if TYPE_CHECKING: + from fastapi.testclient import TestClient + + from airflow.models.taskinstance import TaskInstance + + from tests_common.pytest_plugin import CreateTaskInstance + + +pytestmark = pytest.mark.db_test + +DAG_ID = "test_hitl_dag" + + +@pytest.fixture +def sample_ti(create_task_instance: CreateTaskInstance) -> TaskInstance: + return create_task_instance() + + +@pytest.fixture +def sample_ti_url_identifier(sample_ti: TaskInstance) -> str: + if TYPE_CHECKING: + assert sample_ti.task + + return f"{sample_ti.dag_id}/{sample_ti.run_id}/{sample_ti.task.task_id}" + + +@pytest.fixture +def sample_hitl_detail(sample_ti: TaskInstance, session: Session) -> HITLDetail: + hitl_detail_model = HITLDetail( + ti_id=sample_ti.id, + options=["Approve", "Reject"], + subject="This is subject", + body="this is body", + defaults=["Approve"], + multiple=False, + params={"input_1": 1}, + ) + session.add(hitl_detail_model) + session.commit() + + return hitl_detail_model + + +@pytest.fixture +def expected_ti_not_found_error_msg(sample_ti: TaskInstance) -> str: + if TYPE_CHECKING: + assert sample_ti.task + + return ( + f"The Task Instance with dag_id: `{sample_ti.dag_id}`," + f" run_id: `{sample_ti.run_id}`, task_id: `{sample_ti.task.task_id}`" + " and map_index: `None` was not found" + ) + + +@pytest.fixture +def expected_mapped_ti_not_found_error_msg(sample_ti: TaskInstance) -> str: + if TYPE_CHECKING: + assert sample_ti.task + + return ( + f"The Task Instance with dag_id: `{sample_ti.dag_id}`," + f" run_id: `{sample_ti.run_id}`, task_id: `{sample_ti.task.task_id}`" + " and map_index: `-1` was not found" + ) + + +@pytest.fixture +def expected_sample_hitl_detail_dict(sample_ti: TaskInstance) -> dict[str, Any]: + return { + "body": "this is body", + "defaults": ["Approve"], + "multiple": False, + "options": ["Approve", "Reject"], + "params": {"input_1": 1}, + "params_input": {}, + "response_at": None, + "chosen_options": None, + "response_received": False, + "subject": "This is subject", + "ti_id": sample_ti.id, + "user_id": None, + } + + +class TestUpdateHITLDetailEndpoint: + @time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False) + @pytest.mark.usefixtures("sample_hitl_detail") + def test_should_respond_200_with_existing_response( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = test_client.patch( + f"/hitl-details/{sample_ti_url_identifier}", + json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}}, + ) + + assert response.status_code == 200 + assert response.json() == { + "params_input": {"input_1": 2}, + "chosen_options": ["Approve"], + "user_id": "test", + "response_at": "2025-07-03T00:00:00Z", + } + + def test_should_respond_404( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + expected_ti_not_found_error_msg: str, + ) -> None: + response = test_client.get(f"/hitl-details/{sample_ti_url_identifier}") + assert response.status_code == 404 + assert response.json() == {"detail": expected_ti_not_found_error_msg} + + @time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False) + @pytest.mark.usefixtures("sample_hitl_detail") + def test_should_respond_409( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + sample_ti: TaskInstance, + ) -> None: + response = test_client.patch( + f"/hitl-details/{sample_ti_url_identifier}", + json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}}, + ) + + expected_response = { + "params_input": {"input_1": 2}, + "chosen_options": ["Approve"], + "user_id": "test", + "response_at": "2025-07-03T00:00:00Z", + } + assert response.status_code == 200 + assert response.json() == expected_response + + response = test_client.patch( + f"/hitl-details/{sample_ti_url_identifier}", + json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}}, + ) + assert response.status_code == 409 + assert response.json() == { + "detail": ( + "Human-in-the-loop detail has already been updated for Task Instance " + f"with id {sample_ti.id} " + "and is not allowed to write again." + ) + } + + def test_should_respond_401( + self, + unauthenticated_test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = unauthenticated_test_client.get(f"/hitl-details/{sample_ti_url_identifier}") + assert response.status_code == 401 + + def test_should_respond_403( + self, + unauthorized_test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = unauthorized_test_client.get(f"/hitl-details/{sample_ti_url_identifier}") + assert response.status_code == 403 + + +class TestUpdateMappedTIHITLDetail: + @time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False) + @pytest.mark.usefixtures("sample_hitl_detail") + def test_should_respond_200_with_existing_response( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = test_client.patch( + f"/hitl-details/{sample_ti_url_identifier}/-1", + json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}}, + ) + + assert response.status_code == 200 + assert response.json() == { + "params_input": {"input_1": 2}, + "chosen_options": ["Approve"], + "user_id": "test", + "response_at": "2025-07-03T00:00:00Z", + } + + def test_should_respond_404( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + expected_mapped_ti_not_found_error_msg: str, + ) -> None: + response = test_client.get(f"/hitl-details/{sample_ti_url_identifier}/-1") + assert response.status_code == 404 + assert response.json() == {"detail": expected_mapped_ti_not_found_error_msg} + + @time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False) + @pytest.mark.usefixtures("sample_hitl_detail") + def test_should_respond_409( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + sample_ti: TaskInstance, + ) -> None: + response = test_client.patch( + f"/hitl-details/{sample_ti_url_identifier}/-1", + json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}}, + ) + + expected_response = { + "params_input": {"input_1": 2}, + "chosen_options": ["Approve"], + "user_id": "test", + "response_at": "2025-07-03T00:00:00Z", + } + assert response.status_code == 200 + assert response.json() == expected_response + + response = test_client.patch( + f"/hitl-details/{sample_ti_url_identifier}/-1", + json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}}, + ) + assert response.status_code == 409 + assert response.json() == { + "detail": ( + "Human-in-the-loop detail has already been updated for Task Instance " + f"with id {sample_ti.id} " + "and is not allowed to write again." + ) + } + + def test_should_respond_401( + self, + unauthenticated_test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = unauthenticated_test_client.get(f"/hitl-details/{sample_ti_url_identifier}/-1") + assert response.status_code == 401 + + def test_should_respond_403( + self, + unauthorized_test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = unauthorized_test_client.get(f"/hitl-details/{sample_ti_url_identifier}/-1") + assert response.status_code == 403 + + +class TestGetHITLDetailEndpoint: + @pytest.mark.usefixtures("sample_hitl_detail") + def test_should_respond_200_with_existing_response( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + expected_sample_hitl_detail_dict: dict[str, Any], + ) -> None: + response = test_client.get(f"/hitl-details/{sample_ti_url_identifier}") + assert response.status_code == 200 + assert response.json() == expected_sample_hitl_detail_dict + + def test_should_respond_404( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + expected_ti_not_found_error_msg: str, + ) -> None: + response = test_client.get(f"/hitl-details/{sample_ti_url_identifier}") + assert response.status_code == 404 + assert response.json() == {"detail": expected_ti_not_found_error_msg} + + def test_should_respond_401( + self, + unauthenticated_test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = unauthenticated_test_client.get(f"/hitl-details/{sample_ti_url_identifier}") + assert response.status_code == 401 + + def test_should_respond_403( + self, + unauthorized_test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = unauthorized_test_client.get(f"/hitl-details/{sample_ti_url_identifier}") + assert response.status_code == 403 + + +class TestGetMappedTIHITLDetail: + @pytest.mark.usefixtures("sample_hitl_detail") + def test_should_respond_200_with_existing_response( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + expected_sample_hitl_detail_dict: dict[str, Any], + ) -> None: + response = test_client.get(f"/hitl-details/{sample_ti_url_identifier}/-1") + assert response.status_code == 200 + assert response.json() == expected_sample_hitl_detail_dict + + def test_should_respond_404( + self, + test_client: TestClient, + sample_ti_url_identifier: str, + expected_mapped_ti_not_found_error_msg: str, + ) -> None: + response = test_client.get(f"/hitl-details/{sample_ti_url_identifier}/-1") + assert response.status_code == 404 + assert response.json() == {"detail": expected_mapped_ti_not_found_error_msg} + + def test_should_respond_401( + self, + unauthenticated_test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = unauthenticated_test_client.get(f"/hitl-details/{sample_ti_url_identifier}/-1") + assert response.status_code == 401 + + def test_should_respond_403( + self, + unauthorized_test_client: TestClient, + sample_ti_url_identifier: str, + ) -> None: + response = unauthorized_test_client.get(f"/hitl-details/{sample_ti_url_identifier}/-1") + assert response.status_code == 403 + + +class TestGetHITLDetailsEndpoint: + @pytest.mark.usefixtures("sample_hitl_detail") + def test_should_respond_200_with_existing_response( + self, + test_client: TestClient, + expected_sample_hitl_detail_dict: dict[str, Any], + ) -> None: + response = test_client.get("/hitl-details/") + assert response.status_code == 200 + assert response.json() == { + "hitl_details": [expected_sample_hitl_detail_dict], + "total_entries": 1, + } + + def test_should_respond_200_without_response(self, test_client: TestClient) -> None: + response = test_client.get("/hitl-details/") + assert response.status_code == 200 + assert response.json() == { + "hitl_details": [], + "total_entries": 0, + } + + def test_should_respond_401(self, unauthenticated_test_client: TestClient) -> None: + response = unauthenticated_test_client.get("/hitl-details/") + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client: TestClient) -> None: + response = unauthorized_test_client.get("/hitl-details/") + assert response.status_code == 403 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_hitl.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_hitl.py new file mode 100644 index 0000000000000..3324730477082 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_hitl.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +import pytest +import time_machine +from uuid6 import uuid7 + +from tests_common.test_utils.db import AIRFLOW_V_3_1_PLUS + +if not AIRFLOW_V_3_1_PLUS: + pytest.skip("Human in the loop public API compatible with Airflow >= 3.0.1", allow_module_level=True) + +from typing import TYPE_CHECKING, Any + +from airflow.models.hitl import HITLDetail + +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + +pytestmark = pytest.mark.db_test +TI_ID = uuid7() + + +@pytest.fixture +def sample_ti(create_task_instance) -> TaskInstance: + return create_task_instance() + + +@pytest.fixture +def sample_hitl_detail(session, sample_ti) -> HITLDetail: + hitl_detail_model = HITLDetail( + ti_id=sample_ti.id, + options=["Approve", "Reject"], + subject="This is subject", + body="this is body", + defaults=["Approve"], + multiple=False, + params={"input_1": 1}, + ) + session.add(hitl_detail_model) + session.commit() + + return hitl_detail_model + + +@pytest.fixture +def expected_sample_hitl_detail_dict(sample_ti) -> dict[str, Any]: + return { + "body": "this is body", + "defaults": ["Approve"], + "multiple": False, + "options": ["Approve", "Reject"], + "params": {"input_1": 1}, + "params_input": {}, + "response_at": None, + "chosen_options": None, + "response_received": False, + "subject": "This is subject", + "ti_id": sample_ti.id, + "user_id": None, + } + + +def test_add_hitl_detail(client, create_task_instance, session) -> None: + ti = create_task_instance() + session.commit() + + response = client.post( + f"/execution/hitl-details/{ti.id}", + json={ + "ti_id": ti.id, + "options": ["Approve", "Reject"], + "subject": "This is subject", + "body": "this is body", + "defaults": ["Approve"], + "multiple": False, + "params": {"input_1": 1}, + }, + ) + assert response.status_code == 201 + assert response.json() == { + "ti_id": ti.id, + "options": ["Approve", "Reject"], + "subject": "This is subject", + "body": "this is body", + "defaults": ["Approve"], + "multiple": False, + "params": {"input_1": 1}, + } + + +@time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False) +@pytest.mark.usefixtures("sample_hitl_detail") +def test_update_hitl_detail(client, sample_ti) -> None: + response = client.patch( + f"/execution/hitl-details/{sample_ti.id}", + json={ + "ti_id": sample_ti.id, + "chosen_options": ["Reject"], + "params_input": {"input_1": 2}, + }, + ) + assert response.status_code == 200 + assert response.json() == { + "params_input": {"input_1": 2}, + "response_at": "2025-07-03T00:00:00Z", + "chosen_options": ["Reject"], + "response_received": True, + "user_id": "Fallback to defaults", + } + + +@pytest.mark.usefixtures("sample_hitl_detail") +def test_get_hitl_detail(client, sample_ti) -> None: + response = client.get(f"/execution/hitl-details/{sample_ti.id}") + assert response.status_code == 200 + assert response.json() == { + "params_input": {}, + "response_at": None, + "chosen_options": None, + "response_received": False, + "user_id": None, + } diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index c64907a938a47..0824759850a68 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -573,6 +573,45 @@ class FastAPIRootMiddlewareResponse(BaseModel): name: Annotated[str, Field(title="Name")] +class HITLDetail(BaseModel): + """ + Schema for Human-in-the-loop detail. + """ + + ti_id: Annotated[str, Field(title="Ti Id")] + options: Annotated[list[str], Field(title="Options")] + subject: Annotated[str, Field(title="Subject")] + body: Annotated[str | None, Field(title="Body")] = None + defaults: Annotated[list[str] | None, Field(title="Defaults")] = None + multiple: Annotated[bool | None, Field(title="Multiple")] = False + params: Annotated[dict[str, Any] | None, Field(title="Params")] = None + user_id: Annotated[str | None, Field(title="User Id")] = None + response_at: Annotated[datetime | None, Field(title="Response At")] = None + chosen_options: Annotated[list[str] | None, Field(title="Chosen Options")] = None + params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None + response_received: Annotated[bool | None, Field(title="Response Received")] = False + + +class HITLDetailCollection(BaseModel): + """ + Schema for a collection of Human-in-the-loop details. + """ + + hitl_details: Annotated[list[HITLDetail], Field(title="Hitl Details")] + total_entries: Annotated[int, Field(title="Total Entries")] + + +class HITLDetailResponse(BaseModel): + """ + Response of updating a Human-in-the-loop detail. + """ + + user_id: Annotated[str, Field(title="User Id")] + response_at: Annotated[datetime, Field(title="Response At")] + chosen_options: Annotated[list[str], Field(title="Chosen Options")] + params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None + + class HTTPExceptionResponse(BaseModel): """ HTTPException Model used for error response. @@ -899,6 +938,15 @@ class TriggererInfoResponse(BaseModel): latest_triggerer_heartbeat: Annotated[str | None, Field(title="Latest Triggerer Heartbeat")] = None +class UpdateHITLDetailPayload(BaseModel): + """ + Schema for updating the content of a Human-in-the-loop detail. + """ + + chosen_options: Annotated[list[str], Field(title="Chosen Options")] + params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None + + class ValidationError(BaseModel): loc: Annotated[list[str | int], Field(title="Location")] msg: Annotated[str, Field(title="Message")] diff --git a/providers/standard/provider.yaml b/providers/standard/provider.yaml index f932014c94073..6ccab9577bc7a 100644 --- a/providers/standard/provider.yaml +++ b/providers/standard/provider.yaml @@ -69,6 +69,7 @@ operators: - airflow.providers.standard.operators.latest_only - airflow.providers.standard.operators.smooth - airflow.providers.standard.operators.branch + - airflow.providers.standard.operators.hitl sensors: - integration-name: Standard python-modules: @@ -93,6 +94,7 @@ triggers: - airflow.providers.standard.triggers.external_task - airflow.providers.standard.triggers.file - airflow.providers.standard.triggers.temporal + - airflow.providers.standard.triggers.hitl extra-links: - airflow.providers.standard.operators.trigger_dagrun.TriggerDagRunLink diff --git a/providers/standard/src/airflow/providers/standard/exceptions.py b/providers/standard/src/airflow/providers/standard/exceptions.py index 66acd54aa450f..6975e0afadfb3 100644 --- a/providers/standard/src/airflow/providers/standard/exceptions.py +++ b/providers/standard/src/airflow/providers/standard/exceptions.py @@ -55,3 +55,11 @@ class ExternalDagFailedError(AirflowExternalTaskSensorException): class DuplicateStateError(AirflowExternalTaskSensorException): """Raised when duplicate states are provided across allowed, skipped and failed states.""" + + +class HITLTriggerEventError(AirflowException): + """Raised when TriggerEvent contains error.""" + + +class HITLTimeoutError(HITLTriggerEventError): + """Raised when HILTOperator timeouts.""" diff --git a/providers/standard/src/airflow/providers/standard/get_provider_info.py b/providers/standard/src/airflow/providers/standard/get_provider_info.py index bb40bfaa7b21c..bd7118c78aadf 100644 --- a/providers/standard/src/airflow/providers/standard/get_provider_info.py +++ b/providers/standard/src/airflow/providers/standard/get_provider_info.py @@ -58,6 +58,7 @@ def get_provider_info(): "airflow.providers.standard.operators.latest_only", "airflow.providers.standard.operators.smooth", "airflow.providers.standard.operators.branch", + "airflow.providers.standard.operators.hitl", ], } ], @@ -93,6 +94,7 @@ def get_provider_info(): "airflow.providers.standard.triggers.external_task", "airflow.providers.standard.triggers.file", "airflow.providers.standard.triggers.temporal", + "airflow.providers.standard.triggers.hitl", ], } ], diff --git a/providers/standard/src/airflow/providers/standard/models/__init__.py b/providers/standard/src/airflow/providers/standard/models/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/standard/src/airflow/providers/standard/models/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/standard/src/airflow/providers/standard/operators/hitl.py b/providers/standard/src/airflow/providers/standard/operators/hitl.py new file mode 100644 index 0000000000000..6a1f88ddb4435 --- /dev/null +++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py @@ -0,0 +1,232 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS + +if not AIRFLOW_V_3_1_PLUS: + raise AirflowOptionalProviderFeatureException("Human in the loop functionality needs Airflow 3.1+.") + + +from collections.abc import Collection, Mapping +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from airflow.models.baseoperator import BaseOperator +from airflow.providers.standard.exceptions import HITLTimeoutError, HITLTriggerEventError +from airflow.providers.standard.triggers.hitl import HITLTrigger, HITLTriggerEventSuccessPayload +from airflow.providers.standard.utils.skipmixin import SkipMixin +from airflow.sdk.definitions.param import ParamsDict +from airflow.sdk.execution_time.hitl import add_hitl_detail + +if TYPE_CHECKING: + from airflow.sdk.definitions.context import Context + + +class HITLOperator(BaseOperator): + """ + Base class for all Human-in-the-loop Operators to inherit from. + + :param subject: Headline/subject presented to the user for the interaction task. + :param options: List of options that the an user can select from to complete the task. + :param body: Descriptive text (with Markdown support) that gives the details that are needed to decide. + :param defaults: The default options and the options that are taken if timeout is passed. + :param multiple: Whether the user can select one or multiple options. + :param params: dictionary of parameter definitions that are in the format of Dag params such that + a Form Field can be rendered. Entered data is validated (schema, required fields) like for a Dag run + and added to XCom of the task result. + """ + + template_fields: Collection[str] = ("subject", "body") + + def __init__( + self, + *, + subject: str, + options: list[str], + body: str | None = None, + defaults: str | list[str] | None = None, + multiple: bool = False, + params: ParamsDict | dict[str, Any] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.subject = subject + self.body = body + + self.options = options + # allow defaults to store more than one options when multiple=True + self.defaults = [defaults] if isinstance(defaults, str) else defaults + self.multiple = multiple + + self.params: ParamsDict = params if isinstance(params, ParamsDict) else ParamsDict(params or {}) + + self.validate_defaults() + + def validate_defaults(self) -> None: + """ + Validate whether the given defaults pass the following criteria. + + 1. Default options should be the subset of options. + 2. When multiple is False, there should only be one option. + """ + if self.defaults is not None: + if not set(self.defaults).issubset(self.options): + raise ValueError(f'defaults "{self.defaults}" should be a subset of options "{self.options}"') + + if self.multiple is False and len(self.defaults) > 1: + raise ValueError('More than one defaults given when "multiple" is set to False.') + + def execute(self, context: Context): + """Add a Human-in-the-loop Response and then defer to HITLTrigger and wait for user input.""" + ti_id = context["task_instance"].id + # Write Human-in-the-loop input request to DB + add_hitl_detail( + ti_id=ti_id, + options=self.options, + subject=self.subject, + body=self.body, + defaults=self.defaults, + multiple=self.multiple, + params=self.serialzed_params, + ) + if self.execution_timeout: + timeout_datetime = datetime.now(timezone.utc) + self.execution_timeout + else: + timeout_datetime = None + self.log.info("Waiting for response") + # Defer the Human-in-the-loop response checking process to HITLTrigger + self.defer( + trigger=HITLTrigger( + ti_id=ti_id, + options=self.options, + defaults=self.defaults, + params=self.serialzed_params, + multiple=self.multiple, + timeout_datetime=timeout_datetime, + ), + method_name="execute_complete", + ) + + @property + def serialzed_params(self) -> dict[str, Any]: + return self.params.dump() if isinstance(self.params, ParamsDict) else self.params + + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + if "error" in event: + self.process_trigger_event_error(event) + + chosen_options = event["chosen_options"] + params_input = event["params_input"] or {} + self.validate_chosen_options(chosen_options) + self.validate_params_input(params_input) + return HITLTriggerEventSuccessPayload( + chosen_options=chosen_options, + params_input=params_input, + ) + + def process_trigger_event_error(self, event: dict[str, Any]) -> None: + if "error_type" == "timeout": + raise HITLTimeoutError(event) + + raise HITLTriggerEventError(event) + + def validate_chosen_options(self, chosen_options: list[str]) -> None: + """Check whether user provide valid response.""" + if diff := set(chosen_options) - set(self.options): + raise ValueError(f"Responses {diff} not in {self.options}") + + def validate_params_input(self, params_input: Mapping) -> None: + """Check whether user provide valid params input.""" + if ( + self.serialzed_params is not None + and params_input is not None + and set(self.serialzed_params.keys()) ^ set(params_input) + ): + raise ValueError(f"params_input {params_input} does not match params {self.params}") + + +class ApprovalOperator(HITLOperator, SkipMixin): + """Human-in-the-loop Operator that has only 'Approval' and 'Reject' options.""" + + inherits_from_skipmixin = True + + FIXED_ARGS = ["options", "multiple"] + + def __init__(self, ignore_downstream_trigger_rules: bool = False, **kwargs) -> None: + for arg in self.FIXED_ARGS: + if arg in kwargs: + raise ValueError(f"Passing {arg} to ApprovalOperator is not allowed.") + + self.ignore_downstream_trigger_rules = ignore_downstream_trigger_rules + + super().__init__(options=["Approve", "Reject"], multiple=False, **kwargs) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + ret = super().execute_complete(context=context, event=event) + + chosen_option = ret["chosen_options"][0] + if chosen_option == "Approve": + self.log.info("Approved. Proceeding with downstream tasks...") + return ret + + if not self.downstream_task_ids: + self.log.info("No downstream tasks; nothing to do.") + return ret + + def get_tasks_to_skip(): + if self.ignore_downstream_trigger_rules is True: + tasks = context["task"].get_flat_relatives(upstream=False) + else: + tasks = context["task"].get_direct_relatives(upstream=False) + + yield from (t for t in tasks if not t.is_teardown) + + tasks_to_skip = get_tasks_to_skip() + + # this lets us avoid an intermediate list unless debug logging + if self.log.getEffectiveLevel() <= logging.DEBUG: + self.log.debug("Downstream task IDs %s", tasks_to_skip := list(get_tasks_to_skip())) + + self.log.info("Skipping downstream tasks") + self.skip(ti=context["ti"], tasks=tasks_to_skip) + + return ret + + +class HITLBranchOperator(HITLOperator): + """BranchOperator based on Human-in-the-loop Response.""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + raise NotImplementedError + + +class HITLEntryOperator(HITLOperator): + """Human-in-the-loop Operator that is used to accept user input through TriggerForm.""" + + def __init__(self, **kwargs) -> None: + if "options" not in kwargs: + kwargs["options"] = ["OK"] + kwargs["defaults"] = ["OK"] + + super().__init__(**kwargs) diff --git a/providers/standard/src/airflow/providers/standard/triggers/hitl.py b/providers/standard/src/airflow/providers/standard/triggers/hitl.py new file mode 100644 index 0000000000000..63cea15363717 --- /dev/null +++ b/providers/standard/src/airflow/providers/standard/triggers/hitl.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS + +if not AIRFLOW_V_3_1_PLUS: + raise AirflowOptionalProviderFeatureException("Human in the loop functionality needs Airflow 3.1+.") + +import asyncio +from collections.abc import AsyncIterator +from datetime import datetime +from typing import Any, Literal, TypedDict +from uuid import UUID + +from asgiref.sync import sync_to_async + +from airflow.sdk.execution_time.hitl import ( + get_hitl_detail_content_detail, + update_htil_detail_response, +) +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils import timezone + + +class HITLTriggerEventSuccessPayload(TypedDict, total=False): + """Minimum required keys for a success Human-in-the-loop TriggerEvent.""" + + chosen_options: list[str] + params_input: dict[str, Any] + + +class HITLTriggerEventFailurePayload(TypedDict): + """Minimum required keys for a failed Human-in-the-loop TriggerEvent.""" + + error: str + error_type: Literal["timeout", "unknown"] + + +class HITLTrigger(BaseTrigger): + """A trigger that checks whether Human-in-the-loop responses are received.""" + + def __init__( + self, + *, + ti_id: UUID, + options: list[str], + params: dict[str, Any], + defaults: list[str] | None = None, + multiple: bool = False, + timeout_datetime: datetime | None, + poke_interval: float = 5.0, + **kwargs, + ): + super().__init__(**kwargs) + self.ti_id = ti_id + self.poke_interval = poke_interval + + self.options = options + self.multiple = multiple + self.defaults = defaults + self.timeout_datetime = timeout_datetime + + self.params = params + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize HITLTrigger arguments and classpath.""" + return ( + "airflow.providers.standard.triggers.hitl.HITLTrigger", + { + "ti_id": self.ti_id, + "options": self.options, + "defaults": self.defaults, + "params": self.params, + "multiple": self.multiple, + "timeout_datetime": self.timeout_datetime, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Loop until the Human-in-the-loop response received or timeout reached.""" + while True: + if self.timeout_datetime and self.timeout_datetime < timezone.utcnow(): + if self.defaults is None: + yield TriggerEvent( + HITLTriggerEventFailurePayload( + error="The timeout has passed, and the response has not yet been received.", + error_type="timeout", + ) + ) + return + + await sync_to_async(update_htil_detail_response)( + ti_id=self.ti_id, + chosen_options=self.defaults, + params_input=self.params, + ) + yield TriggerEvent( + HITLTriggerEventSuccessPayload( + chosen_options=self.defaults, + params_input=self.params, + ) + ) + return + + resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id) + if resp.response_received and resp.chosen_options: + self.log.info("Responded by %s at %s", resp.user_id, resp.response_at) + yield TriggerEvent( + HITLTriggerEventSuccessPayload( + chosen_options=resp.chosen_options, + params_input=resp.params_input, + ) + ) + return + await asyncio.sleep(self.poke_interval) diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py b/providers/standard/tests/unit/standard/operators/test_hitl.py new file mode 100644 index 0000000000000..767c379bdc0e5 --- /dev/null +++ b/providers/standard/tests/unit/standard/operators/test_hitl.py @@ -0,0 +1,269 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS + +if not AIRFLOW_V_3_1_PLUS: + pytest.skip("Human in the loop public API compatible with Airflow >= 3.0.1", allow_module_level=True) + +from typing import TYPE_CHECKING, Any + +from sqlalchemy import select + +from airflow.exceptions import DownstreamTasksSkipped +from airflow.models import Trigger +from airflow.models.hitl import HITLDetail +from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.providers.standard.operators.hitl import ( + ApprovalOperator, + HITLEntryOperator, + HITLOperator, +) +from airflow.sdk import Param +from airflow.sdk.definitions.param import ParamsDict + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from tests_common.pytest_plugin import DagMaker + +pytestmark = pytest.mark.db_test + + +class TestHITLOperator: + def test_validate_defaults(self) -> None: + hitl_op = HITLOperator( + task_id="hitl_test", + subject="This is subject", + options=["1", "2", "3", "4", "5"], + body="This is body", + defaults=["1"], + multiple=False, + params=ParamsDict({"input_1": 1}), + ) + hitl_op.validate_defaults() + + @pytest.mark.parametrize( + "extra_kwargs, expected_error_msg", + [ + ({"defaults": ["0"]}, r'defaults ".*" should be a subset of options ".*"'), + ( + {"multiple": False, "defaults": ["1", "2"]}, + 'More than one defaults given when "multiple" is set to False.', + ), + ], + ids=[ + "defaults not in option", + "multiple defaults when multiple is False", + ], + ) + def test_validate_defaults_with_invalid_defaults( + self, + extra_kwargs: dict[str, Any], + expected_error_msg: str, + ) -> None: + with pytest.raises(ValueError, match=expected_error_msg): + HITLOperator( + task_id="hitl_test", + subject="This is subject", + body="This is body", + options=["1", "2", "3", "4", "5"], + params=ParamsDict({"input_1": 1}), + **extra_kwargs, + ) + + def test_execute(self, dag_maker: DagMaker, session: Session) -> None: + with dag_maker("test_dag"): + task = HITLOperator( + task_id="hitl_test", + subject="This is subject", + options=["1", "2", "3", "4", "5"], + body="This is body", + defaults=["1"], + multiple=False, + params=ParamsDict({"input_1": 1}), + ) + dr = dag_maker.create_dagrun() + ti = dag_maker.run_ti(task.task_id, dr) + + hitl_detail_model = session.scalar(select(HITLDetail).where(HITLDetail.ti_id == ti.id)) + assert hitl_detail_model.ti_id == ti.id + assert hitl_detail_model.subject == "This is subject" + assert hitl_detail_model.options == ["1", "2", "3", "4", "5"] + assert hitl_detail_model.body == "This is body" + assert hitl_detail_model.defaults == ["1"] + assert hitl_detail_model.multiple is False + assert hitl_detail_model.params == {"input_1": 1} + assert hitl_detail_model.response_at is None + assert hitl_detail_model.user_id is None + assert hitl_detail_model.chosen_options is None + assert hitl_detail_model.params_input == {} + + registered_trigger = session.scalar( + select(Trigger).where(Trigger.classpath == "airflow.providers.standard.triggers.hitl.HITLTrigger") + ) + assert registered_trigger.kwargs == { + "ti_id": ti.id, + "options": ["1", "2", "3", "4", "5"], + "defaults": ["1"], + "params": {"input_1": 1}, + "multiple": False, + "timeout_datetime": None, + "poke_interval": 5.0, + } + + @pytest.mark.parametrize( + "input_params, expected_params", + [ + (ParamsDict({"input": 1}), {"input": 1}), + ({"input": Param(5, type="integer", minimum=3)}, {"input": 5}), + (None, {}), + ], + ) + def test_serialzed_params(self, input_params, expected_params: dict[str, Any]) -> None: + hitl_op = HITLOperator( + task_id="hitl_test", + subject="This is subject", + body="This is body", + options=["1", "2", "3", "4", "5"], + params=input_params, + ) + assert hitl_op.serialzed_params == expected_params + + def test_execute_complete(self) -> None: + hitl_op = HITLOperator( + task_id="hitl_test", + subject="This is subject", + body="This is body", + options=["1", "2", "3", "4", "5"], + params={"input": 1}, + ) + + ret = hitl_op.execute_complete( + context={}, + event={"chosen_options": ["1"], "params_input": {"input": 2}}, + ) + + assert ret["chosen_options"] == ["1"] + assert ret["params_input"] == {"input": 2} + + def test_validate_chosen_options_with_invalid_content(self) -> None: + hitl_op = HITLOperator( + task_id="hitl_test", + subject="This is subject", + body="This is body", + options=["1", "2", "3", "4", "5"], + params={"input": 1}, + ) + + with pytest.raises(ValueError): + hitl_op.execute_complete( + context={}, + event={ + "chosen_options": ["not exists"], + "params_input": {"input": 2}, + }, + ) + + def test_validate_params_input_with_invalid_input(self) -> None: + hitl_op = HITLOperator( + task_id="hitl_test", + subject="This is subject", + body="This is body", + options=["1", "2", "3", "4", "5"], + params={"input": 1}, + ) + + with pytest.raises(ValueError): + hitl_op.execute_complete( + context={}, + event={ + "chosen_options": ["1"], + "params_input": {"no such key": 2, "input": 333}, + }, + ) + + +class TestApprovalOperator: + def test_init_with_options(self) -> None: + with pytest.raises(ValueError): + ApprovalOperator( + task_id="hitl_test", + subject="This is subject", + body="This is body", + options=["1", "2", "3", "4", "5"], + params={"input": 1}, + ) + + def test_init_with_multiple_set_to_true(self) -> None: + with pytest.raises(ValueError): + ApprovalOperator( + task_id="hitl_test", + subject="This is subject", + params={"input": 1}, + multiple=True, + ) + + def test_execute_complete(self) -> None: + hitl_op = ApprovalOperator( + task_id="hitl_test", + subject="This is subject", + ) + + ret = hitl_op.execute_complete( + context={}, + event={"chosen_options": ["Approve"], "params_input": {}}, + ) + + assert ret == { + "chosen_options": ["Approve"], + "params_input": {}, + } + + def test_execute_complete_with_downstream_tasks(self, dag_maker) -> None: + with dag_maker("hitl_test_dag", serialized=True): + hitl_op = ApprovalOperator( + task_id="hitl_test", + subject="This is subject", + ) + (hitl_op >> EmptyOperator(task_id="op1")) + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("hitl_test") + + with pytest.raises(DownstreamTasksSkipped) as exc_info: + hitl_op.execute_complete( + context={"ti": ti, "task": ti.task}, + event={"chosen_options": ["Reject"], "params_input": {}}, + ) + assert set(exc_info.value.tasks) == {"op1"} + + +class TestHITLEntryOperator: + def test_init(self) -> None: + op = HITLEntryOperator( + task_id="hitl_test", + subject="This is subject", + body="This is body", + params={"input": 1}, + ) + + assert op.options == ["OK"] + assert op.defaults == ["OK"] diff --git a/providers/standard/tests/unit/standard/triggers/test_hitl.py b/providers/standard/tests/unit/standard/triggers/test_hitl.py new file mode 100644 index 0000000000000..ac96d9eed1e07 --- /dev/null +++ b/providers/standard/tests/unit/standard/triggers/test_hitl.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS + +if not AIRFLOW_V_3_1_PLUS: + pytest.skip("Human in the loop public API compatible with Airflow >= 3.0.1", allow_module_level=True) + +import asyncio +from datetime import timedelta +from unittest import mock + +from uuid6 import uuid7 + +from airflow.api_fastapi.execution_api.datamodels.hitl import HITLDetailResponse +from airflow.providers.standard.triggers.hitl import ( + HITLTrigger, + HITLTriggerEventFailurePayload, + HITLTriggerEventSuccessPayload, +) +from airflow.triggers.base import TriggerEvent +from airflow.utils.timezone import utcnow + +TI_ID = uuid7() + + +class TestHITLTrigger: + def test_serialization(self): + trigger = HITLTrigger( + ti_id=TI_ID, + options=["1", "2", "3", "4", "5"], + params={"input": 1}, + defaults=["1"], + multiple=False, + timeout_datetime=None, + poke_interval=50.0, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.standard.triggers.hitl.HITLTrigger" + assert kwargs == { + "ti_id": TI_ID, + "options": ["1", "2", "3", "4", "5"], + "params": {"input": 1}, + "defaults": ["1"], + "multiple": False, + "timeout_datetime": None, + "poke_interval": 50.0, + } + + @pytest.mark.db_test + @pytest.mark.asyncio + @mock.patch("airflow.sdk.execution_time.hitl.update_htil_detail_response") + async def test_run_failed_due_to_timeout(self, mock_update, mock_supervisor_comms): + trigger = HITLTrigger( + ti_id=TI_ID, + options=["1", "2", "3", "4", "5"], + params={"input": 1}, + multiple=False, + timeout_datetime=utcnow() + timedelta(seconds=0.1), + poke_interval=5, + ) + mock_supervisor_comms.send.return_value = HITLDetailResponse( + response_received=False, + user_id=None, + response_at=None, + chosen_options=None, + params_input={}, + ) + + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + await asyncio.sleep(0.3) + event = await trigger_task + assert event == TriggerEvent( + HITLTriggerEventFailurePayload( + error="The timeout has passed, and the response has not yet been received.", + error_type="timeout", + ) + ) + + @pytest.mark.db_test + @pytest.mark.asyncio + @mock.patch("airflow.sdk.execution_time.hitl.update_htil_detail_response") + async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_supervisor_comms): + trigger = HITLTrigger( + ti_id=TI_ID, + options=["1", "2", "3", "4", "5"], + params={"input": 1}, + defaults=["1"], + multiple=False, + timeout_datetime=utcnow() + timedelta(seconds=0.1), + poke_interval=5, + ) + mock_supervisor_comms.send.return_value = HITLDetailResponse( + response_received=False, + user_id=None, + response_at=None, + chosen_options=None, + params_input={}, + ) + + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + await asyncio.sleep(0.3) + event = await trigger_task + assert event == TriggerEvent( + HITLTriggerEventSuccessPayload( + chosen_options=["1"], + params_input={"input": 1}, + ) + ) + + @pytest.mark.db_test + @pytest.mark.asyncio + @mock.patch("airflow.sdk.execution_time.hitl.update_htil_detail_response") + async def test_run(self, mock_update, mock_supervisor_comms): + trigger = HITLTrigger( + ti_id=TI_ID, + options=["1", "2", "3", "4", "5"], + params={"input": 1}, + defaults=["1"], + multiple=False, + timeout_datetime=None, + poke_interval=5, + ) + mock_supervisor_comms.send.return_value = HITLDetailResponse( + response_received=True, + user_id="test", + response_at=utcnow(), + chosen_options=["3"], + params_input={"input": 50}, + ) + + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + await asyncio.sleep(0.3) + event = await trigger_task + assert event == TriggerEvent( + HITLTriggerEventSuccessPayload( + chosen_options=["3"], + params_input={"input": 50}, + ) + ) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 179bb9a443f5a..c8e45c10b3f66 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -38,8 +38,10 @@ AssetEventsResponse, AssetResponse, ConnectionResponse, + CreateHITLDetailPayload, DagRunStateResponse, DagRunType, + HITLDetailResponse, InactiveAssetsResponse, PrevSuccessfulDagRunResponse, TaskInstanceState, @@ -55,6 +57,7 @@ TISuccessStatePayload, TITerminalStatePayload, TriggerDAGRunPayload, + UpdateHITLDetail, ValidationError as RemoteValidationError, VariablePostBody, VariableResponse, @@ -66,6 +69,7 @@ from airflow.sdk.execution_time.comms import ( DRCount, ErrorResponse, + HITLDetailRequestResult, OKResponse, SkipDownstreamTasks, TaskRescheduleStartDate, @@ -618,6 +622,70 @@ def get_count( return DRCount(count=resp.json()) +class HITLOperations: + """ + Operations related to Human in the loop. Require Airflow 3.1+. + + :meta: private + """ + + __slots__ = ("client",) + + def __init__(self, client: Client) -> None: + self.client = client + + def add_response( + self, + *, + ti_id: uuid.UUID, + options: list[str], + subject: str, + body: str | None = None, + defaults: list[str] | None = None, + multiple: bool = False, + params: dict[str, Any] | None = None, + ) -> HITLDetailRequestResult: + """Add a Human-in-the-loop response that waits for human response for a specific Task Instance.""" + payload = CreateHITLDetailPayload( + ti_id=ti_id, + options=options, + subject=subject, + body=body, + defaults=defaults, + multiple=multiple, + params=params, + ) + resp = self.client.post( + f"/hitl-details/{ti_id}", + content=payload.model_dump_json(), + ) + return HITLDetailRequestResult.model_validate_json(resp.read()) + + def update_response( + self, + *, + ti_id: uuid.UUID, + chosen_options: list[str], + params_input: dict[str, Any], + ) -> HITLDetailResponse: + """Update an existing Human-in-the-loop response.""" + payload = UpdateHITLDetail( + ti_id=ti_id, + chosen_options=chosen_options, + params_input=params_input, + ) + resp = self.client.patch( + f"/hitl-details/{ti_id}", + content=payload.model_dump_json(), + ) + return HITLDetailResponse.model_validate_json(resp.read()) + + def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse: + """Get content part of a Human-in-the-loop response for a specific Task Instance.""" + resp = self.client.get(f"/hitl-details/{ti_id}") + return HITLDetailResponse.model_validate_json(resp.read()) + + class BearerAuth(httpx.Auth): def __init__(self, token: str): self.token: str = token @@ -751,6 +819,12 @@ def asset_events(self) -> AssetEventOperations: """Operations related to Asset Events.""" return AssetEventOperations(self) + @lru_cache() # type: ignore[misc] + @property + def hitl(self): + """Operations related to HITL Responses.""" + return HITLOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index f0d3d45d1684a..13e41bd5b6a0a 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -102,6 +102,23 @@ class ConnectionResponse(BaseModel): extra: Annotated[str | None, Field(title="Extra")] = None +class CreateHITLDetailPayload(BaseModel): + """ + Add the input request part of a Human-in-the-loop response. + """ + + ti_id: Annotated[UUID, Field(title="Ti Id")] + options: Annotated[list[str], Field(title="Options")] + subject: Annotated[str, Field(title="Subject")] + body: Annotated[str | None, Field(title="Body")] = None + defaults: Annotated[list[str] | None, Field(title="Defaults")] = None + multiple: Annotated[bool | None, Field(title="Multiple")] = False + params: Annotated[dict[str, Any] | None, Field(title="Params")] = None + type: Annotated[Literal["CreateHITLDetailPayload"] | None, Field(title="Type")] = ( + "CreateHITLDetailPayload" + ) + + class DagRunAssetReference(BaseModel): """ DagRun serializer for asset responses. @@ -154,6 +171,32 @@ class DagRunType(str, Enum): ASSET_TRIGGERED = "asset_triggered" +class HITLDetailRequest(BaseModel): + """ + Schema for the request part of a Human-in-the-loop detail for a specific task instance. + """ + + ti_id: Annotated[UUID, Field(title="Ti Id")] + options: Annotated[list[str], Field(title="Options")] + subject: Annotated[str, Field(title="Subject")] + body: Annotated[str | None, Field(title="Body")] = None + defaults: Annotated[list[str] | None, Field(title="Defaults")] = None + multiple: Annotated[bool | None, Field(title="Multiple")] = False + params: Annotated[dict[str, Any] | None, Field(title="Params")] = None + + +class HITLDetailResponse(BaseModel): + """ + Schema for the response part of a Human-in-the-loop detail for a specific task instance. + """ + + response_received: Annotated[bool, Field(title="Response Received")] + user_id: Annotated[str | None, Field(title="User Id")] = None + response_at: Annotated[AwareDatetime | None, Field(title="Response At")] = None + chosen_options: Annotated[list[str] | None, Field(title="Chosen Options")] = None + params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None + + class InactiveAssetsResponse(BaseModel): """ Response for inactive assets. @@ -325,6 +368,17 @@ class TriggerDAGRunPayload(BaseModel): reset_dag_run: Annotated[bool | None, Field(title="Reset Dag Run")] = False +class UpdateHITLDetail(BaseModel): + """ + Update the response content part of an existing Human-in-the-loop response. + """ + + ti_id: Annotated[UUID, Field(title="Ti Id")] + chosen_options: Annotated[list[str], Field(title="Chosen Options")] + params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None + type: Annotated[Literal["UpdateHITLDetail"] | None, Field(title="Type")] = "UpdateHITLDetail" + + class ValidationError(BaseModel): loc: Annotated[list[str | int], Field(title="Location")] msg: Annotated[str, Field(title="Message")] diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d52e9bb3cab5f..fb9bdbaf6ba61 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -63,6 +63,10 @@ from fastapi import Body from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_serializer +from airflow.api_fastapi.execution_api.datamodels.hitl import ( + GetHITLDetailResponsePayload, + UpdateHITLDetailPayload, +) from airflow.sdk.api.datamodels._generated import ( AssetEventDagRunReference, AssetEventResponse, @@ -71,6 +75,7 @@ BundleInfo, ConnectionResponse, DagRunStateResponse, + HITLDetailRequest, InactiveAssetsResponse, PrevSuccessfulDagRunResponse, TaskInstance, @@ -96,6 +101,7 @@ # Available on Unix and Windows (so "everywhere") but lets be safe recv_fds = None # type: ignore[assignment] + if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -557,6 +563,18 @@ class SentFDs(BaseModel): fds: list[int] +class CreateHITLDetailPayload(HITLDetailRequest): + """Add the input request part of a Human-in-the-loop response.""" + + type: Literal["CreateHITLDetailPayload"] = "CreateHITLDetailPayload" + + +class HITLDetailRequestResult(HITLDetailRequest): + """Response to CreateHITLDetailPayload request.""" + + type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult" + + ToTask = Annotated[ AssetResult | AssetEventsResult @@ -576,6 +594,8 @@ class SentFDs(BaseModel): | XComSequenceIndexResult | XComSequenceSliceResult | InactiveAssetsResult + | CreateHITLDetailPayload + | HITLDetailRequestResult | OKResponse, Field(discriminator="type"), ] @@ -838,6 +858,18 @@ class GetDRCount(BaseModel): type: Literal["GetDRCount"] = "GetDRCount" +class GetHITLDetailResponse(GetHITLDetailResponsePayload): + """Get the response content part of a Human-in-the-loop response.""" + + type: Literal["GetHITLDetailResponse"] = "GetHITLDetailResponse" + + +class UpdateHITLDetail(UpdateHITLDetailPayload): + """Update the response content part of an existing Human-in-the-loop response.""" + + type: Literal["UpdateHITLDetail"] = "UpdateHITLDetail" + + ToSupervisor = Annotated[ DeferTask | DeleteXCom @@ -868,6 +900,9 @@ class GetDRCount(BaseModel): | TaskState | TriggerDagRun | DeleteVariable - | ResendLoggingFD, + | ResendLoggingFD + | CreateHITLDetailPayload + | UpdateHITLDetail + | GetHITLDetailResponse, Field(discriminator="type"), ] diff --git a/task-sdk/src/airflow/sdk/execution_time/hitl.py b/task-sdk/src/airflow/sdk/execution_time/hitl.py new file mode 100644 index 0000000000000..6f6409b3dcadb --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/hitl.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from uuid import UUID + +from airflow.sdk.execution_time.comms import ( + CreateHITLDetailPayload, + GetHITLDetailResponse, + UpdateHITLDetail, +) + +if TYPE_CHECKING: + from airflow.api_fastapi.execution_api.datamodels.hitl import HITLDetailResponse + + +def add_hitl_detail( + ti_id: UUID, + options: list[str], + subject: str, + body: str | None = None, + defaults: list[str] | None = None, + multiple: bool = False, + params: dict[str, Any] | None = None, +) -> None: + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send( + msg=CreateHITLDetailPayload( + ti_id=ti_id, + options=options, + subject=subject, + body=body, + defaults=defaults, + params=params, + multiple=multiple, + ) + ) + + +def update_htil_detail_response( + ti_id: UUID, + chosen_options: list[str], + params_input: dict[str, Any], +) -> HITLDetailResponse: + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + response = SUPERVISOR_COMMS.send( + msg=UpdateHITLDetail( + ti_id=ti_id, + chosen_options=chosen_options, + params_input=params_input, + ), + ) + if TYPE_CHECKING: + assert isinstance(response, HITLDetailResponse) + return response + + +def get_hitl_detail_content_detail(ti_id: UUID) -> HITLDetailResponse: + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + response = SUPERVISOR_COMMS.send(msg=GetHITLDetailResponse(ti_id=ti_id)) + + if TYPE_CHECKING: + assert isinstance(response, HITLDetailResponse) + return response diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 690411979865e..f5eede6274df4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -68,6 +68,7 @@ AssetEventsResult, AssetResult, ConnectionResult, + CreateHITLDetailPayload, DagRunStateResult, DeferTask, DeleteVariable, @@ -1230,6 +1231,17 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: self._send_new_log_fd(req_id) # Since we've sent the message, return. Nothing else in this ifelse/switch should return directly return + elif isinstance(msg, CreateHITLDetailPayload): + resp = self.client.hitl.add_response( + ti_id=msg.ti_id, + options=msg.options, + subject=msg.subject, + body=msg.body, + defaults=msg.defaults, + params=msg.params, + multiple=msg.multiple, + ) + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 7fa17ba1c4d38..4f4af1e137242 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -19,12 +19,15 @@ import json import pickle +from datetime import datetime +from typing import TYPE_CHECKING from unittest import mock import httpx import pytest import uuid6 from task_sdk import make_client, make_client_w_dry_run, make_client_w_responses +from uuid6 import uuid7 from airflow.sdk.api.client import RemoteValidationError, ServerResponseError from airflow.sdk.api.datamodels._generated import ( @@ -33,6 +36,7 @@ ConnectionResponse, DagRunState, DagRunStateResponse, + HITLDetailResponse, VariableResponse, XComResponse, ) @@ -40,6 +44,7 @@ from airflow.sdk.execution_time.comms import ( DeferTask, ErrorResponse, + HITLDetailRequestResult, OKResponse, RescheduleTask, TaskRescheduleStartDate, @@ -47,6 +52,9 @@ from airflow.utils import timezone from airflow.utils.state import TerminalTIState +if TYPE_CHECKING: + from time_machine import TimeMachineFixture + class TestClient: @pytest.mark.parametrize( @@ -1150,3 +1158,101 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert isinstance(result, TaskRescheduleStartDate) assert result.start_date == "2024-01-01T00:00:00Z" + + +class TestHITLOperations: + def test_add_response(self) -> None: + ti_id = uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path in (f"/hitl-details/{ti_id}"): + return httpx.Response( + status_code=201, + json={ + "ti_id": str(ti_id), + "options": ["Approval", "Reject"], + "subject": "This is subject", + "body": "This is body", + "defaults": ["Approval"], + "params": None, + "multiple": False, + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.hitl.add_response( + ti_id=ti_id, + options=["Approval", "Reject"], + subject="This is subject", + body="This is body", + defaults=["Approval"], + params=None, + multiple=False, + ) + assert isinstance(result, HITLDetailRequestResult) + assert result.ti_id == ti_id + assert result.options == ["Approval", "Reject"] + assert result.subject == "This is subject" + assert result.body == "This is body" + assert result.defaults == ["Approval"] + assert result.params is None + assert result.multiple is False + + def test_update_response(self, time_machine: TimeMachineFixture) -> None: + time_machine.move_to(datetime(2025, 7, 3, 0, 0, 0)) + ti_id = uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path in (f"/hitl-details/{ti_id}"): + return httpx.Response( + status_code=200, + json={ + "chosen_options": ["Approval"], + "params_input": {}, + "user_id": "admin", + "response_received": True, + "response_at": "2025-07-03T00:00:00Z", + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.hitl.update_response( + ti_id=ti_id, + chosen_options=["Approve"], + params_input={}, + ) + assert isinstance(result, HITLDetailResponse) + assert result.response_received is True + assert result.chosen_options == ["Approval"] + assert result.params_input == {} + assert result.user_id == "admin" + assert result.response_at == timezone.datetime(2025, 7, 3, 0, 0, 0) + + def test_get_detail_response(self, time_machine: TimeMachineFixture) -> None: + time_machine.move_to(datetime(2025, 7, 3, 0, 0, 0)) + ti_id = uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path in (f"/hitl-details/{ti_id}"): + return httpx.Response( + status_code=200, + json={ + "chosen_options": ["Approval"], + "params_input": {}, + "user_id": "admin", + "response_received": True, + "response_at": "2025-07-03T00:00:00Z", + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.hitl.get_detail_response(ti_id=ti_id) + assert isinstance(result, HITLDetailResponse) + assert result.response_received is True + assert result.chosen_options == ["Approval"] + assert result.params_input == {} + assert result.user_id == "admin" + assert result.response_at == timezone.datetime(2025, 7, 3, 0, 0, 0) diff --git a/task-sdk/tests/task_sdk/execution_time/test_hitl.py b/task-sdk/tests/task_sdk/execution_time/test_hitl.py new file mode 100644 index 0000000000000..cab17e30bacec --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_hitl.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from uuid6 import uuid7 + +from airflow.sdk.api.datamodels._generated import HITLDetailResponse +from airflow.sdk.execution_time.comms import CreateHITLDetailPayload +from airflow.sdk.execution_time.hitl import ( + add_hitl_detail, + get_hitl_detail_content_detail, + update_htil_detail_response, +) +from airflow.utils import timezone + +TI_ID = uuid7() + + +def test_add_hitl_detail(mock_supervisor_comms) -> None: + add_hitl_detail( + ti_id=TI_ID, + options=["Approve", "Reject"], + subject="Subject", + body="Optional body", + defaults=["Approve", "Reject"], + params={"input_1": 1}, + multiple=False, + ) + mock_supervisor_comms.send.assert_called_with( + msg=CreateHITLDetailPayload( + ti_id=TI_ID, + options=["Approve", "Reject"], + subject="Subject", + body="Optional body", + defaults=["Approve", "Reject"], + params={"input_1": 1}, + multiple=False, + ) + ) + + +def test_update_htil_detail_response(mock_supervisor_comms) -> None: + timestamp = timezone.utcnow() + mock_supervisor_comms.send.return_value = HITLDetailResponse( + response_received=True, + chosen_options=["Approve"], + response_at=timestamp, + user_id="admin", + params_input={"input_1": 1}, + ) + resp = update_htil_detail_response( + ti_id=TI_ID, + chosen_options=["Approve"], + params_input={"input_1": 1}, + ) + assert resp == HITLDetailResponse( + response_received=True, + chosen_options=["Approve"], + response_at=timestamp, + user_id="admin", + params_input={"input_1": 1}, + ) + + +def test_get_hitl_detail_content_detail(mock_supervisor_comms) -> None: + mock_supervisor_comms.send.return_value = HITLDetailResponse( + response_received=False, + chosen_options=None, + response_at=None, + user_id=None, + params_input={}, + ) + resp = get_hitl_detail_content_detail(TI_ID) + assert resp == HITLDetailResponse( + response_received=False, + chosen_options=None, + response_at=None, + user_id=None, + params_input={}, + ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 67f565ff9f4c1..db71ab76f8207 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -60,6 +60,7 @@ AssetResult, CommsDecoder, ConnectionResult, + CreateHITLDetailPayload, DagRunStateResult, DeferTask, DeleteVariable, @@ -81,6 +82,7 @@ GetXCom, GetXComSequenceItem, GetXComSequenceSlice, + HITLDetailRequestResult, InactiveAssetsResult, OKResponse, PrevSuccessfulDagRunResult, @@ -1770,6 +1772,49 @@ def watched_subprocess(self, mocker): None, id="get_xcom_seq_slice", ), + pytest.param( + CreateHITLDetailPayload( + ti_id=TI_ID, + options=["Approve", "Reject"], + subject="This is subject", + body="This is body", + defaults=["Approve"], + multiple=False, + params={}, + ), + { + "ti_id": str(TI_ID), + "options": ["Approve", "Reject"], + "subject": "This is subject", + "body": "This is body", + "defaults": ["Approve"], + "multiple": False, + "params": {}, + "type": "HITLDetailRequestResult", + }, + "hitl.add_response", + (), + { + "body": "This is body", + "defaults": ["Approve"], + "multiple": False, + "options": ["Approve", "Reject"], + "params": {}, + "subject": "This is subject", + "ti_id": TI_ID, + }, + HITLDetailRequestResult( + ti_id=TI_ID, + options=["Approve", "Reject"], + subject="This is subject", + body="This is body", + defaults=["Approve"], + multiple=False, + params={}, + ), + None, + id="create_hitl_detail_payload", + ), ], ) def test_handle_requests(