Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):

_json_schema: ClassVar[Validator] = lazy_object_proxy.Proxy(load_dag_schema)

_const_fields: ClassVar[set[str] | None] = None

_can_skip_downstream: bool
_is_empty: bool
_needs_expansion: bool
Expand Down Expand Up @@ -1713,6 +1715,22 @@ def set_task_dag_references(task: SerializedOperator | MappedOperator, dag: Seri
# Bypass set_upstream etc here - it does more than we want
dag.task_dict[task_id].upstream_task_ids.add(task.task_id)

@classmethod
def get_operator_const_fields(cls) -> set[str]:
"""Get the set of operator fields that are marked as const in the JSON schema."""
if (schema_loader := cls._json_schema) is None:
return set()

schema_data = schema_loader.schema
operator_def = schema_data.get("definitions", {}).get("operator", {})
properties = operator_def.get("properties", {})

return {
field_name
for field_name, field_def in properties.items()
if isinstance(field_def, dict) and field_def.get("const")
}

@classmethod
@lru_cache(maxsize=1) # Only one type: "operator"
def get_operator_optional_fields_from_schema(cls) -> set[str]:
Expand Down Expand Up @@ -1868,10 +1886,39 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
# Check if value matches client_defaults (hierarchical defaults optimization)
if cls._matches_client_defaults(var, attrname):
return True
schema_defaults = cls.get_schema_defaults("operator")

# for const fields, we should always be excluded when False, regardless of client_defaults
# Use class-level cache for optimisation
if cls._const_fields is None:
cls._const_fields = cls.get_operator_const_fields()
if attrname in cls._const_fields and var is False:
return True

schema_defaults = cls.get_schema_defaults("operator")
if attrname in schema_defaults:
if schema_defaults[attrname] == var:
# If it also matches client_defaults, exclude (optimization)
client_defaults = cls.generate_client_defaults()
if attrname in client_defaults:
if client_defaults[attrname] == var:
return True
# If client_defaults differs, preserve explicit override from user
# Example: default_args={"retries": 0}, schema default=0, client_defaults={"retries": 3}
if client_defaults[attrname] != var:
if op.has_dag():
dag = op.dag
if dag and attrname in dag.default_args and dag.default_args[attrname] == var:
return False
if (
hasattr(op, "_BaseOperator__init_kwargs")
and attrname in op._BaseOperator__init_kwargs
and op._BaseOperator__init_kwargs[attrname] == var
):
return False

# If client_defaults doesn't have this field (matches schema default),
# exclude for optimization even if in default_args
# Example: default_args={"depends_on_past": False}, schema default=False
return True
optional_fields = cls.get_operator_optional_fields_from_schema()
if var is None:
Expand Down
238 changes: 133 additions & 105 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,40 +101,41 @@
from airflow.sdk.definitions.context import Context


@contextlib.contextmanager
def operator_defaults(overrides):
@pytest.fixture
def operator_defaults(monkeypatch):
"""
Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit.
Fixture that provides a context manager to temporarily patch OPERATOR_DEFAULTS.

Example:
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
# Test code with modified operator defaults
Usage:
def test_something(operator_defaults):
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
# Test code with modified operator defaults
"""
import airflow.sdk.definitions._internal.abstractoperator as abstract_op_module
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS
from airflow.serialization.serialized_objects import SerializedBaseOperator

original_values = {}
try:
# Store original values and apply overrides
@contextlib.contextmanager
def _operator_defaults(overrides):
# Patch OPERATOR_DEFAULTS
for key, value in overrides.items():
original_values[key] = OPERATOR_DEFAULTS.get(key)
OPERATOR_DEFAULTS[key] = value
monkeypatch.setitem(OPERATOR_DEFAULTS, key, value)

# Patch module-level constants
const_name = f"DEFAULT_{key.upper()}"
if hasattr(abstract_op_module, const_name):
monkeypatch.setattr(abstract_op_module, const_name, value)

# Clear the cache to ensure fresh generation
SerializedBaseOperator.generate_client_defaults.cache_clear()

yield
finally:
# Cleanup: restore original values
for key, original_value in original_values.items():
if original_value is None and key in OPERATOR_DEFAULTS:
# Key didn't exist originally, remove it
del OPERATOR_DEFAULTS[key]
else:
# Restore original value
OPERATOR_DEFAULTS[key] = original_value
try:
yield
finally:
# Clear cache again to restore normal behavior
SerializedBaseOperator.generate_client_defaults.cache_clear()

# Clear cache again to restore normal behavior
SerializedBaseOperator.generate_client_defaults.cache_clear()
return _operator_defaults


AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3]
Expand Down Expand Up @@ -4045,77 +4046,104 @@ def test_apply_defaults_to_encoded_op_none_inputs(self):
result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
assert result == encoded_op

@operator_defaults({"retries": 2})
def test_multiple_tasks_share_client_defaults(self):
def test_multiple_tasks_share_client_defaults(self, operator_defaults):
"""Test that multiple tasks can share the same client_defaults when there are actually non-default values."""
with DAG(dag_id="test_dag") as dag:
BashOperator(task_id="task1", bash_command="echo 1")
BashOperator(task_id="task2", bash_command="echo 2")
with operator_defaults({"retries": 2}):
with DAG(dag_id="test_dag") as dag:
BashOperator(task_id="task1", bash_command="echo 1")
BashOperator(task_id="task2", bash_command="echo 2")

serialized = SerializedDAG.to_dict(dag)
serialized = SerializedDAG.to_dict(dag)

# Should have one client_defaults section for all tasks
assert "client_defaults" in serialized
assert "tasks" in serialized["client_defaults"]
# Should have one client_defaults section for all tasks
assert "client_defaults" in serialized
assert "tasks" in serialized["client_defaults"]

# All tasks should benefit from the same client_defaults
client_defaults = serialized["client_defaults"]["tasks"]
# All tasks should benefit from the same client_defaults
client_defaults = serialized["client_defaults"]["tasks"]

# Deserialize and check both tasks get the defaults
deserialized_dag = SerializedDAG.from_dict(serialized)
deserialized_task1 = deserialized_dag.get_task("task1")
deserialized_task2 = deserialized_dag.get_task("task2")
# Deserialize and check both tasks get the defaults
deserialized_dag = SerializedDAG.from_dict(serialized)
deserialized_task1 = deserialized_dag.get_task("task1")
deserialized_task2 = deserialized_dag.get_task("task2")

# Both tasks should have retries=2 from client_defaults
assert deserialized_task1.retries == 2
assert deserialized_task2.retries == 2
# Both tasks should have retries=2 from client_defaults
assert deserialized_task1.retries == 2
assert deserialized_task2.retries == 2

# Both tasks should have the same default values from client_defaults
for field in client_defaults:
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
value1 = getattr(deserialized_task1, field)
value2 = getattr(deserialized_task2, field)
assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}"
# Both tasks should have the same default values from client_defaults
for field in client_defaults:
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
value1 = getattr(deserialized_task1, field)
value2 = getattr(deserialized_task2, field)
assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}"

def test_default_args_when_equal_to_schema_defaults(self, operator_defaults):
"""Test that explicitly set values matching schema defaults are preserved when client_defaults differ."""
with operator_defaults({"retries": 3}):
with DAG(dag_id="test_explicit_schema_default", default_args={"retries": 0}) as dag:
BashOperator(task_id="task1", bash_command="echo 1")
BashOperator(task_id="task2", bash_command="echo 1", retries=2)

serialized = SerializedDAG.to_dict(dag)

# verify client_defaults has retries=3
assert "client_defaults" in serialized
assert "tasks" in serialized["client_defaults"]
client_defaults = serialized["client_defaults"]["tasks"]
assert client_defaults["retries"] == 3

task1_data = serialized["dag"]["tasks"][0]["__var"]
assert task1_data.get("retries", -1) == 0

task2_data = serialized["dag"]["tasks"][1]["__var"]
assert task2_data.get("retries", -1) == 2

deserialized_task1 = SerializedDAG.from_dict(serialized).get_task("task1")
assert deserialized_task1.retries == 0

deserialized_task2 = SerializedDAG.from_dict(serialized).get_task("task2")
assert deserialized_task2.retries == 2


class TestMappedOperatorSerializationAndClientDefaults:
"""Test MappedOperator serialization with client defaults and callback properties."""

@operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_application(self):
def test_mapped_operator_client_defaults_application(self, operator_defaults):
"""Test that client_defaults are correctly applied to MappedOperator during deserialization."""
with DAG(dag_id="test_mapped_dag") as dag:
# Create a mapped operator
BashOperator.partial(
task_id="mapped_task",
retries=5, # Override default
).expand(bash_command=["echo 1", "echo 2", "echo 3"])

# Serialize the DAG
serialized_dag = SerializedDAG.to_dict(dag)
with operator_defaults({"retry_delay": 200.0}):
with DAG(dag_id="test_mapped_dag") as dag:
# Create a mapped operator
BashOperator.partial(
task_id="mapped_task",
retries=5, # Override default
).expand(bash_command=["echo 1", "echo 2", "echo 3"])

# Serialize the DAG
serialized_dag = SerializedDAG.to_dict(dag)

# Should have client_defaults section
assert "client_defaults" in serialized_dag
assert "tasks" in serialized_dag["client_defaults"]
# Should have client_defaults section
assert "client_defaults" in serialized_dag
assert "tasks" in serialized_dag["client_defaults"]

# Deserialize and check that client_defaults are applied
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.get_task("mapped_task")
# Deserialize and check that client_defaults are applied
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.get_task("mapped_task")

# Verify it's still a MappedOperator
from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator
# Verify it's still a MappedOperator
from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator

assert isinstance(deserialized_task, SchedulerMappedOperator)
assert isinstance(deserialized_task, SchedulerMappedOperator)

# Check that client_defaults values are applied (e.g., retry_delay from client_defaults)
client_defaults = serialized_dag["client_defaults"]["tasks"]
if "retry_delay" in client_defaults:
# If retry_delay wasn't explicitly set, it should come from client_defaults
# Since we can't easily convert timedelta back, check the serialized format
assert hasattr(deserialized_task, "retry_delay")
# Check that client_defaults values are applied (e.g., retry_delay from client_defaults)
client_defaults = serialized_dag["client_defaults"]["tasks"]
if "retry_delay" in client_defaults:
# If retry_delay wasn't explicitly set, it should come from client_defaults
# Since we can't easily convert timedelta back, check the serialized format
assert hasattr(deserialized_task, "retry_delay")

# Explicit values should override client_defaults
assert deserialized_task.retries == 5 # Explicitly set value
# Explicit values should override client_defaults
assert deserialized_task.retries == 5 # Explicitly set value

@pytest.mark.parametrize(
["task_config", "dag_id", "task_id", "non_default_fields"],
Expand Down Expand Up @@ -4146,45 +4174,45 @@ def test_mapped_operator_client_defaults_application(self):
),
],
)
@operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_optimization(
self, task_config, dag_id, task_id, non_default_fields
self, task_config, dag_id, task_id, non_default_fields, operator_defaults
):
"""Test that MappedOperator serialization optimizes using client defaults."""
with DAG(dag_id=dag_id) as dag:
# Create mapped operator with specified configuration
BashOperator.partial(
task_id=task_id,
**task_config,
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
with operator_defaults({"retry_delay": 200.0}):
with DAG(dag_id=dag_id) as dag:
# Create mapped operator with specified configuration
BashOperator.partial(
task_id=task_id,
**task_config,
).expand(bash_command=["echo 1", "echo 2", "echo 3"])

serialized_dag = SerializedDAG.to_dict(dag)
mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
serialized_dag = SerializedDAG.to_dict(dag)
mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]

assert mapped_task_serialized is not None
assert mapped_task_serialized.get("_is_mapped") is True
assert mapped_task_serialized is not None
assert mapped_task_serialized.get("_is_mapped") is True

# Check optimization behavior
client_defaults = serialized_dag["client_defaults"]["tasks"]
partial_kwargs = mapped_task_serialized["partial_kwargs"]
# Check optimization behavior
client_defaults = serialized_dag["client_defaults"]["tasks"]
partial_kwargs = mapped_task_serialized["partial_kwargs"]

# Check that all fields are optimized correctly
for field, default_value in client_defaults.items():
if field in non_default_fields:
# Non-default fields should be present in partial_kwargs
assert field in partial_kwargs, (
f"Field '{field}' should be in partial_kwargs as it's non-default"
)
# And have different values than defaults
assert partial_kwargs[field] != default_value, (
f"Field '{field}' should have non-default value"
)
else:
# Default fields should either not be present or have different values if present
if field in partial_kwargs:
# Check that all fields are optimized correctly
for field, default_value in client_defaults.items():
if field in non_default_fields:
# Non-default fields should be present in partial_kwargs
assert field in partial_kwargs, (
f"Field '{field}' should be in partial_kwargs as it's non-default"
)
# And have different values than defaults
assert partial_kwargs[field] != default_value, (
f"Field '{field}' with default value should be optimized out"
f"Field '{field}' should have non-default value"
)
else:
# Default fields should either not be present or have different values if present
if field in partial_kwargs:
assert partial_kwargs[field] != default_value, (
f"Field '{field}' with default value should be optimized out"
)

def test_mapped_operator_expand_input_preservation(self):
"""Test that expand_input is correctly preserved during serialization."""
Expand Down