|
101 | 101 | from airflow.sdk.definitions.context import Context |
102 | 102 |
|
103 | 103 |
|
104 | | -@contextlib.contextmanager |
105 | | -def operator_defaults(overrides): |
| 104 | +@pytest.fixture |
| 105 | +def operator_defaults(monkeypatch): |
106 | 106 | """ |
107 | | - Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit. |
| 107 | + Fixture that provides a context manager to temporarily patch OPERATOR_DEFAULTS. |
108 | 108 |
|
109 | | - Example: |
110 | | - with operator_defaults({"retries": 2, "retry_delay": 200.0}): |
111 | | - # Test code with modified operator defaults |
| 109 | + Usage: |
| 110 | + def test_something(operator_defaults): |
| 111 | + with operator_defaults({"retries": 2, "retry_delay": 200.0}): |
| 112 | + # Test code with modified operator defaults |
112 | 113 | """ |
| 114 | + import airflow.sdk.definitions._internal.abstractoperator as abstract_op_module |
113 | 115 | from airflow.sdk.bases.operator import OPERATOR_DEFAULTS |
| 116 | + from airflow.serialization.serialized_objects import SerializedBaseOperator |
114 | 117 |
|
115 | | - original_values = {} |
116 | | - try: |
117 | | - # Store original values and apply overrides |
| 118 | + @contextlib.contextmanager |
| 119 | + def _operator_defaults(overrides): |
| 120 | + # Patch OPERATOR_DEFAULTS |
118 | 121 | for key, value in overrides.items(): |
119 | | - original_values[key] = OPERATOR_DEFAULTS.get(key) |
120 | | - OPERATOR_DEFAULTS[key] = value |
| 122 | + monkeypatch.setitem(OPERATOR_DEFAULTS, key, value) |
| 123 | + |
| 124 | + # Patch module-level constants |
| 125 | + const_name = f"DEFAULT_{key.upper()}" |
| 126 | + if hasattr(abstract_op_module, const_name): |
| 127 | + monkeypatch.setattr(abstract_op_module, const_name, value) |
121 | 128 |
|
122 | 129 | # Clear the cache to ensure fresh generation |
123 | 130 | SerializedBaseOperator.generate_client_defaults.cache_clear() |
124 | 131 |
|
125 | | - yield |
126 | | - finally: |
127 | | - # Cleanup: restore original values |
128 | | - for key, original_value in original_values.items(): |
129 | | - if original_value is None and key in OPERATOR_DEFAULTS: |
130 | | - # Key didn't exist originally, remove it |
131 | | - del OPERATOR_DEFAULTS[key] |
132 | | - else: |
133 | | - # Restore original value |
134 | | - OPERATOR_DEFAULTS[key] = original_value |
| 132 | + try: |
| 133 | + yield |
| 134 | + finally: |
| 135 | + # Clear cache again to restore normal behavior |
| 136 | + SerializedBaseOperator.generate_client_defaults.cache_clear() |
135 | 137 |
|
136 | | - # Clear cache again to restore normal behavior |
137 | | - SerializedBaseOperator.generate_client_defaults.cache_clear() |
| 138 | + return _operator_defaults |
138 | 139 |
|
139 | 140 |
|
140 | 141 | AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3] |
@@ -4045,77 +4046,104 @@ def test_apply_defaults_to_encoded_op_none_inputs(self): |
4045 | 4046 | result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None) |
4046 | 4047 | assert result == encoded_op |
4047 | 4048 |
|
4048 | | - @operator_defaults({"retries": 2}) |
4049 | | - def test_multiple_tasks_share_client_defaults(self): |
| 4049 | + def test_multiple_tasks_share_client_defaults(self, operator_defaults): |
4050 | 4050 | """Test that multiple tasks can share the same client_defaults when there are actually non-default values.""" |
4051 | | - with DAG(dag_id="test_dag") as dag: |
4052 | | - BashOperator(task_id="task1", bash_command="echo 1") |
4053 | | - BashOperator(task_id="task2", bash_command="echo 2") |
| 4051 | + with operator_defaults({"retries": 2}): |
| 4052 | + with DAG(dag_id="test_dag") as dag: |
| 4053 | + BashOperator(task_id="task1", bash_command="echo 1") |
| 4054 | + BashOperator(task_id="task2", bash_command="echo 2") |
4054 | 4055 |
|
4055 | | - serialized = SerializedDAG.to_dict(dag) |
| 4056 | + serialized = SerializedDAG.to_dict(dag) |
4056 | 4057 |
|
4057 | | - # Should have one client_defaults section for all tasks |
4058 | | - assert "client_defaults" in serialized |
4059 | | - assert "tasks" in serialized["client_defaults"] |
| 4058 | + # Should have one client_defaults section for all tasks |
| 4059 | + assert "client_defaults" in serialized |
| 4060 | + assert "tasks" in serialized["client_defaults"] |
4060 | 4061 |
|
4061 | | - # All tasks should benefit from the same client_defaults |
4062 | | - client_defaults = serialized["client_defaults"]["tasks"] |
| 4062 | + # All tasks should benefit from the same client_defaults |
| 4063 | + client_defaults = serialized["client_defaults"]["tasks"] |
4063 | 4064 |
|
4064 | | - # Deserialize and check both tasks get the defaults |
4065 | | - deserialized_dag = SerializedDAG.from_dict(serialized) |
4066 | | - deserialized_task1 = deserialized_dag.get_task("task1") |
4067 | | - deserialized_task2 = deserialized_dag.get_task("task2") |
| 4065 | + # Deserialize and check both tasks get the defaults |
| 4066 | + deserialized_dag = SerializedDAG.from_dict(serialized) |
| 4067 | + deserialized_task1 = deserialized_dag.get_task("task1") |
| 4068 | + deserialized_task2 = deserialized_dag.get_task("task2") |
4068 | 4069 |
|
4069 | | - # Both tasks should have retries=2 from client_defaults |
4070 | | - assert deserialized_task1.retries == 2 |
4071 | | - assert deserialized_task2.retries == 2 |
| 4070 | + # Both tasks should have retries=2 from client_defaults |
| 4071 | + assert deserialized_task1.retries == 2 |
| 4072 | + assert deserialized_task2.retries == 2 |
4072 | 4073 |
|
4073 | | - # Both tasks should have the same default values from client_defaults |
4074 | | - for field in client_defaults: |
4075 | | - if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field): |
4076 | | - value1 = getattr(deserialized_task1, field) |
4077 | | - value2 = getattr(deserialized_task2, field) |
4078 | | - assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}" |
| 4074 | + # Both tasks should have the same default values from client_defaults |
| 4075 | + for field in client_defaults: |
| 4076 | + if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field): |
| 4077 | + value1 = getattr(deserialized_task1, field) |
| 4078 | + value2 = getattr(deserialized_task2, field) |
| 4079 | + assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}" |
| 4080 | + |
| 4081 | + def test_default_args_when_equal_to_schema_defaults(self, operator_defaults): |
| 4082 | + """Test that explicitly set values matching schema defaults are preserved when client_defaults differ.""" |
| 4083 | + with operator_defaults({"retries": 3}): |
| 4084 | + with DAG(dag_id="test_explicit_schema_default", default_args={"retries": 0}) as dag: |
| 4085 | + BashOperator(task_id="task1", bash_command="echo 1") |
| 4086 | + BashOperator(task_id="task2", bash_command="echo 1", retries=2) |
| 4087 | + |
| 4088 | + serialized = SerializedDAG.to_dict(dag) |
| 4089 | + |
| 4090 | + # verify client_defaults has retries=3 |
| 4091 | + assert "client_defaults" in serialized |
| 4092 | + assert "tasks" in serialized["client_defaults"] |
| 4093 | + client_defaults = serialized["client_defaults"]["tasks"] |
| 4094 | + assert client_defaults["retries"] == 3 |
| 4095 | + |
| 4096 | + task1_data = serialized["dag"]["tasks"][0]["__var"] |
| 4097 | + assert task1_data.get("retries", -1) == 0 |
| 4098 | + |
| 4099 | + task2_data = serialized["dag"]["tasks"][1]["__var"] |
| 4100 | + assert task2_data.get("retries", -1) == 2 |
| 4101 | + |
| 4102 | + deserialized_task1 = SerializedDAG.from_dict(serialized).get_task("task1") |
| 4103 | + assert deserialized_task1.retries == 0 |
| 4104 | + |
| 4105 | + deserialized_task2 = SerializedDAG.from_dict(serialized).get_task("task2") |
| 4106 | + assert deserialized_task2.retries == 2 |
4079 | 4107 |
|
4080 | 4108 |
|
4081 | 4109 | class TestMappedOperatorSerializationAndClientDefaults: |
4082 | 4110 | """Test MappedOperator serialization with client defaults and callback properties.""" |
4083 | 4111 |
|
4084 | | - @operator_defaults({"retry_delay": 200.0}) |
4085 | | - def test_mapped_operator_client_defaults_application(self): |
| 4112 | + def test_mapped_operator_client_defaults_application(self, operator_defaults): |
4086 | 4113 | """Test that client_defaults are correctly applied to MappedOperator during deserialization.""" |
4087 | | - with DAG(dag_id="test_mapped_dag") as dag: |
4088 | | - # Create a mapped operator |
4089 | | - BashOperator.partial( |
4090 | | - task_id="mapped_task", |
4091 | | - retries=5, # Override default |
4092 | | - ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) |
4093 | | - |
4094 | | - # Serialize the DAG |
4095 | | - serialized_dag = SerializedDAG.to_dict(dag) |
| 4114 | + with operator_defaults({"retry_delay": 200.0}): |
| 4115 | + with DAG(dag_id="test_mapped_dag") as dag: |
| 4116 | + # Create a mapped operator |
| 4117 | + BashOperator.partial( |
| 4118 | + task_id="mapped_task", |
| 4119 | + retries=5, # Override default |
| 4120 | + ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) |
| 4121 | + |
| 4122 | + # Serialize the DAG |
| 4123 | + serialized_dag = SerializedDAG.to_dict(dag) |
4096 | 4124 |
|
4097 | | - # Should have client_defaults section |
4098 | | - assert "client_defaults" in serialized_dag |
4099 | | - assert "tasks" in serialized_dag["client_defaults"] |
| 4125 | + # Should have client_defaults section |
| 4126 | + assert "client_defaults" in serialized_dag |
| 4127 | + assert "tasks" in serialized_dag["client_defaults"] |
4100 | 4128 |
|
4101 | | - # Deserialize and check that client_defaults are applied |
4102 | | - deserialized_dag = SerializedDAG.from_dict(serialized_dag) |
4103 | | - deserialized_task = deserialized_dag.get_task("mapped_task") |
| 4129 | + # Deserialize and check that client_defaults are applied |
| 4130 | + deserialized_dag = SerializedDAG.from_dict(serialized_dag) |
| 4131 | + deserialized_task = deserialized_dag.get_task("mapped_task") |
4104 | 4132 |
|
4105 | | - # Verify it's still a MappedOperator |
4106 | | - from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator |
| 4133 | + # Verify it's still a MappedOperator |
| 4134 | + from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator |
4107 | 4135 |
|
4108 | | - assert isinstance(deserialized_task, SchedulerMappedOperator) |
| 4136 | + assert isinstance(deserialized_task, SchedulerMappedOperator) |
4109 | 4137 |
|
4110 | | - # Check that client_defaults values are applied (e.g., retry_delay from client_defaults) |
4111 | | - client_defaults = serialized_dag["client_defaults"]["tasks"] |
4112 | | - if "retry_delay" in client_defaults: |
4113 | | - # If retry_delay wasn't explicitly set, it should come from client_defaults |
4114 | | - # Since we can't easily convert timedelta back, check the serialized format |
4115 | | - assert hasattr(deserialized_task, "retry_delay") |
| 4138 | + # Check that client_defaults values are applied (e.g., retry_delay from client_defaults) |
| 4139 | + client_defaults = serialized_dag["client_defaults"]["tasks"] |
| 4140 | + if "retry_delay" in client_defaults: |
| 4141 | + # If retry_delay wasn't explicitly set, it should come from client_defaults |
| 4142 | + # Since we can't easily convert timedelta back, check the serialized format |
| 4143 | + assert hasattr(deserialized_task, "retry_delay") |
4116 | 4144 |
|
4117 | | - # Explicit values should override client_defaults |
4118 | | - assert deserialized_task.retries == 5 # Explicitly set value |
| 4145 | + # Explicit values should override client_defaults |
| 4146 | + assert deserialized_task.retries == 5 # Explicitly set value |
4119 | 4147 |
|
4120 | 4148 | @pytest.mark.parametrize( |
4121 | 4149 | ["task_config", "dag_id", "task_id", "non_default_fields"], |
@@ -4146,45 +4174,45 @@ def test_mapped_operator_client_defaults_application(self): |
4146 | 4174 | ), |
4147 | 4175 | ], |
4148 | 4176 | ) |
4149 | | - @operator_defaults({"retry_delay": 200.0}) |
4150 | 4177 | def test_mapped_operator_client_defaults_optimization( |
4151 | | - self, task_config, dag_id, task_id, non_default_fields |
| 4178 | + self, task_config, dag_id, task_id, non_default_fields, operator_defaults |
4152 | 4179 | ): |
4153 | 4180 | """Test that MappedOperator serialization optimizes using client defaults.""" |
4154 | | - with DAG(dag_id=dag_id) as dag: |
4155 | | - # Create mapped operator with specified configuration |
4156 | | - BashOperator.partial( |
4157 | | - task_id=task_id, |
4158 | | - **task_config, |
4159 | | - ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) |
| 4181 | + with operator_defaults({"retry_delay": 200.0}): |
| 4182 | + with DAG(dag_id=dag_id) as dag: |
| 4183 | + # Create mapped operator with specified configuration |
| 4184 | + BashOperator.partial( |
| 4185 | + task_id=task_id, |
| 4186 | + **task_config, |
| 4187 | + ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) |
4160 | 4188 |
|
4161 | | - serialized_dag = SerializedDAG.to_dict(dag) |
4162 | | - mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"] |
| 4189 | + serialized_dag = SerializedDAG.to_dict(dag) |
| 4190 | + mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"] |
4163 | 4191 |
|
4164 | | - assert mapped_task_serialized is not None |
4165 | | - assert mapped_task_serialized.get("_is_mapped") is True |
| 4192 | + assert mapped_task_serialized is not None |
| 4193 | + assert mapped_task_serialized.get("_is_mapped") is True |
4166 | 4194 |
|
4167 | | - # Check optimization behavior |
4168 | | - client_defaults = serialized_dag["client_defaults"]["tasks"] |
4169 | | - partial_kwargs = mapped_task_serialized["partial_kwargs"] |
| 4195 | + # Check optimization behavior |
| 4196 | + client_defaults = serialized_dag["client_defaults"]["tasks"] |
| 4197 | + partial_kwargs = mapped_task_serialized["partial_kwargs"] |
4170 | 4198 |
|
4171 | | - # Check that all fields are optimized correctly |
4172 | | - for field, default_value in client_defaults.items(): |
4173 | | - if field in non_default_fields: |
4174 | | - # Non-default fields should be present in partial_kwargs |
4175 | | - assert field in partial_kwargs, ( |
4176 | | - f"Field '{field}' should be in partial_kwargs as it's non-default" |
4177 | | - ) |
4178 | | - # And have different values than defaults |
4179 | | - assert partial_kwargs[field] != default_value, ( |
4180 | | - f"Field '{field}' should have non-default value" |
4181 | | - ) |
4182 | | - else: |
4183 | | - # Default fields should either not be present or have different values if present |
4184 | | - if field in partial_kwargs: |
| 4199 | + # Check that all fields are optimized correctly |
| 4200 | + for field, default_value in client_defaults.items(): |
| 4201 | + if field in non_default_fields: |
| 4202 | + # Non-default fields should be present in partial_kwargs |
| 4203 | + assert field in partial_kwargs, ( |
| 4204 | + f"Field '{field}' should be in partial_kwargs as it's non-default" |
| 4205 | + ) |
| 4206 | + # And have different values than defaults |
4185 | 4207 | assert partial_kwargs[field] != default_value, ( |
4186 | | - f"Field '{field}' with default value should be optimized out" |
| 4208 | + f"Field '{field}' should have non-default value" |
4187 | 4209 | ) |
| 4210 | + else: |
| 4211 | + # Default fields should either not be present or have different values if present |
| 4212 | + if field in partial_kwargs: |
| 4213 | + assert partial_kwargs[field] != default_value, ( |
| 4214 | + f"Field '{field}' with default value should be optimized out" |
| 4215 | + ) |
4188 | 4216 |
|
4189 | 4217 | def test_mapped_operator_expand_input_preservation(self): |
4190 | 4218 | """Test that expand_input is correctly preserved during serialization.""" |
|
0 commit comments