Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 83 additions & 74 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,12 @@ class StartTriggerer(BaseModel):
requests_fd: int
type: Literal["StartTriggerer"] = "StartTriggerer"

class CancelTriggers(BaseModel):
"""Request to cancel running triggers."""

ids: Iterable[int]
type: Literal["CancelTriggersMessage"] = "CancelTriggersMessage"

class TriggerStateChanges(BaseModel):
"""Report state change about triggers back to the TriggerRunnerSupervisor."""
"""
Report state change about triggers back to the TriggerRunnerSupervisor.

The supervisor will respond with a TriggerStateSync message.
"""

type: Literal["TriggerStateChanges"] = "TriggerStateChanges"
events: Annotated[
Expand All @@ -204,12 +202,17 @@ class TriggerStateChanges(BaseModel):
failures: list[tuple[int, list[str] | None]] | None = None
finished: list[int] | None = None

class TriggerStateSync(BaseModel):
type: Literal["TriggerStateSync"] = "TriggerStateSync"

to_create: list[workloads.RunTrigger]
to_cancel: set[int]


ToTriggerRunner = Annotated[
Union[
workloads.RunTrigger,
messages.CancelTriggers,
messages.StartTriggerer,
messages.TriggerStateSync,
ConnectionResult,
VariableResult,
XComResult,
Expand All @@ -236,9 +239,9 @@ class TriggerStateChanges(BaseModel):
class TriggerLoggingFactory:
log_path: str

ti: RuntimeTI
ti: RuntimeTI = attrs.field(repr=False)

bound_logger: WrappedLogger = attrs.field(init=False)
bound_logger: WrappedLogger = attrs.field(init=False, repr=False)

def __call__(self, processors: Iterable[structlog.typing.Processor]) -> WrappedLogger:
if hasattr(self, "bound_logger"):
Expand Down Expand Up @@ -302,6 +305,10 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
# FinishedTriggers message
cancelling_triggers: set[int] = attrs.field(factory=set, init=False)

# A list of RunTrigger workloads to send to the async process when it next checks in. We can't send it
# directly as all comms has to be initiated by the subprocess
creating_triggers: deque[workloads.RunTrigger] = attrs.field(factory=deque, init=False)

# Outbound queue of events
events: deque[tuple[int, events.TriggerEvent]] = attrs.field(factory=deque, init=False)

Expand All @@ -323,7 +330,7 @@ def start( # type: ignore[override]
proc = super().start(id=job.id, job=job, target=cls.run_in_process, logger=logger, **kwargs)

msg = messages.StartTriggerer(requests_fd=proc._requests_fd)
proc._send(msg)
proc.stdin.write(msg.model_dump_json().encode() + b"\n")
return proc

@functools.cached_property
Expand All @@ -341,7 +348,6 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -
resp = None

if isinstance(msg, messages.TriggerStateChanges):
log.debug("State change from async process", state=msg)
if msg.events:
self.events.extend(msg.events)
if msg.failures:
Expand All @@ -353,6 +359,19 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -
# only need to remove the last reference to it to close the open FH
if factory := self.logger_cache.pop(id, None):
factory.upload_to_remote()

response = messages.TriggerStateSync(
to_create=[],
to_cancel=self.cancelling_triggers,
)

# Pull out of these deques in a thread-safe manner
while self.creating_triggers:
workload = self.creating_triggers.popleft()
response.to_create.append(workload)
self.running_triggers.update(m.id for m in response.to_create)
resp = response.model_dump_json().encode()

elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
Expand Down Expand Up @@ -391,10 +410,11 @@ def run(self) -> None:
log.error("Trigger runner process has died! Exiting.")
break
with Trace.start_span(span_name="triggerer_job_loop", component="TriggererJobRunner"):
self.load_triggers()

# Wait for up to 1 second for activity
self._service_subprocess(1)

self.load_triggers()
self.handle_events()
self.handle_failed_triggers()
self.clean_unused()
Expand Down Expand Up @@ -462,9 +482,6 @@ def emit_metrics(self):
}
)

def _send(self, msg: BaseModel):
self.stdin.write(msg.model_dump_json().encode("utf-8") + b"\n")

def update_triggers(self, requested_trigger_ids: set[int]):
"""
Request that we update what triggers we're running.
Expand Down Expand Up @@ -533,14 +550,11 @@ def update_triggers(self, requested_trigger_ids: set[int]):

to_create.append(workload)

for workload in to_create:
self._send(workload)
self.running_triggers.update(m.id for m in to_create)
self.creating_triggers.extend(to_create)

if cancel_trigger_ids:
# Enqueue orphaned triggers for cancellation
self.cancelling_triggers.update(cancel_trigger_ids)
self._send(messages.CancelTriggers(ids=cancel_trigger_ids))

def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socket, logs: socket):
super()._register_pipe_readers(stdout, stderr, requests, logs)
Expand Down Expand Up @@ -655,6 +669,9 @@ class TriggerRunner:
log: FilteringBoundLogger = structlog.get_logger()

requests_sock: asyncio.StreamWriter
response_sock: asyncio.StreamReader

decoder: TypeAdapter[ToTriggerRunner]

def __init__(self):
super().__init__()
Expand All @@ -665,28 +682,10 @@ def __init__(self):
self.events = deque()
self.failed_triggers = deque()
self.job_id = None

def init_comms(self):
"""Init supervisor comms."""
from airflow.sdk.execution_time import task_runner

comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor](
input=sys.stdin,
decoder=TypeAdapter[ToTriggerRunner](ToTriggerRunner),
)

msg = comms_decoder.get_message()
if not isinstance(msg, messages.StartTriggerer):
raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}")
comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)

task_runner.SUPERVISOR_COMMS = comms_decoder
self.decoder = TypeAdapter(ToTriggerRunner)

def run(self):
"""Sync entrypoint - just run a run in an async loop."""
# Make sure comms are initialized before allowing any Triggers to run
self.init_comms()

asyncio.run(self.arun())

async def arun(self):
Expand All @@ -695,25 +694,25 @@ async def arun(self):

Actual triggers run in their own separate coroutines.
"""
# Make sure comms are initialized before allowing any Triggers to run
await self.init_comms()

watchdog = asyncio.create_task(self.block_watchdog())
ready_event = asyncio.Event()
read_workloads = asyncio.create_task(self.read_workloads(ready_event))

await ready_event.wait()
last_status = time.monotonic()
try:
while not self.stop:
# Raise exceptions from the tasks
if read_workloads.done():
read_workloads.result()
if watchdog.done():
watchdog.result()

# Run core logic
await self.create_triggers()
await self.cancel_triggers()

finished_ids = await self.cleanup_finished_triggers()
# This also loads the triggers we need to create or cancel
await self.sync_state_to_supervisor(finished_ids)
await self.create_triggers()
await self.cancel_triggers()
# Sleep for a bit
await asyncio.sleep(1)
# Every minute, log status
Expand All @@ -723,57 +722,54 @@ async def arun(self):
last_status = now

except Exception:
log.exception("Trigger runner failed")
try:
await log.aexception("Trigger runner failed")
except BrokenPipeError:
pass
self.stop = True
raise
read_workloads.cancel()
# Wait for supporting tasks to complete
await watchdog
await read_workloads

async def read_workloads(self, ready_event: asyncio.Event):
async def init_comms(self):
"""
Read the triggers to run on stdin.
Set up the communications pipe between this process and the supervisor.

This reads-and-decodes the JSON lines send by the TriggerRunnerSupervisor to us on our stdint
This also sets up the SUPERVISOR_COMMS so that TaskSDK code can work as expected too (but that will
need to be wrapped in an ``sync_to_async()`` call)
"""
from airflow.sdk.execution_time import task_runner

loop = asyncio.get_event_loop()

task = asyncio.current_task(loop=loop)
if TYPE_CHECKING:
assert task
# Set the event on done callback so that this FN fails the arun wakes up and we catch the exception
task.add_done_callback(lambda _: ready_event.set())
comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor](
input=sys.stdin,
decoder=self.decoder,
)

task_runner.SUPERVISOR_COMMS = comms_decoder

async def connect_stdin() -> asyncio.StreamReader:
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
return reader

stdin = await connect_stdin()
self.response_sock = await connect_stdin()

decoder = TypeAdapter[ToTriggerRunner](ToTriggerRunner)
line = await self.response_sock.readline()

msg = self.decoder.validate_json(line)
if not isinstance(msg, messages.StartTriggerer):
raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}")

comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)
writer_transport, writer_protocol = await loop.connect_write_pipe(
lambda: asyncio.streams.FlowControlMixin(loop=loop),
task_runner.SUPERVISOR_COMMS.request_socket,
comms_decoder.request_socket,
)
self.requests_sock = asyncio.streams.StreamWriter(writer_transport, writer_protocol, None, loop)

# Tell `arun` it can start the main loop now
ready_event.set()

async for line in stdin:
msg = decoder.validate_json(line)

if isinstance(msg, workloads.RunTrigger):
self.to_create.append(msg)
elif isinstance(msg, messages.CancelTriggers):
self.to_cancel.extend(msg.ids)

async def create_triggers(self):
"""Drain the to_create queue and create all new triggers that have been requested in the DB."""
while self.to_create:
Expand Down Expand Up @@ -882,6 +878,8 @@ async def cleanup_finished_triggers(self) -> list[int]:
return finished_ids

async def sync_state_to_supervisor(self, finished_ids: list[int]):
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

# Copy out of our deques in threadsafe manner to sync state with parent
events_to_send = []
while self.events:
Expand All @@ -898,7 +896,6 @@ async def sync_state_to_supervisor(self, finished_ids: list[int]):
events=events_to_send, finished=finished_ids, failures=failures_to_send
)

# Only send a message if there is anything to say
if not events_to_send:
msg.events = None

Expand All @@ -908,9 +905,21 @@ async def sync_state_to_supervisor(self, finished_ids: list[int]):
if not finished_ids:
msg.finished = None

if msg.events or msg.finished or msg.failures:
# Block triggers from making any requests for the duration of this
async with SUPERVISOR_COMMS.lock:
# Tell the monitor that we've finished triggers so it can update things
self.requests_sock.write(msg.model_dump_json(exclude_none=True).encode() + b"\n")
line = await self.response_sock.readline()

if line == b"": # EoF received!
if task := asyncio.current_task():
task.cancel("EOF - shutting down")

resp = self.decoder.validate_json(line)
if not isinstance(resp, messages.TriggerStateSync):
raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got f{type(msg)}")
self.to_create.extend(resp.to_create)
self.to_cancel.extend(resp.to_cancel)

async def block_watchdog(self):
"""
Expand Down
Loading