Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 61 additions & 28 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@
UpdateHITLDetail,
VariableResult,
XComResult,
_new_encoder,
_RequestFrame,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader
from airflow.stats import Stats
from airflow.traces.tracer import DebugTrace, Trace, add_debug_span
from airflow.triggers import base as events
from airflow.triggers.base import BaseEventTrigger, BaseTrigger, DiscrimatedTriggerEvent, TriggerEvent
from airflow.utils.helpers import log_filename_template_renderer
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
Expand Down Expand Up @@ -203,7 +204,7 @@ class TriggerStateChanges(BaseModel):

type: Literal["TriggerStateChanges"] = "TriggerStateChanges"
events: Annotated[
list[tuple[int, events.DiscrimatedTriggerEvent]] | None,
list[tuple[int, DiscrimatedTriggerEvent]] | None,
# We have to specify a default here, as otherwise Pydantic struggles to deal with the discriminated
# union :shrug:
Field(default=None),
Expand Down Expand Up @@ -355,7 +356,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
creating_triggers: deque[workloads.RunTrigger] = attrs.field(factory=deque, init=False)

# Outbound queue of events
events: deque[tuple[int, events.TriggerEvent]] = attrs.field(factory=deque, init=False)
events: deque[tuple[int, TriggerEvent]] = attrs.field(factory=deque, init=False)

# Outbound queue of failed triggers
failed_triggers: deque[tuple[int, list[str] | None]] = attrs.field(factory=deque, init=False)
Expand Down Expand Up @@ -821,7 +822,7 @@ class TriggerRunner:
to_cancel: deque[int]

# Outbound queue of events
events: deque[tuple[int, events.TriggerEvent]]
events: deque[tuple[int, TriggerEvent]]

# Outbound queue of failed triggers
failed_triggers: deque[tuple[int, BaseException | None]]
Expand Down Expand Up @@ -971,7 +972,7 @@ async def create_triggers(self):
"task": asyncio.create_task(
self.run_trigger(trigger_id, trigger_instance, workload.timeout_after), name=trigger_name
),
"is_watcher": isinstance(trigger_instance, events.BaseEventTrigger),
"is_watcher": isinstance(trigger_instance, BaseEventTrigger),
"name": trigger_name,
"events": 0,
}
Expand Down Expand Up @@ -1017,7 +1018,7 @@ async def cleanup_finished_triggers(self) -> list[int]:
saved_exc = e
else:
# See if they foolishly returned a TriggerEvent
if isinstance(result, events.TriggerEvent):
if isinstance(result, TriggerEvent):
self.log.error(
"Trigger returned a TriggerEvent rather than yielding it",
trigger=details["name"],
Expand All @@ -1037,46 +1038,78 @@ async def cleanup_finished_triggers(self) -> list[int]:
await asyncio.sleep(0)
return finished_ids

async def sync_state_to_supervisor(self, finished_ids: list[int]):
def process_trigger_events(self, finished_ids: list[int]) -> messages.TriggerStateChanges:
# Copy out of our deques in threadsafe manner to sync state with parent
events_to_send = []
events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = []
failures_to_send: list[tuple[int, list[str] | None]] = []

while self.events:
data = self.events.popleft()
events_to_send.append(data)
trigger_id, trigger_event = self.events.popleft()
events_to_send.append((trigger_id, trigger_event))

failures_to_send = []
while self.failed_triggers:
id, exc = self.failed_triggers.popleft()
trigger_id, exc = self.failed_triggers.popleft()
tb = format_exception(type(exc), exc, exc.__traceback__) if exc else None
failures_to_send.append((id, tb))
failures_to_send.append((trigger_id, tb))

msg = messages.TriggerStateChanges(
events=events_to_send, finished=finished_ids, failures=failures_to_send
return messages.TriggerStateChanges(
events=events_to_send if events_to_send else None,
finished=finished_ids if finished_ids else None,
failures=failures_to_send if failures_to_send else None,
)

if not events_to_send:
msg.events = None
def sanitize_trigger_events(self, msg: messages.TriggerStateChanges) -> messages.TriggerStateChanges:
req_encoder = _new_encoder()
events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = []

if msg.events:
for trigger_id, trigger_event in msg.events:
try:
req_encoder.encode(trigger_event)
except Exception as e:
logger.error(
"Trigger %s returned non-serializable result %r. Cancelling trigger.",
trigger_id,
trigger_event,
)
self.failed_triggers.append((trigger_id, e))
else:
events_to_send.append((trigger_id, trigger_event))

if not failures_to_send:
msg.failures = None
return messages.TriggerStateChanges(
events=events_to_send if events_to_send else None,
finished=msg.finished,
failures=msg.failures,
)

if not finished_ids:
msg.finished = None
async def sync_state_to_supervisor(self, finished_ids: list[int]) -> None:
msg = self.process_trigger_events(finished_ids=finished_ids)

# Tell the monitor that we've finished triggers so it can update things
try:
resp = await self.comms_decoder.asend(msg)
resp = await self.asend(msg)
except NotImplementedError:
# A non-serializable trigger event was detected, remove it and fail associated trigger
resp = await self.asend(self.sanitize_trigger_events(msg))

if resp:
self.to_create.extend(resp.to_create)
self.to_cancel.extend(resp.to_cancel)

async def asend(self, msg: messages.TriggerStateChanges) -> messages.TriggerStateSync | None:
try:
response = await self.comms_decoder.asend(msg)

if not isinstance(response, messages.TriggerStateSync):
raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}")

return response
except asyncio.IncompleteReadError:
if task := asyncio.current_task():
task.cancel("EOF - shutting down")
return
return None
raise

if not isinstance(resp, messages.TriggerStateSync):
raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}")
self.to_create.extend(resp.to_create)
self.to_cancel.extend(resp.to_cancel)

async def block_watchdog(self):
"""
Watchdog loop that detects blocking (badly-written) triggers.
Expand Down
28 changes: 27 additions & 1 deletion airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ async def test_invalid_trigger(self, supervisor_builder):
trigger_runner = TriggerRunner()
trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder)
trigger_runner.comms_decoder.asend.return_value = messages.TriggerStateSync(
to_create=[], to_cancel=[]
to_create=[], to_cancel=set()
)

trigger_runner.to_create.append(workload)
Expand Down Expand Up @@ -438,6 +438,32 @@ async def test_trigger_kwargs_serialization_cleanup(self, session):
trigger_instance.cancel()
await runner.cleanup_finished_triggers()

@pytest.mark.asyncio
@patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True)
async def test_sync_state_to_supervisor(self, supervisor_builder):
trigger_runner = TriggerRunner()
trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder)
trigger_runner.events.append((1, TriggerEvent(payload={"status": "SUCCESS"})))
trigger_runner.events.append((2, TriggerEvent(payload={"status": "FAILED"})))
trigger_runner.events.append((3, TriggerEvent(payload={"status": "SUCCESS", "data": object()})))

async def asend_side_effect(msg):
if msg.events and len(msg.events) == 3:
raise NotImplementedError("Simulate non-serializable event")
return messages.TriggerStateSync(to_create=[], to_cancel=set())

trigger_runner.comms_decoder.asend.side_effect = asend_side_effect

await trigger_runner.sync_state_to_supervisor(finished_ids=[])

assert trigger_runner.comms_decoder.asend.call_count == 2

first_call = trigger_runner.comms_decoder.asend.call_args_list[0].args[0]
second_call = trigger_runner.comms_decoder.asend.call_args_list[1].args[0]

assert len(first_call.events) == 3
assert len(second_call.events) == 2


@pytest.mark.asyncio
async def test_trigger_create_race_condition_38599(session, supervisor_builder, testing_dag_bundle):
Expand Down