From 982dd177f438a3fa4764c79408a9e7d35913563c Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 9 Dec 2024 16:33:07 +0530 Subject: [PATCH 1/5] AIP-72: Handling skipped tasks in task_sdk --- .../airflow/sdk/execution_time/task_runner.py | 5 ++- task_sdk/tests/dags/super_basic_skipped.py | 36 +++++++++++++++++++ .../tests/execution_time/test_supervisor.py | 14 ++++++-- .../tests/execution_time/test_task_runner.py | 24 +++++++++++++ 4 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 task_sdk/tests/dags/super_basic_skipped.py diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index f4b8292a88eec..410f6dadea6ca 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -178,7 +178,10 @@ def run(ti: RuntimeTaskInstance, log: Logger): trigger_timeout=timeout, ) except AirflowSkipException: - ... + msg = TaskState( + state=TerminalTIState.SKIPPED, + end_date=datetime.now(tz=timezone.utc), + ) except AirflowRescheduleException: ... except (AirflowFailException, AirflowSensorTimeout): diff --git a/task_sdk/tests/dags/super_basic_skipped.py b/task_sdk/tests/dags/super_basic_skipped.py new file mode 100644 index 0000000000000..86d93672acade --- /dev/null +++ b/task_sdk/tests/dags/super_basic_skipped.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.exceptions import AirflowSkipException +from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.dag import dag + + +@dag() +def super_basic_skipped(): + def skip_task(): + raise AirflowSkipException("This task is being skipped intentionally.") + + PythonOperator( + task_id="skip", + python_callable=skip_task, + ) + + +super_basic_skipped() diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index da2ea85589f59..606146f4d7477 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -46,11 +46,12 @@ GetXCom, PutVariable, SetXCom, + TaskState, VariableResult, XComResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise -from airflow.utils import timezone as tz +from airflow.utils import timezone, timezone as tz from task_sdk.tests.api.test_client import make_client @@ -848,6 +849,14 @@ def watched_subprocess(self, mocker): {"ok": True}, id="set_xcom_with_map_index", ), + pytest.param( + TaskState(state="skipped", end_date=timezone.parse("2024-10-31T12:00:00Z")), + b"", + "", + (), + "", + id="patch_task_instance_to_skipped", + ), ], ) def test_handle_requests( @@ -883,7 +892,8 @@ def test_handle_requests( generator.send(msg) # Verify the correct client method was called - mock_client_method.assert_called_once_with(*method_arg) + if client_attr_path: + mock_client_method.assert_called_once_with(*method_arg) # Verify the response was added to the buffer assert watched_subprocess.stdin.getvalue() == expected_buffer diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 5cf681cd67051..d6bbbed9a948b 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -137,3 +137,27 @@ def test_run_deferred_basic(test_dags_dir: Path, time_machine): # send_request will only be called when the TaskDeferred exception is raised mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY) + + +def test_run_basic_skipped(test_dags_dir: Path, time_machine): + """Test running a basic task that marks itself skipped.""" + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="super_basic_skipped", run_id="c", try_number=1), + file=str(test_dags_dir / "super_basic_skipped.py"), + requests_fd=0, + ) + + ti = parse(what) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + mock_supervisor_comms.send_request = mock.Mock() + run(ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant), log=mock.ANY + ) From b0b006810206a00c8c1d53f6a0b1dcc97afe2cf0 Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 9 Dec 2024 16:56:33 +0530 Subject: [PATCH 2/5] dag aint super_basic --- .../tests/dags/{super_basic_skipped.py => basic_skipped.py} | 4 ++-- task_sdk/tests/execution_time/test_task_runner.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) rename task_sdk/tests/dags/{super_basic_skipped.py => basic_skipped.py} (96%) diff --git a/task_sdk/tests/dags/super_basic_skipped.py b/task_sdk/tests/dags/basic_skipped.py similarity index 96% rename from task_sdk/tests/dags/super_basic_skipped.py rename to task_sdk/tests/dags/basic_skipped.py index 86d93672acade..c8fefd1baa8f8 100644 --- a/task_sdk/tests/dags/super_basic_skipped.py +++ b/task_sdk/tests/dags/basic_skipped.py @@ -23,7 +23,7 @@ @dag() -def super_basic_skipped(): +def basic_skipped(): def skip_task(): raise AirflowSkipException("This task is being skipped intentionally.") @@ -33,4 +33,4 @@ def skip_task(): ) -super_basic_skipped() +basic_skipped() diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index d6bbbed9a948b..a48f75503adb2 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -142,8 +142,8 @@ def test_run_deferred_basic(test_dags_dir: Path, time_machine): def test_run_basic_skipped(test_dags_dir: Path, time_machine): """Test running a basic task that marks itself skipped.""" what = StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="super_basic_skipped", run_id="c", try_number=1), - file=str(test_dags_dir / "super_basic_skipped.py"), + ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="basic_skipped", run_id="c", try_number=1), + file=str(test_dags_dir / "basic_skipped.py"), requests_fd=0, ) From c7725c732f120b366565b81986d5ec285770c6df Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 9 Dec 2024 17:23:36 +0530 Subject: [PATCH 3/5] nits --- .../tests/execution_time/test_supervisor.py | 194 +++++++++--------- .../tests/execution_time/test_task_runner.py | 2 - 2 files changed, 93 insertions(+), 103 deletions(-) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 606146f4d7477..661893ed3c9f7 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -39,16 +39,8 @@ from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( - ConnectionResult, DeferTask, - GetConnection, - GetVariable, - GetXCom, - PutVariable, - SetXCom, TaskState, - VariableResult, - XComResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise from airflow.utils import timezone, timezone as tz @@ -756,99 +748,99 @@ def watched_subprocess(self, mocker): @pytest.mark.parametrize( ["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response"], [ - pytest.param( - GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql"}\n', - "connections.get", - ("test_conn",), - ConnectionResult(conn_id="test_conn", conn_type="mysql"), - id="get_connection", - ), - pytest.param( - GetVariable(key="test_key"), - b'{"key":"test_key","value":"test_value"}\n', - "variables.get", - ("test_key",), - VariableResult(key="test_key", value="test_value"), - id="get_variable", - ), - pytest.param( - PutVariable(key="test_key", value="test_value", description="test_description"), - b"", - "variables.set", - ("test_key", "test_value", "test_description"), - {"ok": True}, - id="set_variable", - ), - pytest.param( - DeferTask(next_method="execute_callback", classpath="my-classpath"), - b"", - "task_instances.defer", - (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), - "", - id="patch_task_instance_to_deferred", - ), - pytest.param( - GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":"test_value"}\n', - "xcoms.get", - ("test_dag", "test_run", "test_task", "test_key", -1), - XComResult(key="test_key", value="test_value"), - id="get_xcom", - ), - pytest.param( - GetXCom( - dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 - ), - b'{"key":"test_key","value":"test_value"}\n', - "xcoms.get", - ("test_dag", "test_run", "test_task", "test_key", 2), - XComResult(key="test_key", value="test_value"), - id="get_xcom_map_index", - ), - pytest.param( - SetXCom( - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - key="test_key", - value='{"key": "test_key", "value": {"key2": "value2"}}', - ), - b"", - "xcoms.set", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - '{"key": "test_key", "value": {"key2": "value2"}}', - None, - ), - {"ok": True}, - id="set_xcom", - ), - pytest.param( - SetXCom( - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - key="test_key", - value='{"key": "test_key", "value": {"key2": "value2"}}', - map_index=2, - ), - b"", - "xcoms.set", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - '{"key": "test_key", "value": {"key2": "value2"}}', - 2, - ), - {"ok": True}, - id="set_xcom_with_map_index", - ), + # pytest.param( + # GetConnection(conn_id="test_conn"), + # b'{"conn_id":"test_conn","conn_type":"mysql"}\n', + # "connections.get", + # ("test_conn",), + # ConnectionResult(conn_id="test_conn", conn_type="mysql"), + # id="get_connection", + # ), + # pytest.param( + # GetVariable(key="test_key"), + # b'{"key":"test_key","value":"test_value"}\n', + # "variables.get", + # ("test_key",), + # VariableResult(key="test_key", value="test_value"), + # id="get_variable", + # ), + # pytest.param( + # PutVariable(key="test_key", value="test_value", description="test_description"), + # b"", + # "variables.set", + # ("test_key", "test_value", "test_description"), + # {"ok": True}, + # id="set_variable", + # ), + # pytest.param( + # DeferTask(next_method="execute_callback", classpath="my-classpath"), + # b"", + # "task_instances.defer", + # (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), + # "", + # id="patch_task_instance_to_deferred", + # ), + # pytest.param( + # GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), + # b'{"key":"test_key","value":"test_value"}\n', + # "xcoms.get", + # ("test_dag", "test_run", "test_task", "test_key", -1), + # XComResult(key="test_key", value="test_value"), + # id="get_xcom", + # ), + # pytest.param( + # GetXCom( + # dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 + # ), + # b'{"key":"test_key","value":"test_value"}\n', + # "xcoms.get", + # ("test_dag", "test_run", "test_task", "test_key", 2), + # XComResult(key="test_key", value="test_value"), + # id="get_xcom_map_index", + # ), + # pytest.param( + # SetXCom( + # dag_id="test_dag", + # run_id="test_run", + # task_id="test_task", + # key="test_key", + # value='{"key": "test_key", "value": {"key2": "value2"}}', + # ), + # b"", + # "xcoms.set", + # ( + # "test_dag", + # "test_run", + # "test_task", + # "test_key", + # '{"key": "test_key", "value": {"key2": "value2"}}', + # None, + # ), + # {"ok": True}, + # id="set_xcom", + # ), + # pytest.param( + # SetXCom( + # dag_id="test_dag", + # run_id="test_run", + # task_id="test_task", + # key="test_key", + # value='{"key": "test_key", "value": {"key2": "value2"}}', + # map_index=2, + # ), + # b"", + # "xcoms.set", + # ( + # "test_dag", + # "test_run", + # "test_task", + # "test_key", + # '{"key": "test_key", "value": {"key2": "value2"}}', + # 2, + # ), + # {"ok": True}, + # id="set_xcom_with_map_index", + # ), pytest.param( TaskState(state="skipped", end_date=timezone.parse("2024-10-31T12:00:00Z")), b"", diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index a48f75503adb2..ac834f74aa6dc 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -94,7 +94,6 @@ def test_run_basic(test_dags_dir: Path, time_machine): with mock.patch( "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as mock_supervisor_comms: - mock_supervisor_comms.send_request = mock.Mock() run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( @@ -155,7 +154,6 @@ def test_run_basic_skipped(test_dags_dir: Path, time_machine): with mock.patch( "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as mock_supervisor_comms: - mock_supervisor_comms.send_request = mock.Mock() run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( From ee89c668d11f0d79dfa8322f4b736ea9eda25d2d Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 9 Dec 2024 18:08:21 +0530 Subject: [PATCH 4/5] uncommenting tests --- .../tests/execution_time/test_supervisor.py | 194 +++++++++--------- 1 file changed, 101 insertions(+), 93 deletions(-) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 661893ed3c9f7..606146f4d7477 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -39,8 +39,16 @@ from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( + ConnectionResult, DeferTask, + GetConnection, + GetVariable, + GetXCom, + PutVariable, + SetXCom, TaskState, + VariableResult, + XComResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise from airflow.utils import timezone, timezone as tz @@ -748,99 +756,99 @@ def watched_subprocess(self, mocker): @pytest.mark.parametrize( ["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response"], [ - # pytest.param( - # GetConnection(conn_id="test_conn"), - # b'{"conn_id":"test_conn","conn_type":"mysql"}\n', - # "connections.get", - # ("test_conn",), - # ConnectionResult(conn_id="test_conn", conn_type="mysql"), - # id="get_connection", - # ), - # pytest.param( - # GetVariable(key="test_key"), - # b'{"key":"test_key","value":"test_value"}\n', - # "variables.get", - # ("test_key",), - # VariableResult(key="test_key", value="test_value"), - # id="get_variable", - # ), - # pytest.param( - # PutVariable(key="test_key", value="test_value", description="test_description"), - # b"", - # "variables.set", - # ("test_key", "test_value", "test_description"), - # {"ok": True}, - # id="set_variable", - # ), - # pytest.param( - # DeferTask(next_method="execute_callback", classpath="my-classpath"), - # b"", - # "task_instances.defer", - # (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), - # "", - # id="patch_task_instance_to_deferred", - # ), - # pytest.param( - # GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - # b'{"key":"test_key","value":"test_value"}\n', - # "xcoms.get", - # ("test_dag", "test_run", "test_task", "test_key", -1), - # XComResult(key="test_key", value="test_value"), - # id="get_xcom", - # ), - # pytest.param( - # GetXCom( - # dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 - # ), - # b'{"key":"test_key","value":"test_value"}\n', - # "xcoms.get", - # ("test_dag", "test_run", "test_task", "test_key", 2), - # XComResult(key="test_key", value="test_value"), - # id="get_xcom_map_index", - # ), - # pytest.param( - # SetXCom( - # dag_id="test_dag", - # run_id="test_run", - # task_id="test_task", - # key="test_key", - # value='{"key": "test_key", "value": {"key2": "value2"}}', - # ), - # b"", - # "xcoms.set", - # ( - # "test_dag", - # "test_run", - # "test_task", - # "test_key", - # '{"key": "test_key", "value": {"key2": "value2"}}', - # None, - # ), - # {"ok": True}, - # id="set_xcom", - # ), - # pytest.param( - # SetXCom( - # dag_id="test_dag", - # run_id="test_run", - # task_id="test_task", - # key="test_key", - # value='{"key": "test_key", "value": {"key2": "value2"}}', - # map_index=2, - # ), - # b"", - # "xcoms.set", - # ( - # "test_dag", - # "test_run", - # "test_task", - # "test_key", - # '{"key": "test_key", "value": {"key2": "value2"}}', - # 2, - # ), - # {"ok": True}, - # id="set_xcom_with_map_index", - # ), + pytest.param( + GetConnection(conn_id="test_conn"), + b'{"conn_id":"test_conn","conn_type":"mysql"}\n', + "connections.get", + ("test_conn",), + ConnectionResult(conn_id="test_conn", conn_type="mysql"), + id="get_connection", + ), + pytest.param( + GetVariable(key="test_key"), + b'{"key":"test_key","value":"test_value"}\n', + "variables.get", + ("test_key",), + VariableResult(key="test_key", value="test_value"), + id="get_variable", + ), + pytest.param( + PutVariable(key="test_key", value="test_value", description="test_description"), + b"", + "variables.set", + ("test_key", "test_value", "test_description"), + {"ok": True}, + id="set_variable", + ), + pytest.param( + DeferTask(next_method="execute_callback", classpath="my-classpath"), + b"", + "task_instances.defer", + (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), + "", + id="patch_task_instance_to_deferred", + ), + pytest.param( + GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), + b'{"key":"test_key","value":"test_value"}\n', + "xcoms.get", + ("test_dag", "test_run", "test_task", "test_key", -1), + XComResult(key="test_key", value="test_value"), + id="get_xcom", + ), + pytest.param( + GetXCom( + dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 + ), + b'{"key":"test_key","value":"test_value"}\n', + "xcoms.get", + ("test_dag", "test_run", "test_task", "test_key", 2), + XComResult(key="test_key", value="test_value"), + id="get_xcom_map_index", + ), + pytest.param( + SetXCom( + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + key="test_key", + value='{"key": "test_key", "value": {"key2": "value2"}}', + ), + b"", + "xcoms.set", + ( + "test_dag", + "test_run", + "test_task", + "test_key", + '{"key": "test_key", "value": {"key2": "value2"}}', + None, + ), + {"ok": True}, + id="set_xcom", + ), + pytest.param( + SetXCom( + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + key="test_key", + value='{"key": "test_key", "value": {"key2": "value2"}}', + map_index=2, + ), + b"", + "xcoms.set", + ( + "test_dag", + "test_run", + "test_task", + "test_key", + '{"key": "test_key", "value": {"key2": "value2"}}', + 2, + ), + {"ok": True}, + id="set_xcom_with_map_index", + ), pytest.param( TaskState(state="skipped", end_date=timezone.parse("2024-10-31T12:00:00Z")), b"", From 0037c1b87a1e156abd179621a85ce4f473a4c324 Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 9 Dec 2024 19:18:55 +0530 Subject: [PATCH 5/5] fixing mypy --- task_sdk/tests/execution_time/test_supervisor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 606146f4d7477..e44b4942e13ff 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -850,7 +850,7 @@ def watched_subprocess(self, mocker): id="set_xcom_with_map_index", ), pytest.param( - TaskState(state="skipped", end_date=timezone.parse("2024-10-31T12:00:00Z")), + TaskState(state=TerminalTIState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), b"", "", (),