diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 0a3de02894150..02886cecc18e6 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -48,6 +48,7 @@ ConnectionResponse, DagRunStateResponse, DagRunType, + HITLDetailRequest, HITLDetailResponse, HITLUser, InactiveAssetsResponse, @@ -77,7 +78,6 @@ CreateHITLDetailPayload, DRCount, ErrorResponse, - HITLDetailRequestResult, OKResponse, PreviousDagRunResult, SkipDownstreamTasks, @@ -754,7 +754,7 @@ def add_response( multiple: bool = False, params: dict[str, Any] | None = None, assigned_users: list[HITLUser] | None = None, - ) -> HITLDetailRequestResult: + ) -> HITLDetailRequest: """Add a Human-in-the-loop response that waits for human response for a specific Task Instance.""" payload = CreateHITLDetailPayload( ti_id=ti_id, @@ -770,7 +770,7 @@ def add_response( f"/hitlDetails/{ti_id}", content=payload.model_dump_json(), ) - return HITLDetailRequestResult.model_validate_json(resp.read()) + return HITLDetailRequest.model_validate_json(resp.read()) def update_response( self, diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 618a29e36ed93..45986bd398fee 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -584,6 +584,17 @@ class HITLDetailRequestResult(HITLDetailRequest): type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult" + @classmethod + def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequestResult: + """ + Get HITLDetailRequestResult from HITLDetailRequest (API response). + + HITLDetailRequest is the API response model. We convert it to HITLDetailRequestResult + for communication between the Supervisor and task process, adding the discriminator field + required for the tagged union deserialization. + """ + return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") + ToTask = Annotated[ AssetResult diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 2a20e4eec4e0b..48086a79f2f05 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -94,6 +94,7 @@ GetXComCount, GetXComSequenceItem, GetXComSequenceSlice, + HITLDetailRequestResult, InactiveAssetsResult, MaskSecret, PrevSuccessfulDagRunResult, @@ -1383,7 +1384,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: # Since we've sent the message, return. Nothing else in this ifelse/switch should return directly return elif isinstance(msg, CreateHITLDetailPayload): - resp = self.client.hitl.add_response( + hitl_detail_request = self.client.hitl.add_response( ti_id=msg.ti_id, options=msg.options, subject=msg.subject, @@ -1393,7 +1394,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: multiple=msg.multiple, assigned_users=msg.assigned_users, ) - self.send_msg(resp, request_id=req_id, error=None, **dump_opts) + resp = HITLDetailRequestResult.from_api_response(hitl_detail_request) + dump_opts = {"exclude_unset": True} elif isinstance(msg, MaskSecret): mask_secret(msg.value, msg.name) else: diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index ed3997aefc6cc..51a28fcd196a9 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -39,6 +39,7 @@ ConnectionResponse, DagRunState, DagRunStateResponse, + HITLDetailRequest, HITLDetailResponse, HITLUser, TerminalTIState, @@ -49,7 +50,6 @@ from airflow.sdk.execution_time.comms import ( DeferTask, ErrorResponse, - HITLDetailRequestResult, OKResponse, PreviousDagRunResult, RescheduleTask, @@ -1300,7 +1300,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: params=None, multiple=False, ) - assert isinstance(result, HITLDetailRequestResult) + assert isinstance(result, HITLDetailRequest) assert result.ti_id == ti_id assert result.options == ["Approval", "Reject"] assert result.subject == "This is subject" diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4669e2c7ef2d9..8c219a4eebf74 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2242,9 +2242,7 @@ class RequestTestCase: "subject": "This is subject", "body": "This is body", "defaults": ["Approve"], - "multiple": False, "params": {}, - "assigned_users": None, "type": "HITLDetailRequestResult", }, client_mock=ClientMock(