From a300b5274e7b3744c17e65c0aca0ea74961be1dd Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 29 Jul 2025 06:58:09 +0300 Subject: [PATCH] feat(hitl): add "timedout" column to HITLTriggerEventSuccessPayload so that this information can be used in the following tasks --- .../src/airflow/providers/standard/triggers/hitl.py | 3 +++ .../standard/tests/unit/standard/triggers/test_hitl.py | 6 ++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/triggers/hitl.py b/providers/standard/src/airflow/providers/standard/triggers/hitl.py index 63cea15363717..b29654e4c5538 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/hitl.py +++ b/providers/standard/src/airflow/providers/standard/triggers/hitl.py @@ -43,6 +43,7 @@ class HITLTriggerEventSuccessPayload(TypedDict, total=False): chosen_options: list[str] params_input: dict[str, Any] + timedout: bool class HITLTriggerEventFailurePayload(TypedDict): @@ -115,6 +116,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: HITLTriggerEventSuccessPayload( chosen_options=self.defaults, params_input=self.params, + timedout=True, ) ) return @@ -126,6 +128,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: HITLTriggerEventSuccessPayload( chosen_options=resp.chosen_options, params_input=resp.params_input, + timedout=False, ) ) return diff --git a/providers/standard/tests/unit/standard/triggers/test_hitl.py b/providers/standard/tests/unit/standard/triggers/test_hitl.py index ac96d9eed1e07..da5952ac154f7 100644 --- a/providers/standard/tests/unit/standard/triggers/test_hitl.py +++ b/providers/standard/tests/unit/standard/triggers/test_hitl.py @@ -122,10 +122,7 @@ async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_su await asyncio.sleep(0.3) event = await trigger_task assert event == TriggerEvent( - HITLTriggerEventSuccessPayload( - chosen_options=["1"], - params_input={"input": 1}, - ) + HITLTriggerEventSuccessPayload(chosen_options=["1"], params_input={"input": 1}, timedout=True) ) @pytest.mark.db_test @@ -157,5 +154,6 @@ async def test_run(self, mock_update, mock_supervisor_comms): HITLTriggerEventSuccessPayload( chosen_options=["3"], params_input={"input": 50}, + timedout=False, ) )