diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index f4b8292a88eec..410f6dadea6ca 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -178,7 +178,10 @@ def run(ti: RuntimeTaskInstance, log: Logger): trigger_timeout=timeout, ) except AirflowSkipException: - ... + msg = TaskState( + state=TerminalTIState.SKIPPED, + end_date=datetime.now(tz=timezone.utc), + ) except AirflowRescheduleException: ... except (AirflowFailException, AirflowSensorTimeout): diff --git a/task_sdk/tests/dags/basic_skipped.py b/task_sdk/tests/dags/basic_skipped.py new file mode 100644 index 0000000000000..c8fefd1baa8f8 --- /dev/null +++ b/task_sdk/tests/dags/basic_skipped.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.exceptions import AirflowSkipException +from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.dag import dag + + +@dag() +def basic_skipped(): + def skip_task(): + raise AirflowSkipException("This task is being skipped intentionally.") + + PythonOperator( + task_id="skip", + python_callable=skip_task, + ) + + +basic_skipped() diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index da2ea85589f59..e44b4942e13ff 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -46,11 +46,12 @@ GetXCom, PutVariable, SetXCom, + TaskState, VariableResult, XComResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise -from airflow.utils import timezone as tz +from airflow.utils import timezone, timezone as tz from task_sdk.tests.api.test_client import make_client @@ -848,6 +849,14 @@ def watched_subprocess(self, mocker): {"ok": True}, id="set_xcom_with_map_index", ), + pytest.param( + TaskState(state=TerminalTIState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), + b"", + "", + (), + "", + id="patch_task_instance_to_skipped", + ), ], ) def test_handle_requests( @@ -883,7 +892,8 @@ def test_handle_requests( generator.send(msg) # Verify the correct client method was called - mock_client_method.assert_called_once_with(*method_arg) + if client_attr_path: + mock_client_method.assert_called_once_with(*method_arg) # Verify the response was added to the buffer assert watched_subprocess.stdin.getvalue() == expected_buffer diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 5cf681cd67051..ac834f74aa6dc 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -94,7 +94,6 @@ def test_run_basic(test_dags_dir: Path, time_machine): with mock.patch( "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as mock_supervisor_comms: - mock_supervisor_comms.send_request = mock.Mock() run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( @@ -137,3 +136,26 @@ def test_run_deferred_basic(test_dags_dir: Path, time_machine): # send_request will only be called when the TaskDeferred exception is raised mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY) + + +def test_run_basic_skipped(test_dags_dir: Path, time_machine): + """Test running a basic task that marks itself skipped.""" + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="basic_skipped", run_id="c", try_number=1), + file=str(test_dags_dir / "basic_skipped.py"), + requests_fd=0, + ) + + ti = parse(what) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + run(ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant), log=mock.ANY + )