From 1160eba4c135598b9e820769f6b952512be2e6b9 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 21 Aug 2025 16:15:39 +0100 Subject: [PATCH] Improve test maintainability for supervisor request handling tests Refactor large parametrized test from individual parameters to dataclass structure for better readability and maintainability. The test now uses a `RequestTestCase` dataclass instead of 7 separate parameters, making it much easier to add new test cases and understand existing ones. This change makes the test suite more maintainable for developers working on the supervisor communication protocol. All the existing tests have been ported over. Tests: ``` task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_connection] PASSED [ 2%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_connection_with_password] PASSED [ 4%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_connection_with_alias] PASSED [ 7%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_variable] PASSED [ 9%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[set_variable] PASSED [ 12%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[delete_variable] PASSED [ 14%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[patch_task_instance_to_deferred] PASSED [ 17%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[patch_task_instance_to_up_for_reschedule] PASSED [ 19%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_xcom] PASSED [ 21%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_xcom_map_index] PASSED [ 24%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_xcom_not_found] PASSED [ 26%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_xcom_include_prior_dates] PASSED [ 29%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[set_xcom] PASSED [ 31%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[set_xcom_with_map_index] PASSED [ 34%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[set_xcom_with_map_index_and_mapped_length] PASSED [ 36%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[delete_xcom] PASSED [ 39%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[up_for_retry] PASSED [ 41%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[set_rtif] PASSED [ 43%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[succeed_task] PASSED [ 46%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_asset_by_name] PASSED [ 48%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_asset_by_uri] PASSED [ 51%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_asset_events_by_uri_and_name] PASSED [ 53%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_asset_events_by_uri] PASSED [ 56%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_asset_events_by_name] PASSED [ 58%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_asset_events_by_asset_alias] PASSED [ 60%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[validate_inlets_and_outlets] PASSED [ 63%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_prev_successful_dagrun] PASSED [ 65%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[dag_run_trigger] PASSED [ 68%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[dag_run_trigger_already_exists] PASSED [ 70%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_dag_run_state] PASSED [ 73%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_previous_dagrun] PASSED [ 75%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_previous_dagrun_with_state] PASSED [ 78%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_task_reschedule_start_date] PASSED [ 80%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_ti_count] PASSED [ 82%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_dr_count] PASSED [ 85%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_task_states] PASSED [ 87%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_xcom_seq_item] PASSED [ 90%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_xcom_seq_item_not_found] PASSED [ 92%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_xcom_seq_slice] PASSED [ 95%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[patch_task_instance_to_skipped] PASSED [ 97%] task-sdk/tests/task_sdk/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[create_hitl_detail_payload] PASSED [100%] ``` --- task-sdk/src/airflow/sdk/api/client.py | 5 +- .../airflow/sdk/execution_time/supervisor.py | 4 +- .../execution_time/test_supervisor.py | 1566 +++++++++-------- 3 files changed, 828 insertions(+), 747 deletions(-) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 9f1a1f8195397..be27c28d3ff96 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -78,6 +78,7 @@ TaskRescheduleStartDate, TICount, UpdateHITLDetail, + XComCountResponse, ) if TYPE_CHECKING: @@ -414,7 +415,7 @@ class XComOperations: def __init__(self, client: Client): self.client = client - def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> int: + def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> XComCountResponse: """Get the number of mapped XCom values.""" resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}") @@ -423,7 +424,7 @@ def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> int: "map_indexes " ): raise RuntimeError(f"Unable to parse Content-Range header from HEAD {resp.request.url}") - return int(content_range[len("map_indexes ") :]) + return XComCountResponse(len=int(content_range[len("map_indexes ") :])) def get( self, diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 03e8c1ee345b7..306b96380c756 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -113,7 +113,6 @@ TriggerDagRun, ValidateInletsAndOutlets, VariableResult, - XComCountResponse, XComResult, XComSequenceIndexResult, XComSequenceSliceResult, @@ -1202,8 +1201,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: xcom_result = XComResult.from_xcom_response(xcom) resp = xcom_result elif isinstance(msg, GetXComCount): - xcom_count = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) - resp = XComCountResponse(len=xcom_count) + resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) elif isinstance(msg, GetXComSequenceItem): xcom = self.client.xcoms.get_sequence_item( msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 8905124fc7d48..67b90a1846a90 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -28,10 +28,11 @@ import sys import time from contextlib import nullcontext +from dataclasses import dataclass, field from operator import attrgetter from random import randint from time import sleep -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest import mock from unittest.mock import MagicMock, patch @@ -78,6 +79,7 @@ GetConnection, GetDagRunState, GetDRCount, + GetHITLDetailResponse, GetPreviousDagRun, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, @@ -85,10 +87,12 @@ GetTICount, GetVariable, GetXCom, + GetXComCount, GetXComSequenceItem, GetXComSequenceSlice, HITLDetailRequestResult, InactiveAssetsResult, + MaskSecret, OKResponse, PreviousDagRunResult, PrevSuccessfulDagRunResult, @@ -99,14 +103,18 @@ SentFDs, SetRenderedFields, SetXCom, + SkipDownstreamTasks, SucceedTask, TaskRescheduleStartDate, TaskState, TaskStatesResult, TICount, + ToSupervisor, TriggerDagRun, + UpdateHITLDetail, ValidateInletsAndOutlets, VariableResult, + XComCountResponse, XComResult, XComSequenceIndexResult, XComSequenceSliceResult, @@ -1276,773 +1284,813 @@ def test_max_wait_time_calculation_edge_cases( assert actual_timeout >= expected_min_timeout -class TestHandleRequest: - @pytest.fixture - def watched_subprocess(self, mocker): - read_end, write_end = socket.socketpair() +@dataclass +class ClientMock: + """Configuration for mocking client method calls.""" - subprocess = ActivitySubprocess( - process_log=mocker.MagicMock(), - id=TI_ID, - pid=12345, - stdin=write_end, - client=mocker.Mock(), - process=mocker.Mock(), - ) + method_path: str + """Path to the client method to mock (e.g., 'connections.get', 'variables.set').""" - return subprocess, read_end + args: tuple = field(default_factory=tuple) + """Positional arguments the client method should be called with.""" - @patch("airflow.sdk.execution_time.supervisor.mask_secret") - @pytest.mark.parametrize( - [ - "message", - "expected_body", - "client_attr_path", - "method_arg", - "method_kwarg", - "mock_response", - "mask_secret_args", - ], - [ - pytest.param( - GetConnection(conn_id="test_conn"), - {"conn_id": "test_conn", "conn_type": "mysql", "type": "ConnectionResult"}, - "connections.get", - ("test_conn",), - {}, - ConnectionResult(conn_id="test_conn", conn_type="mysql"), - None, - id="get_connection", - ), - pytest.param( - GetConnection(conn_id="test_conn"), - { - "conn_id": "test_conn", - "conn_type": "mysql", - "password": "password", - "type": "ConnectionResult", - }, - "connections.get", - ("test_conn",), - {}, - ConnectionResult(conn_id="test_conn", conn_type="mysql", password="password"), - ["password"], - id="get_connection_with_password", - ), - pytest.param( - GetConnection(conn_id="test_conn"), - {"conn_id": "test_conn", "conn_type": "mysql", "schema": "mysql", "type": "ConnectionResult"}, - "connections.get", - ("test_conn",), - {}, - ConnectionResult(conn_id="test_conn", conn_type="mysql", schema="mysql"), # type: ignore[call-arg] - None, - id="get_connection_with_alias", - ), - pytest.param( - GetVariable(key="test_key"), - {"key": "test_key", "value": "test_value", "type": "VariableResult"}, - "variables.get", - ("test_key",), - {}, - VariableResult(key="test_key", value="test_value"), - ["test_value", "test_key"], - id="get_variable", - ), - pytest.param( - PutVariable(key="test_key", value="test_value", description="test_description"), - None, - "variables.set", - ("test_key", "test_value", "test_description"), - {}, - OKResponse(ok=True), - None, - id="set_variable", - ), - pytest.param( - DeleteVariable(key="test_key"), - {"ok": True, "type": "OKResponse"}, - "variables.delete", - ("test_key",), - {}, - OKResponse(ok=True), - None, - id="delete_variable", - ), - pytest.param( - DeferTask(next_method="execute_callback", classpath="my-classpath"), - None, - "task_instances.defer", - (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), - {}, - "", - None, - id="patch_task_instance_to_deferred", - ), - pytest.param( + kwargs: dict = field(default_factory=dict) + """Keyword arguments the client method should be called with.""" + + response: Any = None + """What the mocked client method should return when called.""" + + +@dataclass +class RequestTestCase: + """Test case data for request handling tests in `TestHandleRequest` class.""" + + message: Any + """The request message to send to the supervisor (e.g., GetConnection, SetXCom).""" + + test_id: str + """Unique identifier for this test case, used in pytest parameterization.""" + + client_mock: ClientMock | None = None + """Client method mocking configuration. None for messages that don't require client calls.""" + + expected_body: dict | None = None + """Expected response body from supervisor. None if no response body expected.""" + + mask_secret_args: tuple | None = None + """Arguments that should be passed to the secret masker for redaction.""" + + +# Test cases for request handling +REQUEST_TEST_CASES = [ + RequestTestCase( + message=GetConnection(conn_id="test_conn"), + test_id="get_connection", + client_mock=ClientMock( + method_path="connections.get", + args=("test_conn",), + response=ConnectionResult(conn_id="test_conn", conn_type="mysql"), + ), + expected_body={"conn_id": "test_conn", "conn_type": "mysql", "type": "ConnectionResult"}, + ), + RequestTestCase( + message=GetConnection(conn_id="test_conn"), + test_id="get_connection_with_password", + client_mock=ClientMock( + method_path="connections.get", + args=("test_conn",), + response=ConnectionResult(conn_id="test_conn", conn_type="mysql", password="password"), + ), + expected_body={ + "conn_id": "test_conn", + "conn_type": "mysql", + "password": "password", + "type": "ConnectionResult", + }, + mask_secret_args=("password",), + ), + RequestTestCase( + message=GetConnection(conn_id="test_conn"), + test_id="get_connection_with_alias", + client_mock=ClientMock( + method_path="connections.get", + args=("test_conn",), + response=ConnectionResult(conn_id="test_conn", conn_type="mysql", schema="mysql"), # type: ignore[call-arg] + ), + expected_body={ + "conn_id": "test_conn", + "conn_type": "mysql", + "schema": "mysql", + "type": "ConnectionResult", + }, + ), + RequestTestCase( + message=GetVariable(key="test_key"), + test_id="get_variable", + client_mock=ClientMock( + method_path="variables.get", + args=("test_key",), + response=VariableResult(key="test_key", value="test_value"), + ), + expected_body={"key": "test_key", "value": "test_value", "type": "VariableResult"}, + mask_secret_args=("test_value", "test_key"), + ), + RequestTestCase( + message=PutVariable(key="test_key", value="test_value", description="test_description"), + test_id="set_variable", + client_mock=ClientMock( + method_path="variables.set", + args=("test_key", "test_value", "test_description"), + response=OKResponse(ok=True), + ), + ), + RequestTestCase( + message=DeleteVariable(key="test_key"), + test_id="delete_variable", + client_mock=ClientMock( + method_path="variables.delete", + args=("test_key",), + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=DeferTask(next_method="execute_callback", classpath="my-classpath"), + test_id="patch_task_instance_to_deferred", + client_mock=ClientMock( + method_path="task_instances.defer", + args=(TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), + ), + ), + RequestTestCase( + message=RescheduleTask( + reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), + end_date=timezone.parse("2024-10-31T12:00:00Z"), + ), + test_id="patch_task_instance_to_up_for_reschedule", + client_mock=ClientMock( + method_path="task_instances.reschedule", + args=( + TI_ID, RescheduleTask( reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), end_date=timezone.parse("2024-10-31T12:00:00Z"), ), - None, - "task_instances.reschedule", - ( - TI_ID, - RescheduleTask( - reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), - end_date=timezone.parse("2024-10-31T12:00:00Z"), - ), - ), - {}, - "", - None, - id="patch_task_instance_to_up_for_reschedule", ), - pytest.param( - GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - {"key": "test_key", "value": "test_value", "type": "XComResult"}, - "xcoms.get", - ("test_dag", "test_run", "test_task", "test_key", None, False), - {}, - XComResult(key="test_key", value="test_value"), - None, - id="get_xcom", - ), - pytest.param( - GetXCom( - dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 - ), - {"key": "test_key", "value": "test_value", "type": "XComResult"}, - "xcoms.get", - ("test_dag", "test_run", "test_task", "test_key", 2, False), - {}, - XComResult(key="test_key", value="test_value"), - None, - id="get_xcom_map_index", - ), - pytest.param( - GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - {"key": "test_key", "value": None, "type": "XComResult"}, - "xcoms.get", - ("test_dag", "test_run", "test_task", "test_key", None, False), - {}, - XComResult(key="test_key", value=None, type="XComResult"), - None, - id="get_xcom_not_found", - ), - pytest.param( - GetXCom( - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - key="test_key", - include_prior_dates=True, - ), - {"key": "test_key", "value": None, "type": "XComResult"}, - "xcoms.get", - ("test_dag", "test_run", "test_task", "test_key", None, True), - {}, - XComResult(key="test_key", value=None, type="XComResult"), - None, - id="get_xcom_include_prior_dates", - ), - pytest.param( - SetXCom( - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - key="test_key", - value='{"key": "test_key", "value": {"key2": "value2"}}', - ), - None, - "xcoms.set", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - '{"key": "test_key", "value": {"key2": "value2"}}', - None, - None, - ), - {}, - OKResponse(ok=True), - None, - id="set_xcom", - ), - pytest.param( - SetXCom( - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - key="test_key", - value='{"key": "test_key", "value": {"key2": "value2"}}', - map_index=2, - ), - None, - "xcoms.set", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - '{"key": "test_key", "value": {"key2": "value2"}}', - 2, - None, - ), - {}, - OKResponse(ok=True), - None, - id="set_xcom_with_map_index", - ), - pytest.param( - SetXCom( - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - key="test_key", - value='{"key": "test_key", "value": {"key2": "value2"}}', - map_index=2, - mapped_length=3, - ), - None, - "xcoms.set", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - '{"key": "test_key", "value": {"key2": "value2"}}', - 2, - 3, - ), - {}, - OKResponse(ok=True), - None, - id="set_xcom_with_map_index_and_mapped_length", - ), - pytest.param( - DeleteXCom( - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - key="test_key", - map_index=2, - ), - None, - "xcoms.delete", - ("test_dag", "test_run", "test_task", "test_key", 2), - {}, - OKResponse(ok=True), - None, - id="delete_xcom", - ), - # we aren't adding all states under TaskInstanceState here, because this test's scope is only to check - # if it can handle TaskState message - pytest.param( - TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), - None, - "", - (), - {}, - "", - None, - id="patch_task_instance_to_skipped", - ), - pytest.param( - RetryTask( - end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" - ), - None, - "task_instances.retry", - (), - { - "id": TI_ID, - "end_date": timezone.parse("2024-10-31T12:00:00Z"), - "rendered_map_index": "test retry task", - }, - "", - None, - id="up_for_retry", - ), - pytest.param( - SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), - None, - "task_instances.set_rtif", - (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), - {}, - OKResponse(ok=True), - None, - id="set_rtif", - ), - pytest.param( - GetAssetByName(name="asset"), - {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, - "assets.get", - [], - {"name": "asset"}, - AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), - None, - id="get_asset_by_name", - ), - pytest.param( - GetAssetByUri(uri="s3://bucket/obj"), - {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, - "assets.get", - [], - {"uri": "s3://bucket/obj"}, - AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), - None, - id="get_asset_by_uri", - ), - pytest.param( - GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), - { - "asset_events": [ - { - "id": 1, - "timestamp": timezone.parse("2024-10-31T12:00:00Z"), - "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, - "created_dagruns": [], - } - ], - "type": "AssetEventsResult", - }, - "asset_events.get", - [], - {"uri": "s3://bucket/obj", "name": "test"}, - AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), - created_dagruns=[], - timestamp=timezone.parse("2024-10-31T12:00:00Z"), - ) - ] - ), - None, - id="get_asset_events_by_uri_and_name", - ), - pytest.param( - GetAssetEventByAsset(uri="s3://bucket/obj", name=None), - { - "asset_events": [ - { - "id": 1, - "timestamp": timezone.parse("2024-10-31T12:00:00Z"), - "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, - "created_dagruns": [], - } - ], - "type": "AssetEventsResult", - }, - "asset_events.get", - [], - {"uri": "s3://bucket/obj", "name": None}, - AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), - created_dagruns=[], - timestamp=timezone.parse("2024-10-31T12:00:00Z"), - ) - ] - ), - None, - id="get_asset_events_by_uri", - ), - pytest.param( - GetAssetEventByAsset(uri=None, name="test"), - { - "asset_events": [ - { - "id": 1, - "timestamp": timezone.parse("2024-10-31T12:00:00Z"), - "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, - "created_dagruns": [], - } - ], - "type": "AssetEventsResult", - }, - "asset_events.get", - [], - {"uri": None, "name": "test"}, - AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), - created_dagruns=[], - timestamp=timezone.parse("2024-10-31T12:00:00Z"), - ) - ] - ), - None, - id="get_asset_events_by_name", - ), - pytest.param( - GetAssetEventByAssetAlias(alias_name="test_alias"), - { - "asset_events": [ - { - "id": 1, - "timestamp": timezone.parse("2024-10-31T12:00:00Z"), - "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, - "created_dagruns": [], - } - ], - "type": "AssetEventsResult", - }, - "asset_events.get", - [], - {"alias_name": "test_alias"}, - AssetEventsResult( - asset_events=[ - AssetEventResponse( - id=1, - asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), - created_dagruns=[], - timestamp=timezone.parse("2024-10-31T12:00:00Z"), - ) - ] - ), - None, - id="get_asset_events_by_asset_alias", - ), - pytest.param( - ValidateInletsAndOutlets(ti_id=TI_ID), - { - "inactive_assets": [{"name": "asset_name", "uri": "asset_uri", "type": "asset"}], - "type": "InactiveAssetsResult", - }, - "task_instances.validate_inlets_and_outlets", - (TI_ID,), - {}, - InactiveAssetsResult( - inactive_assets=[AssetProfile(name="asset_name", uri="asset_uri", type="asset")] - ), - None, - id="validate_inlets_and_outlets", - ), - pytest.param( - SucceedTask( - end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" - ), - None, - "task_instances.succeed", - (), - { - "id": TI_ID, - "outlet_events": None, - "task_outlets": None, - "when": timezone.parse("2024-10-31T12:00:00Z"), - "rendered_map_index": "test success task", - }, - "", - None, - id="succeed_task", - ), - pytest.param( - GetPrevSuccessfulDagRun(ti_id=TI_ID), - { - "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"), - "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"), - "start_date": timezone.parse("2025-01-10T12:00:00Z"), - "end_date": timezone.parse("2025-01-10T14:00:00Z"), - "type": "PrevSuccessfulDagRunResult", - }, - "task_instances.get_previous_successful_dagrun", - (TI_ID,), - {}, - PrevSuccessfulDagRunResult( - start_date=timezone.parse("2025-01-10T12:00:00Z"), - end_date=timezone.parse("2025-01-10T14:00:00Z"), - data_interval_start=timezone.parse("2025-01-10T12:00:00Z"), - data_interval_end=timezone.parse("2025-01-10T14:00:00Z"), - ), - None, - id="get_prev_successful_dagrun", - ), - pytest.param( - TriggerDagRun( - dag_id="test_dag", - run_id="test_run", - conf={"key": "value"}, - logical_date=timezone.datetime(2025, 1, 1), - reset_dag_run=True, - ), - {"ok": True, "type": "OKResponse"}, - "dag_runs.trigger", - ("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True), - {}, - OKResponse(ok=True), - None, - id="dag_run_trigger", - ), - pytest.param( - # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR - TriggerDagRun(dag_id="test_dag", run_id="test_run"), - {"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": "ErrorResponse"}, - "dag_runs.trigger", - ("test_dag", "test_run", None, None, False), - {}, - ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS), + ), + ), + RequestTestCase( + message=GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), + test_id="get_xcom", + client_mock=ClientMock( + method_path="xcoms.get", + args=("test_dag", "test_run", "test_task", "test_key", None, False), + response=XComResult(key="test_key", value="test_value"), + ), + expected_body={"key": "test_key", "value": "test_value", "type": "XComResult"}, + ), + RequestTestCase( + message=GetXCom( + dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 + ), + test_id="get_xcom_map_index", + client_mock=ClientMock( + method_path="xcoms.get", + args=("test_dag", "test_run", "test_task", "test_key", 2, False), + response=XComResult(key="test_key", value="test_value"), + ), + expected_body={"key": "test_key", "value": "test_value", "type": "XComResult"}, + ), + RequestTestCase( + message=GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), + test_id="get_xcom_not_found", + client_mock=ClientMock( + method_path="xcoms.get", + args=("test_dag", "test_run", "test_task", "test_key", None, False), + response=XComResult(key="test_key", value=None, type="XComResult"), + ), + expected_body={"key": "test_key", "value": None, "type": "XComResult"}, + ), + RequestTestCase( + message=GetXCom( + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + key="test_key", + include_prior_dates=True, + ), + test_id="get_xcom_include_prior_dates", + client_mock=ClientMock( + method_path="xcoms.get", + args=("test_dag", "test_run", "test_task", "test_key", None, True), + response=XComResult(key="test_key", value=None, type="XComResult"), + ), + expected_body={"key": "test_key", "value": None, "type": "XComResult"}, + ), + RequestTestCase( + message=SetXCom( + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + key="test_key", + value='{"key": "test_key", "value": {"key2": "value2"}}', + ), + client_mock=ClientMock( + method_path="xcoms.set", + args=( + "test_dag", + "test_run", + "test_task", + "test_key", + '{"key": "test_key", "value": {"key2": "value2"}}', None, - id="dag_run_trigger_already_exists", - ), - pytest.param( - GetDagRunState(dag_id="test_dag", run_id="test_run"), - {"state": "running", "type": "DagRunStateResult"}, - "dag_runs.get_state", - ("test_dag", "test_run"), - {}, - DagRunStateResult(state=DagRunState.RUNNING), None, - id="get_dag_run_state", ), - pytest.param( - GetTaskRescheduleStartDate(ti_id=TI_ID), - {"start_date": timezone.parse("2024-10-31T12:00:00Z"), "type": "TaskRescheduleStartDate"}, - "task_instances.get_reschedule_start_date", - (TI_ID, 1), - {}, - TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")), + response=OKResponse(ok=True), + ), + test_id="set_xcom", + ), + RequestTestCase( + message=SetXCom( + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + key="test_key", + value='{"key": "test_key", "value": {"key2": "value2"}}', + map_index=2, + ), + client_mock=ClientMock( + method_path="xcoms.set", + args=( + "test_dag", + "test_run", + "test_task", + "test_key", + '{"key": "test_key", "value": {"key2": "value2"}}', + 2, None, - id="get_task_reschedule_start_date", ), - pytest.param( - GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), - {"count": 2, "type": "TICount"}, - "task_instances.get_count", - (), - { - "dag_id": "test_dag", - "map_index": None, - "logical_dates": None, - "run_ids": None, - "states": None, - "task_group_id": None, - "task_ids": ["task1", "task2"], - }, - TICount(count=2), - None, - id="get_ti_count", + response=OKResponse(ok=True), + ), + test_id="set_xcom_with_map_index", + ), + RequestTestCase( + message=SetXCom( + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + key="test_key", + value='{"key": "test_key", "value": {"key2": "value2"}}', + map_index=2, + mapped_length=3, + ), + client_mock=ClientMock( + method_path="xcoms.set", + args=( + "test_dag", + "test_run", + "test_task", + "test_key", + '{"key": "test_key", "value": {"key2": "value2"}}', + 2, + 3, ), - pytest.param( - GetDRCount(dag_id="test_dag", states=["success", "failed"]), - {"count": 2, "type": "DRCount"}, - "dag_runs.get_count", - (), + response=OKResponse(ok=True), + ), + test_id="set_xcom_with_map_index_and_mapped_length", + ), + RequestTestCase( + message=DeleteXCom( + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + key="test_key", + map_index=2, + ), + client_mock=ClientMock( + method_path="xcoms.delete", + args=("test_dag", "test_run", "test_task", "test_key", 2), + response=OKResponse(ok=True), + ), + test_id="delete_xcom", + ), + RequestTestCase( + message=RetryTask( + end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" + ), + client_mock=ClientMock( + method_path="task_instances.retry", + kwargs={ + "id": TI_ID, + "end_date": timezone.parse("2024-10-31T12:00:00Z"), + "rendered_map_index": "test retry task", + }, + response=OKResponse(ok=True), + ), + test_id="up_for_retry", + ), + RequestTestCase( + message=SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), + client_mock=ClientMock( + method_path="task_instances.set_rtif", + args=(TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), + response=OKResponse(ok=True), + ), + test_id="set_rtif", + ), + RequestTestCase( + message=SucceedTask( + end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" + ), + client_mock=ClientMock( + method_path="task_instances.succeed", + kwargs={ + "id": TI_ID, + "outlet_events": None, + "task_outlets": None, + "when": timezone.parse("2024-10-31T12:00:00Z"), + "rendered_map_index": "test success task", + }, + ), + test_id="succeed_task", + ), + RequestTestCase( + message=GetAssetByName(name="asset"), + expected_body={"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, + client_mock=ClientMock( + method_path="assets.get", + kwargs={"name": "asset"}, + response=AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), + ), + test_id="get_asset_by_name", + ), + RequestTestCase( + message=GetAssetByUri(uri="s3://bucket/obj"), + expected_body={"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, + client_mock=ClientMock( + method_path="assets.get", + kwargs={"uri": "s3://bucket/obj"}, + response=AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), + ), + test_id="get_asset_by_uri", + ), + RequestTestCase( + message=GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), + expected_body={ + "asset_events": [ { - "dag_id": "test_dag", - "logical_dates": None, - "run_ids": None, - "states": ["success", "failed"], - }, - DRCount(count=2), - None, - id="get_dr_count", + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, + client_mock=ClientMock( + method_path="asset_events.get", + kwargs={"uri": "s3://bucket/obj", "name": "test"}, + response=AssetEventsResult( + asset_events=[ + AssetEventResponse( + id=1, + asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), + created_dagruns=[], + timestamp=timezone.parse("2024-10-31T12:00:00Z"), + ), + ], ), - pytest.param( - GetPreviousDagRun( - dag_id="test_dag", - logical_date=timezone.parse("2024-01-15T12:00:00Z"), - ), - { - "dag_run": { - "dag_id": "test_dag", - "run_id": "prev_run", - "logical_date": timezone.parse("2024-01-14T12:00:00Z"), - "run_type": "scheduled", - "start_date": timezone.parse("2024-01-15T12:00:00Z"), - "run_after": timezone.parse("2024-01-15T12:00:00Z"), - "consumed_asset_events": [], - "state": "success", - "data_interval_start": None, - "data_interval_end": None, - "end_date": None, - "clear_number": 0, - "conf": None, - }, - "type": "PreviousDagRunResult", - }, - "dag_runs.get_previous", - (), + ), + test_id="get_asset_events_by_uri_and_name", + ), + RequestTestCase( + message=GetAssetEventByAsset(uri="s3://bucket/obj", name=None), + expected_body={ + "asset_events": [ { - "dag_id": "test_dag", - "logical_date": timezone.parse("2024-01-15T12:00:00Z"), - "state": None, - }, - PreviousDagRunResult( - dag_run=DagRun( - dag_id="test_dag", - run_id="prev_run", - logical_date=timezone.parse("2024-01-14T12:00:00Z"), - run_type=DagRunType.SCHEDULED, - start_date=timezone.parse("2024-01-15T12:00:00Z"), - run_after=timezone.parse("2024-01-15T12:00:00Z"), - consumed_asset_events=[], - state=DagRunState.SUCCESS, + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, + client_mock=ClientMock( + method_path="asset_events.get", + kwargs={"uri": "s3://bucket/obj", "name": None}, + response=AssetEventsResult( + asset_events=[ + AssetEventResponse( + id=1, + asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), + created_dagruns=[], + timestamp=timezone.parse("2024-10-31T12:00:00Z"), ) - ), - None, - id="get_previous_dagrun", + ], ), - pytest.param( - GetPreviousDagRun( - dag_id="test_dag", - logical_date=timezone.parse("2024-01-15T12:00:00Z"), - state="success", - ), - { - "dag_run": None, - "type": "PreviousDagRunResult", - }, - "dag_runs.get_previous", - (), + ), + test_id="get_asset_events_by_uri", + ), + RequestTestCase( + message=GetAssetEventByAsset(uri=None, name="test"), + expected_body={ + "asset_events": [ { - "dag_id": "test_dag", - "logical_date": timezone.parse("2024-01-15T12:00:00Z"), - "state": "success", - }, - PreviousDagRunResult(dag_run=None), - None, - id="get_previous_dagrun_with_state", + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, + client_mock=ClientMock( + method_path="asset_events.get", + kwargs={"uri": None, "name": "test"}, + response=AssetEventsResult( + asset_events=[ + AssetEventResponse( + id=1, + asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), + created_dagruns=[], + timestamp=timezone.parse("2024-10-31T12:00:00Z"), + ) + ] ), - pytest.param( - GetTaskStates(dag_id="test_dag", task_group_id="test_group"), - { - "task_states": {"run_id": {"task1": "success", "task2": "failed"}}, - "type": "TaskStatesResult", - }, - "task_instances.get_task_states", - (), + ), + test_id="get_asset_events_by_name", + ), + RequestTestCase( + message=GetAssetEventByAssetAlias(alias_name="test_alias"), + expected_body={ + "asset_events": [ { - "dag_id": "test_dag", - "map_index": None, - "task_ids": None, - "logical_dates": None, - "run_ids": None, - "task_group_id": "test_group", - }, - TaskStatesResult(task_states={"run_id": {"task1": "success", "task2": "failed"}}), - None, - id="get_task_states", + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, + client_mock=ClientMock( + method_path="asset_events.get", + kwargs={"alias_name": "test_alias"}, + response=AssetEventsResult( + asset_events=[ + AssetEventResponse( + id=1, + asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), + created_dagruns=[], + timestamp=timezone.parse("2024-10-31T12:00:00Z"), + ) + ] ), - pytest.param( - GetXComSequenceItem( - key="test_key", - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - offset=0, - ), - {"root": "test_value", "type": "XComSequenceIndexResult"}, - "xcoms.get_sequence_item", - ("test_dag", "test_run", "test_task", "test_key", 0), - {}, - XComSequenceIndexResult(root="test_value"), - None, - id="get_xcom_seq_item", + ), + test_id="get_asset_events_by_asset_alias", + ), + RequestTestCase( + message=ValidateInletsAndOutlets(ti_id=TI_ID), + expected_body={ + "inactive_assets": [{"name": "asset_name", "uri": "asset_uri", "type": "asset"}], + "type": "InactiveAssetsResult", + }, + client_mock=ClientMock( + method_path="task_instances.validate_inlets_and_outlets", + args=(TI_ID,), + response=InactiveAssetsResult( + inactive_assets=[AssetProfile(name="asset_name", uri="asset_uri", type="asset")] ), - pytest.param( - # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR - GetXComSequenceItem( - key="test_key", - dag_id="test_dag", - run_id="test_run", - task_id="test_task", - offset=2, - ), - {"error": "XCOM_NOT_FOUND", "detail": None, "type": "ErrorResponse"}, - "xcoms.get_sequence_item", - ("test_dag", "test_run", "test_task", "test_key", 2), - {}, - ErrorResponse(error=ErrorType.XCOM_NOT_FOUND), - None, - id="get_xcom_seq_item_not_found", + ), + test_id="validate_inlets_and_outlets", + ), + RequestTestCase( + message=GetPrevSuccessfulDagRun(ti_id=TI_ID), + expected_body={ + "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"), + "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"), + "start_date": timezone.parse("2025-01-10T12:00:00Z"), + "end_date": timezone.parse("2025-01-10T14:00:00Z"), + "type": "PrevSuccessfulDagRunResult", + }, + client_mock=ClientMock( + method_path="task_instances.get_previous_successful_dagrun", + args=(TI_ID,), + response=PrevSuccessfulDagRunResult( + start_date=timezone.parse("2025-01-10T12:00:00Z"), + end_date=timezone.parse("2025-01-10T14:00:00Z"), + data_interval_start=timezone.parse("2025-01-10T12:00:00Z"), + data_interval_end=timezone.parse("2025-01-10T14:00:00Z"), ), - pytest.param( - GetXComSequenceSlice( - key="test_key", + ), + test_id="get_prev_successful_dagrun", + ), + RequestTestCase( + message=TriggerDagRun( + dag_id="test_dag", + run_id="test_run", + conf={"key": "value"}, + logical_date=timezone.datetime(2025, 1, 1), + reset_dag_run=True, + ), + expected_body={"ok": True, "type": "OKResponse"}, + client_mock=ClientMock( + method_path="dag_runs.trigger", + args=("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True), + response=OKResponse(ok=True), + ), + test_id="dag_run_trigger", + ), + RequestTestCase( + message=TriggerDagRun(dag_id="test_dag", run_id="test_run"), + expected_body={"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": "ErrorResponse"}, + client_mock=ClientMock( + method_path="dag_runs.trigger", + args=("test_dag", "test_run", None, None, False), + response=ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS), + ), + test_id="dag_run_trigger_already_exists", + ), + RequestTestCase( + message=GetDagRunState(dag_id="test_dag", run_id="test_run"), + expected_body={"state": "running", "type": "DagRunStateResult"}, + client_mock=ClientMock( + method_path="dag_runs.get_state", + args=("test_dag", "test_run"), + response=DagRunStateResult(state=DagRunState.RUNNING), + ), + test_id="get_dag_run_state", + ), + RequestTestCase( + message=GetPreviousDagRun( + dag_id="test_dag", + logical_date=timezone.parse("2024-01-15T12:00:00Z"), + ), + expected_body={ + "dag_run": { + "dag_id": "test_dag", + "run_id": "prev_run", + "logical_date": timezone.parse("2024-01-14T12:00:00Z"), + "run_type": "scheduled", + "start_date": timezone.parse("2024-01-15T12:00:00Z"), + "run_after": timezone.parse("2024-01-15T12:00:00Z"), + "consumed_asset_events": [], + "state": "success", + "data_interval_start": None, + "data_interval_end": None, + "end_date": None, + "clear_number": 0, + "conf": None, + }, + "type": "PreviousDagRunResult", + }, + client_mock=ClientMock( + method_path="dag_runs.get_previous", + kwargs={ + "dag_id": "test_dag", + "logical_date": timezone.parse("2024-01-15T12:00:00Z"), + "state": None, + }, + response=PreviousDagRunResult( + dag_run=DagRun( dag_id="test_dag", - run_id="test_run", - task_id="test_task", - start=None, - stop=None, - step=None, - include_prior_dates=False, - ), - {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, - "xcoms.get_sequence_slice", - ("test_dag", "test_run", "test_task", "test_key", None, None, None, False), - {}, - XComSequenceSliceResult(root=["foo", "bar"]), - None, - id="get_xcom_seq_slice", + run_id="prev_run", + logical_date=timezone.parse("2024-01-14T12:00:00Z"), + run_type=DagRunType.SCHEDULED, + start_date=timezone.parse("2024-01-15T12:00:00Z"), + run_after=timezone.parse("2024-01-15T12:00:00Z"), + consumed_asset_events=[], + state=DagRunState.SUCCESS, + ) ), - pytest.param( - CreateHITLDetailPayload( - ti_id=TI_ID, - options=["Approve", "Reject"], - subject="This is subject", - body="This is body", - defaults=["Approve"], - multiple=False, - params={}, - ), - { - "ti_id": str(TI_ID), - "options": ["Approve", "Reject"], - "subject": "This is subject", - "body": "This is body", - "defaults": ["Approve"], - "multiple": False, - "params": {}, - "respondents": None, - "type": "HITLDetailRequestResult", - }, - "hitl.add_response", - (), - { - "body": "This is body", - "defaults": ["Approve"], - "multiple": False, - "options": ["Approve", "Reject"], - "params": {}, - "respondents": None, - "subject": "This is subject", - "ti_id": TI_ID, - }, - HITLDetailRequestResult( - ti_id=TI_ID, - options=["Approve", "Reject"], - subject="This is subject", - body="This is body", - defaults=["Approve"], - multiple=False, - params={}, - ), - None, - id="create_hitl_detail_payload", + ), + test_id="get_previous_dagrun", + ), + RequestTestCase( + message=GetPreviousDagRun( + dag_id="test_dag", + logical_date=timezone.parse("2024-01-15T12:00:00Z"), + state="success", + ), + expected_body={ + "dag_run": None, + "type": "PreviousDagRunResult", + }, + client_mock=ClientMock( + method_path="dag_runs.get_previous", + kwargs={ + "dag_id": "test_dag", + "logical_date": timezone.parse("2024-01-15T12:00:00Z"), + "state": "success", + }, + response=PreviousDagRunResult(dag_run=None), + ), + test_id="get_previous_dagrun_with_state", + ), + RequestTestCase( + message=GetTaskRescheduleStartDate(ti_id=TI_ID), + expected_body={ + "start_date": timezone.parse("2024-10-31T12:00:00Z"), + "type": "TaskRescheduleStartDate", + }, + client_mock=ClientMock( + method_path="task_instances.get_reschedule_start_date", + args=(TI_ID, 1), + response=TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")), + ), + test_id="get_task_reschedule_start_date", + ), + RequestTestCase( + message=GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), + expected_body={"count": 2, "type": "TICount"}, + client_mock=ClientMock( + method_path="task_instances.get_count", + kwargs={ + "dag_id": "test_dag", + "map_index": None, + "logical_dates": None, + "run_ids": None, + "states": None, + "task_group_id": None, + "task_ids": ["task1", "task2"], + }, + response=TICount(count=2), + ), + test_id="get_ti_count", + ), + RequestTestCase( + message=GetDRCount(dag_id="test_dag", states=["success", "failed"]), + expected_body={"count": 2, "type": "DRCount"}, + client_mock=ClientMock( + method_path="dag_runs.get_count", + kwargs={ + "dag_id": "test_dag", + "logical_dates": None, + "run_ids": None, + "states": ["success", "failed"], + }, + response=DRCount(count=2), + ), + test_id="get_dr_count", + ), + RequestTestCase( + message=GetTaskStates(dag_id="test_dag", task_group_id="test_group"), + expected_body={ + "task_states": {"run_id": {"task1": "success", "task2": "failed"}}, + "type": "TaskStatesResult", + }, + client_mock=ClientMock( + method_path="task_instances.get_task_states", + kwargs={ + "dag_id": "test_dag", + "map_index": None, + "task_ids": None, + "logical_dates": None, + "run_ids": None, + "task_group_id": "test_group", + }, + response=TaskStatesResult(task_states={"run_id": {"task1": "success", "task2": "failed"}}), + ), + test_id="get_task_states", + ), + RequestTestCase( + message=GetXComSequenceItem( + key="test_key", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + offset=0, + ), + expected_body={"root": "test_value", "type": "XComSequenceIndexResult"}, + client_mock=ClientMock( + method_path="xcoms.get_sequence_item", + args=("test_dag", "test_run", "test_task", "test_key", 0), + response=XComSequenceIndexResult(root="test_value"), + ), + test_id="get_xcom_seq_item", + ), + RequestTestCase( + message=GetXComSequenceItem( + key="test_key", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + offset=2, + ), + expected_body={"error": "XCOM_NOT_FOUND", "detail": None, "type": "ErrorResponse"}, + client_mock=ClientMock( + method_path="xcoms.get_sequence_item", + args=("test_dag", "test_run", "test_task", "test_key", 2), + response=ErrorResponse(error=ErrorType.XCOM_NOT_FOUND), + ), + test_id="get_xcom_seq_item_not_found", + ), + RequestTestCase( + message=GetXComSequenceSlice( + key="test_key", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + start=None, + stop=None, + step=None, + include_prior_dates=False, + ), + expected_body={"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, + client_mock=ClientMock( + method_path="xcoms.get_sequence_slice", + args=("test_dag", "test_run", "test_task", "test_key", None, None, None, False), + response=XComSequenceSliceResult(root=["foo", "bar"]), + ), + test_id="get_xcom_seq_slice", + ), + RequestTestCase( + message=TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), + test_id="patch_task_instance_to_skipped", + ), + RequestTestCase( + message=CreateHITLDetailPayload( + ti_id=TI_ID, + options=["Approve", "Reject"], + subject="This is subject", + body="This is body", + defaults=["Approve"], + multiple=False, + params={}, + ), + expected_body={ + "ti_id": str(TI_ID), + "options": ["Approve", "Reject"], + "subject": "This is subject", + "body": "This is body", + "defaults": ["Approve"], + "multiple": False, + "params": {}, + "respondents": None, + "type": "HITLDetailRequestResult", + }, + client_mock=ClientMock( + method_path="hitl.add_response", + kwargs={ + "body": "This is body", + "defaults": ["Approve"], + "multiple": False, + "options": ["Approve", "Reject"], + "params": {}, + "respondents": None, + "subject": "This is subject", + "ti_id": TI_ID, + }, + response=HITLDetailRequestResult( + ti_id=TI_ID, + options=["Approve", "Reject"], + subject="This is subject", + body="This is body", + defaults=["Approve"], + multiple=False, + params={}, ), - ], - ) + ), + test_id="create_hitl_detail_payload", + ), + RequestTestCase( + message=MaskSecret(value=["iter1", "iter2", {"key": "value"}], name="test_secret"), + mask_secret_args=(["iter1", "iter2", {"key": "value"}], "test_secret"), + test_id="mask_secret_list", + ), + RequestTestCase( + message=GetXComCount(key="test_key", dag_id="test_dag", run_id="test_run", task_id="test_task"), + expected_body={"len": 5, "type": "XComLengthResponse"}, + client_mock=ClientMock( + method_path="xcoms.head", + args=("test_dag", "test_run", "test_task", "test_key"), + response=XComCountResponse(len=5), + ), + test_id="get_xcom_count", + ), + RequestTestCase( + message=ResendLoggingFD(), + expected_body={"fds": mock.ANY, "type": "SentFDs"}, + test_id="resend_logging_fd", + ), + RequestTestCase( + message=SkipDownstreamTasks(tasks=["task1", "task2"]), + client_mock=ClientMock( + method_path="task_instances.skip_downstream_tasks", + args=(TI_ID, SkipDownstreamTasks(tasks=["task1", "task2"])), + response=OKResponse(ok=True), + ), + test_id="skip_downstream_tasks", + ), +] + + +class TestHandleRequest: + @pytest.fixture + def watched_subprocess(self, mocker): + read_end, write_end = socket.socketpair() + + subprocess = ActivitySubprocess( + process_log=mocker.MagicMock(), + id=TI_ID, + pid=12345, + stdin=write_end, + client=mocker.Mock(), + process=mocker.Mock(), + ) + + return subprocess, read_end + + @patch("airflow.sdk.execution_time.supervisor.mask_secret") + @pytest.mark.parametrize("test_case", REQUEST_TEST_CASES, ids=lambda tc: tc.test_id) def test_handle_requests( self, mock_mask_secret, watched_subprocess, mocker, time_machine, - message, - expected_body, - client_attr_path, - method_arg, - method_kwarg, - mock_response, - mask_secret_args, + test_case: RequestTestCase, ): """ Test handling of different messages to the subprocess. For any new message type, add a @@ -2055,11 +2103,19 @@ def test_handle_requests( 3. Checks that the buffer is updated with the expected response. 4. Verifies that the response is correctly decoded. """ + # Extract values from test_case + message = test_case.message + expected_body = test_case.expected_body + client_mock = test_case.client_mock + mask_secret_args = test_case.mask_secret_args + + # Rest of test implementation (copied from original) watched_subprocess, read_socket = watched_subprocess # Mock the client method. E.g. `client.variables.get` or `client.connections.get` - mock_client_method = attrgetter(client_attr_path)(watched_subprocess.client) - mock_client_method.return_value = mock_response + if client_mock: + mock_client_method = attrgetter(client_mock.method_path)(watched_subprocess.client) + mock_client_method.return_value = client_mock.response # Simulate the generator generator = watched_subprocess.handle_requests(log=mocker.Mock()) @@ -2069,14 +2125,14 @@ def test_handle_requests( req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=message.model_dump()) generator.send(req_frame) - if mask_secret_args: + if mask_secret_args is not None: mock_mask_secret.assert_called_with(*mask_secret_args) time_machine.move_to(timezone.datetime(2024, 10, 31), tick=False) # Verify the correct client method was called - if client_attr_path: - mock_client_method.assert_called_once_with(*method_arg, **method_kwarg) + if client_mock: + mock_client_method.assert_called_once_with(*client_mock.args, **client_mock.kwargs) # Read response from the read end of the socket read_socket.settimeout(0.1) @@ -2093,9 +2149,35 @@ def test_handle_requests( # This is important because the subprocess/task runner will read the response # and deserialize it to the correct message type - if frame.body is not None: - decoder = CommsDecoder(socket=None).body_decoder - assert decoder.validate_python(frame.body) == mock_response + if frame.body is not None and client_mock: + decoder = CommsDecoder(socket=None).body_decoder # type: ignore[var-annotated, arg-type] + assert decoder.validate_python(frame.body) == client_mock.response + + def test_all_to_supervisor_messages_are_covered(self): + """Ensure all ToSupervisor message types have test coverage.""" + + # Extract the individual message types from the Union + union_type = ToSupervisor.__args__[0] + supervisor_message_types = set(union_type.__args__) + + # Get all message types covered in our test cases + tested_message_types = {type(test_case.message) for test_case in REQUEST_TEST_CASES} + + # Message types which are excluded for a good reason + excluded_message_types = { + GetHITLDetailResponse, # Only used in Triggerer, not needed in worker + UpdateHITLDetail, # Only used in Triggerer, not needed in worker + } + + untested_types = supervisor_message_types - tested_message_types - excluded_message_types + + # Assert all types are covered + assert not untested_types, ( + f"Missing test coverage for {len(untested_types)}/{len(supervisor_message_types)} " + f"ToSupervisor message types:\n" + + "\n".join(f" - {t.__name__}" for t in sorted(untested_types, key=lambda x: x.__name__)) + + "\n\nPlease add test cases to REQUEST_TEST_CASES." + ) def test_handle_requests_api_server_error(self, watched_subprocess, mocker): """Test that API server errors are properly handled and sent back to the task."""