Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2507,6 +2507,23 @@ def serialize_dag(cls, dag: DAG) -> dict:
serialized_dag["has_on_success_callback"] = True
if dag.has_on_failure_callback:
serialized_dag["has_on_failure_callback"] = True

# TODO: Move this logic to a better place -- ideally before serializing contents of default_args.
# There is some duplication with this and SerializedBaseOperator.partial_kwargs serialization.
# Ideally default_args goes through same logic as fields of SerializedBaseOperator.
if serialized_dag.get("default_args", {}):
default_args_dict = serialized_dag["default_args"][Encoding.VAR]
callbacks_to_remove = []
for k, v in list(default_args_dict.items()):
if k in [
f"on_{x}_callback" for x in ("execute", "failure", "success", "retry", "skipped")
]:
if bool(v):
default_args_dict[f"has_{k}"] = True
callbacks_to_remove.append(k)
for k in callbacks_to_remove:
del default_args_dict[k]

return serialized_dag
except SerializationError:
raise
Expand Down
62 changes: 62 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4292,3 +4292,65 @@ def test_partial_kwargs_end_to_end_deserialization(self):
assert "owner" in deserialized_task.partial_kwargs
assert deserialized_task.partial_kwargs["retry_delay"] == timedelta(seconds=600)
assert deserialized_task.partial_kwargs["owner"] == "custom_owner"


@pytest.mark.parametrize(
["callbacks", "expected_has_flags", "absent_keys"],
[
pytest.param(
{
"on_failure_callback": lambda ctx: None,
"on_success_callback": lambda ctx: None,
"on_retry_callback": lambda ctx: None,
},
["has_on_failure_callback", "has_on_success_callback", "has_on_retry_callback"],
["on_failure_callback", "on_success_callback", "on_retry_callback"],
id="multiple_callbacks",
),
pytest.param(
{"on_failure_callback": lambda ctx: None},
["has_on_failure_callback"],
["on_failure_callback", "has_on_success_callback", "on_success_callback"],
id="single_callback",
),
pytest.param(
{"on_failure_callback": lambda ctx: None, "on_execute_callback": None},
["has_on_failure_callback"],
["on_failure_callback", "has_on_execute_callback", "on_execute_callback"],
id="callback_with_none",
),
pytest.param(
{},
[],
[
"has_on_execute_callback",
"has_on_failure_callback",
"has_on_success_callback",
"has_on_retry_callback",
"has_on_skipped_callback",
],
id="no_callbacks",
),
],
)
def test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags, absent_keys):
"""Test callbacks in DAG default_args are serialized as boolean flags."""
default_args = {"owner": "test_owner", "retries": 2, **callbacks}

with DAG(dag_id="test_default_args_callbacks", default_args=default_args) as dag:
BashOperator(task_id="task1", bash_command="echo 1", dag=dag)

serialized_dag_dict = SerializedDAG.serialize_dag(dag)
default_args_dict = serialized_dag_dict["default_args"][Encoding.VAR]

for flag in expected_has_flags:
assert default_args_dict.get(flag) is True

for key in absent_keys:
assert key not in default_args_dict

assert default_args_dict["owner"] == "test_owner"
assert default_args_dict["retries"] == 2

deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag_dict)
assert deserialized_dag.dag_id == "test_default_args_callbacks"