diff --git a/airflow-core/src/airflow/serialization/schema.json b/airflow-core/src/airflow/serialization/schema.json index 7707494fce546..d79c7477297e6 100644 --- a/airflow-core/src/airflow/serialization/schema.json +++ b/airflow-core/src/airflow/serialization/schema.json @@ -283,7 +283,7 @@ "pool": { "type": "string", "default": "default_pool" }, "pool_slots": { "type": "number", "default": 1 }, "execution_timeout": { "$ref": "#/definitions/timedelta" }, - "retry_delay": { "$ref": "#/definitions/timedelta" }, + "retry_delay": { "$ref": "#/definitions/timedelta", "default": 300.0 }, "retry_exponential_backoff": { "type": "boolean", "default": false }, "max_retry_delay": { "$ref": "#/definitions/timedelta" }, "params": { "$ref": "#/definitions/params" }, diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index df79fdb2a4aae..47391b074df37 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -1290,7 +1290,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): resources: dict[str, Any] | None = None retries: int = 0 - retry_delay: datetime.timedelta + retry_delay: datetime.timedelta = datetime.timedelta(seconds=300) retry_exponential_backoff: bool = False run_as_user: str | None = None @@ -2056,19 +2056,26 @@ def generate_client_defaults(cls) -> dict[str, Any]: for k, v in OPERATOR_DEFAULTS.items(): if k not in cls.get_serialized_fields(): continue - # Exclude values that are the same as the schema defaults - if k in schema_defaults and schema_defaults[k] == v: - continue # Exclude values that are None or empty collections if v is None or v in [[], (), set(), {}]: continue + # Check schema defaults first with raw value comparison (fast path) + if k in schema_defaults and schema_defaults[k] == v: + continue + # Use the existing serialize method to ensure consistent format serialized_value = cls.serialize(v) # Extract just the value part, consistent with serialize_to_json behavior if isinstance(serialized_value, dict) and Encoding.TYPE in serialized_value: serialized_value = serialized_value[Encoding.VAR] + + # For cases where raw comparison failed but serialized values might match + # (e.g., timedelta vs float), check again with serialized value + if k in schema_defaults and schema_defaults[k] == serialized_value: + continue + client_defaults[k] = serialized_value return client_defaults diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 6df1b2253504f..7e2bb6e5a5934 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -19,6 +19,7 @@ from __future__ import annotations +import contextlib import copy import dataclasses import importlib @@ -99,6 +100,43 @@ if TYPE_CHECKING: from airflow.sdk.definitions.context import Context + +@contextlib.contextmanager +def operator_defaults(overrides): + """ + Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit. + + Example: + with operator_defaults({"retries": 2, "retry_delay": 200.0}): + # Test code with modified operator defaults + """ + from airflow.sdk.bases.operator import OPERATOR_DEFAULTS + + original_values = {} + try: + # Store original values and apply overrides + for key, value in overrides.items(): + original_values[key] = OPERATOR_DEFAULTS.get(key) + OPERATOR_DEFAULTS[key] = value + + # Clear the cache to ensure fresh generation + SerializedBaseOperator.generate_client_defaults.cache_clear() + + yield + finally: + # Cleanup: restore original values + for key, original_value in original_values.items(): + if original_value is None and key in OPERATOR_DEFAULTS: + # Key didn't exist originally, remove it + del OPERATOR_DEFAULTS[key] + else: + # Restore original value + OPERATOR_DEFAULTS[key] = original_value + + # Clear cache again to restore normal behavior + SerializedBaseOperator.generate_client_defaults.cache_clear() + + AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3] @@ -117,14 +155,13 @@ VAR = Encoding.VAR serialized_simple_dag_ground_truth = { "__version": 3, - "client_defaults": {"tasks": {"retry_delay": 300.0}}, "dag": { "default_args": { "__type": "dict", "__var": { "depends_on_past": False, "retries": 1, - "retry_delay": {"__type": "timedelta", "__var": 300.0}, + "retry_delay": {"__type": "timedelta", "__var": 240.0}, "max_retry_delay": {"__type": "timedelta", "__var": 600.0}, }, }, @@ -165,7 +202,7 @@ "__var": { "task_id": "bash_task", "retries": 1, - "retry_delay": 300.0, + "retry_delay": 240.0, "max_retry_delay": 600.0, "ui_color": "#f0ede4", "template_ext": [".sh", ".bash"], @@ -224,7 +261,7 @@ "__var": { "task_id": "custom_task", "retries": 1, - "retry_delay": 300.0, + "retry_delay": 240.0, "max_retry_delay": 600.0, "_operator_extra_links": {"Google Custom": "_link_CustomOpLink"}, "template_fields": ["bash_command"], @@ -294,7 +331,7 @@ def make_simple_dag(): schedule=timedelta(days=1), default_args={ "retries": 1, - "retry_delay": timedelta(minutes=5), + "retry_delay": timedelta(minutes=4), "max_retry_delay": timedelta(minutes=10), "depends_on_past": False, }, @@ -3072,7 +3109,7 @@ def test_handle_v1_serdag(): "__var": { "depends_on_past": False, "retries": 1, - "retry_delay": {"__type": "timedelta", "__var": 300.0}, + "retry_delay": {"__type": "timedelta", "__var": 240.0}, "max_retry_delay": {"__type": "timedelta", "__var": 600.0}, "sla": {"__type": "timedelta", "__var": 100.0}, }, @@ -3110,7 +3147,7 @@ def test_handle_v1_serdag(): "__var": { "task_id": "bash_task", "retries": 1, - "retry_delay": 300.0, + "retry_delay": 240.0, "max_retry_delay": 600.0, "sla": 100.0, "downstream_task_ids": [], @@ -3173,7 +3210,7 @@ def test_handle_v1_serdag(): "__var": { "task_id": "custom_task", "retries": 1, - "retry_delay": 300.0, + "retry_delay": 240.0, "max_retry_delay": 600.0, "sla": 100.0, "downstream_task_ids": [], @@ -3383,7 +3420,7 @@ def test_handle_v2_serdag(): "__var": { "depends_on_past": False, "retries": 1, - "retry_delay": {"__type": "timedelta", "__var": 300.0}, + "retry_delay": {"__type": "timedelta", "__var": 240.0}, "max_retry_delay": {"__type": "timedelta", "__var": 600.0}, }, }, @@ -3425,7 +3462,7 @@ def test_handle_v2_serdag(): "__var": { "task_id": "bash_task", "retries": 1, - "retry_delay": 300.0, + "retry_delay": 240.0, "max_retry_delay": 600.0, "downstream_task_ids": [], "ui_color": "#f0ede4", @@ -3491,7 +3528,7 @@ def test_handle_v2_serdag(): "__var": { "task_id": "custom_task", "retries": 1, - "retry_delay": 300.0, + "retry_delay": 240.0, "max_retry_delay": 600.0, "downstream_task_ids": [], "_operator_extra_links": {"Google Custom": "_link_CustomOpLink"}, @@ -4004,8 +4041,9 @@ def test_apply_defaults_to_encoded_op_none_inputs(self): result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None) assert result == encoded_op + @operator_defaults({"retries": 2}) def test_multiple_tasks_share_client_defaults(self): - """Test that multiple tasks can share the same client_defaults.""" + """Test that multiple tasks can share the same client_defaults when there are actually non-default values.""" with DAG(dag_id="test_dag") as dag: BashOperator(task_id="task1", bash_command="echo 1") BashOperator(task_id="task2", bash_command="echo 2") @@ -4024,6 +4062,10 @@ def test_multiple_tasks_share_client_defaults(self): deserialized_task1 = deserialized_dag.get_task("task1") deserialized_task2 = deserialized_dag.get_task("task2") + # Both tasks should have retries=2 from client_defaults + assert deserialized_task1.retries == 2 + assert deserialized_task2.retries == 2 + # Both tasks should have the same default values from client_defaults for field in client_defaults: if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field): @@ -4035,6 +4077,7 @@ def test_multiple_tasks_share_client_defaults(self): class TestMappedOperatorSerializationAndClientDefaults: """Test MappedOperator serialization with client defaults and callback properties.""" + @operator_defaults({"retry_delay": 200.0}) def test_mapped_operator_client_defaults_application(self): """Test that client_defaults are correctly applied to MappedOperator during deserialization.""" with DAG(dag_id="test_mapped_dag") as dag: @@ -4099,6 +4142,7 @@ def test_mapped_operator_client_defaults_application(self): ), ], ) + @operator_defaults({"retry_delay": 200.0}) def test_mapped_operator_client_defaults_optimization( self, task_config, dag_id, task_id, non_default_fields ): diff --git a/scripts/in_container/run_schema_defaults_check.py b/scripts/in_container/run_schema_defaults_check.py index bc7c1e2844cb3..e9744134360d5 100755 --- a/scripts/in_container/run_schema_defaults_check.py +++ b/scripts/in_container/run_schema_defaults_check.py @@ -28,6 +28,7 @@ import json import sys +from datetime import timedelta from pathlib import Path from typing import Any @@ -80,6 +81,8 @@ def get_server_side_operator_defaults() -> dict[str, Any]: if isinstance(default_value, (set, tuple)): # Convert to list since schema.json is pure JSON default_value = list(default_value) + elif isinstance(default_value, timedelta): + default_value = default_value.total_seconds() server_defaults[field_name] = default_value return server_defaults