Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions airflow-core/src/airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@
"value": { "$ref": "#/definitions/dict" }
}
},
"catchup": { "type": "boolean" },
"fail_fast": { "type": "boolean" },
"catchup": { "type": "boolean", "default": false },
"fail_fast": { "type": "boolean", "default": false },
"fileloc": { "type" : "string"},
"relative_fileloc": { "type" : "string"},
"_processor_dags_folder": {
Expand All @@ -198,27 +198,27 @@
]
},
"_concurrency": { "type" : "number"},
"max_active_tasks": { "type" : "number"},
"max_active_runs": { "type" : "number"},
"max_consecutive_failed_dag_runs": { "type" : "number"},
"max_active_tasks": { "type" : "number", "default": 16},
"max_active_runs": { "type" : "number", "default": 16},
"max_consecutive_failed_dag_runs": { "type" : "number", "default": 0},
"default_args": { "$ref": "#/definitions/dict" },
"start_date": { "$ref": "#/definitions/datetime" },
"end_date": { "$ref": "#/definitions/datetime" },
"dagrun_timeout": { "$ref": "#/definitions/timedelta" },
"doc_md": { "type" : "string"},
"access_control": {"$ref": "#/definitions/dict" },
"is_paused_upon_creation": { "type": "boolean" },
"has_on_success_callback": { "type": "boolean" },
"has_on_failure_callback": { "type": "boolean" },
"render_template_as_native_obj": { "type": "boolean" },
"has_on_success_callback": { "type": "boolean", "default": false },
"has_on_failure_callback": { "type": "boolean", "default": false },
"render_template_as_native_obj": { "type": "boolean", "default": false },
"tags": { "type": "array" },
"task_group": {"anyOf": [
{ "type": "null" },
{ "$ref": "#/definitions/task_group" }
]},
"edge_info": { "$ref": "#/definitions/edge_info" },
"dag_dependencies": { "$ref": "#/definitions/dag_dependencies" },
"disable_bundle_versioning": {"type": "boolean"}
"disable_bundle_versioning": {"type": "boolean", "default": false }
},
"required": [
"dag_id",
Expand Down
80 changes: 55 additions & 25 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,21 +1094,7 @@ def _deserialize_params_dict(cls, encoded_params: list[tuple[str, dict]]) -> Par
return ParamsDict(op_params)

@classmethod
def get_operator_optional_fields_from_schema(cls) -> set[str]:
schema_loader = cls._json_schema

if schema_loader is None:
return set()

schema_data = schema_loader.schema
operator_def = schema_data.get("definitions", {}).get("operator", {})
operator_fields = set(operator_def.get("properties", {}).keys())
required_fields = set(operator_def.get("required", []))

optional_fields = operator_fields - required_fields
return optional_fields

@classmethod
@lru_cache(maxsize=4) # Cache for "operator", "dag", and a few others
def get_schema_defaults(cls, object_type: str) -> dict[str, Any]:
"""
Extract default values from JSON schema for any object type.
Expand Down Expand Up @@ -1713,6 +1699,22 @@ def set_task_dag_references(task: SerializedOperator | MappedOperator, dag: Seri
# Bypass set_upstream etc here - it does more than we want
dag.task_dict[task_id].upstream_task_ids.add(task.task_id)

@classmethod
@lru_cache(maxsize=1) # Only one type: "operator"
def get_operator_optional_fields_from_schema(cls) -> set[str]:
schema_loader = cls._json_schema

if schema_loader is None:
return set()

schema_data = schema_loader.schema
operator_def = schema_data.get("definitions", {}).get("operator", {})
operator_fields = set(operator_def.get("properties", {}).keys())
required_fields = set(operator_def.get("required", []))

optional_fields = operator_fields - required_fields
return optional_fields

@classmethod
def deserialize_operator(
cls,
Expand Down Expand Up @@ -1814,7 +1816,7 @@ def detect_dependencies(cls, op: SdkOperator) -> set[DagDependency]:
return deps

@classmethod
def _matches_client_defaults(cls, var: Any, attrname: str, op: DAGNode) -> bool:
def _matches_client_defaults(cls, var: Any, attrname: str) -> bool:
"""
Check if a field value matches client_defaults and should be excluded.

Expand All @@ -1823,7 +1825,6 @@ def _matches_client_defaults(cls, var: Any, attrname: str, op: DAGNode) -> bool:

:param var: The value to check
:param attrname: The attribute name
:param op: The operator instance
:return: True if value matches client_defaults and should be excluded
"""
try:
Expand Down Expand Up @@ -1851,7 +1852,7 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
:return: True if a variable is excluded, False otherwise.
"""
# Check if value matches client_defaults (hierarchical defaults optimization)
if cls._matches_client_defaults(var, attrname, op):
if cls._matches_client_defaults(var, attrname):
return True
schema_defaults = cls.get_schema_defaults("operator")

Expand Down Expand Up @@ -2384,25 +2385,23 @@ class SerializedDAG(BaseSerialization):
_processor_dags_folder: str

def __init__(self, *, dag_id: str) -> None:
self.catchup = airflow_conf.getboolean("scheduler", "catchup_by_default")
self.catchup = False # Schema default
self.dag_id = self.dag_display_name = dag_id
self.dagrun_timeout = None
self.deadline = None
self.default_args = {}
self.description = None
self.disable_bundle_versioning = airflow_conf.getboolean("dag_processor", "disable_bundle_versioning")
self.disable_bundle_versioning = False
self.doc_md = None
self.edge_info = {}
self.end_date = None
self.fail_fast = False
self.has_on_failure_callback = False
self.has_on_success_callback = False
self.is_paused_upon_creation = None
self.max_active_runs = airflow_conf.getint("core", "max_active_runs_per_dag")
self.max_active_tasks = airflow_conf.getint("core", "max_active_tasks_per_dag")
self.max_consecutive_failed_dag_runs = airflow_conf.getint(
"core", "max_consecutive_failed_dag_runs_per_dag"
)
self.max_active_runs = 16 # Schema default
self.max_active_tasks = 16 # Schema default
self.max_consecutive_failed_dag_runs = 0 # Schema default
self.owner_links = {}
self.params = ParamsDict()
self.partial = False
Expand Down Expand Up @@ -2624,8 +2623,38 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
return False
if attrname == "dag_display_name" and var == op.dag_id:
return True

# DAG schema defaults exclusion (same pattern as SerializedBaseOperator)
dag_schema_defaults = cls.get_schema_defaults("dag")
if attrname in dag_schema_defaults:
if dag_schema_defaults[attrname] == var:
return True

optional_fields = cls.get_dag_optional_fields_from_schema()
if var is None:
return True
if attrname in optional_fields:
if var in [[], (), set(), {}]:
return True

return super()._is_excluded(var, attrname, op)

@classmethod
@lru_cache(maxsize=1) # Only one type: "dag"
def get_dag_optional_fields_from_schema(cls) -> set[str]:
schema_loader = cls._json_schema

if schema_loader is None:
return set()

schema_data = schema_loader.schema
operator_def = schema_data.get("definitions", {}).get("dag", {})
operator_fields = set(operator_def.get("properties", {}).keys())
required_fields = set(operator_def.get("required", []))

optional_fields = operator_fields - required_fields
return optional_fields

@classmethod
def to_dict(cls, var: Any) -> dict:
"""Stringifies DAGs and operators contained by var and returns a dict of var."""
Expand Down Expand Up @@ -3798,6 +3827,7 @@ class LazyDeserializedDAG(pydantic.BaseModel):
"dag_display_name",
"has_on_success_callback",
"has_on_failure_callback",
"tags",
# Attr properties that are nullable, or have a default that loads from config
"description",
"start_date",
Expand Down
110 changes: 94 additions & 16 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,8 @@
"downstream_task_ids": [],
},
"is_paused_upon_creation": False,
"max_active_runs": 16,
"max_active_tasks": 16,
"max_consecutive_failed_dag_runs": 0,
"dag_id": "simple_dag",
"deadline": None,
"catchup": False,
"disable_bundle_versioning": False,
"doc_md": "### DAG Tutorial Documentation",
"fileloc": None,
"_processor_dags_folder": (
Expand Down Expand Up @@ -269,7 +264,6 @@
},
],
"params": [],
"tags": [],
},
}

Expand Down Expand Up @@ -2177,7 +2171,8 @@ def test_dag_disable_bundle_versioning_roundtrip(self, dag_arg, conf_arg, expect
"""
with conf_vars({("dag_processor", "disable_bundle_versioning"): conf_arg}):
kwargs = {}
kwargs["disable_bundle_versioning"] = dag_arg
if dag_arg is not None:
kwargs["disable_bundle_versioning"] = dag_arg
dag = DAG(
dag_id="test_dag_disable_bundle_versioning_roundtrip",
schedule=None,
Expand Down Expand Up @@ -3299,17 +3294,34 @@ def test_handle_v1_serdag():
SerializedDAG.conversion_v1_to_v2(v1)
SerializedDAG.conversion_v2_to_v3(v1)

# Update a few subtle differences
v1["dag"]["tags"] = []
v1["dag"]["catchup"] = False
v1["dag"]["disable_bundle_versioning"] = False
dag = SerializedDAG.from_dict(v1)

expected = copy.deepcopy(serialized_simple_dag_ground_truth)
expected["dag"]["dag_dependencies"] = expected_dag_dependencies
del expected["dag"]["tasks"][1]["__var"]["_operator_extra_links"]
expected_sdag = copy.deepcopy(serialized_simple_dag_ground_truth)
expected = SerializedDAG.from_dict(expected_sdag)

fields_to_verify = set(vars(expected).keys()) - {
"task_group", # Tested separately
"dag_dependencies", # Tested separately
"last_loaded", # Dynamically set to utcnow
}

for f in fields_to_verify:
dag_value = getattr(dag, f)
expected_value = getattr(expected, f)

assert dag_value == expected_value, (
f"V2 DAG field '{f}' differs from V3: V2={dag_value!r} != V3={expected_value!r}"
)

for f in set(vars(expected.task_group).keys()) - {"dag"}:
dag_tg_value = getattr(dag.task_group, f)
expected_tg_value = getattr(expected.task_group, f)

assert dag_tg_value == expected_tg_value, (
f"V2 task_group field '{f}' differs: V2={dag_tg_value!r} != V3={expected_tg_value!r}"
)

del expected["client_defaults"]
assert v1 == expected
assert getattr(dag, "dag_dependencies") == expected_dag_dependencies


def test_handle_v2_serdag():
Expand Down Expand Up @@ -3514,6 +3526,72 @@ def test_handle_v2_serdag():
)


def test_dag_schema_defaults_optimization():
"""Test that DAG fields matching schema defaults are excluded from serialization."""

# Create DAG with all schema default values
dag_with_defaults = DAG(
dag_id="test_defaults_dag",
start_date=datetime(2023, 1, 1),
# These should match schema defaults and be excluded
catchup=False,
fail_fast=False,
max_active_runs=16,
max_active_tasks=16,
max_consecutive_failed_dag_runs=0,
render_template_as_native_obj=False,
disable_bundle_versioning=False,
# These should be excluded as None
description=None,
doc_md=None,
)

# Serialize and check exclusions
serialized = SerializedDAG.to_dict(dag_with_defaults)
dag_data = serialized["dag"]

# Schema default fields should be excluded
for field in SerializedDAG.get_schema_defaults("dag").keys():
assert field not in dag_data, f"Schema default field '{field}' should be excluded"

# None fields should also be excluded
none_fields = ["description", "doc_md"]
for field in none_fields:
assert field not in dag_data, f"None field '{field}' should be excluded"

# Test deserialization restores defaults correctly
deserialized_dag = SerializedDAG.from_dict(serialized)

# Verify schema defaults are restored
assert deserialized_dag.catchup is False
assert deserialized_dag.fail_fast is False
assert deserialized_dag.max_active_runs == 16
assert deserialized_dag.max_active_tasks == 16
assert deserialized_dag.max_consecutive_failed_dag_runs == 0
assert deserialized_dag.render_template_as_native_obj is False
assert deserialized_dag.disable_bundle_versioning is False

# Test with non-default values (should be included)
dag_non_defaults = DAG(
dag_id="test_non_defaults_dag",
start_date=datetime(2023, 1, 1),
catchup=True, # Non-default
max_active_runs=32, # Non-default
description="Test description", # Non-None
)

serialized_non_defaults = SerializedDAG.to_dict(dag_non_defaults)
dag_non_defaults_data = serialized_non_defaults["dag"]

# Non-default values should be included
assert "catchup" in dag_non_defaults_data
assert dag_non_defaults_data["catchup"] is True
assert "max_active_runs" in dag_non_defaults_data
assert dag_non_defaults_data["max_active_runs"] == 32
assert "description" in dag_non_defaults_data
assert dag_non_defaults_data["description"] == "Test description"


def test_email_optimization_removes_email_attrs_when_email_empty():
"""Test that email_on_failure and email_on_retry are removed when email is empty."""
with DAG(dag_id="test_email_optimization") as dag:
Expand Down
Loading