diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index ec2ad8699e953..d680dc11bcc6c 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -43,6 +43,7 @@ GetPreviousDagRun, GetPrevSuccessfulDagRun, GetVariable, + MaskSecret, OKResponse, PreviousDagRunResult, PrevSuccessfulDagRunResult, @@ -104,7 +105,8 @@ class DagFileParsingResult(BaseModel): | PutVariable | DeleteVariable | GetPrevSuccessfulDagRun - | GetPreviousDagRun, + | GetPreviousDagRun + | MaskSecret, Field(discriminator="type"), ] @@ -427,6 +429,10 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp) resp = dagrun_result dump_opts = {"exclude_unset": True} + elif isinstance(msg, MaskSecret): + from airflow.sdk.execution_time.secrets_masker import mask_secret + + mask_secret(msg.value, msg.name) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 42fcc76dca5a3..976de591bf82b 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -125,11 +125,12 @@ I/O Helpers Execution Time Components ------------------------- .. rubric:: Context + .. autoapiclass:: airflow.sdk.Context -.. autoapimodule:: airflow.sdk.execution_time.context - :members: - :undoc-members: +.. rubric:: Logging + +.. autofunction:: airflow.sdk.log.mask_secret Everything else --------------- diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index 1ed851e32e071..cb7cce6445666 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -160,11 +160,15 @@ def get(cls, conn_id: str) -> Any: @property def extra_dejson(self) -> dict: """Deserialize `extra` property to JSON.""" + from airflow.sdk.execution_time.secrets_masker import mask_secret + extra = {} if self.extra: try: extra = json.loads(self.extra) except JSONDecodeError: log.exception("Failed to deserialize extra property `extra`, returning empty dictionary") - # TODO: Mask sensitive keys from this list or revisit if it will be done in server + else: + mask_secret(extra) + return extra diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index 9e5d6a069d9d0..1386aee258e1a 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -23,6 +23,7 @@ import attrs from airflow.sdk.definitions._internal.types import NOTSET +from airflow.sdk.log import mask_secret log = logging.getLogger(__name__) @@ -53,6 +54,7 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False): return _get_variable(key, deserialize_json=deserialize_json) except AirflowRuntimeError as e: if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET: + mask_secret(default, name=key) return default raise diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 1e97ed71db7be..b35342f8dc2e0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -49,7 +49,7 @@ from __future__ import annotations import itertools -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from datetime import datetime from functools import cached_property from pathlib import Path @@ -886,6 +886,17 @@ class UpdateHITLDetail(UpdateHITLDetailPayload): type: Literal["UpdateHITLDetail"] = "UpdateHITLDetail" +class MaskSecret(BaseModel): + """Add a new value to be redacted in task logs.""" + + # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value + # to the child process, but the redaction happens in the parent. + + value: str | dict | Iterable + name: str | None = None + type: Literal["MaskSecret"] = "MaskSecret" + + ToSupervisor = Annotated[ DeferTask | DeleteXCom @@ -920,6 +931,7 @@ class UpdateHITLDetail(UpdateHITLDetailPayload): | ResendLoggingFD | CreateHITLDetailPayload | UpdateHITLDetail - | GetHITLDetailResponse, + | GetHITLDetailResponse + | MaskSecret, Field(discriminator="type"), ] diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 8757e7819e370..0dbf86aa4e301 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -130,6 +130,11 @@ def _get_connection(conn_id: str) -> Connection: try: conn = secrets_backend.get_connection(conn_id=conn_id) if conn: + # TODO: this should probably be in get conn + if conn.password: + mask_secret(conn.password) + if conn.extra: + mask_secret(conn.extra) return conn except Exception: log.exception( diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py index 273095401a1df..4c0e5311759f1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py @@ -104,19 +104,28 @@ def should_hide_value_for_key(name): def mask_secret(secret: str | dict | Iterable, name: str | None = None) -> None: """ - Mask a secret from appearing in the task logs. + Mask a secret from appearing in the logs. - If ``name`` is provided, then it will only be masked if the name matches - one of the configured "sensitive" names. + If ``name`` is provided, then it will only be masked if the name matches one of the configured "sensitive" + names. - If ``secret`` is a dict or a iterable (excluding str) then it will be - recursively walked and keys with sensitive names will be hidden. + If ``secret`` is a dict or a iterable (excluding str) then it will be recursively walked and keys with + sensitive names will be hidden. + + If the secret value is too short (by default 5 characters or fewer, configurable via the + :ref:`[logging] min_length_masked_secret ` setting) it will not + be masked """ - # Filtering all log messages is not a free process, so we only do it when - # running tasks if not secret: return + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.comms import MaskSecret + + if comms := getattr(task_runner, "SUPERVISOR_COMMS", None): + # Tell the parent, the process which handles all logs writing and output, about the values to mask + comms.send(MaskSecret(value=secret, name=name)) + _secrets_masker().add_mask(secret, name) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 7b137e8f93c8d..b343e2845dee9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -94,6 +94,7 @@ GetXComSequenceItem, GetXComSequenceSlice, InactiveAssetsResult, + MaskSecret, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -1064,7 +1065,10 @@ def final_state(self): return TaskInstanceState.FAILED def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): - log.debug("Received message from task runner", msg=msg) + if isinstance(msg, MaskSecret): + log.debug("Received message from task runner (body omitted)", msg=type(msg)) + else: + log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None dump_opts = {} if isinstance(msg, TaskState): @@ -1264,6 +1268,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: multiple=msg.multiple, ) self.send_msg(resp, request_id=req_id, error=None, **dump_opts) + elif isinstance(msg, MaskSecret): + mask_secret(msg.value, msg.name) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index ec0c75fa5307d..98b70b9e51c93 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -24,7 +24,6 @@ import re import sys import warnings -from collections.abc import Callable from functools import cache from pathlib import Path from typing import TYPE_CHECKING, Any, BinaryIO, Generic, TextIO, TypeVar, cast @@ -33,15 +32,19 @@ import structlog if TYPE_CHECKING: + from collections.abc import Callable + from structlog.typing import EventDict, ExcInfo, FilteringBoundLogger, Processor from airflow.logging_config import RemoteLogIO + from airflow.sdk.execution_time.secrets_masker import mask_secret as mask_secret from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI __all__ = [ "configure_logging", "reset_logging", + "mask_secret", ] @@ -569,3 +572,12 @@ def upload_to_remote(logger: FilteringBoundLogger, ti: RuntimeTI): log_relative_path = relative_path.as_posix() handler.upload(log_relative_path, ti) + + +def __getattr__(name: str): + if name == "mask_secret": + from airflow.sdk.execution_time.secrets_masker import mask_secret + + globals()["mask_secret"] = mask_secret + return mask_secret + raise AttributeError(f"module {__name__!r} has no attribute {name!r}")