diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 0c9f6a7de8ad7..4fd26cd66f07b 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -43,6 +43,7 @@ GetPreviousDagRun, GetPrevSuccessfulDagRun, GetVariable, + MaskSecret, OKResponse, PreviousDagRunResult, PrevSuccessfulDagRunResult, @@ -106,6 +107,7 @@ class DagFileParsingResult(BaseModel): DeleteVariable, GetPrevSuccessfulDagRun, GetPreviousDagRun, + MaskSecret, ], Field(discriminator="type"), ] @@ -431,6 +433,10 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp) resp = dagrun_result dump_opts = {"exclude_unset": True} + elif isinstance(msg, MaskSecret): + from airflow.sdk.execution_time.secrets_masker import mask_secret + + mask_secret(msg.value, msg.name) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 916854fdf426a..ddcee486680e6 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -58,6 +58,7 @@ GetTICount, GetVariable, GetXCom, + MaskSecret, OKResponse, PutVariable, SetXCom, @@ -250,6 +251,7 @@ class TriggerStateSync(BaseModel): GetTaskStates, GetDagRunState, GetDRCount, + MaskSecret, ], Field(discriminator="type"), ] @@ -472,6 +474,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = TaskStatesResult.from_api_response(run_id_task_state_map) else: resp = run_id_task_state_map + elif isinstance(msg, MaskSecret): + from airflow.sdk.execution_time.secrets_masker import mask_secret + + mask_secret(msg.value, msg.name) else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/airflow-core/tests/unit/hooks/test_base.py b/airflow-core/tests/unit/hooks/test_base.py index 6d54827e67ced..9cba1cbbd17b8 100644 --- a/airflow-core/tests/unit/hooks/test_base.py +++ b/airflow-core/tests/unit/hooks/test_base.py @@ -17,12 +17,14 @@ # under the License. from __future__ import annotations +from unittest.mock import call + import pytest from airflow.exceptions import AirflowNotFoundException from airflow.hooks.base import BaseHook from airflow.sdk.exceptions import ErrorType -from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, GetConnection +from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, GetConnection, MaskSecret from tests_common.test_utils.config import conf_vars @@ -56,8 +58,12 @@ def test_get_connection(self, mock_supervisor_comms): hook = BaseHook(logger_name="") hook.get_connection(conn_id="test_conn") - mock_supervisor_comms.send.assert_called_once_with( - msg=GetConnection(conn_id="test_conn"), + mock_supervisor_comms.send.assert_has_calls( + [ + call(GetConnection(conn_id="test_conn", type="GetConnection")), + call(MaskSecret(value="password", name=None, type="MaskSecret")), + call(MaskSecret(value='{"extra_key": "extra_value"}', name=None, type="MaskSecret")), + ] ) def test_get_connection_not_found(self, mock_supervisor_comms): diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 89f5823f658d6..4265d6e2c40a7 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -626,11 +626,13 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: from airflow.sdk import Variable from airflow.sdk.execution_time.xcom import XCom + from airflow.sdk.log import mask_secret conn = await sync_to_async(BaseHook.get_connection)("test_connection") self.log.info("Loaded conn %s", conn.conn_id) get_variable_value = await sync_to_async(Variable.get)("test_get_variable") + await sync_to_async(mask_secret)(get_variable_value) self.log.info("Loaded variable %s", get_variable_value) get_xcom_value = await sync_to_async(XCom.get_one)( diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 8605b234792a6..f0a9f7499fed2 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -114,11 +114,12 @@ I/O Helpers Execution Time Components ------------------------- .. rubric:: Context + .. autoapiclass:: airflow.sdk.Context -.. autoapimodule:: airflow.sdk.execution_time.context - :members: - :undoc-members: +.. rubric:: Logging + +.. autofunction:: airflow.sdk.log.mask_secret Everything else --------------- diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index 837a1d4e25df2..0a0d5ca7ad374 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -150,11 +150,15 @@ def get(cls, conn_id: str) -> Any: @property def extra_dejson(self) -> dict: """Deserialize `extra` property to JSON.""" + from airflow.sdk.execution_time.secrets_masker import mask_secret + extra = {} if self.extra: try: extra = json.loads(self.extra) except JSONDecodeError: log.exception("Failed to deserialize extra property `extra`, returning empty dictionary") - # TODO: Mask sensitive keys from this list or revisit if it will be done in server + else: + mask_secret(extra) + return extra diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index 9e5d6a069d9d0..1386aee258e1a 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -23,6 +23,7 @@ import attrs from airflow.sdk.definitions._internal.types import NOTSET +from airflow.sdk.log import mask_secret log = logging.getLogger(__name__) @@ -53,6 +54,7 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False): return _get_variable(key, deserialize_json=deserialize_json) except AirflowRuntimeError as e: if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET: + mask_secret(default, name=key) return default raise diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 3069490521783..d0f0eae04c88c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -49,7 +49,7 @@ from __future__ import annotations import itertools -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from datetime import datetime from functools import cached_property from pathlib import Path @@ -858,6 +858,17 @@ class GetDRCount(BaseModel): type: Literal["GetDRCount"] = "GetDRCount" +class MaskSecret(BaseModel): + """Add a new value to be redacted in task logs.""" + + # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value + # to the child process, but the redaction happens in the parent. + + value: str | dict | Iterable + name: str | None = None + type: Literal["MaskSecret"] = "MaskSecret" + + ToSupervisor = Annotated[ Union[ DeferTask, @@ -891,6 +902,7 @@ class GetDRCount(BaseModel): TriggerDagRun, DeleteVariable, ResendLoggingFD, + MaskSecret, ], Field(discriminator="type"), ] diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 583c5efcb6eec..adcafbf8cd878 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -130,6 +130,11 @@ def _get_connection(conn_id: str) -> Connection: try: conn = secrets_backend.get_connection(conn_id=conn_id) if conn: + # TODO: this should probably be in get conn + if conn.password: + mask_secret(conn.password) + if conn.extra: + mask_secret(conn.extra) return conn except Exception: log.exception( diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py index b9235a2c2db39..9d75c097bd22a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py @@ -107,19 +107,28 @@ def should_hide_value_for_key(name): def mask_secret(secret: str | dict | Iterable, name: str | None = None) -> None: """ - Mask a secret from appearing in the task logs. + Mask a secret from appearing in the logs. - If ``name`` is provided, then it will only be masked if the name matches - one of the configured "sensitive" names. + If ``name`` is provided, then it will only be masked if the name matches one of the configured "sensitive" + names. - If ``secret`` is a dict or a iterable (excluding str) then it will be - recursively walked and keys with sensitive names will be hidden. + If ``secret`` is a dict or a iterable (excluding str) then it will be recursively walked and keys with + sensitive names will be hidden. + + If the secret value is too short (by default 5 characters or fewer, configurable via the + :ref:`[logging] min_length_masked_secret ` setting) it will not + be masked """ - # Filtering all log messages is not a free process, so we only do it when - # running tasks if not secret: return + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.comms import MaskSecret + + if comms := getattr(task_runner, "SUPERVISOR_COMMS", None): + # Tell the parent, the process which handles all logs writing and output, about the values to mask + comms.send(MaskSecret(value=secret, name=name)) + _secrets_masker().add_mask(secret, name) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 2145053747f8a..2d083150920ab 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -94,6 +94,7 @@ GetXComSequenceItem, GetXComSequenceSlice, InactiveAssetsResult, + MaskSecret, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -1064,7 +1065,10 @@ def final_state(self): return TaskInstanceState.FAILED def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): - log.debug("Received message from task runner", msg=msg) + if isinstance(msg, MaskSecret): + log.debug("Received message from task runner (body omitted)", msg=type(msg)) + else: + log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None dump_opts = {} if isinstance(msg, TaskState): @@ -1253,6 +1257,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: self._send_new_log_fd(req_id) # Since we've sent the message, return. Nothing else in this ifelse/switch should return directly return + elif isinstance(msg, MaskSecret): + mask_secret(msg.value, msg.name) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 0f8ca6097fae6..4cfd255e46b86 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -32,15 +32,19 @@ import structlog if TYPE_CHECKING: + from collections.abc import Callable + from structlog.typing import EventDict, ExcInfo, FilteringBoundLogger, Processor from airflow.logging_config import RemoteLogIO + from airflow.sdk.execution_time.secrets_masker import mask_secret as mask_secret from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI __all__ = [ "configure_logging", "reset_logging", + "mask_secret", ] @@ -568,3 +572,12 @@ def upload_to_remote(logger: FilteringBoundLogger, ti: RuntimeTI): log_relative_path = relative_path.as_posix() handler.upload(log_relative_path, ti) + + +def __getattr__(name: str): + if name == "mask_secret": + from airflow.sdk.execution_time.secrets_masker import mask_secret + + globals()["mask_secret"] = mask_secret + return mask_secret + raise AttributeError(f"module {__name__!r} has no attribute {name!r}")