diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index e10bcd1d5bde9..2bf035c0392c1 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1826,7 +1826,7 @@ def _disable_redact(request: pytest.FixtureRequest, mocker): ) mocked_redact = mocker.patch(target) - mocked_redact.side_effect = lambda item, name=None, max_depth=None: item + mocked_redact.side_effect = lambda item, *args, **kwargs: item with pytest.MonkeyPatch.context() as mp_ctx: mp_ctx.setattr(settings, "MASK_SECRETS_IN_LOGS", False) yield diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index a3dd26d686bca..6ab450eb12151 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -830,7 +830,7 @@ def from_masker(cls, other: SecretsMasker) -> OpenLineageRedactor: instance.replacer = other.replacer return instance - def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int) -> Redacted: + def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int, **kwargs) -> Redacted: # type: ignore[override] if AIRFLOW_V_3_0_PLUS: # Keep compatibility for Airflow 2.x, remove when Airflow 3.0 is the minimum version class AirflowContextDeprecationWarning(UserWarning): @@ -886,7 +886,7 @@ class AirflowContextDeprecationWarning(UserWarning): ), ) return item - return super()._redact(item, name, depth, max_depth) + return super()._redact(item, name, depth, max_depth, **kwargs) except Exception as exc: log.warning("Unable to redact %r. Error was: %s: %s", item, type(exc).__name__, exc) return item diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py index 04f2c0e8e821b..273095401a1df 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py @@ -20,6 +20,8 @@ import collections.abc import contextlib +import functools +import inspect import logging import re import sys @@ -118,9 +120,11 @@ def mask_secret(secret: str | dict | Iterable, name: str | None = None) -> None: _secrets_masker().add_mask(secret, name) -def redact(value: Redactable, name: str | None = None, max_depth: int | None = None) -> Redacted: - """Redact any secrets found in ``value``.""" - return _secrets_masker().redact(value, name, max_depth) +def redact( + value: Redactable, name: str | None = None, max_depth: int | None = None, replacement: str = "***" +) -> Redacted: + """Redact any secrets found in ``value`` with the given replacement.""" + return _secrets_masker().redact(value, name, max_depth, replacement=replacement) @overload @@ -198,6 +202,29 @@ def __init__(self): super().__init__() self.patterns = set() + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + if cls._redact is not SecretsMasker._redact: + sig = inspect.signature(cls._redact) + # Compat for older versions of the OpenLineage plugin which subclasses this -- call the method + # without the replacement character + for param in sig.parameters.values(): + if param.name == "replacement" or param.kind == param.VAR_KEYWORD: + break + else: + # Block only runs if no break above. + + f = cls._redact + + @functools.wraps(f) + def _redact(*args, replacement: str = "***", **kwargs): + return f(*args, **kwargs) + + cls._redact = _redact + ... + @cached_property def _record_attrs_to_ignore(self) -> Iterable[str]: # Doing log.info(..., extra={'foo': 2}) sets extra properties on @@ -251,21 +278,35 @@ def filter(self, record) -> bool: # Default on `max_depth` is to support versions of the OpenLineage plugin (not the provider) which called # this function directly. New versions of that provider, and this class itself call it with a value - def _redact_all(self, item: Redactable, depth: int, max_depth: int = MAX_RECURSION_DEPTH) -> Redacted: + def _redact_all( + self, + item: Redactable, + depth: int, + max_depth: int = MAX_RECURSION_DEPTH, + *, + replacement: str = "***", + ) -> Redacted: if depth > max_depth or isinstance(item, str): - return "***" + return replacement if isinstance(item, dict): return { - dict_key: self._redact_all(subval, depth + 1, max_depth) for dict_key, subval in item.items() + dict_key: self._redact_all(subval, depth + 1, max_depth, replacement=replacement) + for dict_key, subval in item.items() } if isinstance(item, (tuple, set)): # Turn set in to tuple! - return tuple(self._redact_all(subval, depth + 1, max_depth) for subval in item) + return tuple( + self._redact_all(subval, depth + 1, max_depth, replacement=replacement) for subval in item + ) if isinstance(item, list): - return list(self._redact_all(subval, depth + 1, max_depth) for subval in item) + return list( + self._redact_all(subval, depth + 1, max_depth, replacement=replacement) for subval in item + ) return item - def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int) -> Redacted: + def _redact( + self, item: Redactable, name: str | None, depth: int, max_depth: int, replacement: str = "***" + ) -> Redacted: # Avoid spending too much effort on redacting on deeply nested # structures. This also avoid infinite recursion if a structure has # reference to self. @@ -273,37 +314,49 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int return item try: if name and should_hide_value_for_key(name): - return self._redact_all(item, depth, max_depth) + return self._redact_all(item, depth, max_depth, replacement=replacement) if isinstance(item, dict): to_return = { - dict_key: self._redact(subval, name=dict_key, depth=(depth + 1), max_depth=max_depth) + dict_key: self._redact( + subval, name=dict_key, depth=(depth + 1), max_depth=max_depth, replacement=replacement + ) for dict_key, subval in item.items() } return to_return if isinstance(item, Enum): - return self._redact(item=item.value, name=name, depth=depth, max_depth=max_depth) + return self._redact( + item=item.value, name=name, depth=depth, max_depth=max_depth, replacement=replacement + ) if _is_v1_env_var(item): tmp = item.to_dict() if should_hide_value_for_key(tmp.get("name", "")) and "value" in tmp: - tmp["value"] = "***" + tmp["value"] = replacement else: - return self._redact(item=tmp, name=name, depth=depth, max_depth=max_depth) + return self._redact( + item=tmp, name=name, depth=depth, max_depth=max_depth, replacement=replacement + ) return tmp if isinstance(item, str): if self.replacer: # We can't replace specific values, but the key-based redacting # can still happen, so we can't short-circuit, we need to walk # the structure. - return self.replacer.sub("***", str(item)) + return self.replacer.sub(replacement, str(item)) return item if isinstance(item, (tuple, set)): # Turn set in to tuple! return tuple( - self._redact(subval, name=None, depth=(depth + 1), max_depth=max_depth) for subval in item + self._redact( + subval, name=None, depth=(depth + 1), max_depth=max_depth, replacement=replacement + ) + for subval in item ) if isinstance(item, list): return [ - self._redact(subval, name=None, depth=(depth + 1), max_depth=max_depth) for subval in item + self._redact( + subval, name=None, depth=(depth + 1), max_depth=max_depth, replacement=replacement + ) + for subval in item ] return item # I think this should never happen, but it does not hurt to leave it just in case @@ -325,10 +378,12 @@ def _merge( self, new_item: Redacted, old_item: Redactable, + *, name: str | None, depth: int, max_depth: int, force_sensitive: bool = False, + replacement: str, ) -> Redacted: """Merge a redacted item with its original unredacted counterpart.""" if depth > max_depth: @@ -353,6 +408,7 @@ def _merge( depth=depth + 1, max_depth=max_depth, force_sensitive=is_sensitive, + replacement=replacement, ) else: merged[key] = new_item[key] @@ -374,6 +430,7 @@ def _merge( depth=depth + 1, max_depth=max_depth, force_sensitive=is_sensitive, + replacement=replacement, ) ) else: @@ -398,7 +455,13 @@ def _merge( except (TypeError, AttributeError, ValueError): return new_item - def redact(self, item: Redactable, name: str | None = None, max_depth: int | None = None) -> Redacted: + def redact( + self, + item: Redactable, + name: str | None = None, + max_depth: int | None = None, + replacement: str = "***", + ) -> Redacted: """ Redact an any secrets found in ``item``, if it is a string. @@ -406,17 +469,24 @@ def redact(self, item: Redactable, name: str | None = None, max_depth: int | Non :func:`should_hide_value_for_key`) then all string values in the item is redacted. """ - return self._redact(item, name, depth=0, max_depth=max_depth or self.MAX_RECURSION_DEPTH) + return self._redact( + item, name, depth=0, max_depth=max_depth or self.MAX_RECURSION_DEPTH, replacement=replacement + ) def merge( - self, new_item: Redacted, old_item: Redactable, name: str | None = None, max_depth: int | None = None + self, + new_item: Redacted, + old_item: Redactable, + name: str | None = None, + max_depth: int | None = None, + replacement: str = "***", ) -> Redacted: """ Merge a redacted item with its original unredacted counterpart. Takes a user-modified redacted item and merges it with the original unredacted item. - For sensitive fields that still contain "***" (unchanged), the original value is restored. - For fields that have been updated, the new value is preserved. + For sensitive fields that still contain "***" (or whatever the ``replacement`` is specified as), the + original value is restored. For fields that have been updated, the new value is preserved. """ return self._merge( new_item, @@ -425,6 +495,7 @@ def merge( depth=0, max_depth=max_depth or self.MAX_RECURSION_DEPTH, force_sensitive=False, + replacement=replacement, ) @cached_property diff --git a/task-sdk/tests/task_sdk/definitions/test_secrets_masker.py b/task-sdk/tests/task_sdk/definitions/test_secrets_masker.py index 2e95d60b94397..5dd24a800a3f2 100644 --- a/task-sdk/tests/task_sdk/definitions/test_secrets_masker.py +++ b/task-sdk/tests/task_sdk/definitions/test_secrets_masker.py @@ -287,6 +287,26 @@ def test_redact(self, patterns, name, value, expected): assert filt.redact(value, name) == expected + @pytest.mark.parametrize( + ("name", "value", "expected"), + [ + ("api_key", "pass", "*️⃣*️⃣*️⃣"), + ("api_key", ("pass",), ("*️⃣*️⃣*️⃣",)), + (None, {"data": {"secret": "secret"}}, {"data": {"secret": "*️⃣*️⃣*️⃣"}}), + # Non string dict keys + (None, {1: {"secret": "secret"}}, {1: {"secret": "*️⃣*️⃣*️⃣"}}), + ( + "api_key", + {"other": "innoent", "nested": ["x", "y"]}, + {"other": "*️⃣*️⃣*️⃣", "nested": ["*️⃣*️⃣*️⃣", "*️⃣*️⃣*️⃣"]}, + ), + ], + ) + def test_redact_replacement(self, name, value, expected): + filt = SecretsMasker() + + assert filt.redact(value, name, replacement="*️⃣*️⃣*️⃣") == expected + def test_redact_filehandles(self, caplog): filt = SecretsMasker() with open("/dev/null", "w") as handle: @@ -699,7 +719,7 @@ def test_redact_all_directly(self): "nested": {"tuple": ("a", "b", "c"), "set": {"x", "y", "z"}}, } - result = secrets_masker._redact_all(test_data, depth=0) + result = secrets_masker._redact_all(test_data, depth=0, replacement="***") assert result["string"] == "***" assert result["number"] == 12345