From 22df107c4c5108bd1f05fd5dadd59099d5082b5c Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 22 Dec 2025 12:32:36 +0530 Subject: [PATCH] [v3-1-test] Redact secrets in rendered templates properly when truncating it (#59566) (cherry picked from commit 8defa759f9f569d1fe8bc2d6de9c0bb957ec505b) Co-authored-by: Amogh Desai --- .../airflow/sdk/execution_time/task_runner.py | 80 +++++++++++++++++-- .../execution_time/test_task_runner.py | 44 ++++++++++ 2 files changed, 119 insertions(+), 5 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index dbcd6329bce5d..79df1beb605ad 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -763,17 +763,87 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: return ti, ti.get_template_context(), log +def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: + """ + Return a serializable representation of the templated field. + + If ``templated_field`` contains a class or instance that requires recursive + templating, store them as strings. Otherwise simply return the field as-is. + + Used sdk secrets masker to redact secrets in the serialized output. + """ + import json + + from airflow.sdk._shared.secrets_masker import redact + + def is_jsonable(x): + try: + json.dumps(x) + except (TypeError, OverflowError): + return False + else: + return True + + def translate_tuples_to_lists(obj: Any): + """Recursively convert tuples to lists.""" + if isinstance(obj, tuple): + return [translate_tuples_to_lists(item) for item in obj] + if isinstance(obj, list): + return [translate_tuples_to_lists(item) for item in obj] + if isinstance(obj, dict): + return {key: translate_tuples_to_lists(value) for key, value in obj.items()} + return obj + + def sort_dict_recursively(obj: Any) -> Any: + """Recursively sort dictionaries to ensure consistent ordering.""" + if isinstance(obj, dict): + return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} + if isinstance(obj, list): + return [sort_dict_recursively(item) for item in obj] + if isinstance(obj, tuple): + return tuple(sort_dict_recursively(item) for item in obj) + return obj + + max_length = conf.getint("core", "max_templated_field_length") + + if not is_jsonable(template_field): + try: + serialized = template_field.serialize() + except AttributeError: + serialized = str(template_field) + if len(serialized) > max_length: + rendered = redact(serialized, name) + return ( + "Truncated. You can change this behaviour in [core]max_templated_field_length. " + f"{rendered[: max_length - 79]!r}... " + ) + return serialized + if not template_field and not isinstance(template_field, tuple): + # Avoid unnecessary serialization steps for empty fields unless they are tuples + # and need to be converted to lists + return template_field + template_field = translate_tuples_to_lists(template_field) + # Sort dictionaries recursively to ensure consistent string representation + # This prevents hash inconsistencies when dict ordering varies + if isinstance(template_field, dict): + template_field = sort_dict_recursively(template_field) + serialized = str(template_field) + if len(serialized) > max_length: + rendered = redact(serialized, name) + return ( + "Truncated. You can change this behaviour in [core]max_templated_field_length. " + f"{rendered[: max_length - 79]!r}... " + ) + return template_field + + def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: - # TODO: Port one of the following to Task SDK - # airflow.serialization.helpers.serialize_template_field or - # airflow.models.renderedtifields.get_serialized_template_fields from airflow.sdk._shared.secrets_masker import redact - from airflow.serialization.helpers import serialize_template_field rendered_fields = {} for field in task.template_fields: value = getattr(task, field) - serialized = serialize_template_field(value, field) + serialized = _serialize_template_field(value, field) # Redact secrets in the task process itself before sending to API server # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked # on the UI. diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 96128bc810de8..cb0086195acef 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2155,6 +2155,50 @@ def dict_task(): pulled = runtime_ti.xcom_pull(key="key/slash", task_ids="dict_task") assert pulled == "Some Value" + @pytest.mark.enable_redact + def test_rendered_templates_mask_secrets_with_truncation(self, create_runtime_ti, mock_supervisor_comms): + """Test that secrets are masked before truncation when rendered fields exceed max_templated_field_length.""" + from airflow.sdk._shared.secrets_masker import _secrets_masker + + secret_url = "postgresql+psycopg2://username:testpass123@test.domain.com/testdb" + _secrets_masker().add_mask(secret_url, None) + + class CustomOperator(BaseOperator): + template_fields = ("env_vars", "region") + + def __init__(self, env_vars, region, *args, **kwargs): + super().__init__(*args, **kwargs) + self.env_vars = env_vars + self.region = region + + def execute(self, context): + pass + + # generate 50 env_vars to exceed default char limit of 4096 (50 * 87 chars ≈ 4350 chars) + env_vars = {f"TEST_URL_{i}": secret_url for i in range(50)} + + task = CustomOperator( + task_id="test_truncation_masking", + env_vars=env_vars, + region="us-west-2", + ) + + runtime_ti = create_runtime_ti(task=task, dag_id="test_truncation_masking_dag") + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + assert ( + call( + msg=SetRenderedFields( + rendered_fields={ + "env_vars": "Truncated. You can change this behaviour in [core]max_templated_field_length. \"{'TEST_URL_0': '***', 'TEST_URL_1': '***', 'TEST_URL_10': '***', 'TEST_URL_11': '***', 'TEST_URL_12': '***', 'TEST_URL_13': '***', 'TEST_URL_14': '***', 'TEST_URL_15': '***', 'TEST_URL_16': '***', 'TEST_URL_17': '***', 'TEST_URL_18': '***', 'TEST_URL_19': '***', 'TEST_URL_2': '***', 'TEST_URL_20': '***', 'TEST_URL_21': '***', 'TEST_URL_22': '***', 'TEST_URL_23': '***', 'TEST_URL_24': '***', 'TEST_URL_25': '***', 'TEST_URL_26': '***', 'TEST_URL_27': '***', 'TEST_URL_28': '***', 'TEST_URL_29': '***', 'TEST_URL_3': '***', 'TEST_URL_30': '***', 'TEST_URL_31': '***', 'TEST_URL_32': '***', 'TEST_URL_33': '***', 'TEST_URL_34': '***', 'TEST_URL_35': '***', 'TEST_URL_36': '***', 'TEST_URL_37': '***', 'TEST_URL_38': '***', 'TEST_URL_39': '***', 'TEST_URL_4': '***', 'TEST_URL_40': '***', 'TEST_URL_41': '***', 'TEST_URL_42': '***', 'TEST_URL_43': '***', 'TEST_URL_44': '***', 'TEST_URL_45': '***', 'TEST_URL_46': '***', 'TEST_URL_47': '***', 'TEST_URL_48': '***', 'TEST_URL_49': '***', 'TEST_URL_5': '***', 'TEST_URL_6': '***', 'TEST_URL_7': '***', 'TEST_URL_8': '***', 'TEST_URL_9': '***'}\"... ", + "region": "us-west-2", + }, + type="SetRenderedFields", + ) + ) + in mock_supervisor_comms.send.mock_calls + ) + class TestXComAfterTaskExecution: @pytest.mark.parametrize(