diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 2a41d40690068..aa95a7be7d7d3 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -74,10 +74,11 @@ UpdateHITLDetail, VariableResult, XComResult, + _new_encoder, _RequestFrame, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader -from airflow.triggers import base as events +from airflow.triggers.base import BaseEventTrigger, BaseTrigger, DiscrimatedTriggerEvent, TriggerEvent from airflow.utils.helpers import log_filename_template_renderer from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session @@ -90,7 +91,6 @@ from airflow.jobs.job import Job from airflow.sdk.api.client import Client from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI - from airflow.triggers.base import BaseTrigger logger = logging.getLogger(__name__) @@ -209,7 +209,7 @@ class TriggerStateChanges(BaseModel): type: Literal["TriggerStateChanges"] = "TriggerStateChanges" events: Annotated[ - list[tuple[int, events.DiscrimatedTriggerEvent]] | None, + list[tuple[int, DiscrimatedTriggerEvent]] | None, # We have to specify a default here, as otherwise Pydantic struggles to deal with the discriminated # union :shrug: Field(default=None), @@ -363,7 +363,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess): creating_triggers: deque[workloads.RunTrigger] = attrs.field(factory=deque, init=False) # Outbound queue of events - events: deque[tuple[int, events.TriggerEvent]] = attrs.field(factory=deque, init=False) + events: deque[tuple[int, TriggerEvent]] = attrs.field(factory=deque, init=False) # Outbound queue of failed triggers failed_triggers: deque[tuple[int, list[str] | None]] = attrs.field(factory=deque, init=False) @@ -843,7 +843,7 @@ class TriggerRunner: to_cancel: deque[int] # Outbound queue of events - events: deque[tuple[int, events.TriggerEvent]] + events: deque[tuple[int, TriggerEvent]] # Outbound queue of failed triggers failed_triggers: deque[tuple[int, BaseException | None]] @@ -993,7 +993,7 @@ async def create_triggers(self): "task": asyncio.create_task( self.run_trigger(trigger_id, trigger_instance, workload.timeout_after), name=trigger_name ), - "is_watcher": isinstance(trigger_instance, events.BaseEventTrigger), + "is_watcher": isinstance(trigger_instance, BaseEventTrigger), "name": trigger_name, "events": 0, } @@ -1039,7 +1039,7 @@ async def cleanup_finished_triggers(self) -> list[int]: saved_exc = e else: # See if they foolishly returned a TriggerEvent - if isinstance(result, events.TriggerEvent): + if isinstance(result, TriggerEvent): self.log.error( "Trigger returned a TriggerEvent rather than yielding it", trigger=details["name"], @@ -1059,46 +1059,64 @@ async def cleanup_finished_triggers(self) -> list[int]: await asyncio.sleep(0) return finished_ids - async def sync_state_to_supervisor(self, finished_ids: list[int]): + def process_trigger_events(self, finished_ids: list[int]) -> messages.TriggerStateChanges: # Copy out of our dequeues in threadsafe manner to sync state with parent - events_to_send = [] + + req_encoder = _new_encoder() + events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = [] + failures_to_send: list[tuple[int, list[str] | None]] = [] + while self.events: - data = self.events.popleft() - events_to_send.append(data) + trigger_id, trigger_event = self.events.popleft() + + try: + req_encoder.encode(trigger_event) + except Exception as e: + logger.error( + "Trigger %s returned non-serializable result %r. Cancelling trigger.", + trigger_id, + trigger_event, + ) + self.failed_triggers.append((trigger_id, e)) + else: + events_to_send.append((trigger_id, trigger_event)) + - failures_to_send = [] while self.failed_triggers: - id, exc = self.failed_triggers.popleft() + trigger_id, exc = self.failed_triggers.popleft() tb = format_exception(type(exc), exc, exc.__traceback__) if exc else None - failures_to_send.append((id, tb)) + failures_to_send.append((trigger_id, tb)) - msg = messages.TriggerStateChanges( - events=events_to_send, finished=finished_ids, failures=failures_to_send + return messages.TriggerStateChanges( + events=events_to_send if events_to_send else None, + finished=finished_ids if finished_ids else None, + failures=failures_to_send if failures_to_send else None, ) - if not events_to_send: - msg.events = None + async def sync_state_to_supervisor(self, finished_ids: list[int]) -> None: + msg = self.process_trigger_events(finished_ids=finished_ids) - if not failures_to_send: - msg.failures = None + # Tell the monitor that we've finished triggers so it can update things + resp = await self.send_trigger_state_changes(msg) - if not finished_ids: - msg.finished = None + if resp: + self.to_create.extend(resp.to_create) + self.to_cancel.extend(resp.to_cancel) - # Tell the monitor that we've finished triggers so it can update things + async def send_trigger_state_changes(self, msg: messages.TriggerStateChanges) -> messages.TriggerStateSync | None: try: - resp = await self.comms_decoder.asend(msg) + response = await self.comms_decoder.asend(msg) + + if not isinstance(response, messages.TriggerStateSync): + raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") + + return response except asyncio.IncompleteReadError: if task := asyncio.current_task(): task.cancel("EOF - shutting down") - return + return None raise - if not isinstance(resp, messages.TriggerStateSync): - raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") - self.to_create.extend(resp.to_create) - self.to_cancel.extend(resp.to_cancel) - async def block_watchdog(self): """ Watchdog loop that detects blocking (badly-written) triggers. diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index df31c9225c272..aa342be195ce3 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -369,7 +369,7 @@ async def test_invalid_trigger(self, supervisor_builder): trigger_runner = TriggerRunner() trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) trigger_runner.comms_decoder.asend.return_value = messages.TriggerStateSync( - to_create=[], to_cancel=[] + to_create=[], to_cancel=set() ) trigger_runner.to_create.append(workload) @@ -436,6 +436,23 @@ async def test_trigger_kwargs_serialization_cleanup(self, session): trigger_instance.cancel() await runner.cleanup_finished_triggers() + @pytest.mark.asyncio + @patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) + async def test_sync_state_to_supervisor(self, supervisor_builder): + trigger_runner = TriggerRunner() + trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) + trigger_runner.comms_decoder.asend.side_effect = [ + messages.TriggerStateSync(to_create=[], to_cancel=set()), + ] + trigger_runner.events.append((1, TriggerEvent(payload={"status": "SUCCESS"}))) + trigger_runner.events.append((2, TriggerEvent(payload={"status": "FAILED"}))) + trigger_runner.events.append((3, TriggerEvent(payload={"status": "SUCCESS", "data": object()}))) + + await trigger_runner.sync_state_to_supervisor(finished_ids=[]) + + assert trigger_runner.comms_decoder.asend.call_count == 1 + assert len(trigger_runner.comms_decoder.asend.call_args_list[0].args[0].events) == 2 + @pytest.mark.asyncio @pytest.mark.usefixtures("testing_dag_bundle")