diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index d0d7b7d9eb1fe..0328525b0c816 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1034,7 +1034,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): var = self.client.variables.get(msg.key) if isinstance(var, VariableResponse): if var.value: - mask_secret(var.value) + mask_secret(var.value, var.key) var_result = VariableResult.from_variable_response(var) resp = var_result dump_opts = {"exclude_unset": True} diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 1e6aec6bed06d..1d19d19021e96 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -92,7 +92,6 @@ VariableResult, XComResult, ) -from airflow.sdk.execution_time.secrets_masker import SecretsMasker from airflow.sdk.execution_time.supervisor import ( BUFFER_SIZE, ActivitySubprocess, @@ -977,9 +976,17 @@ def watched_subprocess(self, mocker): return subprocess, read_end - @patch("airflow.sdk.execution_time.secrets_masker._secrets_masker") + @patch("airflow.sdk.execution_time.supervisor.mask_secret") @pytest.mark.parametrize( - ["message", "expected_buffer", "client_attr_path", "method_arg", "method_kwarg", "mock_response"], + [ + "message", + "expected_buffer", + "client_attr_path", + "method_arg", + "method_kwarg", + "mock_response", + "mask_secret_args", + ], [ pytest.param( GetConnection(conn_id="test_conn"), @@ -988,8 +995,19 @@ def watched_subprocess(self, mocker): ("test_conn",), {}, ConnectionResult(conn_id="test_conn", conn_type="mysql"), + None, id="get_connection", ), + pytest.param( + GetConnection(conn_id="test_conn"), + b'{"conn_id":"test_conn","conn_type":"mysql","password":"password","type":"ConnectionResult"}\n', + "connections.get", + ("test_conn",), + {}, + ConnectionResult(conn_id="test_conn", conn_type="mysql", password="password"), + ["password"], + id="get_connection_with_password", + ), pytest.param( GetConnection(conn_id="test_conn"), b'{"conn_id":"test_conn","conn_type":"mysql","schema":"mysql","type":"ConnectionResult"}\n', @@ -997,6 +1015,7 @@ def watched_subprocess(self, mocker): ("test_conn",), {}, ConnectionResult(conn_id="test_conn", conn_type="mysql", schema="mysql"), # type: ignore[call-arg] + None, id="get_connection_with_alias", ), pytest.param( @@ -1006,6 +1025,7 @@ def watched_subprocess(self, mocker): ("test_key",), {}, VariableResult(key="test_key", value="test_value"), + ["test_value", "test_key"], id="get_variable", ), pytest.param( @@ -1015,6 +1035,7 @@ def watched_subprocess(self, mocker): ("test_key", "test_value", "test_description"), {}, OKResponse(ok=True), + None, id="set_variable", ), pytest.param( @@ -1024,6 +1045,7 @@ def watched_subprocess(self, mocker): ("test_key",), {}, OKResponse(ok=True), + None, id="delete_variable", ), pytest.param( @@ -1033,6 +1055,7 @@ def watched_subprocess(self, mocker): (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), {}, "", + None, id="patch_task_instance_to_deferred", ), pytest.param( @@ -1051,6 +1074,7 @@ def watched_subprocess(self, mocker): ), {}, "", + None, id="patch_task_instance_to_up_for_reschedule", ), pytest.param( @@ -1060,6 +1084,7 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run", "test_task", "test_key", None, False), {}, XComResult(key="test_key", value="test_value"), + None, id="get_xcom", ), pytest.param( @@ -1071,6 +1096,7 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run", "test_task", "test_key", 2, False), {}, XComResult(key="test_key", value="test_value"), + None, id="get_xcom_map_index", ), pytest.param( @@ -1080,6 +1106,7 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run", "test_task", "test_key", None, False), {}, XComResult(key="test_key", value=None, type="XComResult"), + None, id="get_xcom_not_found", ), pytest.param( @@ -1095,6 +1122,7 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run", "test_task", "test_key", None, True), {}, XComResult(key="test_key", value=None, type="XComResult"), + None, id="get_xcom_include_prior_dates", ), pytest.param( @@ -1118,6 +1146,7 @@ def watched_subprocess(self, mocker): ), {}, OKResponse(ok=True), + None, id="set_xcom", ), pytest.param( @@ -1142,6 +1171,7 @@ def watched_subprocess(self, mocker): ), {}, OKResponse(ok=True), + None, id="set_xcom_with_map_index", ), pytest.param( @@ -1167,6 +1197,7 @@ def watched_subprocess(self, mocker): ), {}, OKResponse(ok=True), + None, id="set_xcom_with_map_index_and_mapped_length", ), pytest.param( @@ -1188,6 +1219,7 @@ def watched_subprocess(self, mocker): ), {}, OKResponse(ok=True), + None, id="delete_xcom", ), # we aren't adding all states under TaskInstanceState here, because this test's scope is only to check @@ -1199,6 +1231,7 @@ def watched_subprocess(self, mocker): (), {}, "", + None, id="patch_task_instance_to_skipped", ), pytest.param( @@ -1214,6 +1247,7 @@ def watched_subprocess(self, mocker): "rendered_map_index": "test retry task", }, "", + None, id="up_for_retry", ), pytest.param( @@ -1223,6 +1257,7 @@ def watched_subprocess(self, mocker): (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), {}, OKResponse(ok=True), + None, id="set_rtif", ), pytest.param( @@ -1232,6 +1267,7 @@ def watched_subprocess(self, mocker): [], {"name": "asset"}, AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), + None, id="get_asset_by_name", ), pytest.param( @@ -1241,6 +1277,7 @@ def watched_subprocess(self, mocker): [], {"uri": "s3://bucket/obj"}, AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), + None, id="get_asset_by_uri", ), pytest.param( @@ -1263,6 +1300,7 @@ def watched_subprocess(self, mocker): ) ] ), + None, id="get_asset_events_by_uri_and_name", ), pytest.param( @@ -1285,6 +1323,7 @@ def watched_subprocess(self, mocker): ) ] ), + None, id="get_asset_events_by_uri", ), pytest.param( @@ -1307,6 +1346,7 @@ def watched_subprocess(self, mocker): ) ] ), + None, id="get_asset_events_by_name", ), pytest.param( @@ -1329,6 +1369,7 @@ def watched_subprocess(self, mocker): ) ] ), + None, id="get_asset_events_by_asset_alias", ), pytest.param( @@ -1346,6 +1387,7 @@ def watched_subprocess(self, mocker): "rendered_map_index": "test success task", }, "", + None, id="succeed_task", ), pytest.param( @@ -1364,6 +1406,7 @@ def watched_subprocess(self, mocker): data_interval_start=timezone.parse("2025-01-10T12:00:00Z"), data_interval_end=timezone.parse("2025-01-10T14:00:00Z"), ), + None, id="get_prev_successful_dagrun", ), pytest.param( @@ -1379,6 +1422,7 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True), {}, OKResponse(ok=True), + None, id="dag_run_trigger", ), pytest.param( @@ -1388,6 +1432,7 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run", None, None, False), {}, ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS), + None, id="dag_run_trigger_already_exists", ), pytest.param( @@ -1397,6 +1442,7 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run"), {}, DagRunStateResult(state=DagRunState.RUNNING), + None, id="get_dag_run_state", ), pytest.param( @@ -1406,6 +1452,7 @@ def watched_subprocess(self, mocker): (TI_ID, 1), {}, TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")), + None, id="get_task_reschedule_start_date", ), pytest.param( @@ -1423,6 +1470,7 @@ def watched_subprocess(self, mocker): "task_ids": ["task1", "task2"], }, TICount(count=2), + None, id="get_ti_count", ), pytest.param( @@ -1437,6 +1485,7 @@ def watched_subprocess(self, mocker): "states": ["success", "failed"], }, DRCount(count=2), + None, id="get_dr_count", ), pytest.param( @@ -1453,6 +1502,7 @@ def watched_subprocess(self, mocker): "task_group_id": "test_group", }, TaskStatesResult(task_states={"run_id": {"task1": "success", "task2": "failed"}}), + None, id="get_task_states", ), pytest.param( @@ -1468,6 +1518,7 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run", "test_task", "test_key", 0), {}, XComResult(key="test_key", value="test_value"), + None, id="get_xcom_seq_item", ), pytest.param( @@ -1483,13 +1534,14 @@ def watched_subprocess(self, mocker): ("test_dag", "test_run", "test_task", "test_key", 2), {}, ErrorResponse(error=ErrorType.XCOM_NOT_FOUND), + None, id="get_xcom_seq_item_not_found", ), ], ) def test_handle_requests( self, - mock_secrets_masker, + mock_mask_secret, watched_subprocess, mocker, time_machine, @@ -1499,6 +1551,7 @@ def test_handle_requests( method_arg, method_kwarg, mock_response, + mask_secret_args, ): """ Test handling of different messages to the subprocess. For any new message type, add a @@ -1511,7 +1564,6 @@ def test_handle_requests( 3. Checks that the buffer is updated with the expected response. 4. Verifies that the response is correctly decoded. """ - mock_secrets_masker.return_value = SecretsMasker() watched_subprocess, read_socket = watched_subprocess # Mock the client method. E.g. `client.variables.get` or `client.connections.get` @@ -1524,6 +1576,10 @@ def test_handle_requests( next(generator) msg = message.model_dump_json().encode() + b"\n" generator.send(msg) + + if mask_secret_args: + mock_mask_secret.assert_called_with(*mask_secret_args) + time_machine.move_to(timezone.datetime(2024, 10, 31), tick=False) # Verify the correct client method was called