Skip to content

Commit 9f1f4c3

Browse files
github-actions[bot]amoghrajesh
authored andcommitted
[v3-1-test] Respect default_args in DAG when its set to a "falsy" value (#57853) (#58396)
(cherry picked from commit c03fc79) Co-authored-by: Amogh Desai <amoghrajesh1999@gmail.com>
1 parent fc46b72 commit 9f1f4c3

File tree

2 files changed

+181
-106
lines changed

2 files changed

+181
-106
lines changed

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
12361236

12371237
_json_schema: ClassVar[Validator] = lazy_object_proxy.Proxy(load_dag_schema)
12381238

1239+
_const_fields: ClassVar[set[str] | None] = None
1240+
12391241
_can_skip_downstream: bool
12401242
_is_empty: bool
12411243
_needs_expansion: bool
@@ -1713,6 +1715,22 @@ def set_task_dag_references(task: SerializedOperator | MappedOperator, dag: Seri
17131715
# Bypass set_upstream etc here - it does more than we want
17141716
dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
17151717

1718+
@classmethod
1719+
def get_operator_const_fields(cls) -> set[str]:
1720+
"""Get the set of operator fields that are marked as const in the JSON schema."""
1721+
if (schema_loader := cls._json_schema) is None:
1722+
return set()
1723+
1724+
schema_data = schema_loader.schema
1725+
operator_def = schema_data.get("definitions", {}).get("operator", {})
1726+
properties = operator_def.get("properties", {})
1727+
1728+
return {
1729+
field_name
1730+
for field_name, field_def in properties.items()
1731+
if isinstance(field_def, dict) and field_def.get("const")
1732+
}
1733+
17161734
@classmethod
17171735
@lru_cache(maxsize=1) # Only one type: "operator"
17181736
def get_operator_optional_fields_from_schema(cls) -> set[str]:
@@ -1868,10 +1886,39 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
18681886
# Check if value matches client_defaults (hierarchical defaults optimization)
18691887
if cls._matches_client_defaults(var, attrname):
18701888
return True
1871-
schema_defaults = cls.get_schema_defaults("operator")
18721889

1890+
# for const fields, we should always be excluded when False, regardless of client_defaults
1891+
# Use class-level cache for optimisation
1892+
if cls._const_fields is None:
1893+
cls._const_fields = cls.get_operator_const_fields()
1894+
if attrname in cls._const_fields and var is False:
1895+
return True
1896+
1897+
schema_defaults = cls.get_schema_defaults("operator")
18731898
if attrname in schema_defaults:
18741899
if schema_defaults[attrname] == var:
1900+
# If it also matches client_defaults, exclude (optimization)
1901+
client_defaults = cls.generate_client_defaults()
1902+
if attrname in client_defaults:
1903+
if client_defaults[attrname] == var:
1904+
return True
1905+
# If client_defaults differs, preserve explicit override from user
1906+
# Example: default_args={"retries": 0}, schema default=0, client_defaults={"retries": 3}
1907+
if client_defaults[attrname] != var:
1908+
if op.has_dag():
1909+
dag = op.dag
1910+
if dag and attrname in dag.default_args and dag.default_args[attrname] == var:
1911+
return False
1912+
if (
1913+
hasattr(op, "_BaseOperator__init_kwargs")
1914+
and attrname in op._BaseOperator__init_kwargs
1915+
and op._BaseOperator__init_kwargs[attrname] == var
1916+
):
1917+
return False
1918+
1919+
# If client_defaults doesn't have this field (matches schema default),
1920+
# exclude for optimization even if in default_args
1921+
# Example: default_args={"depends_on_past": False}, schema default=False
18751922
return True
18761923
optional_fields = cls.get_operator_optional_fields_from_schema()
18771924
if var is None:

airflow-core/tests/unit/serialization/test_dag_serialization.py

Lines changed: 133 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -101,40 +101,41 @@
101101
from airflow.sdk.definitions.context import Context
102102

103103

104-
@contextlib.contextmanager
105-
def operator_defaults(overrides):
104+
@pytest.fixture
105+
def operator_defaults(monkeypatch):
106106
"""
107-
Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit.
107+
Fixture that provides a context manager to temporarily patch OPERATOR_DEFAULTS.
108108
109-
Example:
110-
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
111-
# Test code with modified operator defaults
109+
Usage:
110+
def test_something(operator_defaults):
111+
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
112+
# Test code with modified operator defaults
112113
"""
114+
import airflow.sdk.definitions._internal.abstractoperator as abstract_op_module
113115
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS
116+
from airflow.serialization.serialized_objects import SerializedBaseOperator
114117

115-
original_values = {}
116-
try:
117-
# Store original values and apply overrides
118+
@contextlib.contextmanager
119+
def _operator_defaults(overrides):
120+
# Patch OPERATOR_DEFAULTS
118121
for key, value in overrides.items():
119-
original_values[key] = OPERATOR_DEFAULTS.get(key)
120-
OPERATOR_DEFAULTS[key] = value
122+
monkeypatch.setitem(OPERATOR_DEFAULTS, key, value)
123+
124+
# Patch module-level constants
125+
const_name = f"DEFAULT_{key.upper()}"
126+
if hasattr(abstract_op_module, const_name):
127+
monkeypatch.setattr(abstract_op_module, const_name, value)
121128

122129
# Clear the cache to ensure fresh generation
123130
SerializedBaseOperator.generate_client_defaults.cache_clear()
124131

125-
yield
126-
finally:
127-
# Cleanup: restore original values
128-
for key, original_value in original_values.items():
129-
if original_value is None and key in OPERATOR_DEFAULTS:
130-
# Key didn't exist originally, remove it
131-
del OPERATOR_DEFAULTS[key]
132-
else:
133-
# Restore original value
134-
OPERATOR_DEFAULTS[key] = original_value
132+
try:
133+
yield
134+
finally:
135+
# Clear cache again to restore normal behavior
136+
SerializedBaseOperator.generate_client_defaults.cache_clear()
135137

136-
# Clear cache again to restore normal behavior
137-
SerializedBaseOperator.generate_client_defaults.cache_clear()
138+
return _operator_defaults
138139

139140

140141
AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3]
@@ -4045,77 +4046,104 @@ def test_apply_defaults_to_encoded_op_none_inputs(self):
40454046
result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
40464047
assert result == encoded_op
40474048

4048-
@operator_defaults({"retries": 2})
4049-
def test_multiple_tasks_share_client_defaults(self):
4049+
def test_multiple_tasks_share_client_defaults(self, operator_defaults):
40504050
"""Test that multiple tasks can share the same client_defaults when there are actually non-default values."""
4051-
with DAG(dag_id="test_dag") as dag:
4052-
BashOperator(task_id="task1", bash_command="echo 1")
4053-
BashOperator(task_id="task2", bash_command="echo 2")
4051+
with operator_defaults({"retries": 2}):
4052+
with DAG(dag_id="test_dag") as dag:
4053+
BashOperator(task_id="task1", bash_command="echo 1")
4054+
BashOperator(task_id="task2", bash_command="echo 2")
40544055

4055-
serialized = SerializedDAG.to_dict(dag)
4056+
serialized = SerializedDAG.to_dict(dag)
40564057

4057-
# Should have one client_defaults section for all tasks
4058-
assert "client_defaults" in serialized
4059-
assert "tasks" in serialized["client_defaults"]
4058+
# Should have one client_defaults section for all tasks
4059+
assert "client_defaults" in serialized
4060+
assert "tasks" in serialized["client_defaults"]
40604061

4061-
# All tasks should benefit from the same client_defaults
4062-
client_defaults = serialized["client_defaults"]["tasks"]
4062+
# All tasks should benefit from the same client_defaults
4063+
client_defaults = serialized["client_defaults"]["tasks"]
40634064

4064-
# Deserialize and check both tasks get the defaults
4065-
deserialized_dag = SerializedDAG.from_dict(serialized)
4066-
deserialized_task1 = deserialized_dag.get_task("task1")
4067-
deserialized_task2 = deserialized_dag.get_task("task2")
4065+
# Deserialize and check both tasks get the defaults
4066+
deserialized_dag = SerializedDAG.from_dict(serialized)
4067+
deserialized_task1 = deserialized_dag.get_task("task1")
4068+
deserialized_task2 = deserialized_dag.get_task("task2")
40684069

4069-
# Both tasks should have retries=2 from client_defaults
4070-
assert deserialized_task1.retries == 2
4071-
assert deserialized_task2.retries == 2
4070+
# Both tasks should have retries=2 from client_defaults
4071+
assert deserialized_task1.retries == 2
4072+
assert deserialized_task2.retries == 2
40724073

4073-
# Both tasks should have the same default values from client_defaults
4074-
for field in client_defaults:
4075-
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
4076-
value1 = getattr(deserialized_task1, field)
4077-
value2 = getattr(deserialized_task2, field)
4078-
assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}"
4074+
# Both tasks should have the same default values from client_defaults
4075+
for field in client_defaults:
4076+
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
4077+
value1 = getattr(deserialized_task1, field)
4078+
value2 = getattr(deserialized_task2, field)
4079+
assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}"
4080+
4081+
def test_default_args_when_equal_to_schema_defaults(self, operator_defaults):
4082+
"""Test that explicitly set values matching schema defaults are preserved when client_defaults differ."""
4083+
with operator_defaults({"retries": 3}):
4084+
with DAG(dag_id="test_explicit_schema_default", default_args={"retries": 0}) as dag:
4085+
BashOperator(task_id="task1", bash_command="echo 1")
4086+
BashOperator(task_id="task2", bash_command="echo 1", retries=2)
4087+
4088+
serialized = SerializedDAG.to_dict(dag)
4089+
4090+
# verify client_defaults has retries=3
4091+
assert "client_defaults" in serialized
4092+
assert "tasks" in serialized["client_defaults"]
4093+
client_defaults = serialized["client_defaults"]["tasks"]
4094+
assert client_defaults["retries"] == 3
4095+
4096+
task1_data = serialized["dag"]["tasks"][0]["__var"]
4097+
assert task1_data.get("retries", -1) == 0
4098+
4099+
task2_data = serialized["dag"]["tasks"][1]["__var"]
4100+
assert task2_data.get("retries", -1) == 2
4101+
4102+
deserialized_task1 = SerializedDAG.from_dict(serialized).get_task("task1")
4103+
assert deserialized_task1.retries == 0
4104+
4105+
deserialized_task2 = SerializedDAG.from_dict(serialized).get_task("task2")
4106+
assert deserialized_task2.retries == 2
40794107

40804108

40814109
class TestMappedOperatorSerializationAndClientDefaults:
40824110
"""Test MappedOperator serialization with client defaults and callback properties."""
40834111

4084-
@operator_defaults({"retry_delay": 200.0})
4085-
def test_mapped_operator_client_defaults_application(self):
4112+
def test_mapped_operator_client_defaults_application(self, operator_defaults):
40864113
"""Test that client_defaults are correctly applied to MappedOperator during deserialization."""
4087-
with DAG(dag_id="test_mapped_dag") as dag:
4088-
# Create a mapped operator
4089-
BashOperator.partial(
4090-
task_id="mapped_task",
4091-
retries=5, # Override default
4092-
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
4093-
4094-
# Serialize the DAG
4095-
serialized_dag = SerializedDAG.to_dict(dag)
4114+
with operator_defaults({"retry_delay": 200.0}):
4115+
with DAG(dag_id="test_mapped_dag") as dag:
4116+
# Create a mapped operator
4117+
BashOperator.partial(
4118+
task_id="mapped_task",
4119+
retries=5, # Override default
4120+
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
4121+
4122+
# Serialize the DAG
4123+
serialized_dag = SerializedDAG.to_dict(dag)
40964124

4097-
# Should have client_defaults section
4098-
assert "client_defaults" in serialized_dag
4099-
assert "tasks" in serialized_dag["client_defaults"]
4125+
# Should have client_defaults section
4126+
assert "client_defaults" in serialized_dag
4127+
assert "tasks" in serialized_dag["client_defaults"]
41004128

4101-
# Deserialize and check that client_defaults are applied
4102-
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
4103-
deserialized_task = deserialized_dag.get_task("mapped_task")
4129+
# Deserialize and check that client_defaults are applied
4130+
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
4131+
deserialized_task = deserialized_dag.get_task("mapped_task")
41044132

4105-
# Verify it's still a MappedOperator
4106-
from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator
4133+
# Verify it's still a MappedOperator
4134+
from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator
41074135

4108-
assert isinstance(deserialized_task, SchedulerMappedOperator)
4136+
assert isinstance(deserialized_task, SchedulerMappedOperator)
41094137

4110-
# Check that client_defaults values are applied (e.g., retry_delay from client_defaults)
4111-
client_defaults = serialized_dag["client_defaults"]["tasks"]
4112-
if "retry_delay" in client_defaults:
4113-
# If retry_delay wasn't explicitly set, it should come from client_defaults
4114-
# Since we can't easily convert timedelta back, check the serialized format
4115-
assert hasattr(deserialized_task, "retry_delay")
4138+
# Check that client_defaults values are applied (e.g., retry_delay from client_defaults)
4139+
client_defaults = serialized_dag["client_defaults"]["tasks"]
4140+
if "retry_delay" in client_defaults:
4141+
# If retry_delay wasn't explicitly set, it should come from client_defaults
4142+
# Since we can't easily convert timedelta back, check the serialized format
4143+
assert hasattr(deserialized_task, "retry_delay")
41164144

4117-
# Explicit values should override client_defaults
4118-
assert deserialized_task.retries == 5 # Explicitly set value
4145+
# Explicit values should override client_defaults
4146+
assert deserialized_task.retries == 5 # Explicitly set value
41194147

41204148
@pytest.mark.parametrize(
41214149
["task_config", "dag_id", "task_id", "non_default_fields"],
@@ -4146,45 +4174,45 @@ def test_mapped_operator_client_defaults_application(self):
41464174
),
41474175
],
41484176
)
4149-
@operator_defaults({"retry_delay": 200.0})
41504177
def test_mapped_operator_client_defaults_optimization(
4151-
self, task_config, dag_id, task_id, non_default_fields
4178+
self, task_config, dag_id, task_id, non_default_fields, operator_defaults
41524179
):
41534180
"""Test that MappedOperator serialization optimizes using client defaults."""
4154-
with DAG(dag_id=dag_id) as dag:
4155-
# Create mapped operator with specified configuration
4156-
BashOperator.partial(
4157-
task_id=task_id,
4158-
**task_config,
4159-
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
4181+
with operator_defaults({"retry_delay": 200.0}):
4182+
with DAG(dag_id=dag_id) as dag:
4183+
# Create mapped operator with specified configuration
4184+
BashOperator.partial(
4185+
task_id=task_id,
4186+
**task_config,
4187+
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
41604188

4161-
serialized_dag = SerializedDAG.to_dict(dag)
4162-
mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
4189+
serialized_dag = SerializedDAG.to_dict(dag)
4190+
mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
41634191

4164-
assert mapped_task_serialized is not None
4165-
assert mapped_task_serialized.get("_is_mapped") is True
4192+
assert mapped_task_serialized is not None
4193+
assert mapped_task_serialized.get("_is_mapped") is True
41664194

4167-
# Check optimization behavior
4168-
client_defaults = serialized_dag["client_defaults"]["tasks"]
4169-
partial_kwargs = mapped_task_serialized["partial_kwargs"]
4195+
# Check optimization behavior
4196+
client_defaults = serialized_dag["client_defaults"]["tasks"]
4197+
partial_kwargs = mapped_task_serialized["partial_kwargs"]
41704198

4171-
# Check that all fields are optimized correctly
4172-
for field, default_value in client_defaults.items():
4173-
if field in non_default_fields:
4174-
# Non-default fields should be present in partial_kwargs
4175-
assert field in partial_kwargs, (
4176-
f"Field '{field}' should be in partial_kwargs as it's non-default"
4177-
)
4178-
# And have different values than defaults
4179-
assert partial_kwargs[field] != default_value, (
4180-
f"Field '{field}' should have non-default value"
4181-
)
4182-
else:
4183-
# Default fields should either not be present or have different values if present
4184-
if field in partial_kwargs:
4199+
# Check that all fields are optimized correctly
4200+
for field, default_value in client_defaults.items():
4201+
if field in non_default_fields:
4202+
# Non-default fields should be present in partial_kwargs
4203+
assert field in partial_kwargs, (
4204+
f"Field '{field}' should be in partial_kwargs as it's non-default"
4205+
)
4206+
# And have different values than defaults
41854207
assert partial_kwargs[field] != default_value, (
4186-
f"Field '{field}' with default value should be optimized out"
4208+
f"Field '{field}' should have non-default value"
41874209
)
4210+
else:
4211+
# Default fields should either not be present or have different values if present
4212+
if field in partial_kwargs:
4213+
assert partial_kwargs[field] != default_value, (
4214+
f"Field '{field}' with default value should be optimized out"
4215+
)
41884216

41894217
def test_mapped_operator_expand_input_preservation(self):
41904218
"""Test that expand_input is correctly preserved during serialization."""

0 commit comments

Comments
 (0)