Skip to content

Commit

Permalink
AIP-72: Inline DAG injection for task runner tests
Browse files Browse the repository at this point in the history
closes: apache#44805
Dependent on apache#44786

Every time when we port the different TI state handling in the task runner, it is usually followed by an integration test of sorts to test the end to end flow of whether that state is testable or not. For example:
1. For skipped state, we use the DAG https://github.com/apache/airflow/pull/44786/files#diff-cabbddd33130ce1a769412f5fc55dd23e4af4d0fa75f8981689daae769e0680dR1 and we test using the UT in task runner: https://github.com/apache/airflow/pull/44786/files#diff-413c3c59636a3c7b41b8bb822827d18a959778d0b6331532e0db175c829dbfd2R141-R161
2. For deferred state, we use the DAG: https://github.com/apache/airflow/pull/44241/files#diff-2152ed5392424771e27a69173b3c18caae717939719df8f5dbbbdfee5f9efd9bR1 and test it using UT in task runner: https://github.com/apache/airflow/pull/44241/files#diff-413c3c59636a3c7b41b8bb822827d18a959778d0b6331532e0db175c829dbfd2R93-R127

Due to this, when new ti states are added or tests for that matter, it eventually leads to a huge folder with DAGs under `task_sdk/tests/dags` which could soon get ever growing and unmanageable.

The solution is in two parts:
1. The first part would be the ability to create dynamic or in line dags which has been implemented using a DAGFactory kind of function:
```
def get_inline_dag(dag_id: str, tasks: BaseOperator) -> DAG:
    dag = DAG(
        dag_id=dag_id,
        default_args={"start_date": timezone.datetime(2024, 12, 3)},
    )
    setattr(tasks, "dag", dag)

    return dag

```
This function is capable of accepting `one` task as of now and creating a DAG out of it and returning the DAG object which should suffice our current testing needs, if there is a need, we can extend this function to support more than one tasks and their relationships.
Usage:
```
    task = PythonOperator(
        task_id="skip",
        python_callable=lambda: (_ for _ in ()).throw(
            AirflowSkipException("This task is being skipped intentionally."),
        ),
    )

    dag = get_inline_dag("basic_skipped", task)
```
The usage is as simple as creating any task from any operator and passing it down to this function.

2. Mocking the parse function using KGB spy_agency: https://pypi.org/project/kgb/
The idea here is to use a spy agency to substitute out the `parse` function with a mock parser that does a bare minimum of the actual parser. We choose spy_agency over the mock library for two reasons primarily:
a) With `spy_agency`, you can mock specific methods or functions without affecting the entire class or module.
b) Minimal dispruption and ease of use.

1. Replaced usage of all "actual" dags with in line dags in task runner tests which either do parsing or run.
2. Deleted two DAGs
3. Cannot remove the other two DAGs as they are tied to test_supervisor.py tests which use the DAG path as of now. Can be taken in a follow up if needed. Example:
![image](https://github.com/user-attachments/assets/01baa82a-7b43-4ff1-bc7e-c2fc20cef50d)

1. No need to create any more DAG files for integration tests for task runner, which could be frequent with current development rate for AIP 72.
2. Ability to easily create in line DAGs.

Basic DAG
![image](https://github.com/user-attachments/assets/cf7a94b5-6c4c-4103-99a0-32047207a9b2)

deferred DAG
![image](https://github.com/user-attachments/assets/328f99d0-4483-48c5-9127-dd7812f47ae0)

Co-Authored-By: Kaxil Naik <kaxilnaik@gmail.com>
  • Loading branch information
amoghrajesh and kaxil committed Dec 10, 2024
1 parent e122b20 commit 016c22b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 95 deletions.
36 changes: 0 additions & 36 deletions task_sdk/tests/dags/basic_skipped.py

This file was deleted.

37 changes: 0 additions & 37 deletions task_sdk/tests/dags/basic_templated_dag.py

This file was deleted.

118 changes: 96 additions & 22 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,62 @@
import pytest
from uuid6 import uuid7

from airflow.exceptions import AirflowSkipException
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run, startup
from airflow.sdk.execution_time.task_runner import CommsDecoder, RuntimeTaskInstance, parse, run, startup
from airflow.utils import timezone


def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG:
"""Creates an inline dag and returns it based on dag_id and task."""
dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
task.dag = dag

return dag


@pytest.fixture
def mocked_parse(spy_agency):
"""
Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you
want to isolate and test `parse` or `run` logic without having to define a DAG file.
This fixture returns a helper function `set_dag` that:
1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task)
2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task.
3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`.
After adding the fixture in your test function signature, you can use it like this ::
mocked_parse(
StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
file="",
requests_fd=0,
),
"example_dag_id",
CustomOperator(task_id="hello"),
)
"""

def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance:
dag = get_inline_dag(dag_id, task)
t = dag.task_dict[task.task_id]
ti = RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=t)
spy_agency.spy_on(parse, call_fake=lambda _: ti)
return ti

return set_dag


class CustomOperator(BaseOperator):
def execute(self, context):
task_id = context["task_instance"].task_id
print(f"Hello World {task_id}!")


class TestCommsDecoder:
"""Test the communication between the subprocess and the "supervisor"."""

Expand Down Expand Up @@ -64,6 +113,7 @@ def test_recv_StartupDetails(self):


def test_parse(test_dags_dir: Path):
"""Test that checks parsing of a basic dag with an un-mocked parse."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1),
file=str(test_dags_dir / "super_basic.py"),
Expand All @@ -78,42 +128,48 @@ def test_parse(test_dags_dir: Path):
assert isinstance(ti.task.dag, DAG)


def test_run_basic(test_dags_dir: Path, time_machine):
def test_run_basic(time_machine, mocked_parse):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
file=str(test_dags_dir / "super_basic_run.py"),
file="",
requests_fd=0,
)

ti = parse(what)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
ti = mocked_parse(what, "super_basic_run", CustomOperator(task_id="hello"))
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY
)


def test_run_deferred_basic(test_dags_dir: Path, time_machine):
def test_run_deferred_basic(time_machine, mocked_parse):
"""Test that a task can transition to a deferred state."""
what = StartupDetails(
ti=TaskInstance(
id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="c", try_number=1
),
file=str(test_dags_dir / "super_basic_deferred_run.py"),
requests_fd=0,
)
import datetime

from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync

# Use the time machine to set the current time
instant = timezone.datetime(2024, 11, 22)
task = DateTimeSensorAsync(
task_id="async",
target_time=str(instant + datetime.timedelta(seconds=3)),
poke_interval=60,
timeout=600,
)
time_machine.move_to(instant, tick=False)
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="async", dag_id="basic_deferred_run", run_id="c", try_number=1),
file="",
requests_fd=0,
)

# Expected DeferTask
expected_defer_task = DeferTask(
Expand All @@ -131,22 +187,31 @@ def test_run_deferred_basic(test_dags_dir: Path, time_machine):
with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
ti = parse(what)
ti = mocked_parse(what, "basic_deferred_run", task)
run(ti, log=mock.MagicMock())

# send_request will only be called when the TaskDeferred exception is raised
mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY)


def test_run_basic_skipped(test_dags_dir: Path, time_machine):
def test_run_basic_skipped(time_machine, mocked_parse):
"""Test running a basic task that marks itself skipped."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="skip",
python_callable=lambda: (_ for _ in ()).throw(
AirflowSkipException("This task is being skipped intentionally."),
),
)

what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="basic_skipped", run_id="c", try_number=1),
file=str(test_dags_dir / "basic_skipped.py"),
file="",
requests_fd=0,
)

ti = parse(what)
ti = mocked_parse(what, "basic_skipped", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
Expand All @@ -161,14 +226,23 @@ def test_run_basic_skipped(test_dags_dir: Path, time_machine):
)


def test_startup_basic_templated_dag(test_dags_dir: Path):
"""Test running a basic task."""
def test_startup_basic_templated_dag(mocked_parse):
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
task_id="templated_task",
bash_command="echo 'Logical date is {{ logical_date }}'",
)

what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="task1", dag_id="basic_templated_dag", run_id="c", try_number=1),
file=str(test_dags_dir / "basic_templated_dag.py"),
ti=TaskInstance(
id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1
),
file="",
requests_fd=0,
)
parse(what)
mocked_parse(what, "basic_templated_dag", task)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
Expand Down

0 comments on commit 016c22b

Please sign in to comment.