diff --git a/task_sdk/tests/dags/basic_skipped.py b/task_sdk/tests/dags/basic_skipped.py deleted file mode 100644 index c8fefd1baa8f8..0000000000000 --- a/task_sdk/tests/dags/basic_skipped.py +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from airflow.exceptions import AirflowSkipException -from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.dag import dag - - -@dag() -def basic_skipped(): - def skip_task(): - raise AirflowSkipException("This task is being skipped intentionally.") - - PythonOperator( - task_id="skip", - python_callable=skip_task, - ) - - -basic_skipped() diff --git a/task_sdk/tests/dags/basic_templated_dag.py b/task_sdk/tests/dags/basic_templated_dag.py deleted file mode 100644 index 02db62cf801de..0000000000000 --- a/task_sdk/tests/dags/basic_templated_dag.py +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from airflow.providers.standard.operators.bash import BashOperator -from airflow.sdk.definitions.dag import dag - -default_args = { - "owner": "airflow", - "depends_on_past": False, - "retries": 1, -} - - -@dag() -def basic_templated_dag(): - BashOperator( - task_id="task1", - bash_command="echo 'Logical date is {{ logical_date }}'", - ) - - -basic_templated_dag() diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 8b9c2d5d932a4..517157e0a7a90 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -26,13 +26,62 @@ import pytest from uuid6 import uuid7 +from airflow.exceptions import AirflowSkipException from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState -from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run, startup +from airflow.sdk.execution_time.task_runner import CommsDecoder, RuntimeTaskInstance, parse, run, startup from airflow.utils import timezone +def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG: + """Creates an inline dag and returns it based on dag_id and task.""" + dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) + task.dag = dag + + return dag + + +@pytest.fixture +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you + want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + dag = get_inline_dag(dag_id, task) + t = dag.task_dict[task.task_id] + ti = RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=t) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + +class CustomOperator(BaseOperator): + def execute(self, context): + task_id = context["task_instance"].task_id + print(f"Hello World {task_id}!") + + class TestCommsDecoder: """Test the communication between the subprocess and the "supervisor".""" @@ -64,6 +113,7 @@ def test_recv_StartupDetails(self): def test_parse(test_dags_dir: Path): + """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1), file=str(test_dags_dir / "super_basic.py"), @@ -78,22 +128,21 @@ def test_parse(test_dags_dir: Path): assert isinstance(ti.task.dag, DAG) -def test_run_basic(test_dags_dir: Path, time_machine): +def test_run_basic(time_machine, mocked_parse): """Test running a basic task.""" what = StartupDetails( ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), - file=str(test_dags_dir / "super_basic_run.py"), + file="", requests_fd=0, ) - ti = parse(what) - instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) with mock.patch( "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as mock_supervisor_comms: + ti = mocked_parse(what, "super_basic_run", CustomOperator(task_id="hello")) run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( @@ -101,19 +150,26 @@ def test_run_basic(test_dags_dir: Path, time_machine): ) -def test_run_deferred_basic(test_dags_dir: Path, time_machine): +def test_run_deferred_basic(time_machine, mocked_parse): """Test that a task can transition to a deferred state.""" - what = StartupDetails( - ti=TaskInstance( - id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="c", try_number=1 - ), - file=str(test_dags_dir / "super_basic_deferred_run.py"), - requests_fd=0, - ) + import datetime + + from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync # Use the time machine to set the current time instant = timezone.datetime(2024, 11, 22) + task = DateTimeSensorAsync( + task_id="async", + target_time=str(instant + datetime.timedelta(seconds=3)), + poke_interval=60, + timeout=600, + ) time_machine.move_to(instant, tick=False) + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="async", dag_id="basic_deferred_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ) # Expected DeferTask expected_defer_task = DeferTask( @@ -131,22 +187,31 @@ def test_run_deferred_basic(test_dags_dir: Path, time_machine): with mock.patch( "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as mock_supervisor_comms: - ti = parse(what) + ti = mocked_parse(what, "basic_deferred_run", task) run(ti, log=mock.MagicMock()) # send_request will only be called when the TaskDeferred exception is raised mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY) -def test_run_basic_skipped(test_dags_dir: Path, time_machine): +def test_run_basic_skipped(time_machine, mocked_parse): """Test running a basic task that marks itself skipped.""" + from airflow.providers.standard.operators.python import PythonOperator + + task = PythonOperator( + task_id="skip", + python_callable=lambda: (_ for _ in ()).throw( + AirflowSkipException("This task is being skipped intentionally."), + ), + ) + what = StartupDetails( ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="basic_skipped", run_id="c", try_number=1), - file=str(test_dags_dir / "basic_skipped.py"), + file="", requests_fd=0, ) - ti = parse(what) + ti = mocked_parse(what, "basic_skipped", task) instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) @@ -161,14 +226,23 @@ def test_run_basic_skipped(test_dags_dir: Path, time_machine): ) -def test_startup_basic_templated_dag(test_dags_dir: Path): - """Test running a basic task.""" +def test_startup_basic_templated_dag(mocked_parse): + """Test running a DAG with templated task.""" + from airflow.providers.standard.operators.bash import BashOperator + + task = BashOperator( + task_id="templated_task", + bash_command="echo 'Logical date is {{ logical_date }}'", + ) + what = StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="task1", dag_id="basic_templated_dag", run_id="c", try_number=1), - file=str(test_dags_dir / "basic_templated_dag.py"), + ti=TaskInstance( + id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1 + ), + file="", requests_fd=0, ) - parse(what) + mocked_parse(what, "basic_templated_dag", task) with mock.patch( "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True