diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d0f0eae04c88c..41fb1d99da0d8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -49,7 +49,7 @@ from __future__ import annotations import itertools -from collections.abc import Iterable, Iterator +from collections.abc import Iterator from datetime import datetime from functools import cached_property from pathlib import Path @@ -863,8 +863,9 @@ class MaskSecret(BaseModel): # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value # to the child process, but the redaction happens in the parent. - - value: str | dict | Iterable + # We cannot use `string | Iterable | dict here` (would be more intuitive) because bug in Pydantic + # https://github.com/pydantic/pydantic/issues/9541 turns iterable into a ValidatorIterator + value: JsonValue name: str | None = None type: Literal["MaskSecret"] = "MaskSecret" diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py index 9d75c097bd22a..391bf49701246 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py @@ -37,6 +37,10 @@ overload, ) +# We have to import this here, as it is used in the type annotations at runtime even if it seems it is +# not used in the code. This is because Pydantic uses type at runtime to validate the types of the fields. +from pydantic import JsonValue # noqa: TC002 + from airflow import settings if TYPE_CHECKING: @@ -105,7 +109,7 @@ def should_hide_value_for_key(name): return False -def mask_secret(secret: str | dict | Iterable, name: str | None = None) -> None: +def mask_secret(secret: JsonValue, name: str | None = None) -> None: """ Mask a secret from appearing in the logs. @@ -475,7 +479,7 @@ def _adaptations(self, secret: str) -> Generator[str, None, None]: else: yield secret_or_secrets - def add_mask(self, secret: str | dict | Iterable, name: str | None = None): + def add_mask(self, secret: JsonValue, name: str | None = None): """Add a new secret to be masked to this filter instance.""" if isinstance(secret, dict): for k, v in secret.items(): diff --git a/task-sdk/tests/task_sdk/definitions/test_secrets_masker.py b/task-sdk/tests/task_sdk/definitions/test_secrets_masker.py index e9599af4995de..e03266bbf5ff8 100644 --- a/task-sdk/tests/task_sdk/definitions/test_secrets_masker.py +++ b/task-sdk/tests/task_sdk/definitions/test_secrets_masker.py @@ -30,6 +30,7 @@ import pytest from airflow.models import Connection +from airflow.sdk.execution_time.comms import MaskSecret from airflow.sdk.execution_time.secrets_masker import ( RedactedIO, SecretsMasker, @@ -545,6 +546,36 @@ def test_add_mask_short_secrets_and_skip_keywords( if should_be_masked: assert filt.replacer is not None + @pytest.mark.parametrize( + "object_to_mask", + [ + { + "key_path": "/files/airflow-breeze-config/keys2/keys.json", + "scope": "https://www.googleapis.com/auth/cloud-platform", + "project": "project_id", + "num_retries": 6, + }, + ["iter1", "iter2", {"key": "value"}], + "string", + { + "key1": "value1", + }, + ], + ) + def test_mask_secret_with_objects(self, object_to_mask): + mask_secret_object = MaskSecret(value=object_to_mask, name="test_secret") + assert mask_secret_object.value == object_to_mask + + def test_mask_secret_with_list(self): + example_dict = ["test"] + mask_secret_object = MaskSecret(value=example_dict, name="test_secret") + assert mask_secret_object.value == example_dict + + def test_mask_secret_with_iterable(self): + example_dict = ["test"] + mask_secret_object = MaskSecret(value=example_dict, name="test_secret") + assert mask_secret_object.value == example_dict + class TestStructuredVsUnstructuredMasking: def test_structured_sensitive_fields_always_masked(self):