diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index bc735f684747e..f4476322e8a3f 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -43,6 +43,7 @@ ConnectionResponse, DagRunStateResponse, DagRunType, + HITLDetailRequest, HITLDetailResponse, HITLUser, InactiveAssetsResponse, @@ -72,7 +73,6 @@ CreateHITLDetailPayload, DRCount, ErrorResponse, - HITLDetailRequestResult, OKResponse, PreviousDagRunResult, SkipDownstreamTasks, @@ -725,7 +725,7 @@ def add_response( multiple: bool = False, params: dict[str, Any] | None = None, assigned_users: list[HITLUser] | None = None, - ) -> HITLDetailRequestResult: + ) -> HITLDetailRequest: """Add a Human-in-the-loop response that waits for human response for a specific Task Instance.""" payload = CreateHITLDetailPayload( ti_id=ti_id, @@ -741,7 +741,7 @@ def add_response( f"/hitlDetails/{ti_id}", content=payload.model_dump_json(), ) - return HITLDetailRequestResult.model_validate_json(resp.read()) + return HITLDetailRequest.model_validate_json(resp.read()) def update_response( self, diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index bfa9fb012ae17..fe6435ca7f8a1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -584,6 +584,17 @@ class HITLDetailRequestResult(HITLDetailRequest): type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult" + @classmethod + def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequestResult: + """ + Get HITLDetailRequestResult from HITLDetailRequest (API response). + + HITLDetailRequest is the API response model. We convert it to HITLDetailRequestResult + for communication between the Supervisor and task process, adding the discriminator field + required for the tagged union deserialization. + """ + return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") + ToTask = Annotated[ AssetResult diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 6bf609f113937..fafd56209bdf2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -93,6 +93,7 @@ GetXComCount, GetXComSequenceItem, GetXComSequenceSlice, + HITLDetailRequestResult, InactiveAssetsResult, MaskSecret, PrevSuccessfulDagRunResult, @@ -1352,7 +1353,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: # Since we've sent the message, return. Nothing else in this ifelse/switch should return directly return elif isinstance(msg, CreateHITLDetailPayload): - resp = self.client.hitl.add_response( + hitl_detail_request = self.client.hitl.add_response( ti_id=msg.ti_id, options=msg.options, subject=msg.subject, @@ -1362,7 +1363,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: multiple=msg.multiple, assigned_users=msg.assigned_users, ) - self.send_msg(resp, request_id=req_id, error=None, **dump_opts) + resp = HITLDetailRequestResult.from_api_response(hitl_detail_request) + dump_opts = {"exclude_unset": True} elif isinstance(msg, MaskSecret): mask_secret(msg.value, msg.name) else: diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index e73f05eeac8b9..e5322519fbf6e 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -38,6 +38,7 @@ ConnectionResponse, DagRunState, DagRunStateResponse, + HITLDetailRequest, HITLDetailResponse, HITLUser, VariableResponse, @@ -47,7 +48,6 @@ from airflow.sdk.execution_time.comms import ( DeferTask, ErrorResponse, - HITLDetailRequestResult, OKResponse, PreviousDagRunResult, RescheduleTask, @@ -1300,7 +1300,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: params=None, multiple=False, ) - assert isinstance(result, HITLDetailRequestResult) + assert isinstance(result, HITLDetailRequest) assert result.ti_id == ti_id assert result.options == ["Approval", "Reject"] assert result.subject == "This is subject" diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 7c9828a29fd79..bfd8caa3b476b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2030,9 +2030,7 @@ class RequestTestCase: "subject": "This is subject", "body": "This is body", "defaults": ["Approve"], - "multiple": False, "params": {}, - "assigned_users": None, "type": "HITLDetailRequestResult", }, client_mock=ClientMock(