diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow/api_fastapi/core_api/datamodels/task_instances.py index 85ab3d650225..7cdbd5c975f6 100644 --- a/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -17,10 +17,9 @@ from __future__ import annotations from datetime import datetime -from typing import Annotated +from typing import Annotated, Any from pydantic import ( - AliasChoices, AliasPath, AwareDatetime, BaseModel, @@ -29,6 +28,7 @@ Field, NonNegativeInt, field_validator, + model_validator, ) from airflow.api_fastapi.core_api.datamodels.job import JobResponse @@ -154,25 +154,34 @@ class TaskInstanceHistoryCollectionResponse(BaseModel): total_entries: int -class TaskInstanceReferenceResponse(BaseModel): - """Task Instance Reference serializer for responses.""" - - task_id: str - dag_run_id: str = Field(validation_alias=AliasChoices("run_id")) - dag_id: str - logical_date: datetime - - class PatchTaskInstanceBody(BaseModel): """Request body for Clear Task Instances endpoint.""" dry_run: bool = True - new_state: str + new_state: str | None = None + note: str | None = None + + @model_validator(mode="before") + @classmethod + def validate_model(cls, data: Any) -> Any: + if data.get("note") is None and data.get("new_state") is None: + raise ValueError("new_state is required.") + return data + + @field_validator("note", mode="before") + @classmethod + def validate_note(cls, note: str | None) -> str | None: + """Validate note.""" + if note is None: + return None + if len(note) > 1000: + raise ValueError("Note length should not exceed 1000 characters.") + return note @field_validator("new_state", mode="before") @classmethod def validate_new_state(cls, ns: str) -> str: - """Convert timezone attribute to string representation.""" + """Validate new_state.""" valid_states = [ vs.name.lower() for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED) diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 413f29c6b6a7..d6421e950475 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3434,7 +3434,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/TaskInstanceReferenceResponse' + $ref: '#/components/schemas/TaskInstanceResponse' '401': content: application/json: @@ -3447,6 +3447,12 @@ paths: schema: $ref: '#/components/schemas/HTTPExceptionResponse' description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request '404': content: application/json: @@ -3892,7 +3898,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/TaskInstanceReferenceResponse' + $ref: '#/components/schemas/TaskInstanceResponse' '401': content: application/json: @@ -3905,6 +3911,12 @@ paths: schema: $ref: '#/components/schemas/HTTPExceptionResponse' description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request '404': content: application/json: @@ -6553,11 +6565,16 @@ components: title: Dry Run default: true new_state: - type: string + anyOf: + - type: string + - type: 'null' title: New State + note: + anyOf: + - type: string + - type: 'null' + title: Note type: object - required: - - new_state title: PatchTaskInstanceBody description: Request body for Clear Task Instances endpoint. PluginCollectionResponse: @@ -7038,29 +7055,6 @@ components: - executor_config title: TaskInstanceHistoryResponse description: TaskInstanceHistory serializer for responses. - TaskInstanceReferenceResponse: - properties: - task_id: - type: string - title: Task Id - dag_run_id: - type: string - title: Dag Run Id - dag_id: - type: string - title: Dag Id - logical_date: - type: string - format: date-time - title: Logical Date - type: object - required: - - task_id - - dag_run_id - - dag_id - - logical_date - title: TaskInstanceReferenceResponse - description: Task Instance Reference serializer for responses. TaskInstanceResponse: properties: id: diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index fcf9552f21a7..7decb57d76ea 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -20,8 +20,9 @@ from typing import Annotated, Literal from fastapi import Depends, HTTPException, Request, status +from sqlalchemy import or_, select +from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm import Session, joinedload -from sqlalchemy.sql import select from airflow.api_fastapi.common.db.common import get_session, paginated_select from airflow.api_fastapi.common.parameters import ( @@ -52,7 +53,6 @@ TaskDependencyCollectionResponse, TaskInstanceCollectionResponse, TaskInstanceHistoryResponse, - TaskInstanceReferenceResponse, TaskInstanceResponse, TaskInstancesBatchBody, ) @@ -488,11 +488,11 @@ def get_mapped_task_instance_try_details( @task_instances_router.patch( "/{task_id}", - responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]), ) @task_instances_router.patch( "/{task_id}/{map_index}", - responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]), ) def patch_task_instance( dag_id: str, @@ -502,7 +502,7 @@ def patch_task_instance( body: PatchTaskInstanceBody, session: Annotated[Session, Depends(get_session)], map_index: int = -1, -) -> TaskInstanceReferenceResponse: +) -> TaskInstanceResponse: """Update the state of a task instance.""" dag = request.app.state.dag_bag.get_dag(dag_id) if not dag: @@ -511,13 +511,27 @@ def patch_task_instance( if not dag.has_task(task_id): raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found in DAG '{dag_id}'") - ti: TI | None = session.scalars( - select(TI).where( - TI.dag_id == dag_id, TI.task_id == task_id, TI.run_id == dag_run_id, TI.map_index == map_index + query = ( + select(TI) + .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id) + .join(TI.dag_run) + .options(joinedload(TI.rendered_task_instance_fields)) + ) + if map_index == -1: + query = query.where(or_(TI.map_index == -1, TI.map_index is None)) + else: + query = query.where(TI.map_index == map_index) + + try: + ti = session.scalar(query) + except MultipleResultsFound: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + "Multiple task instances found. As the TI is mapped, add the map_index value to the URL", ) - ).one_or_none() - err_msg_404 = f"Task instance not found for task '{task_id}' on DAG run with ID '{dag_run_id}'" - if not ti: + + err_msg_404 = f"Task Instance not found for dag_id={dag_id}, run_id={dag_run_id}, task_id={task_id}" + if ti is None: raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404) if not body.dry_run: @@ -532,4 +546,16 @@ def patch_task_instance( if not ti: raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404) ti = tis[0] if isinstance(tis, list) else tis - return TaskInstanceReferenceResponse.model_validate(ti, from_attributes=True) + + # Set new note to the task instance if available in body. + if body.note is not None: + # @TODO: replace None passed for user_id with actual user id when + # permissions and auth is in place. + if ti.task_instance_note is None: + ti.note = (body.note, None) + else: + ti.task_instance_note.content = body.note + ti.task_instance_note.user_id = None + session.commit() + + return TaskInstanceResponse.model_validate(ti, from_attributes=True) diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index b2fe160ae4e7..4e0c70b5aae8 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -2851,7 +2851,7 @@ export const usePoolServicePatchPool = < * @param data.taskId * @param data.requestBody * @param data.mapIndex - * @returns TaskInstanceReferenceResponse Successful Response + * @returns TaskInstanceResponse Successful Response * @throws ApiError */ export const useTaskInstanceServicePatchTaskInstance = < @@ -2906,7 +2906,7 @@ export const useTaskInstanceServicePatchTaskInstance = < * @param data.taskId * @param data.mapIndex * @param data.requestBody - * @returns TaskInstanceReferenceResponse Successful Response + * @returns TaskInstanceResponse Successful Response * @throws ApiError */ export const useTaskInstanceServicePatchTaskInstance1 = < diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index f8ce2d4d6ed2..0cdcddff2f1b 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -2660,12 +2660,29 @@ export const $PatchTaskInstanceBody = { default: true, }, new_state: { - type: "string", + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], title: "New State", }, + note: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Note", + }, }, type: "object", - required: ["new_state"], title: "PatchTaskInstanceBody", description: "Request body for Clear Task Instances endpoint.", } as const; @@ -3365,32 +3382,6 @@ export const $TaskInstanceHistoryResponse = { description: "TaskInstanceHistory serializer for responses.", } as const; -export const $TaskInstanceReferenceResponse = { - properties: { - task_id: { - type: "string", - title: "Task Id", - }, - dag_run_id: { - type: "string", - title: "Dag Run Id", - }, - dag_id: { - type: "string", - title: "Dag Id", - }, - logical_date: { - type: "string", - format: "date-time", - title: "Logical Date", - }, - }, - type: "object", - required: ["task_id", "dag_run_id", "dag_id", "logical_date"], - title: "TaskInstanceReferenceResponse", - description: "Task Instance Reference serializer for responses.", -} as const; - export const $TaskInstanceResponse = { properties: { id: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index e5786e0e9f41..b33a105c9a49 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -1869,7 +1869,7 @@ export class TaskInstanceService { * @param data.taskId * @param data.requestBody * @param data.mapIndex - * @returns TaskInstanceReferenceResponse Successful Response + * @returns TaskInstanceResponse Successful Response * @throws ApiError */ public static patchTaskInstance( @@ -1889,6 +1889,7 @@ export class TaskInstanceService { body: data.requestBody, mediaType: "application/json", errors: { + 400: "Bad Request", 401: "Unauthorized", 403: "Forbidden", 404: "Not Found", @@ -2070,7 +2071,7 @@ export class TaskInstanceService { * @param data.taskId * @param data.mapIndex * @param data.requestBody - * @returns TaskInstanceReferenceResponse Successful Response + * @returns TaskInstanceResponse Successful Response * @throws ApiError */ public static patchTaskInstance1( @@ -2088,6 +2089,7 @@ export class TaskInstanceService { body: data.requestBody, mediaType: "application/json", errors: { + 400: "Bad Request", 401: "Unauthorized", 403: "Forbidden", 404: "Not Found", diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 2759bb6c3cca..6d00198e26c3 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -663,7 +663,8 @@ export type JobResponse = { */ export type PatchTaskInstanceBody = { dry_run?: boolean; - new_state: string; + new_state?: string | null; + note?: string | null; }; /** @@ -844,16 +845,6 @@ export type TaskInstanceHistoryResponse = { executor_config: string; }; -/** - * Task Instance Reference serializer for responses. - */ -export type TaskInstanceReferenceResponse = { - task_id: string; - dag_run_id: string; - dag_id: string; - logical_date: string; -}; - /** * TaskInstance serializer for responses. */ @@ -1551,7 +1542,7 @@ export type PatchTaskInstanceData = { taskId: string; }; -export type PatchTaskInstanceResponse = TaskInstanceReferenceResponse; +export type PatchTaskInstanceResponse = TaskInstanceResponse; export type GetMappedTaskInstancesData = { dagId: string; @@ -1615,7 +1606,7 @@ export type PatchTaskInstance1Data = { taskId: string; }; -export type PatchTaskInstance1Response = TaskInstanceReferenceResponse; +export type PatchTaskInstance1Response = TaskInstanceResponse; export type GetTaskInstancesData = { dagId: string; @@ -3207,7 +3198,11 @@ export type $OpenApiTs = { /** * Successful Response */ - 200: TaskInstanceReferenceResponse; + 200: TaskInstanceResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; /** * Unauthorized */ @@ -3340,7 +3335,11 @@ export type $OpenApiTs = { /** * Successful Response */ - 200: TaskInstanceReferenceResponse; + 200: TaskInstanceResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; /** * Unauthorized */ diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 4aacb763ffc8..c06b589d20b5 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -1338,7 +1338,7 @@ def test_should_respond_200_dag_ids_filter(self, test_client, payload, expected_ assert len(response.json()["task_instances"]) == expected_ti assert response.json()["total_entries"] == total_ti - def test_should_raise_400_for_no_json(self, test_client): + def test_should_raise_422_for_no_json(self, test_client): response = test_client.post( "/public/dags/~/dagRuns/~/taskInstances/list", ) @@ -1376,7 +1376,7 @@ def test_should_respond_422_for_non_wildcard_path_parameters(self, test_client): ({"logical_date_lte": "2020-11-10T12:42:39.442973"}, "Input should have timezone info"), ], ) - def test_should_raise_400_for_naive_and_bad_datetime(self, test_client, payload, expected, session): + def test_should_raise_422_for_naive_and_bad_datetime(self, test_client, payload, expected, session): self.create_task_instances(session) response = test_client.post( "/public/dags/~/dagRuns/~/taskInstances/list", @@ -1698,6 +1698,31 @@ def test_should_call_mocked_api(self, mock_set_task_instance_state, test_client, "dag_run_id": "TEST_DAG_RUN_ID", "logical_date": "2020-01-01T00:00:00Z", "task_id": "print_the_context", + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "executor": None, + "executor_config": "{}", + "hostname": "", + "id": mock.ANY, + "map_index": -1, + "max_tries": 0, + "note": "placeholder-note", + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_display_name": "print_the_context", + "try_number": 0, + "unixname": getuser(), + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, } mock_set_task_instance_state.assert_called_once() @@ -1729,6 +1754,31 @@ def test_should_not_call_mocked_api_for_dry_run(self, mock_set_task_instance_sta "dag_run_id": "TEST_DAG_RUN_ID", "logical_date": "2020-01-01T00:00:00Z", "task_id": "print_the_context", + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "executor": None, + "executor_config": "{}", + "hostname": "", + "id": mock.ANY, + "map_index": -1, + "max_tries": 0, + "note": "placeholder-note", + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_display_name": "print_the_context", + "try_number": 0, + "unixname": getuser(), + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, } mock_set_task_instance_state.assert_not_called() @@ -1792,7 +1842,10 @@ def test_should_update_mapped_task_instance_state(self, test_client, session): "error, code, payload", [ [ - "Task instance not found for task 'print_the_context' on DAG run with ID 'TEST_DAG_RUN_ID'", + ( + "Task Instance not found for dag_id=example_python_operator" + ", run_id=TEST_DAG_RUN_ID, task_id=print_the_context" + ), 404, { "dry_run": True, @@ -1883,7 +1936,7 @@ def test_should_raise_404_not_found_task(self, test_client): ), ], ) - def test_should_raise_400_for_invalid_task_instance_state(self, payload, expected, test_client, session): + def test_should_raise_422_for_invalid_task_instance_state(self, payload, expected, test_client, session): self.create_task_instances(session) response = test_client.patch( self.ENDPOINT_URL, @@ -1901,3 +1954,114 @@ def test_should_raise_400_for_invalid_task_instance_state(self, payload, expecte } ] } + + def test_set_note_should_respond_200(self, test_client, session): + self.create_task_instances(session) + new_note_value = "My super cool TaskInstance note." + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + json={"note": new_note_value}, + ) + assert response.status_code == 200, response.text + assert response.json() == { + "dag_id": "example_python_operator", + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "logical_date": "2020-01-01T00:00:00Z", + "id": mock.ANY, + "executor": None, + "executor_config": "{}", + "hostname": "", + "map_index": -1, + "max_tries": 0, + "note": new_note_value, + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_id": "print_the_context", + "task_display_name": "print_the_context", + "try_number": 0, + "unixname": getuser(), + "dag_run_id": "TEST_DAG_RUN_ID", + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + # @TODO: Uncomment the 2 lines below when permissions and auth is in place. + # ti = session.scalars(select(TaskInstance).where(TaskInstance.task_id == "print_the_context")).one() + # assert ti.task_instance_note.user_id is not None + + def test_set_note_should_respond_200_mapped_task_instance_with_rtif(self, test_client, session): + """Verify we don't duplicate rows through join to RTIF""" + tis = self.create_task_instances(session) + old_ti = tis[0] + for idx in (1, 2): + ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) + for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: + setattr(ti, attr, getattr(old_ti, attr)) + session.add(ti) + session.commit() + + # in each loop, we should get the right mapped TI back + for map_index in (1, 2): + new_note_value = f"My super cool TaskInstance note {map_index}" + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" + f"print_the_context/{map_index}", + json={"note": new_note_value}, + ) + assert response.status_code == 200, response.text + + assert response.json() == { + "dag_id": "example_python_operator", + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "logical_date": "2020-01-01T00:00:00Z", + "id": mock.ANY, + "executor": None, + "executor_config": "{}", + "hostname": "", + "map_index": map_index, + "max_tries": 0, + "note": new_note_value, + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_id": "print_the_context", + "task_display_name": "print_the_context", + "try_number": 0, + "unixname": getuser(), + "dag_run_id": "TEST_DAG_RUN_ID", + "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + + def test_set_note_should_respond_200_when_note_is_empty(self, test_client, session): + tis = self.create_task_instances(session) + for ti in tis: + ti.task_instance_note = None + session.add(ti) + session.commit() + new_note_value = "My super cool TaskInstance note." + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + json={"note": new_note_value}, + ) + assert response.status_code == 200, response.text + assert response.json()["note"] == new_note_value