diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index bdd3b240dafb8..62aa7d37b7f7a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -832,6 +832,8 @@ def run( log: Logger, ) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]: """Run the task in this process.""" + import signal + from airflow.exceptions import ( AirflowException, AirflowFailException, @@ -849,6 +851,17 @@ def run( assert ti.task is not None assert isinstance(ti.task, BaseOperator) + parent_pid = os.getpid() + + def _on_term(signum, frame): + pid = os.getpid() + if pid != parent_pid: + return + + ti.task.on_kill() + + signal.signal(signal.SIGTERM, _on_term) + msg: ToSupervisor | None = None state: TaskInstanceState error: BaseException | None = None diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 27b655a323db2..d1109b68e4ebb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -44,6 +44,7 @@ from uuid6 import uuid7 from airflow.executors.workloads import BundleInfo +from airflow.sdk import BaseOperator from airflow.sdk.api import client as sdk_client from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import ( @@ -113,6 +114,7 @@ set_supervisor_comms, supervise, ) +from airflow.sdk.execution_time.task_runner import run from airflow.utils import timezone, timezone as tz if TYPE_CHECKING: @@ -330,6 +332,85 @@ def subprocess_main(): ] ) + def test_on_kill_hook_called_when_sigkilled( + self, + client_with_ti_start, + mocked_parse, + make_ti_context, + mock_supervisor_comms, + create_runtime_ti, + make_ti_context_dict, + capfd, + ): + main_pid = os.getpid() + ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/task-instances/{ti_id}/heartbeat": + return httpx.Response( + status_code=409, + json={ + "detail": { + "reason": "not_running", + "message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate", + "current_state": "failed", + } + }, + ) + if request.url.path == f"/task-instances/{ti_id}/run": + return httpx.Response(200, json=make_ti_context_dict()) + return httpx.Response(status_code=204) + + def subprocess_main(): + # Ensure we follow the "protocol" and get the startup message before we do anything + CommsDecoder()._get_response() + + class CustomOperator(BaseOperator): + def execute(self, context): + for i in range(1000): + print(f"Iteration {i}") + sleep(1) + + def on_kill(self) -> None: + print("On kill hook called!") + + task = CustomOperator(task_id="print-params") + runtime_ti = create_runtime_ti( + dag_id="c", + task=task, + conf={ + "x": 3, + "text": "Hello World!", + "flag": False, + "a_simple_list": ["one", "two", "three", "actually one value is made per line"], + }, + ) + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + assert os.getpid() != main_pid + os.kill(os.getpid(), signal.SIGTERM) + # Ensure that the signal is serviced before we finish and exit the subprocess. + sleep(0.5) + + proc = ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=TaskInstance( + id=ti_id, + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + dag_version_id=uuid7(), + ), + client=make_client(transport=httpx.MockTransport(handle_request)), + target=subprocess_main, + ) + + proc.wait() + captured = capfd.readouterr() + assert "On kill hook called!" in captured.out + def test_subprocess_sigkilled(self, client_with_ti_start): main_pid = os.getpid()