diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index 0fcb6530608e9..3a5590a57592d 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -24,7 +24,6 @@ from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import or_, select -from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm import joinedload from sqlalchemy.sql.selectable import Select @@ -736,9 +735,9 @@ def _patch_ti_validate_request( dag_bag: DagBagDep, body: PatchTaskInstanceBody, session: SessionDep, - map_index: int = -1, + map_index: int | None = -1, update_mask: list[str] | None = Query(None), -) -> tuple[DAG, TI, dict]: +) -> tuple[DAG, list[TI], dict]: dag = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found") @@ -752,20 +751,15 @@ def _patch_ti_validate_request( .join(TI.dag_run) .options(joinedload(TI.rendered_task_instance_fields)) ) - query = query.where(TI.map_index == map_index) + if map_index is not None: + query = query.where(TI.map_index == map_index) - try: - ti = session.scalar(query) - except MultipleResultsFound: - raise HTTPException( - status.HTTP_400_BAD_REQUEST, - "Multiple task instances found. As the TI is mapped, add the map_index value to the URL", - ) + tis = session.scalars(query).all() err_msg_404 = ( f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}` and map_index: `{map_index}` was not found", ) - if ti is None: + if len(tis) == 0: raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404) fields_to_update = body.model_fields_set @@ -777,7 +771,7 @@ def _patch_ti_validate_request( except ValidationError as e: raise RequestValidationError(errors=e.errors()) - return dag, ti, body.model_dump(include=fields_to_update, by_alias=True) + return dag, list(tis), body.model_dump(include=fields_to_update, by_alias=True) @task_instances_router.patch( @@ -807,21 +801,16 @@ def patch_task_instance_dry_run( update_mask: list[str] | None = Query(None), ) -> TaskInstanceCollectionResponse: """Update a task instance dry_run mode.""" - if map_index is None: - map_index = -1 - - dag, ti, data = _patch_ti_validate_request( + dag, tis, data = _patch_ti_validate_request( dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask ) - tis: list[TI] = [] - if data.get("new_state"): tis = ( dag.set_task_instance_state( task_id=task_id, run_id=dag_run_id, - map_indexes=[map_index], + map_indexes=[map_index] if map_index is not None else None, state=data["new_state"], upstream=body.include_upstream, downstream=body.include_downstream, @@ -833,9 +822,6 @@ def patch_task_instance_dry_run( or [] ) - elif "note" in data: - tis = [ti] - return TaskInstanceCollectionResponse( task_instances=[ TaskInstanceResponse.model_validate( @@ -881,19 +867,16 @@ def patch_task_instance( update_mask: list[str] | None = Query(None), ) -> TaskInstanceCollectionResponse: """Update a task instance.""" - if map_index is None: - map_index = -1 - - dag, ti, data = _patch_ti_validate_request( + dag, tis, data = _patch_ti_validate_request( dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask ) for key, _ in data.items(): if key == "new_state": - tis: list[TI] = dag.set_task_instance_state( + tis = dag.set_task_instance_state( task_id=task_id, run_id=dag_run_id, - map_indexes=[map_index], + map_indexes=[map_index] if map_index is not None else None, state=data["new_state"], upstream=body.include_upstream, downstream=body.include_downstream, @@ -906,37 +889,39 @@ def patch_task_instance( raise HTTPException( status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state" ) - ti = tis[0] if isinstance(tis, list) else tis - try: - if data["new_state"] == TaskInstanceState.SUCCESS: - get_listener_manager().hook.on_task_instance_success( - previous_state=None, task_instance=ti - ) - elif data["new_state"] == TaskInstanceState.FAILED: - get_listener_manager().hook.on_task_instance_failed( - previous_state=None, - task_instance=ti, - error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.", - ) - except Exception: - log.exception("error calling listener") + + for ti in tis: + try: + if data["new_state"] == TaskInstanceState.SUCCESS: + get_listener_manager().hook.on_task_instance_success( + previous_state=None, task_instance=ti + ) + elif data["new_state"] == TaskInstanceState.FAILED: + get_listener_manager().hook.on_task_instance_failed( + previous_state=None, + task_instance=ti, + error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.", + ) + except Exception: + log.exception("error calling listener") elif key == "note": - if update_mask or body.note is not None: - if ti.task_instance_note is None: - ti.note = (body.note, user.get_id()) - else: - ti.task_instance_note.content = body.note - ti.task_instance_note.user_id = user.get_id() - session.commit() + for ti in tis: + if update_mask or body.note is not None: + if ti.task_instance_note is None: + ti.note = (body.note, user.get_id()) + else: + ti.task_instance_note.content = body.note + ti.task_instance_note.user_id = user.get_id() return TaskInstanceCollectionResponse( task_instances=[ TaskInstanceResponse.model_validate( ti, ) + for ti in tis ], - total_entries=1, + total_entries=len(tis), ) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index e0b694f748da6..3e12df5239c20 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -3103,7 +3103,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): TaskInstance.run_id == self.RUN_ID, TaskInstance.map_index == -1, ) - ).one_or_none() + ).all() response = test_client.patch( self.ENDPOINT_URL, @@ -3158,7 +3158,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): downstream=False, upstream=False, future=False, - map_indexes=[-1], + map_indexes=None, past=False, run_id=self.RUN_ID, session=mock.ANY, @@ -3201,6 +3201,29 @@ def test_should_update_mapped_task_instance_state(self, test_client, session): assert response2.status_code == 200 assert response2.json()["state"] == self.NEW_STATE + def test_should_update_mapped_task_instance_summary_state(self, test_client, session): + tis = self.create_task_instances(session) + + for map_index in [1, 2, 3]: + ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index) + ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) + session.add(ti) + tis[0].map_index = 0 + session.commit() + + response = test_client.patch( + f"{self.ENDPOINT_URL}", + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 200 + + response_data = response.json() + assert response_data["total_entries"] == 4 + for map_index in range(4): + assert response_data["task_instances"][map_index]["state"] == self.NEW_STATE + def test_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.patch( self.ENDPOINT_URL, @@ -3224,7 +3247,7 @@ def test_should_respond_403(self, unauthorized_test_client): [ [ [ - "The Task Instance with dag_id: `example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: `print_the_context` and map_index: `-1` was not found", + "The Task Instance with dag_id: `example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: `print_the_context` and map_index: `None` was not found", ], 404, { @@ -3413,7 +3436,7 @@ def test_update_mask_should_call_mocked_api( TaskInstance.run_id == self.RUN_ID, TaskInstance.map_index == -1, ) - ).one_or_none() + ).all() response = test_client.patch( self.ENDPOINT_URL, @@ -3546,7 +3569,7 @@ def test_set_note_should_respond_200(self, test_client, session): session, response_data["task_instances"][0]["id"], {"content": new_note_value, "user_id": "test"} ) - def test_set_note_should_respond_200_mapped_task_instance_with_rtif(self, test_client, session): + def test_set_note_should_respond_200_mapped_task_with_rtif(self, test_client, session): """Verify we don't duplicate rows through join to RTIF""" tis = self.create_task_instances(session) old_ti = tis[0] @@ -3616,6 +3639,70 @@ def test_set_note_should_respond_200_mapped_task_instance_with_rtif(self, test_c {"content": new_note_value, "user_id": "test"}, ) + def test_set_note_should_respond_200_mapped_task_summary_with_rtif(self, test_client, session): + """Verify we don't duplicate rows through join to RTIF""" + tis = self.create_task_instances(session) + old_ti = tis[0] + for idx in (1, 2): + ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) + for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: + setattr(ti, attr, getattr(old_ti, attr)) + session.add(ti) + session.commit() + + new_note_value = "My super cool TaskInstance note" + response = test_client.patch( + "/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + json={"note": new_note_value}, + ) + assert response.status_code == 200, response.text + response_data = response.json() + + assert response_data["total_entries"] == 3 + + for map_index in range(1, 3): + response_ti = response_data["task_instances"][map_index] + assert response_ti == { + "dag_id": self.DAG_ID, + "dag_display_name": self.DAG_DISPLAY_NAME, + "dag_version": None, + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "logical_date": "2020-01-01T00:00:00Z", + "id": mock.ANY, + "executor": None, + "executor_config": "{}", + "hostname": "", + "map_index": map_index, + "max_tries": 0, + "note": new_note_value, + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "scheduled_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_id": self.TASK_ID, + "task_display_name": self.TASK_ID, + "try_number": 0, + "unixname": getuser(), + "dag_run_id": self.RUN_ID, + "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_map_index": str(map_index), + "run_after": "2020-01-01T00:00:00Z", + "trigger": None, + "triggerer_job": None, + } + + _check_task_instance_note( + session, response_ti["id"], {"content": new_note_value, "user_id": "test"} + ) + def test_set_note_should_respond_200_when_note_is_empty(self, test_client, session): tis = self.create_task_instances(session) for ti in tis: @@ -3663,16 +3750,14 @@ class TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint): def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): self.create_task_instances(session) - mock_set_ti_state.return_value = [ - session.scalars( - select(TaskInstance).where( - TaskInstance.dag_id == self.DAG_ID, - TaskInstance.task_id == self.TASK_ID, - TaskInstance.run_id == self.RUN_ID, - TaskInstance.map_index == -1, - ) - ).one_or_none() - ] + mock_set_ti_state.return_value = session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == self.DAG_ID, + TaskInstance.task_id == self.TASK_ID, + TaskInstance.run_id == self.RUN_ID, + TaskInstance.map_index == -1, + ) + ).all() response = test_client.patch( f"{self.ENDPOINT_URL}/dry_run", @@ -3727,7 +3812,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): downstream=False, upstream=False, future=False, - map_indexes=[-1], + map_indexes=None, past=False, run_id=self.RUN_ID, session=mock.ANY, @@ -3808,6 +3893,38 @@ def test_should_not_update_mapped_task_instance(self, test_client, session): assert task_before == task_after _check_task_instance_note(session, task_after["id"], None) + def test_should_not_update_mapped_task_instance_summary(self, test_client, session): + map_indexes = [1, 2, 3] + tis = self.create_task_instances(session) + for map_index in map_indexes: + ti = TaskInstance( + task=tis[0].task, + run_id=tis[0].run_id, + map_index=map_index, + state="running", + ) + ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) + session.add(ti) + + session.delete(tis[0]) + session.commit() + + response = test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json={ + "new_state": self.NEW_STATE, + }, + ) + + assert response.status_code == 200 + assert response.json()["total_entries"] == len(map_indexes) + + for map_index in map_indexes: + task_after = test_client.get(f"{self.ENDPOINT_URL}/{map_index}").json() + assert task_after["note"] is None + assert task_after["state"] == "running" + _check_task_instance_note(session, task_after["id"], None) + @pytest.mark.parametrize( "error, code, payload", [ @@ -3824,7 +3941,7 @@ def test_should_not_update_mapped_task_instance(self, test_client, session): ) def test_should_handle_errors(self, error, code, payload, test_client, session): response = test_client.patch( - f"{self.ENDPOINT_URL}/dry_run", + f"{self.ENDPOINT_URL}/dry_run?map_index=-1", json=payload, ) assert response.status_code == code @@ -3995,16 +4112,14 @@ def test_update_mask_should_call_mocked_api( ): self.create_task_instances(session) - mock_set_ti_state.return_value = [ - session.scalars( - select(TaskInstance).where( - TaskInstance.dag_id == self.DAG_ID, - TaskInstance.task_id == self.TASK_ID, - TaskInstance.run_id == self.RUN_ID, - TaskInstance.map_index == -1, - ) - ).one_or_none() - ] + mock_set_ti_state.return_value = session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == self.DAG_ID, + TaskInstance.task_id == self.TASK_ID, + TaskInstance.run_id == self.RUN_ID, + TaskInstance.map_index == -1, + ) + ).all() response = test_client.patch( f"{self.ENDPOINT_URL}/dry_run",