Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
chosen_options=self.defaults,
params_input=self.params,
)
self.log.info(
"[HITL] timeout reached before receiving response, fallback to default %s", self.defaults
)
yield TriggerEvent(
HITLTriggerEventSuccessPayload(
chosen_options=self.defaults,
Expand All @@ -121,7 +124,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
if resp.response_received and resp.chosen_options:
self.log.info("Responded by %s at %s", resp.user_id, resp.response_at)
self.log.info(
"[HITL] user=%s options=%s at %s", resp.user_id, resp.chosen_options, resp.response_at
)
yield TriggerEvent(
HITLTriggerEventSuccessPayload(
chosen_options=resp.chosen_options,
Expand Down
50 changes: 30 additions & 20 deletions providers/standard/tests/unit/standard/triggers/test_hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,36 @@
pytest.skip("Human in the loop public API compatible with Airflow >= 3.0.1", allow_module_level=True)

import asyncio
from datetime import timedelta
from datetime import datetime, timedelta
from unittest import mock

from uuid6 import uuid7

from airflow._shared.timezones.timezone import utc, utcnow
from airflow.api_fastapi.execution_api.datamodels.hitl import HITLDetailResponse
from airflow.providers.standard.triggers.hitl import (
HITLTrigger,
HITLTriggerEventFailurePayload,
HITLTriggerEventSuccessPayload,
)
from airflow.triggers.base import TriggerEvent
from airflow.utils.timezone import utcnow

TI_ID = uuid7()
default_trigger_args = {
"ti_id": TI_ID,
"options": ["1", "2", "3", "4", "5"],
"params": {"input": 1},
"multiple": False,
}


class TestHITLTrigger:
def test_serialization(self):
trigger = HITLTrigger(
ti_id=TI_ID,
options=["1", "2", "3", "4", "5"],
params={"input": 1},
defaults=["1"],
multiple=False,
timeout_datetime=None,
poke_interval=50.0,
**default_trigger_args,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.standard.triggers.hitl.HITLTrigger"
Expand All @@ -70,12 +73,9 @@ def test_serialization(self):
@mock.patch("airflow.sdk.execution_time.hitl.update_htil_detail_response")
async def test_run_failed_due_to_timeout(self, mock_update, mock_supervisor_comms):
trigger = HITLTrigger(
ti_id=TI_ID,
options=["1", "2", "3", "4", "5"],
params={"input": 1},
multiple=False,
timeout_datetime=utcnow() + timedelta(seconds=0.1),
poke_interval=5,
**default_trigger_args,
)
mock_supervisor_comms.send.return_value = HITLDetailResponse(
response_received=False,
Expand All @@ -98,16 +98,14 @@ async def test_run_failed_due_to_timeout(self, mock_update, mock_supervisor_comm

@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch.object(HITLTrigger, "log")
@mock.patch("airflow.sdk.execution_time.hitl.update_htil_detail_response")
async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_supervisor_comms):
async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_log, mock_supervisor_comms):
trigger = HITLTrigger(
ti_id=TI_ID,
options=["1", "2", "3", "4", "5"],
params={"input": 1},
defaults=["1"],
multiple=False,
timeout_datetime=utcnow() + timedelta(seconds=0.1),
poke_interval=5,
**default_trigger_args,
)
mock_supervisor_comms.send.return_value = HITLDetailResponse(
response_received=False,
Expand All @@ -121,25 +119,30 @@ async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_su
trigger_task = asyncio.create_task(gen.__anext__())
await asyncio.sleep(0.3)
event = await trigger_task

assert event == TriggerEvent(
HITLTriggerEventSuccessPayload(
chosen_options=["1"],
params_input={"input": 1},
)
)

assert mock_log.info.call_args == mock.call(
"[HITL] timeout reached before receiving response, fallback to default %s", ["1"]
)

@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch.object(HITLTrigger, "log")
@mock.patch("airflow.sdk.execution_time.hitl.update_htil_detail_response")
async def test_run(self, mock_update, mock_supervisor_comms):
async def test_run(self, mock_update, mock_log, mock_supervisor_comms, time_machine):
time_machine.move_to(datetime(2025, 7, 29, 2, 0, 0))

trigger = HITLTrigger(
ti_id=TI_ID,
options=["1", "2", "3", "4", "5"],
params={"input": 1},
defaults=["1"],
multiple=False,
timeout_datetime=None,
poke_interval=5,
**default_trigger_args,
)
mock_supervisor_comms.send.return_value = HITLDetailResponse(
response_received=True,
Expand All @@ -159,3 +162,10 @@ async def test_run(self, mock_update, mock_supervisor_comms):
params_input={"input": 50},
)
)

assert mock_log.info.call_args == mock.call(
"[HITL] user=%s options=%s at %s",
"test",
["3"],
datetime(2025, 7, 29, 2, 0, 0, tzinfo=utc),
)