diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 727e2685daa9e..6a13ae1cdf72c 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -1236,6 +1236,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): _json_schema: ClassVar[Validator] = lazy_object_proxy.Proxy(load_dag_schema) + _const_fields: ClassVar[set[str] | None] = None + _can_skip_downstream: bool _is_empty: bool _needs_expansion: bool @@ -1713,6 +1715,22 @@ def set_task_dag_references(task: SerializedOperator | MappedOperator, dag: Seri # Bypass set_upstream etc here - it does more than we want dag.task_dict[task_id].upstream_task_ids.add(task.task_id) + @classmethod + def get_operator_const_fields(cls) -> set[str]: + """Get the set of operator fields that are marked as const in the JSON schema.""" + if (schema_loader := cls._json_schema) is None: + return set() + + schema_data = schema_loader.schema + operator_def = schema_data.get("definitions", {}).get("operator", {}) + properties = operator_def.get("properties", {}) + + return { + field_name + for field_name, field_def in properties.items() + if isinstance(field_def, dict) and field_def.get("const") + } + @classmethod @lru_cache(maxsize=1) # Only one type: "operator" def get_operator_optional_fields_from_schema(cls) -> set[str]: @@ -1868,10 +1886,39 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): # Check if value matches client_defaults (hierarchical defaults optimization) if cls._matches_client_defaults(var, attrname): return True - schema_defaults = cls.get_schema_defaults("operator") + # for const fields, we should always be excluded when False, regardless of client_defaults + # Use class-level cache for optimisation + if cls._const_fields is None: + cls._const_fields = cls.get_operator_const_fields() + if attrname in cls._const_fields and var is False: + return True + + schema_defaults = cls.get_schema_defaults("operator") if attrname in schema_defaults: if schema_defaults[attrname] == var: + # If it also matches client_defaults, exclude (optimization) + client_defaults = cls.generate_client_defaults() + if attrname in client_defaults: + if client_defaults[attrname] == var: + return True + # If client_defaults differs, preserve explicit override from user + # Example: default_args={"retries": 0}, schema default=0, client_defaults={"retries": 3} + if client_defaults[attrname] != var: + if op.has_dag(): + dag = op.dag + if dag and attrname in dag.default_args and dag.default_args[attrname] == var: + return False + if ( + hasattr(op, "_BaseOperator__init_kwargs") + and attrname in op._BaseOperator__init_kwargs + and op._BaseOperator__init_kwargs[attrname] == var + ): + return False + + # If client_defaults doesn't have this field (matches schema default), + # exclude for optimization even if in default_args + # Example: default_args={"depends_on_past": False}, schema default=False return True optional_fields = cls.get_operator_optional_fields_from_schema() if var is None: diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 3373f238dd55c..337d5cd725460 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -101,40 +101,41 @@ from airflow.sdk.definitions.context import Context -@contextlib.contextmanager -def operator_defaults(overrides): +@pytest.fixture +def operator_defaults(monkeypatch): """ - Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit. + Fixture that provides a context manager to temporarily patch OPERATOR_DEFAULTS. - Example: - with operator_defaults({"retries": 2, "retry_delay": 200.0}): - # Test code with modified operator defaults + Usage: + def test_something(operator_defaults): + with operator_defaults({"retries": 2, "retry_delay": 200.0}): + # Test code with modified operator defaults """ + import airflow.sdk.definitions._internal.abstractoperator as abstract_op_module from airflow.sdk.bases.operator import OPERATOR_DEFAULTS + from airflow.serialization.serialized_objects import SerializedBaseOperator - original_values = {} - try: - # Store original values and apply overrides + @contextlib.contextmanager + def _operator_defaults(overrides): + # Patch OPERATOR_DEFAULTS for key, value in overrides.items(): - original_values[key] = OPERATOR_DEFAULTS.get(key) - OPERATOR_DEFAULTS[key] = value + monkeypatch.setitem(OPERATOR_DEFAULTS, key, value) + + # Patch module-level constants + const_name = f"DEFAULT_{key.upper()}" + if hasattr(abstract_op_module, const_name): + monkeypatch.setattr(abstract_op_module, const_name, value) # Clear the cache to ensure fresh generation SerializedBaseOperator.generate_client_defaults.cache_clear() - yield - finally: - # Cleanup: restore original values - for key, original_value in original_values.items(): - if original_value is None and key in OPERATOR_DEFAULTS: - # Key didn't exist originally, remove it - del OPERATOR_DEFAULTS[key] - else: - # Restore original value - OPERATOR_DEFAULTS[key] = original_value + try: + yield + finally: + # Clear cache again to restore normal behavior + SerializedBaseOperator.generate_client_defaults.cache_clear() - # Clear cache again to restore normal behavior - SerializedBaseOperator.generate_client_defaults.cache_clear() + return _operator_defaults AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3] @@ -4045,77 +4046,104 @@ def test_apply_defaults_to_encoded_op_none_inputs(self): result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None) assert result == encoded_op - @operator_defaults({"retries": 2}) - def test_multiple_tasks_share_client_defaults(self): + def test_multiple_tasks_share_client_defaults(self, operator_defaults): """Test that multiple tasks can share the same client_defaults when there are actually non-default values.""" - with DAG(dag_id="test_dag") as dag: - BashOperator(task_id="task1", bash_command="echo 1") - BashOperator(task_id="task2", bash_command="echo 2") + with operator_defaults({"retries": 2}): + with DAG(dag_id="test_dag") as dag: + BashOperator(task_id="task1", bash_command="echo 1") + BashOperator(task_id="task2", bash_command="echo 2") - serialized = SerializedDAG.to_dict(dag) + serialized = SerializedDAG.to_dict(dag) - # Should have one client_defaults section for all tasks - assert "client_defaults" in serialized - assert "tasks" in serialized["client_defaults"] + # Should have one client_defaults section for all tasks + assert "client_defaults" in serialized + assert "tasks" in serialized["client_defaults"] - # All tasks should benefit from the same client_defaults - client_defaults = serialized["client_defaults"]["tasks"] + # All tasks should benefit from the same client_defaults + client_defaults = serialized["client_defaults"]["tasks"] - # Deserialize and check both tasks get the defaults - deserialized_dag = SerializedDAG.from_dict(serialized) - deserialized_task1 = deserialized_dag.get_task("task1") - deserialized_task2 = deserialized_dag.get_task("task2") + # Deserialize and check both tasks get the defaults + deserialized_dag = SerializedDAG.from_dict(serialized) + deserialized_task1 = deserialized_dag.get_task("task1") + deserialized_task2 = deserialized_dag.get_task("task2") - # Both tasks should have retries=2 from client_defaults - assert deserialized_task1.retries == 2 - assert deserialized_task2.retries == 2 + # Both tasks should have retries=2 from client_defaults + assert deserialized_task1.retries == 2 + assert deserialized_task2.retries == 2 - # Both tasks should have the same default values from client_defaults - for field in client_defaults: - if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field): - value1 = getattr(deserialized_task1, field) - value2 = getattr(deserialized_task2, field) - assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}" + # Both tasks should have the same default values from client_defaults + for field in client_defaults: + if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field): + value1 = getattr(deserialized_task1, field) + value2 = getattr(deserialized_task2, field) + assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}" + + def test_default_args_when_equal_to_schema_defaults(self, operator_defaults): + """Test that explicitly set values matching schema defaults are preserved when client_defaults differ.""" + with operator_defaults({"retries": 3}): + with DAG(dag_id="test_explicit_schema_default", default_args={"retries": 0}) as dag: + BashOperator(task_id="task1", bash_command="echo 1") + BashOperator(task_id="task2", bash_command="echo 1", retries=2) + + serialized = SerializedDAG.to_dict(dag) + + # verify client_defaults has retries=3 + assert "client_defaults" in serialized + assert "tasks" in serialized["client_defaults"] + client_defaults = serialized["client_defaults"]["tasks"] + assert client_defaults["retries"] == 3 + + task1_data = serialized["dag"]["tasks"][0]["__var"] + assert task1_data.get("retries", -1) == 0 + + task2_data = serialized["dag"]["tasks"][1]["__var"] + assert task2_data.get("retries", -1) == 2 + + deserialized_task1 = SerializedDAG.from_dict(serialized).get_task("task1") + assert deserialized_task1.retries == 0 + + deserialized_task2 = SerializedDAG.from_dict(serialized).get_task("task2") + assert deserialized_task2.retries == 2 class TestMappedOperatorSerializationAndClientDefaults: """Test MappedOperator serialization with client defaults and callback properties.""" - @operator_defaults({"retry_delay": 200.0}) - def test_mapped_operator_client_defaults_application(self): + def test_mapped_operator_client_defaults_application(self, operator_defaults): """Test that client_defaults are correctly applied to MappedOperator during deserialization.""" - with DAG(dag_id="test_mapped_dag") as dag: - # Create a mapped operator - BashOperator.partial( - task_id="mapped_task", - retries=5, # Override default - ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) - - # Serialize the DAG - serialized_dag = SerializedDAG.to_dict(dag) + with operator_defaults({"retry_delay": 200.0}): + with DAG(dag_id="test_mapped_dag") as dag: + # Create a mapped operator + BashOperator.partial( + task_id="mapped_task", + retries=5, # Override default + ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) + + # Serialize the DAG + serialized_dag = SerializedDAG.to_dict(dag) - # Should have client_defaults section - assert "client_defaults" in serialized_dag - assert "tasks" in serialized_dag["client_defaults"] + # Should have client_defaults section + assert "client_defaults" in serialized_dag + assert "tasks" in serialized_dag["client_defaults"] - # Deserialize and check that client_defaults are applied - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.get_task("mapped_task") + # Deserialize and check that client_defaults are applied + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.get_task("mapped_task") - # Verify it's still a MappedOperator - from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator + # Verify it's still a MappedOperator + from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator - assert isinstance(deserialized_task, SchedulerMappedOperator) + assert isinstance(deserialized_task, SchedulerMappedOperator) - # Check that client_defaults values are applied (e.g., retry_delay from client_defaults) - client_defaults = serialized_dag["client_defaults"]["tasks"] - if "retry_delay" in client_defaults: - # If retry_delay wasn't explicitly set, it should come from client_defaults - # Since we can't easily convert timedelta back, check the serialized format - assert hasattr(deserialized_task, "retry_delay") + # Check that client_defaults values are applied (e.g., retry_delay from client_defaults) + client_defaults = serialized_dag["client_defaults"]["tasks"] + if "retry_delay" in client_defaults: + # If retry_delay wasn't explicitly set, it should come from client_defaults + # Since we can't easily convert timedelta back, check the serialized format + assert hasattr(deserialized_task, "retry_delay") - # Explicit values should override client_defaults - assert deserialized_task.retries == 5 # Explicitly set value + # Explicit values should override client_defaults + assert deserialized_task.retries == 5 # Explicitly set value @pytest.mark.parametrize( ["task_config", "dag_id", "task_id", "non_default_fields"], @@ -4146,45 +4174,45 @@ def test_mapped_operator_client_defaults_application(self): ), ], ) - @operator_defaults({"retry_delay": 200.0}) def test_mapped_operator_client_defaults_optimization( - self, task_config, dag_id, task_id, non_default_fields + self, task_config, dag_id, task_id, non_default_fields, operator_defaults ): """Test that MappedOperator serialization optimizes using client defaults.""" - with DAG(dag_id=dag_id) as dag: - # Create mapped operator with specified configuration - BashOperator.partial( - task_id=task_id, - **task_config, - ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) + with operator_defaults({"retry_delay": 200.0}): + with DAG(dag_id=dag_id) as dag: + # Create mapped operator with specified configuration + BashOperator.partial( + task_id=task_id, + **task_config, + ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) - serialized_dag = SerializedDAG.to_dict(dag) - mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"] + serialized_dag = SerializedDAG.to_dict(dag) + mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"] - assert mapped_task_serialized is not None - assert mapped_task_serialized.get("_is_mapped") is True + assert mapped_task_serialized is not None + assert mapped_task_serialized.get("_is_mapped") is True - # Check optimization behavior - client_defaults = serialized_dag["client_defaults"]["tasks"] - partial_kwargs = mapped_task_serialized["partial_kwargs"] + # Check optimization behavior + client_defaults = serialized_dag["client_defaults"]["tasks"] + partial_kwargs = mapped_task_serialized["partial_kwargs"] - # Check that all fields are optimized correctly - for field, default_value in client_defaults.items(): - if field in non_default_fields: - # Non-default fields should be present in partial_kwargs - assert field in partial_kwargs, ( - f"Field '{field}' should be in partial_kwargs as it's non-default" - ) - # And have different values than defaults - assert partial_kwargs[field] != default_value, ( - f"Field '{field}' should have non-default value" - ) - else: - # Default fields should either not be present or have different values if present - if field in partial_kwargs: + # Check that all fields are optimized correctly + for field, default_value in client_defaults.items(): + if field in non_default_fields: + # Non-default fields should be present in partial_kwargs + assert field in partial_kwargs, ( + f"Field '{field}' should be in partial_kwargs as it's non-default" + ) + # And have different values than defaults assert partial_kwargs[field] != default_value, ( - f"Field '{field}' with default value should be optimized out" + f"Field '{field}' should have non-default value" ) + else: + # Default fields should either not be present or have different values if present + if field in partial_kwargs: + assert partial_kwargs[field] != default_value, ( + f"Field '{field}' with default value should be optimized out" + ) def test_mapped_operator_expand_input_preservation(self): """Test that expand_input is correctly preserved during serialization."""