Skip to content
Merged
49 changes: 48 additions & 1 deletion airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):

_json_schema: ClassVar[Validator] = lazy_object_proxy.Proxy(load_dag_schema)

_const_fields: ClassVar[set[str] | None] = None

_can_skip_downstream: bool
_is_empty: bool
_needs_expansion: bool
Expand Down Expand Up @@ -1707,6 +1709,22 @@ def set_task_dag_references(task: SerializedOperator | MappedOperator, dag: Seri
# Bypass set_upstream etc here - it does more than we want
dag.task_dict[task_id].upstream_task_ids.add(task.task_id)

@classmethod
def get_operator_const_fields(cls) -> set[str]:
"""Get the set of operator fields that are marked as const in the JSON schema."""
if (schema_loader := cls._json_schema) is None:
return set()

schema_data = schema_loader.schema
operator_def = schema_data.get("definitions", {}).get("operator", {})
properties = operator_def.get("properties", {})

return {
field_name
for field_name, field_def in properties.items()
if isinstance(field_def, dict) and field_def.get("const")
}

@classmethod
@lru_cache(maxsize=1) # Only one type: "operator"
def get_operator_optional_fields_from_schema(cls) -> set[str]:
Expand Down Expand Up @@ -1862,10 +1880,39 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
# Check if value matches client_defaults (hierarchical defaults optimization)
if cls._matches_client_defaults(var, attrname):
return True
schema_defaults = cls.get_schema_defaults("operator")

# for const fields, we should always be excluded when False, regardless of client_defaults
# Use class-level cache for optimisation
if cls._const_fields is None:
cls._const_fields = cls.get_operator_const_fields()
if attrname in cls._const_fields and var is False:
return True

schema_defaults = cls.get_schema_defaults("operator")
if attrname in schema_defaults:
if schema_defaults[attrname] == var:
# If it also matches client_defaults, exclude (optimization)
client_defaults = cls.generate_client_defaults()
if attrname in client_defaults:
if client_defaults[attrname] == var:
return True
# If client_defaults differs, preserve explicit override from user
# Example: default_args={"retries": 0}, schema default=0, client_defaults={"retries": 3}
if client_defaults[attrname] != var:
if op.has_dag():
dag = op.dag
if dag and attrname in dag.default_args and dag.default_args[attrname] == var:
return False
if (
hasattr(op, "_BaseOperator__init_kwargs")
and attrname in op._BaseOperator__init_kwargs
and op._BaseOperator__init_kwargs[attrname] == var
):
return False

# If client_defaults doesn't have this field (matches schema default),
# exclude for optimization even if in default_args
# Example: default_args={"depends_on_past": False}, schema default=False
return True
optional_fields = cls.get_operator_optional_fields_from_schema()
if var is None:
Expand Down
238 changes: 133 additions & 105 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,40 +104,41 @@
from airflow.sdk.definitions.context import Context


@contextlib.contextmanager
def operator_defaults(overrides):
@pytest.fixture
def operator_defaults(monkeypatch):
"""
Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit.
Fixture that provides a context manager to temporarily patch OPERATOR_DEFAULTS.

Example:
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
# Test code with modified operator defaults
Usage:
def test_something(operator_defaults):
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
# Test code with modified operator defaults
"""
import airflow.sdk.definitions._internal.abstractoperator as abstract_op_module
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS
from airflow.serialization.serialized_objects import SerializedBaseOperator

original_values = {}
try:
# Store original values and apply overrides
@contextlib.contextmanager
def _operator_defaults(overrides):
# Patch OPERATOR_DEFAULTS
for key, value in overrides.items():
original_values[key] = OPERATOR_DEFAULTS.get(key)
OPERATOR_DEFAULTS[key] = value
monkeypatch.setitem(OPERATOR_DEFAULTS, key, value)

# Patch module-level constants
const_name = f"DEFAULT_{key.upper()}"
if hasattr(abstract_op_module, const_name):
monkeypatch.setattr(abstract_op_module, const_name, value)

# Clear the cache to ensure fresh generation
SerializedBaseOperator.generate_client_defaults.cache_clear()

yield
finally:
# Cleanup: restore original values
for key, original_value in original_values.items():
if original_value is None and key in OPERATOR_DEFAULTS:
# Key didn't exist originally, remove it
del OPERATOR_DEFAULTS[key]
else:
# Restore original value
OPERATOR_DEFAULTS[key] = original_value
try:
yield
finally:
# Clear cache again to restore normal behavior
SerializedBaseOperator.generate_client_defaults.cache_clear()

# Clear cache again to restore normal behavior
SerializedBaseOperator.generate_client_defaults.cache_clear()
return _operator_defaults


AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3]
Expand Down Expand Up @@ -4107,77 +4108,104 @@ def test_apply_defaults_to_encoded_op_none_inputs(self):
result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
assert result == encoded_op

@operator_defaults({"retries": 2})
def test_multiple_tasks_share_client_defaults(self):
def test_multiple_tasks_share_client_defaults(self, operator_defaults):
"""Test that multiple tasks can share the same client_defaults when there are actually non-default values."""
with DAG(dag_id="test_dag") as dag:
BashOperator(task_id="task1", bash_command="echo 1")
BashOperator(task_id="task2", bash_command="echo 2")
with operator_defaults({"retries": 2}):
with DAG(dag_id="test_dag") as dag:
BashOperator(task_id="task1", bash_command="echo 1")
BashOperator(task_id="task2", bash_command="echo 2")

serialized = SerializedDAG.to_dict(dag)
serialized = SerializedDAG.to_dict(dag)

# Should have one client_defaults section for all tasks
assert "client_defaults" in serialized
assert "tasks" in serialized["client_defaults"]
# Should have one client_defaults section for all tasks
assert "client_defaults" in serialized
assert "tasks" in serialized["client_defaults"]

# All tasks should benefit from the same client_defaults
client_defaults = serialized["client_defaults"]["tasks"]
# All tasks should benefit from the same client_defaults
client_defaults = serialized["client_defaults"]["tasks"]

# Deserialize and check both tasks get the defaults
deserialized_dag = SerializedDAG.from_dict(serialized)
deserialized_task1 = deserialized_dag.get_task("task1")
deserialized_task2 = deserialized_dag.get_task("task2")
# Deserialize and check both tasks get the defaults
deserialized_dag = SerializedDAG.from_dict(serialized)
deserialized_task1 = deserialized_dag.get_task("task1")
deserialized_task2 = deserialized_dag.get_task("task2")

# Both tasks should have retries=2 from client_defaults
assert deserialized_task1.retries == 2
assert deserialized_task2.retries == 2
# Both tasks should have retries=2 from client_defaults
assert deserialized_task1.retries == 2
assert deserialized_task2.retries == 2

# Both tasks should have the same default values from client_defaults
for field in client_defaults:
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
value1 = getattr(deserialized_task1, field)
value2 = getattr(deserialized_task2, field)
assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}"
# Both tasks should have the same default values from client_defaults
for field in client_defaults:
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
value1 = getattr(deserialized_task1, field)
value2 = getattr(deserialized_task2, field)
assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}"

def test_default_args_when_equal_to_schema_defaults(self, operator_defaults):
"""Test that explicitly set values matching schema defaults are preserved when client_defaults differ."""
with operator_defaults({"retries": 3}):
with DAG(dag_id="test_explicit_schema_default", default_args={"retries": 0}) as dag:
BashOperator(task_id="task1", bash_command="echo 1")
BashOperator(task_id="task2", bash_command="echo 1", retries=2)

serialized = SerializedDAG.to_dict(dag)

# verify client_defaults has retries=3
assert "client_defaults" in serialized
assert "tasks" in serialized["client_defaults"]
client_defaults = serialized["client_defaults"]["tasks"]
assert client_defaults["retries"] == 3

task1_data = serialized["dag"]["tasks"][0]["__var"]
assert task1_data.get("retries", -1) == 0

task2_data = serialized["dag"]["tasks"][1]["__var"]
assert task2_data.get("retries", -1) == 2

deserialized_task1 = SerializedDAG.from_dict(serialized).get_task("task1")
assert deserialized_task1.retries == 0

deserialized_task2 = SerializedDAG.from_dict(serialized).get_task("task2")
assert deserialized_task2.retries == 2


class TestMappedOperatorSerializationAndClientDefaults:
"""Test MappedOperator serialization with client defaults and callback properties."""

@operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_application(self):
def test_mapped_operator_client_defaults_application(self, operator_defaults):
"""Test that client_defaults are correctly applied to MappedOperator during deserialization."""
with DAG(dag_id="test_mapped_dag") as dag:
# Create a mapped operator
BashOperator.partial(
task_id="mapped_task",
retries=5, # Override default
).expand(bash_command=["echo 1", "echo 2", "echo 3"])

# Serialize the DAG
serialized_dag = SerializedDAG.to_dict(dag)
with operator_defaults({"retry_delay": 200.0}):
with DAG(dag_id="test_mapped_dag") as dag:
# Create a mapped operator
BashOperator.partial(
task_id="mapped_task",
retries=5, # Override default
).expand(bash_command=["echo 1", "echo 2", "echo 3"])

# Serialize the DAG
serialized_dag = SerializedDAG.to_dict(dag)

# Should have client_defaults section
assert "client_defaults" in serialized_dag
assert "tasks" in serialized_dag["client_defaults"]
# Should have client_defaults section
assert "client_defaults" in serialized_dag
assert "tasks" in serialized_dag["client_defaults"]

# Deserialize and check that client_defaults are applied
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.get_task("mapped_task")
# Deserialize and check that client_defaults are applied
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.get_task("mapped_task")

# Verify it's still a MappedOperator
from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator
# Verify it's still a MappedOperator
from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator

assert isinstance(deserialized_task, SchedulerMappedOperator)
assert isinstance(deserialized_task, SchedulerMappedOperator)

# Check that client_defaults values are applied (e.g., retry_delay from client_defaults)
client_defaults = serialized_dag["client_defaults"]["tasks"]
if "retry_delay" in client_defaults:
# If retry_delay wasn't explicitly set, it should come from client_defaults
# Since we can't easily convert timedelta back, check the serialized format
assert hasattr(deserialized_task, "retry_delay")
# Check that client_defaults values are applied (e.g., retry_delay from client_defaults)
client_defaults = serialized_dag["client_defaults"]["tasks"]
if "retry_delay" in client_defaults:
# If retry_delay wasn't explicitly set, it should come from client_defaults
# Since we can't easily convert timedelta back, check the serialized format
assert hasattr(deserialized_task, "retry_delay")

# Explicit values should override client_defaults
assert deserialized_task.retries == 5 # Explicitly set value
# Explicit values should override client_defaults
assert deserialized_task.retries == 5 # Explicitly set value

@pytest.mark.parametrize(
("task_config", "dag_id", "task_id", "non_default_fields"),
Expand Down Expand Up @@ -4208,45 +4236,45 @@ def test_mapped_operator_client_defaults_application(self):
),
],
)
@operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_optimization(
self, task_config, dag_id, task_id, non_default_fields
self, task_config, dag_id, task_id, non_default_fields, operator_defaults
):
"""Test that MappedOperator serialization optimizes using client defaults."""
with DAG(dag_id=dag_id) as dag:
# Create mapped operator with specified configuration
BashOperator.partial(
task_id=task_id,
**task_config,
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
with operator_defaults({"retry_delay": 200.0}):
with DAG(dag_id=dag_id) as dag:
# Create mapped operator with specified configuration
BashOperator.partial(
task_id=task_id,
**task_config,
).expand(bash_command=["echo 1", "echo 2", "echo 3"])

serialized_dag = SerializedDAG.to_dict(dag)
mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
serialized_dag = SerializedDAG.to_dict(dag)
mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]

assert mapped_task_serialized is not None
assert mapped_task_serialized.get("_is_mapped") is True
assert mapped_task_serialized is not None
assert mapped_task_serialized.get("_is_mapped") is True

# Check optimization behavior
client_defaults = serialized_dag["client_defaults"]["tasks"]
partial_kwargs = mapped_task_serialized["partial_kwargs"]
# Check optimization behavior
client_defaults = serialized_dag["client_defaults"]["tasks"]
partial_kwargs = mapped_task_serialized["partial_kwargs"]

# Check that all fields are optimized correctly
for field, default_value in client_defaults.items():
if field in non_default_fields:
# Non-default fields should be present in partial_kwargs
assert field in partial_kwargs, (
f"Field '{field}' should be in partial_kwargs as it's non-default"
)
# And have different values than defaults
assert partial_kwargs[field] != default_value, (
f"Field '{field}' should have non-default value"
)
else:
# Default fields should either not be present or have different values if present
if field in partial_kwargs:
# Check that all fields are optimized correctly
for field, default_value in client_defaults.items():
if field in non_default_fields:
# Non-default fields should be present in partial_kwargs
assert field in partial_kwargs, (
f"Field '{field}' should be in partial_kwargs as it's non-default"
)
# And have different values than defaults
assert partial_kwargs[field] != default_value, (
f"Field '{field}' with default value should be optimized out"
f"Field '{field}' should have non-default value"
)
else:
# Default fields should either not be present or have different values if present
if field in partial_kwargs:
assert partial_kwargs[field] != default_value, (
f"Field '{field}' with default value should be optimized out"
)

def test_mapped_operator_expand_input_preservation(self):
"""Test that expand_input is correctly preserved during serialization."""
Expand Down