diff --git a/airflow-core/src/airflow/serialization/schema.json b/airflow-core/src/airflow/serialization/schema.json index c4740f346f3f8..7707494fce546 100644 --- a/airflow-core/src/airflow/serialization/schema.json +++ b/airflow-core/src/airflow/serialization/schema.json @@ -175,8 +175,8 @@ "value": { "$ref": "#/definitions/dict" } } }, - "catchup": { "type": "boolean" }, - "fail_fast": { "type": "boolean" }, + "catchup": { "type": "boolean", "default": false }, + "fail_fast": { "type": "boolean", "default": false }, "fileloc": { "type" : "string"}, "relative_fileloc": { "type" : "string"}, "_processor_dags_folder": { @@ -198,9 +198,9 @@ ] }, "_concurrency": { "type" : "number"}, - "max_active_tasks": { "type" : "number"}, - "max_active_runs": { "type" : "number"}, - "max_consecutive_failed_dag_runs": { "type" : "number"}, + "max_active_tasks": { "type" : "number", "default": 16}, + "max_active_runs": { "type" : "number", "default": 16}, + "max_consecutive_failed_dag_runs": { "type" : "number", "default": 0}, "default_args": { "$ref": "#/definitions/dict" }, "start_date": { "$ref": "#/definitions/datetime" }, "end_date": { "$ref": "#/definitions/datetime" }, @@ -208,9 +208,9 @@ "doc_md": { "type" : "string"}, "access_control": {"$ref": "#/definitions/dict" }, "is_paused_upon_creation": { "type": "boolean" }, - "has_on_success_callback": { "type": "boolean" }, - "has_on_failure_callback": { "type": "boolean" }, - "render_template_as_native_obj": { "type": "boolean" }, + "has_on_success_callback": { "type": "boolean", "default": false }, + "has_on_failure_callback": { "type": "boolean", "default": false }, + "render_template_as_native_obj": { "type": "boolean", "default": false }, "tags": { "type": "array" }, "task_group": {"anyOf": [ { "type": "null" }, @@ -218,7 +218,7 @@ ]}, "edge_info": { "$ref": "#/definitions/edge_info" }, "dag_dependencies": { "$ref": "#/definitions/dag_dependencies" }, - "disable_bundle_versioning": {"type": "boolean"} + "disable_bundle_versioning": {"type": "boolean", "default": false } }, "required": [ "dag_id", diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index ad85dcc380642..d7a4007e5888e 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -1094,21 +1094,7 @@ def _deserialize_params_dict(cls, encoded_params: list[tuple[str, dict]]) -> Par return ParamsDict(op_params) @classmethod - def get_operator_optional_fields_from_schema(cls) -> set[str]: - schema_loader = cls._json_schema - - if schema_loader is None: - return set() - - schema_data = schema_loader.schema - operator_def = schema_data.get("definitions", {}).get("operator", {}) - operator_fields = set(operator_def.get("properties", {}).keys()) - required_fields = set(operator_def.get("required", [])) - - optional_fields = operator_fields - required_fields - return optional_fields - - @classmethod + @lru_cache(maxsize=4) # Cache for "operator", "dag", and a few others def get_schema_defaults(cls, object_type: str) -> dict[str, Any]: """ Extract default values from JSON schema for any object type. @@ -1713,6 +1699,22 @@ def set_task_dag_references(task: SerializedOperator | MappedOperator, dag: Seri # Bypass set_upstream etc here - it does more than we want dag.task_dict[task_id].upstream_task_ids.add(task.task_id) + @classmethod + @lru_cache(maxsize=1) # Only one type: "operator" + def get_operator_optional_fields_from_schema(cls) -> set[str]: + schema_loader = cls._json_schema + + if schema_loader is None: + return set() + + schema_data = schema_loader.schema + operator_def = schema_data.get("definitions", {}).get("operator", {}) + operator_fields = set(operator_def.get("properties", {}).keys()) + required_fields = set(operator_def.get("required", [])) + + optional_fields = operator_fields - required_fields + return optional_fields + @classmethod def deserialize_operator( cls, @@ -1814,7 +1816,7 @@ def detect_dependencies(cls, op: SdkOperator) -> set[DagDependency]: return deps @classmethod - def _matches_client_defaults(cls, var: Any, attrname: str, op: DAGNode) -> bool: + def _matches_client_defaults(cls, var: Any, attrname: str) -> bool: """ Check if a field value matches client_defaults and should be excluded. @@ -1823,7 +1825,6 @@ def _matches_client_defaults(cls, var: Any, attrname: str, op: DAGNode) -> bool: :param var: The value to check :param attrname: The attribute name - :param op: The operator instance :return: True if value matches client_defaults and should be excluded """ try: @@ -1851,7 +1852,7 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): :return: True if a variable is excluded, False otherwise. """ # Check if value matches client_defaults (hierarchical defaults optimization) - if cls._matches_client_defaults(var, attrname, op): + if cls._matches_client_defaults(var, attrname): return True schema_defaults = cls.get_schema_defaults("operator") @@ -2384,13 +2385,13 @@ class SerializedDAG(BaseSerialization): _processor_dags_folder: str def __init__(self, *, dag_id: str) -> None: - self.catchup = airflow_conf.getboolean("scheduler", "catchup_by_default") + self.catchup = False # Schema default self.dag_id = self.dag_display_name = dag_id self.dagrun_timeout = None self.deadline = None self.default_args = {} self.description = None - self.disable_bundle_versioning = airflow_conf.getboolean("dag_processor", "disable_bundle_versioning") + self.disable_bundle_versioning = False self.doc_md = None self.edge_info = {} self.end_date = None @@ -2398,11 +2399,9 @@ def __init__(self, *, dag_id: str) -> None: self.has_on_failure_callback = False self.has_on_success_callback = False self.is_paused_upon_creation = None - self.max_active_runs = airflow_conf.getint("core", "max_active_runs_per_dag") - self.max_active_tasks = airflow_conf.getint("core", "max_active_tasks_per_dag") - self.max_consecutive_failed_dag_runs = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ) + self.max_active_runs = 16 # Schema default + self.max_active_tasks = 16 # Schema default + self.max_consecutive_failed_dag_runs = 0 # Schema default self.owner_links = {} self.params = ParamsDict() self.partial = False @@ -2624,8 +2623,38 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): return False if attrname == "dag_display_name" and var == op.dag_id: return True + + # DAG schema defaults exclusion (same pattern as SerializedBaseOperator) + dag_schema_defaults = cls.get_schema_defaults("dag") + if attrname in dag_schema_defaults: + if dag_schema_defaults[attrname] == var: + return True + + optional_fields = cls.get_dag_optional_fields_from_schema() + if var is None: + return True + if attrname in optional_fields: + if var in [[], (), set(), {}]: + return True + return super()._is_excluded(var, attrname, op) + @classmethod + @lru_cache(maxsize=1) # Only one type: "dag" + def get_dag_optional_fields_from_schema(cls) -> set[str]: + schema_loader = cls._json_schema + + if schema_loader is None: + return set() + + schema_data = schema_loader.schema + operator_def = schema_data.get("definitions", {}).get("dag", {}) + operator_fields = set(operator_def.get("properties", {}).keys()) + required_fields = set(operator_def.get("required", [])) + + optional_fields = operator_fields - required_fields + return optional_fields + @classmethod def to_dict(cls, var: Any) -> dict: """Stringifies DAGs and operators contained by var and returns a dict of var.""" @@ -3798,6 +3827,7 @@ class LazyDeserializedDAG(pydantic.BaseModel): "dag_display_name", "has_on_success_callback", "has_on_failure_callback", + "tags", # Attr properties that are nullable, or have a default that loads from config "description", "start_date", diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 4ed84c32fbe4c..102b092770f17 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -152,13 +152,8 @@ "downstream_task_ids": [], }, "is_paused_upon_creation": False, - "max_active_runs": 16, - "max_active_tasks": 16, - "max_consecutive_failed_dag_runs": 0, "dag_id": "simple_dag", "deadline": None, - "catchup": False, - "disable_bundle_versioning": False, "doc_md": "### DAG Tutorial Documentation", "fileloc": None, "_processor_dags_folder": ( @@ -269,7 +264,6 @@ }, ], "params": [], - "tags": [], }, } @@ -2177,7 +2171,8 @@ def test_dag_disable_bundle_versioning_roundtrip(self, dag_arg, conf_arg, expect """ with conf_vars({("dag_processor", "disable_bundle_versioning"): conf_arg}): kwargs = {} - kwargs["disable_bundle_versioning"] = dag_arg + if dag_arg is not None: + kwargs["disable_bundle_versioning"] = dag_arg dag = DAG( dag_id="test_dag_disable_bundle_versioning_roundtrip", schedule=None, @@ -3299,17 +3294,34 @@ def test_handle_v1_serdag(): SerializedDAG.conversion_v1_to_v2(v1) SerializedDAG.conversion_v2_to_v3(v1) - # Update a few subtle differences - v1["dag"]["tags"] = [] - v1["dag"]["catchup"] = False - v1["dag"]["disable_bundle_versioning"] = False + dag = SerializedDAG.from_dict(v1) - expected = copy.deepcopy(serialized_simple_dag_ground_truth) - expected["dag"]["dag_dependencies"] = expected_dag_dependencies - del expected["dag"]["tasks"][1]["__var"]["_operator_extra_links"] + expected_sdag = copy.deepcopy(serialized_simple_dag_ground_truth) + expected = SerializedDAG.from_dict(expected_sdag) + + fields_to_verify = set(vars(expected).keys()) - { + "task_group", # Tested separately + "dag_dependencies", # Tested separately + "last_loaded", # Dynamically set to utcnow + } + + for f in fields_to_verify: + dag_value = getattr(dag, f) + expected_value = getattr(expected, f) + + assert dag_value == expected_value, ( + f"V2 DAG field '{f}' differs from V3: V2={dag_value!r} != V3={expected_value!r}" + ) + + for f in set(vars(expected.task_group).keys()) - {"dag"}: + dag_tg_value = getattr(dag.task_group, f) + expected_tg_value = getattr(expected.task_group, f) + + assert dag_tg_value == expected_tg_value, ( + f"V2 task_group field '{f}' differs: V2={dag_tg_value!r} != V3={expected_tg_value!r}" + ) - del expected["client_defaults"] - assert v1 == expected + assert getattr(dag, "dag_dependencies") == expected_dag_dependencies def test_handle_v2_serdag(): @@ -3514,6 +3526,72 @@ def test_handle_v2_serdag(): ) +def test_dag_schema_defaults_optimization(): + """Test that DAG fields matching schema defaults are excluded from serialization.""" + + # Create DAG with all schema default values + dag_with_defaults = DAG( + dag_id="test_defaults_dag", + start_date=datetime(2023, 1, 1), + # These should match schema defaults and be excluded + catchup=False, + fail_fast=False, + max_active_runs=16, + max_active_tasks=16, + max_consecutive_failed_dag_runs=0, + render_template_as_native_obj=False, + disable_bundle_versioning=False, + # These should be excluded as None + description=None, + doc_md=None, + ) + + # Serialize and check exclusions + serialized = SerializedDAG.to_dict(dag_with_defaults) + dag_data = serialized["dag"] + + # Schema default fields should be excluded + for field in SerializedDAG.get_schema_defaults("dag").keys(): + assert field not in dag_data, f"Schema default field '{field}' should be excluded" + + # None fields should also be excluded + none_fields = ["description", "doc_md"] + for field in none_fields: + assert field not in dag_data, f"None field '{field}' should be excluded" + + # Test deserialization restores defaults correctly + deserialized_dag = SerializedDAG.from_dict(serialized) + + # Verify schema defaults are restored + assert deserialized_dag.catchup is False + assert deserialized_dag.fail_fast is False + assert deserialized_dag.max_active_runs == 16 + assert deserialized_dag.max_active_tasks == 16 + assert deserialized_dag.max_consecutive_failed_dag_runs == 0 + assert deserialized_dag.render_template_as_native_obj is False + assert deserialized_dag.disable_bundle_versioning is False + + # Test with non-default values (should be included) + dag_non_defaults = DAG( + dag_id="test_non_defaults_dag", + start_date=datetime(2023, 1, 1), + catchup=True, # Non-default + max_active_runs=32, # Non-default + description="Test description", # Non-None + ) + + serialized_non_defaults = SerializedDAG.to_dict(dag_non_defaults) + dag_non_defaults_data = serialized_non_defaults["dag"] + + # Non-default values should be included + assert "catchup" in dag_non_defaults_data + assert dag_non_defaults_data["catchup"] is True + assert "max_active_runs" in dag_non_defaults_data + assert dag_non_defaults_data["max_active_runs"] == 32 + assert "description" in dag_non_defaults_data + assert dag_non_defaults_data["description"] == "Test description" + + def test_email_optimization_removes_email_attrs_when_email_empty(): """Test that email_on_failure and email_on_retry are removed when email is empty.""" with DAG(dag_id="test_email_optimization") as dag: diff --git a/scripts/in_container/run_schema_defaults_check.py b/scripts/in_container/run_schema_defaults_check.py index ef672c297ca0c..bc7c1e2844cb3 100755 --- a/scripts/in_container/run_schema_defaults_check.py +++ b/scripts/in_container/run_schema_defaults_check.py @@ -32,7 +32,7 @@ from typing import Any -def load_schema_defaults() -> dict[str, Any]: +def load_schema_defaults(object_type: str = "operator") -> dict[str, Any]: """Load default values from the JSON schema.""" schema_path = Path("airflow-core/src/airflow/serialization/schema.json") @@ -43,9 +43,9 @@ def load_schema_defaults() -> dict[str, Any]: with open(schema_path) as f: schema = json.load(f) - # Extract defaults from the operator definition - operator_def = schema.get("definitions", {}).get("operator", {}) - properties = operator_def.get("properties", {}) + # Extract defaults from the specified object type definition + object_def = schema.get("definitions", {}).get(object_type, {}) + properties = object_def.get("properties", {}) defaults = {} for field_name, field_def in properties.items(): @@ -55,7 +55,7 @@ def load_schema_defaults() -> dict[str, Any]: return defaults -def get_server_side_defaults() -> dict[str, Any]: +def get_server_side_operator_defaults() -> dict[str, Any]: """Get default values from server-side SerializedBaseOperator class.""" try: from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -92,14 +92,46 @@ def get_server_side_defaults() -> dict[str, Any]: sys.exit(1) -def compare_defaults() -> list[str]: - """Compare schema defaults with server-side defaults and return discrepancies.""" - schema_defaults = load_schema_defaults() - server_defaults = get_server_side_defaults() +def get_server_side_dag_defaults() -> dict[str, Any]: + """Get default values from server-side SerializedDAG class.""" + try: + from airflow.serialization.serialized_objects import SerializedDAG + + # DAG defaults are set in __init__, so we create a temporary instance + temp_dag = SerializedDAG(dag_id="temp") + + # Get all serializable DAG fields from the server-side class + serialized_fields = SerializedDAG.get_serialized_fields() + + server_defaults = {} + for field_name in serialized_fields: + if hasattr(temp_dag, field_name): + default_value = getattr(temp_dag, field_name) + # Only include actual default values that are not None, callables, or descriptors + if not callable(default_value) and not isinstance(default_value, (property, type)): + if isinstance(default_value, (set, tuple)): + # Convert to list since schema.json is pure JSON + default_value = list(default_value) + server_defaults[field_name] = default_value + + return server_defaults + + except ImportError as e: + print(f"Error importing SerializedDAG: {e}") + sys.exit(1) + except Exception as e: + print(f"Error getting server-side DAG defaults: {e}") + sys.exit(1) + + +def compare_operator_defaults() -> list[str]: + """Compare operator schema defaults with server-side defaults and return discrepancies.""" + schema_defaults = load_schema_defaults("operator") + server_defaults = get_server_side_operator_defaults() errors = [] - print(f"Found {len(schema_defaults)} schema defaults") - print(f"Found {len(server_defaults)} server-side defaults") + print(f"Found {len(schema_defaults)} operator schema defaults") + print(f"Found {len(server_defaults)} operator server-side defaults") # Check each server default against schema for field_name, server_value in server_defaults.items(): @@ -141,25 +173,82 @@ def compare_defaults() -> list[str]: return errors +def compare_dag_defaults() -> list[str]: + """Compare DAG schema defaults with server-side defaults and return discrepancies.""" + schema_defaults = load_schema_defaults("dag") + server_defaults = get_server_side_dag_defaults() + errors = [] + + print(f"Found {len(schema_defaults)} DAG schema defaults") + print(f"Found {len(server_defaults)} DAG server-side defaults") + + # Check each server default against schema + for field_name, server_value in server_defaults.items(): + schema_value = schema_defaults.get(field_name) + + # Check if field exists in schema + if field_name not in schema_defaults: + # Some server fields don't need defaults in schema (like None values, empty collections, or computed fields) + if ( + server_value is not None + and server_value not in [[], {}, (), set()] + and field_name not in ["dag_id", "dag_display_name"] + ): + errors.append( + f"DAG server field '{field_name}' has default {server_value!r} but no schema default" + ) + continue + + # Direct comparison + if schema_value != server_value: + errors.append( + f"DAG field '{field_name}': schema default is {schema_value!r}, " + f"server default is {server_value!r}" + ) + + # Check for schema defaults that don't have corresponding server defaults + for field_name, schema_value in schema_defaults.items(): + if field_name not in server_defaults: + # Some schema fields are computed properties (like has_on_*_callback) + computed_properties = { + "has_on_success_callback", + "has_on_failure_callback", + } + if field_name not in computed_properties: + errors.append( + f"DAG schema has default for '{field_name}' = {schema_value!r} but no corresponding server default" + ) + + return errors + + def main(): """Main function to run the schema defaults check.""" - print("Checking schema defaults against server-side SerializedBaseOperator...") + print("Checking schema defaults against server-side serialization classes...") + + # Check Operator defaults + print("\n1. Checking Operator defaults...") + operator_errors = compare_operator_defaults() + + # Check Dag defaults + print("\n2. Checking Dag defaults...") + dag_errors = compare_dag_defaults() - errors = compare_defaults() + all_errors = operator_errors + dag_errors - if errors: - print("❌ Found discrepancies between schema and server defaults:") - for error in errors: + if all_errors: + print("\n❌ Found discrepancies between schema and server defaults:") + for error in all_errors: print(f" • {error}") print() print("To fix these issues:") print("1. Update airflow-core/src/airflow/serialization/schema.json to match server defaults, OR") print( - "2. Update airflow-core/src/airflow/serialization/serialized_objects.py class defaults to match schema" + "2. Update airflow-core/src/airflow/serialization/serialized_objects.py class/init defaults to match schema" ) sys.exit(1) else: - print("✅ All schema defaults match server-side defaults!") + print("\n✅ All schema defaults match server-side defaults!") if __name__ == "__main__":