Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ def run(
log: Logger,
) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
"""Run the task in this process."""
import signal

from airflow.exceptions import (
AirflowException,
AirflowFailException,
Expand All @@ -849,6 +851,17 @@ def run(
assert ti.task is not None
assert isinstance(ti.task, BaseOperator)

parent_pid = os.getpid()

def _on_term(signum, frame):
pid = os.getpid()
if pid != parent_pid:
return

ti.task.on_kill()

signal.signal(signal.SIGTERM, _on_term)

msg: ToSupervisor | None = None
state: TaskInstanceState
error: BaseException | None = None
Expand Down
81 changes: 81 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from uuid6 import uuid7

from airflow.executors.workloads import BundleInfo
from airflow.sdk import BaseOperator
from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import (
Expand Down Expand Up @@ -113,6 +114,7 @@
set_supervisor_comms,
supervise,
)
from airflow.sdk.execution_time.task_runner import run
from airflow.utils import timezone, timezone as tz

if TYPE_CHECKING:
Expand Down Expand Up @@ -330,6 +332,85 @@ def subprocess_main():
]
)

def test_on_kill_hook_called_when_sigkilled(
self,
client_with_ti_start,
mocked_parse,
make_ti_context,
mock_supervisor_comms,
create_runtime_ti,
make_ti_context_dict,
capfd,
):
main_pid = os.getpid()
ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/heartbeat":
return httpx.Response(
status_code=409,
json={
"detail": {
"reason": "not_running",
"message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate",
"current_state": "failed",
}
},
)
if request.url.path == f"/task-instances/{ti_id}/run":
return httpx.Response(200, json=make_ti_context_dict())
return httpx.Response(status_code=204)

def subprocess_main():
# Ensure we follow the "protocol" and get the startup message before we do anything
CommsDecoder()._get_response()

class CustomOperator(BaseOperator):
def execute(self, context):
for i in range(1000):
print(f"Iteration {i}")
sleep(1)

def on_kill(self) -> None:
print("On kill hook called!")

task = CustomOperator(task_id="print-params")
runtime_ti = create_runtime_ti(
dag_id="c",
task=task,
conf={
"x": 3,
"text": "Hello World!",
"flag": False,
"a_simple_list": ["one", "two", "three", "actually one value is made per line"],
},
)
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())

assert os.getpid() != main_pid
os.kill(os.getpid(), signal.SIGTERM)
# Ensure that the signal is serviced before we finish and exit the subprocess.
sleep(0.5)

proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
what=TaskInstance(
id=ti_id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
)

proc.wait()
captured = capfd.readouterr()
assert "On kill hook called!" in captured.out

def test_subprocess_sigkilled(self, client_with_ti_start):
main_pid = os.getpid()

Expand Down