From eb97006f2b5cd9887ad13d45eccf0d5d31ad9c72 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 15 Aug 2025 01:58:39 +0100 Subject: [PATCH] Decouple serialization and deserialization code for tasks Remove Task SDK dependencies from airflow-core deserialization by establishing a schema-based contract between client and server components. This change enables independent deployment and upgrades while laying the foundation for multi-language SDK support. Key Decoupling Achievements: - Replace dynamic get_serialized_fields() calls with hardcoded class methods - Add schema-driven default resolution with get_operator_defaults_from_schema() - Remove OPERATOR_DEFAULTS import dependency from airflow-core - Implement SerializedBaseOperator class attributes for all operator defaults - Update _is_excluded() logic to use schema defaults for efficient serialization Serialization Optimizations: - Unified partial_kwargs optimization supporting both encoded/non-encoded formats - Intelligent default exclusion reducing storage redundancy - MappedOperator.operator_class memory optimization (~90-95% reduction) - Comprehensive client_defaults system with hierarchical resolution Compatibility & Performance: - Significant size reduction for typical DAGs with mapped operators - Minimal overhead for client_defaults section (excellent efficiency) - All existing serialized DAGs continue to work unchanged Technical Implementation: - Add generate_client_defaults() with LRU caching for optimal performance - Implement _deserialize_partial_kwargs() supporting dual formats - Centralized field deserialization eliminating code duplication - Consolidated preprocessing logic in _preprocess_encoded_operator() - Callback field preprocessing for backward compatibility Testing & Validation: - Added TestMappedOperatorSerializationAndClientDefaults with 9 comprehensive tests - Parameterized tests for multiple serialization formats - End-to-end validation of serialization/deserialization workflows - Backward compatibility validation for callback field migration This decoupling enables independent deployment/upgrades and provides the foundation for multi-language SDK ecosystem alongside the Task Execution API. Part of https://github.com/apache/airflow/issues/45428 --- .pre-commit-config.yaml | 8 + .../dag-serialization.rst | 137 ++++ .../src/airflow/jobs/scheduler_job_runner.py | 4 +- airflow-core/src/airflow/models/dagrun.py | 4 +- .../src/airflow/models/mappedoperator.py | 101 ++- .../src/airflow/serialization/json_schema.py | 2 + .../src/airflow/serialization/schema.json | 87 +- .../serialization/serialized_objects.py | 707 ++++++++++++---- .../serialization/test_dag_serialization.py | 761 +++++++++++++++--- .../providers/openlineage/utils/utils.py | 2 +- .../unit/openlineage/utils/test_utils.py | 2 +- scripts/ci/prek/check_schema_defaults.py | 44 + .../in_container/run_schema_defaults_check.py | 166 ++++ task-sdk/src/airflow/sdk/bases/operator.py | 46 +- .../airflow/sdk/definitions/mappedoperator.py | 36 +- 15 files changed, 1736 insertions(+), 371 deletions(-) create mode 100755 scripts/ci/prek/check_schema_defaults.py create mode 100755 scripts/in_container/run_schema_defaults_check.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 35c062e6e252c..cea777e49ed76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1540,6 +1540,7 @@ repos: ^airflow-core/tests/unit/core/test_configuration\.py$| ^airflow-core/tests/unit/models/test_renderedtifields\.py$| ^airflow-core/tests/unit/models/test_variable\.py$ + - id: check-sdk-imports name: Check for SDK imports in core files entry: ./scripts/ci/prek/check_sdk_imports.py @@ -1638,3 +1639,10 @@ repos: ^airflow-core/src/airflow/utils/trigger_rule\.py$| ^airflow-core/src/airflow/utils/types\.py$ ## ONLY ADD PREK HOOKS HERE THAT REQUIRE CI IMAGE + - id: check-schema-defaults + name: Check schema defaults match server-side defaults + entry: ./scripts/ci/prek/check_schema_defaults.py + language: python + files: ^airflow-core/src/airflow/serialization/schema\.json$|^airflow-core/src/airflow/serialization/serialized_objects\.py$ + pass_filenames: false + require_serial: true diff --git a/airflow-core/docs/administration-and-deployment/dag-serialization.rst b/airflow-core/docs/administration-and-deployment/dag-serialization.rst index 3bdb918400794..7c1d9c5435cde 100644 --- a/airflow-core/docs/administration-and-deployment/dag-serialization.rst +++ b/airflow-core/docs/administration-and-deployment/dag-serialization.rst @@ -119,3 +119,140 @@ define a ``json`` variable in local Airflow settings (``airflow_local_settings.p See :ref:`Configuring local settings ` for details on how to configure local settings. + + +.. _dag-serialization-defaults: + +DAG Serialization with Default Values (Airflow 3.1+) +------------------------------------------------------ + +Starting with Airflow 3.1, DAG serialization establishes a versioned contract between Task SDKs +and Airflow server components (Scheduler & API-Server). Combined with the Task Execution API, this +decouples client and server components, enabling independent deployments and upgrades while maintaining +backward compatibility and automatic default value resolution. + +How Default Values Work +~~~~~~~~~~~~~~~~~~~~~~~ + +When Airflow processes DAGs, it applies default values in a specific order of precedence for the server: + +1. **Schema defaults**: Built-in Airflow defaults (lowest priority) +2. **Client defaults**: SDK-specific defaults +3. **DAG default_args**: DAG-level settings (existing behavior) +4. **Partial arguments**: MappedOperator shared values +5. **Task values**: Explicit task settings (highest priority) + +This means you can set defaults at different levels and more specific settings will override +more general ones. + +JSON Structure +~~~~~~~~~~~~~~ + +Serialized DAGs now include a ``client_defaults`` section that contains common default values: + +.. code-block:: json + + { + "__version": 2, + "client_defaults": { + "tasks": { + "retry_delay": 300.0, + "owner": "data_team" + } + }, + "dag": { + "dag_id": "example_dag", + "default_args": { + "retries": 3 + }, + "tasks": [{ + "task_id": "example_task", + "task_type": "BashOperator", + "_task_module": "airflow.operators.bash", + "bash_command": "echo hello", + "owner": "specific_owner" + }] + } + } + +How Values Are Applied +~~~~~~~~~~~~~~~~~~~~~~ + +In the example above, the task ``example_task`` will have these final values: + +- **retry_delay**: 300.0 (from client_defaults.tasks) +- **owner**: "data_team" (from client_defaults.tasks) +- **retries**: 3 (from dag.default_args, overrides client_defaults) +- **bash_command**: "echo hello" (explicit task value) +- **pool**: "default_pool" (from schema defaults) + +The system automatically fills in any missing values by walking up the hierarchy. + +MappedOperator Default Handling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MappedOperators (dynamic task mapping) also participate in the default value system: + +.. code-block:: python + + # DAG Definition + BashOperator.partial(task_id="mapped_task", retries=2, owner="team_lead").expand( + bash_command=["echo 1", "echo 2", "echo 3"] + ) + +In this example, each of the three generated task instances will inherit: + +- **retries**: 2 (from partial arguments) +- **owner**: "team_lead" (from partial arguments) +- **pool**: "default_pool" (from client_defaults, since not specified in partial) +- **bash_command**: "echo 1", "echo 2", or "echo 3" respectively (from expand) + +Independent Deployment Architecture +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Decoupled Components:** +The serialization contract, combined with the Task Execution API, enables complete separation between: + +- **Server Components** (Scheduler, API-Server): Handle orchestration, don't run user code +- **Client Components** (Task SDK, DAG processing): Run user code in isolated environments + +**Key Benefits:** + +- **Independent upgrades**: Upgrade server components without touching user environments +- **Version compatibility**: Single server version supports multiple SDK versions simultaneously +- **Deployment flexibility**: Server and client components can be deployed and scaled separately +- **Security isolation**: User code runs only in client environments, never on server components +- **Multi-language SDK support**: Any language can implement a compliant Task SDK + +**SDK Requirements:** +Any Task SDK implementation must: + +1. **Follow published schemas**: + - DAG serialization: Produce JSON that validates against schema. Example: ``https://airflow.apache.org/schemas/dag-serialization/v2.json`` + - Task execution: Support runtime communication via Execution API schema. Example: ``https://airflow.apache.org/schemas/execution-api/2025-05-20.json`` +2. **Include client_defaults**: Optionally, provide SDK-specific defaults in the ``client_defaults.tasks`` section +3. **Use proper versioning**: Include ``__version`` field to indicate serialization format + +**Server Guarantees:** +As long as SDKs conform to both schema contracts, Airflow server components will: + +- Correctly deserialize DAGs from any compliant SDK +- Support task execution communication during runtime +- Apply appropriate default values according to the hierarchy +- Maintain compatibility across SDK versions and languages + +Implementation Status +~~~~~~~~~~~~~~~~~~~~~ + +**Current State (Airflow 3.1):** +The serialization contract establishes the foundation for client/server decoupling. While some +server components still contain Task SDK code (and vice versa), the contract ensures that: + +- **Schema compliance** enables independent deployment when components are separated +- **Version compatibility** works regardless of code coupling +- **Deployment separation** is architecturally supported even if not yet fully implemented + +**Future Evolution:** +Complete code decoupling between server and client components is planned for future releases. +The schema contract provides the stable interface that will remain consistent as this evolution +continues. diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 20a92efeb8609..8128761077e92 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -897,7 +897,7 @@ def process_executor_events( ti.set_state(state) continue ti.task = task - if task.on_retry_callback or task.on_failure_callback: + if task.has_on_retry_callback or task.has_on_failure_callback: # Only log the error/extra info here, since the `ti.handle_failure()` path will log it # too, which would lead to double logging cls.logger().error(msg) @@ -2033,7 +2033,7 @@ def _maybe_requeue_stuck_ti(self, *, ti, session, executor): exc_info=True, ) else: - if task.on_failure_callback: + if task.has_on_failure_callback: if inspect(ti).detached: ti = session.merge(ti) request = TaskCallbackRequest( diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index ccf01c65e19da..2b79da564235c 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1998,8 +1998,8 @@ def schedule_tis( assert isinstance(task, Operator) if ( task.inherits_from_empty_operator - and not task.on_execute_callback - and not task.on_success_callback + and not task.has_on_execute_callback + and not task.has_on_success_callback and not task.outlets and not task.inlets ): diff --git a/airflow-core/src/airflow/models/mappedoperator.py b/airflow-core/src/airflow/models/mappedoperator.py index 01f0c04205b70..6e086beeaaae3 100644 --- a/airflow-core/src/airflow/models/mappedoperator.py +++ b/airflow-core/src/airflow/models/mappedoperator.py @@ -30,19 +30,7 @@ from airflow.exceptions import AirflowException from airflow.sdk import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions._internal.abstractoperator import ( - DEFAULT_EXECUTOR, - DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - DEFAULT_OWNER, - DEFAULT_POOL_NAME, - DEFAULT_POOL_SLOTS, - DEFAULT_PRIORITY_WEIGHT, - DEFAULT_QUEUE, - DEFAULT_RETRIES, - DEFAULT_RETRY_DELAY, - DEFAULT_TRIGGER_RULE, - DEFAULT_WEIGHT_RULE, NotMapped, - TaskStateChangeCallbackAttrType, ) from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator @@ -91,6 +79,7 @@ def is_mapped(task: Operator) -> TypeGuard[MappedOperator]: class MappedOperator(DAGNode): """Object representing a mapped operator in a DAG.""" + # Stores minimal class type information (task_type, _operator_name) instead of full serialized data operator_class: dict[str, Any] partial_kwargs: dict[str, Any] = attrs.field(init=False, factory=dict) @@ -107,10 +96,10 @@ class MappedOperator(DAGNode): _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream") _is_sensor: bool = attrs.field(alias="is_sensor", default=False) _task_module: str - _task_type: str + task_type: str _operator_name: str - start_trigger_args: StartTriggerArgs | None - start_from_trigger: bool + start_trigger_args: StartTriggerArgs | None = None + start_from_trigger: bool = False _needs_expansion: bool = True dag: SchedulerDAG = attrs.field(init=False) @@ -154,11 +143,6 @@ def leaves(self) -> Sequence[DAGNode]: # TODO (GH-52141): Review if any of the properties below are used in the # SDK and the scheduler, and remove those not needed. - @property - def task_type(self) -> str: - """Implementing Operator.""" - return self._task_type - @property def operator_name(self) -> str: return self._operator_name @@ -186,11 +170,11 @@ def inherits_from_skipmixin(self) -> bool: @property def owner(self) -> str: - return self.partial_kwargs.get("owner", DEFAULT_OWNER) + return self.partial_kwargs.get("owner", SerializedBaseOperator.owner) @property def trigger_rule(self) -> TriggerRule: - return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) + return self.partial_kwargs.get("trigger_rule", SerializedBaseOperator.trigger_rule) @property def is_setup(self) -> bool: @@ -206,7 +190,9 @@ def depends_on_past(self) -> bool: @property def ignore_first_depends_on_past(self) -> bool: - value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) + value = self.partial_kwargs.get( + "ignore_first_depends_on_past", SerializedBaseOperator.ignore_first_depends_on_past + ) return bool(value) @property @@ -215,19 +201,19 @@ def wait_for_downstream(self) -> bool: @property def retries(self) -> int: - return self.partial_kwargs.get("retries", DEFAULT_RETRIES) + return self.partial_kwargs.get("retries", SerializedBaseOperator.retries) @property def queue(self) -> str: - return self.partial_kwargs.get("queue", DEFAULT_QUEUE) + return self.partial_kwargs.get("queue", SerializedBaseOperator.queue) @property def pool(self) -> str: - return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME) + return self.partial_kwargs.get("pool", SerializedBaseOperator.pool) @property def pool_slots(self) -> int: - return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) + return self.partial_kwargs.get("pool_slots", SerializedBaseOperator.pool_slots) @property def resources(self) -> Resources | None: @@ -242,24 +228,24 @@ def max_active_tis_per_dagrun(self) -> int | None: return self.partial_kwargs.get("max_active_tis_per_dagrun") @property - def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_execute_callback") or [] + def has_on_execute_callback(self) -> bool: + return bool(self.partial_kwargs.get("has_on_execute_callback", False)) @property - def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_failure_callback") or [] + def has_on_failure_callback(self) -> bool: + return bool(self.partial_kwargs.get("has_on_failure_callback", False)) @property - def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_retry_callback") or [] + def has_on_retry_callback(self) -> bool: + return bool(self.partial_kwargs.get("has_on_retry_callback", False)) @property - def on_success_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_success_callback") or [] + def has_on_success_callback(self) -> bool: + return bool(self.partial_kwargs.get("has_on_success_callback", False)) @property - def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_skipped_callback") or [] + def has_on_skipped_callback(self) -> bool: + return bool(self.partial_kwargs.get("has_on_skipped_callback", False)) @property def run_as_user(self) -> str | None: @@ -267,11 +253,11 @@ def run_as_user(self) -> str | None: @property def priority_weight(self) -> int: - return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) + return self.partial_kwargs.get("priority_weight", SerializedBaseOperator.priority_weight) @property def retry_delay(self) -> datetime.timedelta: - return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) + return self.partial_kwargs["retry_delay"] @property def retry_exponential_backoff(self) -> bool: @@ -280,12 +266,12 @@ def retry_exponential_backoff(self) -> bool: @property def weight_rule(self) -> PriorityWeightStrategy: return validate_and_load_priority_weight_strategy( - self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + self.partial_kwargs.get("weight_rule", SerializedBaseOperator._weight_rule) ) @property def executor(self) -> str | None: - return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) + return self.partial_kwargs.get("executor") @property def executor_config(self) -> dict: @@ -311,8 +297,37 @@ def on_failure_fail_dagrun(self) -> bool: def on_failure_fail_dagrun(self, v) -> None: self.partial_kwargs["on_failure_fail_dagrun"] = bool(v) - def get_serialized_fields(self): - return TaskSDKMappedOperator.get_serialized_fields() + @classmethod + def get_serialized_fields(cls): + """Fields to extract from JSON-Serialized DAG.""" + return frozenset( + { + "_disallow_kwargs_override", + "_expand_input_attr", + "_is_sensor", + "_needs_expansion", + "_operator_name", + "_task_module", + "downstream_task_ids", + "end_date", + "operator_extra_links", + "params", + "partial_kwargs", + "start_date", + "start_from_trigger", + "start_trigger_args", + "task_id", + "task_type", + "template_ext", + "template_fields", + "template_fields_renderers", + "ui_color", + "ui_fgcolor", + # TODO: Need to verify if the following two are needed on the server side. + "expand_input", + "op_kwargs_expand_input", + } + ) @functools.cached_property def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]: diff --git a/airflow-core/src/airflow/serialization/json_schema.py b/airflow-core/src/airflow/serialization/json_schema.py index e2b237e3b67d9..c74f207048765 100644 --- a/airflow-core/src/airflow/serialization/json_schema.py +++ b/airflow-core/src/airflow/serialization/json_schema.py @@ -39,6 +39,8 @@ class Validator(Protocol): Hence, you can not have ``type: Draft7Validator``. """ + schema: dict + def is_valid(self, instance) -> bool: """Check if the instance is valid under the current schema.""" ... diff --git a/airflow-core/src/airflow/serialization/schema.json b/airflow-core/src/airflow/serialization/schema.json index aa13f618f844d..0ce253e2262c2 100644 --- a/airflow-core/src/airflow/serialization/schema.json +++ b/airflow-core/src/airflow/serialization/schema.json @@ -261,58 +261,78 @@ "template_fields" ], "properties": { - "task_type": { "type": "string" }, + "task_type": { "type": "string", "default": "BaseOperator"}, "_task_module": { "type": "string" }, "_operator_extra_links": { "$ref": "#/definitions/extra_links" }, "task_id": { "type": "string" }, - "task_display_name": { "type": "string" }, - "label": { "type": "string" }, - "owner": { "type": "string" }, + "_task_display_name": { "type": "string" }, + "owner": { "type": "string", "default": "airflow" }, "start_date": { "$ref": "#/definitions/datetime" }, "end_date": { "$ref": "#/definitions/datetime" }, - "trigger_rule": { "type": "string" }, - "depends_on_past": { "type": "boolean" }, - "ignore_first_depends_on_past": { "type": "boolean" }, - "wait_for_past_depends_before_skipping": { "type": "boolean" }, - "wait_for_downstream": { "type": "boolean" }, - "retries": { "type": "number" }, - "queue": { "type": "string" }, - "pool": { "type": "string" }, - "pool_slots": { "type": "number" }, + "trigger_rule": { "type": "string", "default": "all_success" }, + "depends_on_past": { "type": "boolean", "default": false }, + "ignore_first_depends_on_past": { "type": "boolean", "default": false }, + "wait_for_past_depends_before_skipping": { "type": "boolean", "default": false }, + "wait_for_downstream": { "type": "boolean", "default": false }, + "retries": { "type": "number", "default": 0 }, + "queue": { "type": "string", "default": "default" }, + "pool": { "type": "string", "default": "default_pool" }, + "pool_slots": { "type": "number", "default": 1 }, "execution_timeout": { "$ref": "#/definitions/timedelta" }, "retry_delay": { "$ref": "#/definitions/timedelta" }, - "retry_exponential_backoff": { "type": "boolean" }, + "retry_exponential_backoff": { "type": "boolean", "default": false }, "max_retry_delay": { "$ref": "#/definitions/timedelta" }, "params": { "$ref": "#/definitions/params" }, - "priority_weight": { "type": "number" }, - "weight_rule": { "type": "string" }, + "priority_weight": { "type": "number", "default": 1 }, + "weight_rule": { "type": "string", "default": "downstream" }, "executor": { "type": "string" }, "executor_config": { "$ref": "#/definitions/dict" }, - "do_xcom_push": { "type": "boolean" }, - "ui_color": { "$ref": "#/definitions/color" }, - "ui_fgcolor": { "$ref": "#/definitions/color" }, + "do_xcom_push": { "type": "boolean", "default": true }, + "email_on_failure": { "type": "boolean", "default": true }, + "email_on_retry": { "type": "boolean", "default": true }, + "ui_color": { "type": "string", "default": "#fff" }, + "ui_fgcolor": { "type": "string", "default": "#000" }, "template_fields": { "type": "array", - "items": { "type": "string" } + "items": { "type": "string" }, + "default": [] }, + "template_ext": {"type": "array", "default": []}, + "template_fields_renderers": {"$ref": "#/definitions/dict", "default": {}}, "downstream_task_ids": { "type": "array", - "items": { "type": "string" } + "items": { "type": "string" }, + "default": [] }, - "_is_dummy": { "type": "boolean" }, "doc": { "type": "string" }, "doc_md": { "type": "string" }, "doc_json": { "type": "string" }, "doc_yaml": { "type": "string" }, "doc_rst": { "type": "string" }, "_logger_name": { "type": "string" }, - "_log_config_logger_name": { "type": "string" }, - "_is_mapped": { "const": true, "$comment": "only present when True" }, - "_is_sensor": { "const": true, "$comment": "only present when True" }, - "expand_input": { "type": "object" }, + "_needs_expansion": { "type": "boolean"}, + "_is_mapped": { "const": true, "$comment": "only present when True", "default": false }, + "_is_sensor": { "const": true, "$comment": "only present when True", "default": false }, "partial_kwargs": { "type": "object" }, + "_disallow_kwargs_override": { "type": "boolean"}, + "_expand_input_attr": { "type": "string" }, "map_index_template": { "type": "string" }, - "allow_nested_operators": { "type": "boolean" } + "allow_nested_operators": { "type": "boolean", "default": true }, + "inlets": {"type": "array", "default": []}, + "outlets": {"type": "array", "default": []}, + "has_on_execute_callback": {"type": "boolean", "default": false}, + "has_on_failure_callback": {"type": "boolean", "default": false}, + "has_on_skipped_callback": {"type": "boolean", "default": false}, + "has_on_success_callback": {"type": "boolean", "default": false}, + "has_on_retry_callback": {"type": "boolean", "default": false}, + "multiple_outputs": {"type": "boolean", "default": false}, + "start_from_trigger": {"type": "boolean", "default": false}, + "start_trigger_args": {"type": "object", "default": null}, + "is_setup": {"type": "boolean", "default": false}, + "is_teardown": {"type": "boolean", "default": false}, + "on_failure_fail_dagrun": {"type": "boolean", "default": false}, + "max_active_tis_per_dag": {"type": "integer"}, + "max_active_tis_per_dagrun": {"type": "integer"} }, "dependencies": { "expand_input": ["partial_kwargs", "_is_mapped"], @@ -391,7 +411,18 @@ "type": "integer", "exclusiveMinimum": 0 }, - "dag": { "$ref": "#/definitions/dag" } + "dag": { "$ref": "#/definitions/dag" }, + "client_defaults": { + "type": "object", + "description": "SDK-specific default values that differ from schema defaults", + "properties": { + "tasks": { + "type": "object", + "description": "Task-level default overrides" + } + }, + "additionalProperties": false + } }, "additionalProperties": false, "required": [ "__version", "dag" ] diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 3f60d059edc81..0a33a1ecf25b8 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -28,7 +28,7 @@ import math import weakref from collections.abc import Collection, Generator, Iterable, Iterator, Mapping, Sequence -from functools import cache, cached_property +from functools import cached_property, lru_cache from inspect import signature from textwrap import dedent from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, TypeAlias, TypeVar, cast, overload @@ -52,7 +52,6 @@ from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.sdk import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, BaseOperator, XComArg from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler? -from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.asset import ( AssetAliasEvent, @@ -77,6 +76,7 @@ PriorityWeightStrategy, airflow_priority_weight_strategies, airflow_priority_weight_strategies_classes, + validate_and_load_priority_weight_strategy, ) from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep @@ -134,21 +134,6 @@ log = logging.getLogger(__name__) -@cache -def _get_default_mapped_partial() -> dict[str, Any]: - """ - Get default partial kwargs in a mapped operator. - - This is used to simplify a serialized mapped operator by excluding default - values supplied in the implementation from the serialized dict. Since those - are defaults, they are automatically supplied on de-serialization, so we - don't need to store them. - """ - # Use the private _expand() method to avoid the empty kwargs check. - default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs - return BaseSerialization.serialize(default)[Encoding.VAR] - - def encode_relativedelta(var: relativedelta.relativedelta) -> dict[str, Any]: """Encode a relativedelta object.""" encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v} @@ -691,7 +676,13 @@ def serialize_to_json( elif key == "timetable" and value is not None: serialized_object[key] = encode_timetable(value) elif key == "weight_rule" and value is not None: - serialized_object[key] = encode_priority_weight_strategy(value) + encoded_priority_weight_strategy = encode_priority_weight_strategy(value) + + # Exclude if it is just default + default_pri_weight_stra = cls.get_schema_defaults("operator").get(key, None) + if default_pri_weight_stra != encoded_priority_weight_strategy: + serialized_object[key] = encoded_priority_weight_strategy + else: value = cls.serialize(value) if isinstance(value, dict) and Encoding.TYPE in value: @@ -1075,6 +1066,47 @@ def _deserialize_params_dict(cls, encoded_params: list[tuple[str, dict]]) -> Par return ParamsDict(op_params) + @classmethod + def get_operator_optional_fields_from_schema(cls) -> set[str]: + schema_loader = cls._json_schema + + if schema_loader is None: + return set() + + schema_data = schema_loader.schema + operator_def = schema_data.get("definitions", {}).get("operator", {}) + operator_fields = set(operator_def.get("properties", {}).keys()) + required_fields = set(operator_def.get("required", [])) + + optional_fields = operator_fields - required_fields + return optional_fields + + @classmethod + def get_schema_defaults(cls, object_type: str) -> dict[str, Any]: + """ + Extract default values from JSON schema for any object type. + + :param object_type: The object type to get defaults for (e.g., "operator", "dag") + :return: Dictionary of field name -> default value + """ + # Load schema if needed (handles lazy loading) + schema_loader = cls._json_schema + + if schema_loader is None: + return {} + + # Access the schema definitions (trigger lazy loading) + schema_data = schema_loader.schema + object_def = schema_data.get("definitions", {}).get(object_type, {}) + properties = object_def.get("properties", {}) + + defaults = {} + for field_name, field_def in properties.items(): + if isinstance(field_def, dict) and "default" in field_def: + defaults[field_name] = field_def["default"] + + return defaults + class DependencyDetector: """ @@ -1185,41 +1217,91 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): _decorated_fields = {"executor_config"} - _CONSTRUCTOR_PARAMS = { - k: v.default - for k, v in signature(BaseOperator.__init__).parameters.items() - if v.default is not v.empty - } + _CONSTRUCTOR_PARAMS = {} + + _json_schema: Validator = lazy_object_proxy.Proxy(load_dag_schema) _can_skip_downstream: bool _is_empty: bool _needs_expansion: bool _task_display_name: str | None - depends_on_past: bool + _weight_rule: str | PriorityWeightStrategy = "downstream" + + dag: DAG | None = None + task_group: TaskGroup | None = None + + allow_nested_operators: bool = True + depends_on_past: bool = False + do_xcom_push: bool = True + doc: str | None = None + doc_md: str | None = None + doc_json: str | None = None + doc_yaml: str | None = None + doc_rst: str | None = None + downstream_task_ids: set[str] = set() email: str | Sequence[str] | None + + # Following 2 should be deprecated + email_on_retry: bool = True + email_on_failure: bool = True + execution_timeout: datetime.timedelta | None executor: str | None - executor_config: dict | None - ignore_first_depends_on_past: bool - inlets: Sequence - is_setup: bool - is_teardown: bool - on_execute_callback: Sequence - on_failure_callback: Sequence - on_retry_callback: Sequence - on_success_callback: Sequence - outlets: Sequence - pool: str - pool_slots: int - priority_weight: int - queue: str - retries: int | None - run_as_user: str | None - start_from_trigger: bool - start_trigger_args: StartTriggerArgs - trigger_rule: TriggerRule - wait_for_downstream: bool - weight_rule: PriorityWeightStrategy + executor_config: dict | None = {} + ignore_first_depends_on_past: bool = False + + inlets: Sequence = [] + is_setup: bool = False + is_teardown: bool = False + + map_index_template: str | None = None + max_active_tis_per_dag: int | None = None + max_active_tis_per_dagrun: int | None = None + max_retry_delay: datetime.timedelta | float | None = None + multiple_outputs: bool = False + + # Boolean flags for callback existence + has_on_execute_callback: bool = False + has_on_failure_callback: bool = False + has_on_retry_callback: bool = False + has_on_success_callback: bool = False + has_on_skipped_callback: bool = False + + operator_extra_links: Collection[BaseOperatorLink] = () + on_failure_fail_dagrun: bool = False + + outlets: Sequence = [] + owner: str = "airflow" + pool: str = "default_pool" + pool_slots: int = 1 + priority_weight: int = 1 + queue: str = "default" + + resources: dict[str, Any] | None = None + retries: int = 0 + retry_delay: datetime.timedelta + retry_exponential_backoff: bool = False + run_as_user: str | None = None + + start_date: datetime.datetime | None = None + end_date: datetime.datetime | None = None + + start_from_trigger: bool = False + start_trigger_args: StartTriggerArgs | None = None + + task_type: str = "BaseOperator" + template_ext: Sequence[str] = [] + template_fields: Collection[str] = [] + template_fields_renderers: ClassVar[dict[str, str]] = {} + + trigger_rule: str | TriggerRule = "all_success" + + # TODO: Remove the following, they aren't used anymore + ui_color: str = "#fff" + ui_fgcolor: str = "#000" + + wait_for_downstream: bool = False + wait_for_past_depends_before_skipping: bool = False is_mapped = False @@ -1231,20 +1313,11 @@ def __init__( _airflow_from_mapped: bool = False, ) -> None: super().__init__() - self.__dict__.update(self._CONSTRUCTOR_PARAMS) - self.__dict__.update(OPERATOR_DEFAULTS) + self._BaseOperator__from_mapped = _airflow_from_mapped self.task_id = task_id self.params = ParamsDict(params) - # task_type is used by UI to display the correct class type, because UI only - # receives BaseOperator from deserialized DAGs. - self._task_type = "BaseOperator" # Move class attributes into object attributes. - self.ui_color = BaseOperator.ui_color - self.ui_fgcolor = BaseOperator.ui_fgcolor - self.template_ext = BaseOperator.template_ext - self.template_fields = BaseOperator.template_fields - self.operator_extra_links = BaseOperator.operator_extra_links self.deps = DEFAULT_OPERATOR_DEPS self._operator_name: str | None = None @@ -1308,17 +1381,6 @@ def get_extra_links(self, ti: TaskInstance, name: str) -> str | None: return None return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type] # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but receives SerializedBaseOperator - @property - def task_type(self) -> str: - # Overwrites task_type of BaseOperator to use _task_type instead of - # __class__.__name__. - - return self._task_type - - @task_type.setter - def task_type(self, task_type: str): - self._task_type = task_type - @property def operator_name(self) -> str: # Overwrites operator_name of BaseOperator to use _operator_name instead of @@ -1333,18 +1395,13 @@ def operator_name(self, operator_name: str): def task_display_name(self) -> str: return self._task_display_name or self.task_id - # TODO (GH-52141): For compatibility... can we just rename this? - @property - def on_failure_fail_dagrun(self): - return self._on_failure_fail_dagrun - - @on_failure_fail_dagrun.setter - def on_failure_fail_dagrun(self, value): - self._on_failure_fail_dagrun = value - def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: return self.start_trigger_args + @property + def weight_rule(self) -> PriorityWeightStrategy: + return validate_and_load_priority_weight_strategy(self._weight_rule) + def __getattr__(self, name): # Handle missing attributes with task_type instead of SerializedBaseOperator # Don't intercept special methods that Python internals might check @@ -1366,16 +1423,17 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: "value": cls.serialize(expansion_kwargs.value), } - # Simplify partial_kwargs by comparing it to the most barebone object. - # Remove all entries that are simply default values. - serialized_partial = serialized_op["partial_kwargs"] - for k, default in _get_default_mapped_partial().items(): - try: - v = serialized_partial[k] - except KeyError: - continue - if v == default: - del serialized_partial[k] + if op.partial_kwargs: + serialized_op["partial_kwargs"] = {} + for k, v in op.partial_kwargs.items(): + if cls._is_excluded(v, k, op): + continue + + if k in [f"on_{x}_callback" for x in ("execute", "failure", "success", "retry", "skipped")]: + if bool(v): + serialized_op["partial_kwargs"][f"has_{k}"] = True + continue + serialized_op["partial_kwargs"].update({k: cls.serialize(v)}) serialized_op["_is_mapped"] = True return serialized_op @@ -1389,6 +1447,12 @@ def _serialize_node(cls, op: SdkOperator) -> dict[str, Any]: """Serialize operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) + if not op.email: + # If "email" is empty, we do not need to include other email attrs + for attr in ["email_on_failure", "email_on_retry"]: + if attr in serialize_op: + del serialize_op[attr] + # Detect if there's a change in python callable name python_callable = getattr(op, "python_callable", None) if python_callable: @@ -1408,10 +1472,8 @@ def _serialize_node(cls, op: SdkOperator) -> dict[str, Any]: if op.inherits_from_skipmixin: serialize_op["_can_skip_downstream"] = True - serialize_op["start_trigger_args"] = ( - encode_start_trigger_args(op.start_trigger_args) if op.start_trigger_args else None - ) - serialize_op["start_from_trigger"] = op.start_from_trigger + if op.start_trigger_args: + serialize_op["start_trigger_args"] = encode_start_trigger_args(op.start_trigger_args) if op.operator_extra_links: serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links( @@ -1423,7 +1485,7 @@ def _serialize_node(cls, op: SdkOperator) -> dict[str, Any]: # Store all template_fields as they are if there are JSON Serializable # If not, store them as strings # And raise an exception if the field is not templateable - forbidden_fields = set(SerializedBaseOperator._CONSTRUCTOR_PARAMS.keys()) + forbidden_fields = set(signature(BaseOperator.__init__).parameters.keys()) # Though allow some of the BaseOperator fields to be templated anyway forbidden_fields.difference_update({"email"}) if op.template_fields: @@ -1445,7 +1507,12 @@ def _serialize_node(cls, op: SdkOperator) -> dict[str, Any]: return serialize_op @classmethod - def populate_operator(cls, op: SchedulerOperator, encoded_op: dict[str, Any]) -> None: + def populate_operator( + cls, + op: SchedulerOperator, + encoded_op: dict[str, Any], + client_defaults: dict[str, Any] | None = None, + ) -> None: """ Populate operator attributes with serialized values. @@ -1454,6 +1521,11 @@ def populate_operator(cls, op: SchedulerOperator, encoded_op: dict[str, Any]) -> done in ``set_task_dag_references`` instead, which is called after the DAG is hydrated. """ + # Apply defaults by merging them into encoded_op BEFORE main deserialization + encoded_op = cls._apply_defaults_to_encoded_op(encoded_op, client_defaults) + + # Preprocess and upgrade all field names for backward compatibility and consistency + encoded_op = cls._preprocess_encoded_operator(encoded_op) # Extra Operator Links defined in Plugins op_extra_links_from_plugin = {} @@ -1485,35 +1557,12 @@ def populate_operator(cls, op: SchedulerOperator, encoded_op: dict[str, Any]) -> list(op_extra_links_from_plugin.values()), ) - for k, v in encoded_op.items(): - # python_callable_name only serves to detect function name changes - if k == "python_callable_name": - continue - if k in ("_outlets", "_inlets"): - # `_outlets` -> `outlets` - k = k[1:] - elif k == "task_type": - k = "_task_type" - if k == "_downstream_task_ids": - # Upgrade from old format/name - k = "downstream_task_ids" + deserialized_partial_kwarg_defaults = {} - if k == "label": - # Label shouldn't be set anymore -- it's computed from task_id now - continue - if k == "downstream_task_ids": - v = set(v) - elif k in {"retry_delay", "execution_timeout", "max_retry_delay"}: - # If operator's execution_timeout is None and core.default_task_execution_timeout is not None, - # v will be None so do not deserialize into timedelta - if v is not None: - v = cls._deserialize_timedelta(v) - elif k in encoded_op["template_fields"]: - pass - elif k == "resources": - v = Resources.from_dict(v) - elif k.endswith("_date"): - v = cls._deserialize_datetime(v) + for k, v in encoded_op.items(): + # Use centralized field deserialization logic + if k in encoded_op.get("template_fields", []): + pass # Template fields are handled separately elif k == "_operator_extra_links": if cls._load_operator_extra_links: op_predefined_extra_links = cls._deserialize_operator_extra_links(v) @@ -1532,7 +1581,8 @@ def populate_operator(cls, op: SchedulerOperator, encoded_op: dict[str, Any]) -> v, new = op.params, v v.update(new) elif k == "partial_kwargs": - v = {arg: cls.deserialize(value) for arg, value in v.items()} + # Use unified deserializer that supports both encoded and non-encoded values + v = cls._deserialize_partial_kwargs(v, client_defaults) elif k in {"expand_input", "op_kwargs_expand_input"}: v = _ExpandInputRef(v["type"], cls.deserialize(v["value"])) elif k == "operator_class": @@ -1550,16 +1600,37 @@ def populate_operator(cls, op: SchedulerOperator, encoded_op: dict[str, Any]) -> or k in ("outlets", "inlets") ): v = cls.deserialize(v) - elif k == "on_failure_fail_dagrun": - k = "_on_failure_fail_dagrun" + elif k == "_on_failure_fail_dagrun": + k = "on_failure_fail_dagrun" elif k == "weight_rule": + k = "_weight_rule" v = decode_priority_weight_strategy(v) + else: + # Apply centralized deserialization for all other fields + v = cls._deserialize_field_value(k, v) + + # Handle field differences between SerializedBaseOperator and MappedOperator + # Fields that exist in SerializedBaseOperator but not in MappedOperator need to go to partial_kwargs + if ( + op.is_mapped + and k in SerializedBaseOperator.get_serialized_fields() + and k not in op.get_serialized_fields() + ): + # This field belongs to SerializedBaseOperator but not MappedOperator + # Store it in partial_kwargs where it belongs + deserialized_partial_kwarg_defaults[k] = v + continue # else use v as it is - setattr(op, k, v) - for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): + # Apply the fields that belong in partial_kwargs for MappedOperator + if op.is_mapped: + for k, v in deserialized_partial_kwarg_defaults.items(): + if k not in op.partial_kwargs: + op.partial_kwargs[k] = v + + for k in op.get_serialized_fields() - encoded_op.keys(): # TODO: refactor deserialization of BaseOperator and MappedOperator (split it out), then check # could go away. if not hasattr(op, k): @@ -1609,13 +1680,14 @@ def set_task_dag_references(task: SchedulerOperator, dag: DAG) -> None: dag.task_dict[task_id].upstream_task_ids.add(task.task_id) @classmethod - def deserialize_operator(cls, encoded_op: dict[str, Any]) -> SchedulerOperator: + def deserialize_operator( + cls, + encoded_op: dict[str, Any], + client_defaults: dict[str, Any] | None = None, + ) -> SchedulerOperator: """Deserializes an operator from a JSON object.""" op: SchedulerOperator if encoded_op.get("_is_mapped", False): - # Most of these will be loaded later, these are just some stand-ins. - op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()} - from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator try: @@ -1623,15 +1695,22 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> SchedulerOperator: except KeyError: operator_name = encoded_op["task_type"] + # Only store minimal class type information instead of full operator data + # This significantly reduces memory usage for mapped operators + operator_class_info = { + "task_type": encoded_op["task_type"], + "_operator_name": operator_name, + } + op = SchedulerMappedOperator( - operator_class=op_data, + operator_class=operator_class_info, task_id=encoded_op["task_id"], - operator_extra_links=BaseOperator.operator_extra_links, - template_ext=BaseOperator.template_ext, - template_fields=BaseOperator.template_fields, - template_fields_renderers=BaseOperator.template_fields_renderers, - ui_color=BaseOperator.ui_color, - ui_fgcolor=BaseOperator.ui_fgcolor, + operator_extra_links=SerializedBaseOperator.operator_extra_links, + template_ext=SerializedBaseOperator.template_ext, + template_fields=SerializedBaseOperator.template_fields, + template_fields_renderers=SerializedBaseOperator.template_fields_renderers, + ui_color=SerializedBaseOperator.ui_color, + ui_fgcolor=SerializedBaseOperator.ui_fgcolor, is_sensor=encoded_op.get("_is_sensor", False), can_skip_downstream=encoded_op.get("_can_skip_downstream", False), task_module=encoded_op["_task_module"], @@ -1644,10 +1723,55 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> SchedulerOperator: ) else: op = SerializedBaseOperator(task_id=encoded_op["task_id"]) - cls.populate_operator(op, encoded_op) + + cls.populate_operator(op, encoded_op, client_defaults) return op + @classmethod + def _preprocess_encoded_operator(cls, encoded_op: dict[str, Any]) -> dict[str, Any]: + """ + Preprocess and upgrade all field names for backward compatibility and consistency. + + This consolidates all field name transformations in one place: + - Callback field renaming (on_*_callback -> has_on_*_callback) + - Other field upgrades and renames + - Field exclusions + """ + preprocessed = encoded_op.copy() + + # Handle callback field renaming for backward compatibility + for callback_type in ("execute", "failure", "success", "retry", "skipped"): + old_key = f"on_{callback_type}_callback" + new_key = f"has_{old_key}" + if old_key in preprocessed: + preprocessed[new_key] = bool(preprocessed[old_key]) + del preprocessed[old_key] + + # Handle other field renames and upgrades from old format/name + field_renames = { + "task_display_name": "_task_display_name", + "_downstream_task_ids": "downstream_task_ids", + "_task_type": "task_type", + "_outlets": "outlets", + "_inlets": "inlets", + } + + for old_name, new_name in field_renames.items(): + if old_name in preprocessed: + preprocessed[new_name] = preprocessed.pop(old_name) + + # Remove fields that shouldn't be processed + fields_to_exclude = { + "python_callable_name", # Only serves to detect function name changes + "label", # Shouldn't be set anymore - computed from task_id now + } + + for field in fields_to_exclude: + preprocessed.pop(field, None) + + return preprocessed + @classmethod def detect_dependencies(cls, op: SdkOperator) -> set[DagDependency]: """Detect between DAG dependencies for the operator.""" @@ -1655,15 +1779,67 @@ def detect_dependencies(cls, op: SdkOperator) -> set[DagDependency]: deps = set(dependency_detector.detect_task_dependencies(op)) return deps + @classmethod + def _matches_client_defaults(cls, var: Any, attrname: str, op: DAGNode) -> bool: + """ + Check if a field value matches client_defaults and should be excluded. + + This implements the hierarchical defaults optimization where values that match + client_defaults are omitted from individual task serialization. + + :param var: The value to check + :param attrname: The attribute name + :param op: The operator instance + :return: True if value matches client_defaults and should be excluded + """ + try: + # Get cached client defaults for tasks + task_defaults = cls.generate_client_defaults() + + # Check if this field is in client_defaults and values match + if attrname in task_defaults and var == task_defaults[attrname]: + return True + + except Exception: + # If anything goes wrong with client_defaults, fall back to normal logic + pass + + return False + @classmethod def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): + """ + Determine if a variable is excluded from the serialized object. + + :param var: The value to check. [var == getattr(op, attrname)] + :param attrname: The name of the attribute to check. + :param op: The operator to check. + :return: True if a variable is excluded, False otherwise. + """ + # Check if value matches client_defaults (hierarchical defaults optimization) + if cls._matches_client_defaults(var, attrname, op): + return True + schema_defaults = cls.get_schema_defaults("operator") + + if attrname in schema_defaults: + if schema_defaults[attrname] == var: + return True + optional_fields = cls.get_operator_optional_fields_from_schema() + if var is None: + return True + if attrname in optional_fields: + if var in [[], (), set(), {}]: + return True + if var is not None and op.has_dag() and attrname.endswith("_date"): # If this date is the same as the matching field in the dag, then # don't store it again at the task level. dag_date = getattr(op.dag, attrname, None) if var is dag_date or var == dag_date: return True - return super()._is_excluded(var, attrname, op) + + # If none of the exclusion conditions are met, don't exclude the field + return False @classmethod def _deserialize_operator_extra_links( @@ -1752,8 +1928,216 @@ def expand_start_from_trigger(self, *, context: Context) -> bool: """ return self.start_from_trigger - def get_serialized_fields(self): - return BaseOperator.get_serialized_fields() + @classmethod + def get_serialized_fields(cls): + """Fields to deserialize from the serialized JSON object.""" + return frozenset( + { + "_logger_name", + "_needs_expansion", + "_task_display_name", + "allow_nested_operators", + "depends_on_past", + "do_xcom_push", + "doc", + "doc_json", + "doc_md", + "doc_rst", + "doc_yaml", + "downstream_task_ids", + "email", + "email_on_failure", + "email_on_retry", + "end_date", + "execution_timeout", + "executor", + "executor_config", + "ignore_first_depends_on_past", + "inlets", + "is_setup", + "is_teardown", + "map_index_template", + "max_active_tis_per_dag", + "max_active_tis_per_dagrun", + "max_retry_delay", + "multiple_outputs", + "has_on_execute_callback", + "has_on_failure_callback", + "has_on_retry_callback", + "has_on_skipped_callback", + "has_on_success_callback", + "on_failure_fail_dagrun", + "outlets", + "owner", + "params", + "pool", + "pool_slots", + "priority_weight", + "queue", + "resources", + "retries", + "retry_delay", + "retry_exponential_backoff", + "run_as_user", + "start_date", + "start_from_trigger", + "start_trigger_args", + "task_id", + "task_type", + "template_ext", + "template_fields", + "template_fields_renderers", + "trigger_rule", + "ui_color", + "ui_fgcolor", + "wait_for_downstream", + "wait_for_past_depends_before_skipping", + "weight_rule", + } + ) + + @classmethod + @lru_cache(maxsize=1) + def generate_client_defaults(cls) -> dict[str, Any]: + """ + Generate `client_defaults` section that only includes values differing from schema defaults. + + This optimizes serialization size by avoiding redundant storage of schema defaults. + Uses OPERATOR_DEFAULTS as the source of truth for task default values. + + :return: client_defaults dictionary with only non-schema values + """ + # Get schema defaults for comparison + schema_defaults = cls.get_schema_defaults("operator") + + client_defaults = {} + + # Only include OPERATOR_DEFAULTS values that differ from schema defaults + for k, v in OPERATOR_DEFAULTS.items(): + if k not in cls.get_serialized_fields(): + continue + # Exclude values that are the same as the schema defaults + if k in schema_defaults and schema_defaults[k] == v: + continue + + # Exclude values that are None or empty collections + if v is None or v in [[], (), set(), {}]: + continue + + # Use the existing serialize method to ensure consistent format + serialized_value = cls.serialize(v) + # Extract just the value part, consistent with serialize_to_json behavior + if isinstance(serialized_value, dict) and Encoding.TYPE in serialized_value: + serialized_value = serialized_value[Encoding.VAR] + client_defaults[k] = serialized_value + + return client_defaults + + @classmethod + def _deserialize_field_value(cls, field_name: str, value: Any) -> Any: + """ + Deserialize a single field value using the same logic as populate_operator. + + This method centralizes field-specific deserialization logic to avoid duplication. + + :param field_name: The name of the field being deserialized + :param value: The value to deserialize + :return: The deserialized value + """ + if field_name == "downstream_task_ids": + return set(value) if value is not None else set() + elif field_name in [ + f"has_on_{x}_callback" for x in ("execute", "failure", "success", "retry", "skipped") + ]: + return bool(value) + elif field_name in {"retry_delay", "execution_timeout", "max_retry_delay"}: + # Reuse existing timedelta deserialization logic + if value is not None: + return cls._deserialize_timedelta(value) + return None + elif field_name == "resources": + return Resources.from_dict(value) if value is not None else None + elif field_name.endswith("_date"): + return cls._deserialize_datetime(value) if value is not None else None + else: + # For all other fields, return as-is (strings, ints, bools, etc.) + return value + + @classmethod + def _deserialize_partial_kwargs( + cls, partial_kwargs_data: dict[str, Any], client_defaults: dict[str, Any] | None = None + ) -> dict[str, Any]: + """ + Deserialize partial_kwargs supporting both encoded and non-encoded values. + + This method can handle: + 1. Encoded values: {"__type": "timedelta", "__var": 300.0} + 2. Non-encoded values: 300.0 (for optimization) + + It also applies client_defaults for missing fields. + + :param partial_kwargs_data: The partial_kwargs data from serialized JSON + :param client_defaults: Client defaults to apply for missing fields + :return: Deserialized partial_kwargs dict + """ + deserialized = {} + + for k, v in partial_kwargs_data.items(): + # Check if this is an encoded value (has __type and __var structure) + if isinstance(v, dict) and Encoding.TYPE in v and Encoding.VAR in v: + # This is encoded - use full deserialization + deserialized[k] = cls.deserialize(v) + else: + # This is non-encoded (optimized format) + # Reuse the same deserialization logic from populate_operator + deserialized[k] = cls._deserialize_field_value(k, v) + + # Apply client_defaults for missing fields if provided + if client_defaults and "tasks" in client_defaults: + task_defaults = client_defaults["tasks"] + for k, default_value in task_defaults.items(): + if k not in deserialized: + # Apply the same deserialization logic to client_defaults + deserialized[k] = cls._deserialize_field_value(k, default_value) + + return deserialized + + @classmethod + def _apply_defaults_to_encoded_op( + cls, + encoded_op: dict[str, Any], + client_defaults: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Apply client defaults to encoded operator before deserialization. + + Args: + encoded_op: The serialized operator data (already includes applied default_args) + client_defaults: SDK-specific defaults from client_defaults section + + Note: DAG default_args are already applied during task creation in the SDK, + so encoded_op contains the final resolved values. + + Hierarchy (lowest to highest priority): + 1. client_defaults.tasks (SDK-wide defaults for size optimization) + 2. Explicit task values (already in encoded_op, includes applied default_args) + + Returns a new dict with defaults merged in. + """ + # Build hierarchy from lowest to highest priority + result = {} + + # Level 1: Apply client_defaults.tasks (lowest priority) + # Values are already serialized in generate_client_defaults() + if client_defaults: + task_defaults = client_defaults.get("tasks", {}) + result.update(task_defaults) + + # Level 2: Apply explicit task values (highest priority - overrides everything) + # Note: encoded_op already contains default_args applied during task creation + result.update(encoded_op) + + return result def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """ @@ -1887,7 +2271,7 @@ def __get_constructor_defaults(): _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore del __get_constructor_defaults - _json_schema = lazy_object_proxy.Proxy(load_dag_schema) + _json_schema: Validator = lazy_object_proxy.Proxy(load_dag_schema) @classmethod def serialize_dag(cls, dag: SdkDag) -> dict: @@ -1924,7 +2308,9 @@ def serialize_dag(cls, dag: SdkDag) -> dict: raise SerializationError(f"Failed to serialize DAG {dag.dag_id!r}: {e}") @classmethod - def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: + def deserialize_dag( + cls, encoded_dag: dict[str, Any], client_defaults: dict[str, Any] | None = None + ) -> SerializedDAG: """Deserializes a DAG from a JSON object.""" if "dag_id" not in encoded_dag: raise RuntimeError( @@ -1933,6 +2319,8 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: dag = SerializedDAG(dag_id=encoded_dag["dag_id"], schedule=None) + # Note: Context is passed explicitly through method parameters, no class attributes needed + for k, v in encoded_dag.items(): if k == "_downstream_task_ids": v = set(v) @@ -1941,7 +2329,9 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: tasks = {} for obj in v: if obj.get(Encoding.TYPE) == DAT.OP: - deser = SerializedBaseOperator.deserialize_operator(obj[Encoding.VAR]) + deser = SerializedBaseOperator.deserialize_operator( + obj[Encoding.VAR], client_defaults + ) tasks[deser.task_id] = deser k = "task_dict" v = tasks @@ -2019,8 +2409,18 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): @classmethod def to_dict(cls, var: Any) -> dict: """Stringifies DAGs and operators contained by var and returns a dict of var.""" + # Clear any cached client_defaults to ensure fresh generation for this DAG + # Clear lru_cache for client defaults + SerializedBaseOperator.generate_client_defaults.cache_clear() + json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)} + # Add client_defaults section with only values that differ from schema defaults + # for tasks + client_defaults = SerializedBaseOperator.generate_client_defaults() + if client_defaults: + json_dict["client_defaults"] = {"tasks": client_defaults} + # Validate Serialized DAG with Json Schema. Raises Error if it mismatches cls.validate_schema(json_dict) return json_dict @@ -2033,7 +2433,7 @@ def conversion_v1_to_v2(ser_obj: dict): ("_task_group", "task_group"), ("_access_control", "access_control"), ] - task_renames = [("_task_type", "task_type")] + task_renames = [("_task_type", "task_type"), ("task_display_name", "_task_display_name")] # tasks_remove = [ "_log_config_logger_name", @@ -2136,7 +2536,8 @@ def _create_compat_timetable(value): for k in tasks_remove: task_var.pop(k, None) for old, new in task_renames: - task_var[new] = task_var.pop(old) + if old in task_var: + task_var[new] = task_var.pop(old) for item in itertools.chain(*(task_var.get(key, []) for key in ("inlets", "outlets"))): original_item_type = item["__type"] if isinstance(item, dict) and "__type" in item: @@ -2147,6 +2548,11 @@ def _create_compat_timetable(value): var_["name"] = var_["uri"] var_["group"] = "asset" + for k, v in list(task_var.items()): + op_defaults = SerializedDAG.get_schema_defaults("operator") + if k in op_defaults and v == op_defaults[k]: + del task_var[k] + # Set on the root TG dag_dict["task_group"]["group_display_name"] = "" @@ -2158,7 +2564,12 @@ def from_dict(cls, serialized_obj: dict) -> SerializedDAG: raise ValueError(f"Unsure how to deserialize version {ver!r}") if ver == 1: cls.conversion_v1_to_v2(serialized_obj) - return cls.deserialize_dag(serialized_obj["dag"]) + + # Extract client_defaults for hierarchical defaults resolution + client_defaults = serialized_obj.get("client_defaults", {}) + + # Pass client_defaults directly to deserialize_dag + return cls.deserialize_dag(serialized_obj["dag"], client_defaults) class TaskGroupSerialization(BaseSerialization): @@ -2314,7 +2725,11 @@ def access_control(self) -> Mapping[str, Mapping[str, Collection[str]] | Collect @cached_property def _real_dag(self): - return SerializedDAG.from_dict(self.data) + try: + return SerializedDAG.from_dict(self.data) + except Exception: + log.exception("Failed to deserialize DAG") + raise def __getattr__(self, name: str, /) -> Any: if name in self.NULLABLE_PROPERTIES: diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 7384d09268606..e4f0535493359 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -64,7 +64,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.sdk import AssetAlias, BaseHook, teardown from airflow.sdk.bases.decorator import DecoratedOperator -from airflow.sdk.bases.operator import BaseOperator +from airflow.sdk.bases.operator import OPERATOR_DEFAULTS, BaseOperator from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY from airflow.sdk.definitions.asset import Asset, AssetUniqueKey from airflow.sdk.definitions.operator_resources import Resources @@ -121,6 +121,7 @@ VAR = Encoding.VAR serialized_simple_dag_ground_truth = { "__version": 2, + "client_defaults": {"tasks": {"retry_delay": 300.0}}, "dag": { "default_args": { "__type": "dict", @@ -172,9 +173,7 @@ "retries": 1, "retry_delay": 300.0, "max_retry_delay": 600.0, - "downstream_task_ids": [], "ui_color": "#f0ede4", - "ui_fgcolor": "#000", "template_ext": [".sh", ".bash"], "template_fields": ["bash_command", "env", "cwd"], "template_fields_renderers": { @@ -184,11 +183,9 @@ "bash_command": "echo {{ task.task_id }}", "task_type": "BashOperator", "_task_module": "airflow.providers.standard.operators.bash", - "owner": "airflow", - "pool": "default_pool", - "is_setup": False, - "is_teardown": False, - "on_failure_fail_dagrun": False, + "_task_display_name": "my_bash_task", + "owner": "airflow1", + "pool": "pool1", "executor_config": { "__type": "dict", "__var": { @@ -200,9 +197,6 @@ }, "doc_md": "### Task Tutorial Documentation", "_needs_expansion": False, - "weight_rule": "downstream", - "start_trigger_args": None, - "start_from_trigger": False, "inlets": [ { "__type": "asset", @@ -238,24 +232,12 @@ "retries": 1, "retry_delay": 300.0, "max_retry_delay": 600.0, - "downstream_task_ids": [], "_operator_extra_links": {"Google Custom": "_link_CustomOpLink"}, - "ui_color": "#fff", - "ui_fgcolor": "#000", - "template_ext": [], "template_fields": ["bash_command"], - "template_fields_renderers": {}, "task_type": "CustomOperator", "_operator_name": "@custom", "_task_module": "tests_common.test_utils.mock_operators", - "pool": "default_pool", - "is_setup": False, - "is_teardown": False, - "on_failure_fail_dagrun": False, "_needs_expansion": False, - "weight_rule": "downstream", - "start_trigger_args": None, - "start_from_trigger": False, }, }, ], @@ -332,11 +314,13 @@ def make_simple_dag(): BashOperator( task_id="bash_task", bash_command="echo {{ task.task_id }}", - owner="airflow", + owner="airflow1", executor_config={"pod_override": executor_config_pod}, doc_md="### Task Tutorial Documentation", inlets=[Asset("asset-1"), AssetAlias(name="alias-name")], outlets=Asset("asset-2"), + pool="pool1", + task_display_name="my_bash_task", ) return dag @@ -750,10 +734,6 @@ def validate_deserialized_task( # We store the string, real dag has the actual code "_pre_execute_hook", "_post_execute_hook", - "on_execute_callback", - "on_failure_callback", - "on_success_callback", - "on_retry_callback", # Checked separately "resources", "on_failure_fail_dagrun", @@ -808,9 +788,8 @@ def validate_deserialized_task( assert serialized_task.params.dump() == task.params.dump() if isinstance(task, MappedOperator): - # MappedOperator.operator_class holds a backup of the serialized - # data; checking its entirety basically duplicates this validation - # function, so we just do some sanity checks. + # MappedOperator.operator_class now stores only minimal type information + # for memory efficiency (task_type and _operator_name). serialized_task.operator_class["task_type"] == type(task).__name__ if isinstance(serialized_task.operator_class, DecoratedOperator): serialized_task.operator_class["_operator_name"] == task._operator_name @@ -820,11 +799,23 @@ def validate_deserialized_task( default_partial_kwargs = ( BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs ) + + # These are added in `_TaskDecorator` e.g. when @setup or @teardown task is passed + default_decorator_partial_kwargs = { + "is_setup": False, + "is_teardown": False, + "on_failure_fail_dagrun": False, + } serialized_partial_kwargs = { **default_partial_kwargs, + **default_decorator_partial_kwargs, **serialized_task.partial_kwargs, } - original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs} + original_partial_kwargs = { + **default_partial_kwargs, + **default_decorator_partial_kwargs, + **task.partial_kwargs, + } assert serialized_partial_kwargs == original_partial_kwargs # ExpandInputs have different classes between scheduler and definition @@ -1424,6 +1415,11 @@ def test_no_new_fields_added_to_base_operator(self): "execution_timeout": None, "executor": None, "executor_config": {}, + "has_on_execute_callback": False, + "has_on_failure_callback": False, + "has_on_retry_callback": False, + "has_on_skipped_callback": False, + "has_on_success_callback": False, "ignore_first_depends_on_past": False, "is_setup": False, "is_teardown": False, @@ -1432,12 +1428,7 @@ def test_no_new_fields_added_to_base_operator(self): "max_active_tis_per_dag": None, "max_active_tis_per_dagrun": None, "max_retry_delay": None, - "on_execute_callback": [], "on_failure_fail_dagrun": False, - "on_failure_callback": [], - "on_retry_callback": [], - "on_skipped_callback": [], - "on_success_callback": [], "outlets": [], "owner": "airflow", "params": {}, @@ -2499,9 +2490,6 @@ def test_operator_expand_serde(): "_is_mapped": True, "_task_module": "airflow.providers.standard.operators.bash", "task_type": "BashOperator", - "start_trigger_args": None, - "start_from_trigger": False, - "downstream_task_ids": [], "expand_input": { "type": "dict-of-lists", "value": { @@ -2514,14 +2502,13 @@ def test_operator_expand_serde(): "__type": "dict", "__var": {"dict": {"__type": "dict", "__var": {"sub": "value"}}}, }, + "retry_delay": {"__type": "timedelta", "__var": 300.0}, }, "task_id": "a", - "operator_extra_links": [], "template_fields": ["bash_command", "env", "cwd"], "template_ext": [".sh", ".bash"], "template_fields_renderers": {"bash_command": "bash", "env": "json"}, "ui_color": "#f0ede4", - "ui_fgcolor": "#000", "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", } @@ -2529,17 +2516,10 @@ def test_operator_expand_serde(): op = BaseSerialization.deserialize(serialized) assert isinstance(op, MappedOperator) + # operator_class now stores only minimal type information for memory efficiency assert op.operator_class == { "task_type": "BashOperator", - "start_trigger_args": None, - "start_from_trigger": False, - "downstream_task_ids": [], - "task_id": "a", - "template_ext": [".sh", ".bash"], - "template_fields": ["bash_command", "env", "cwd"], - "template_fields_renderers": {"bash_command": "bash", "env": "json"}, - "ui_color": "#f0ede4", - "ui_fgcolor": "#000", + "_operator_name": "BashOperator", } assert op.expand_input.value["bash_command"] == literal assert op.partial_kwargs["executor_config"] == {"dict": {"sub": "value"}} @@ -2559,7 +2539,6 @@ def test_operator_expand_xcomarg_serde(): "_is_mapped": True, "_task_module": "tests_common.test_utils.mock_operators", "task_type": "MockOperator", - "downstream_task_ids": [], "expand_input": { "type": "dict-of-lists", "value": { @@ -2572,18 +2551,13 @@ def test_operator_expand_xcomarg_serde(): }, }, }, - "partial_kwargs": {}, + "partial_kwargs": { + "retry_delay": {"__type": "timedelta", "__var": 300.0}, + }, "task_id": "task_2", "template_fields": ["arg1", "arg2"], - "template_ext": [], - "template_fields_renderers": {}, - "operator_extra_links": [], - "ui_color": "#fff", - "ui_fgcolor": "#000", "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", - "start_trigger_args": None, - "start_from_trigger": False, } op = BaseSerialization.deserialize(serialized) @@ -2616,7 +2590,6 @@ def test_operator_expand_kwargs_literal_serde(strict): "_is_mapped": True, "_task_module": "tests_common.test_utils.mock_operators", "task_type": "MockOperator", - "downstream_task_ids": [], "expand_input": { "type": "list-of-dicts", "value": [ @@ -2632,18 +2605,13 @@ def test_operator_expand_kwargs_literal_serde(strict): }, ], }, - "partial_kwargs": {}, + "partial_kwargs": { + "retry_delay": {"__type": "timedelta", "__var": 300.0}, + }, "task_id": "task_2", "template_fields": ["arg1", "arg2"], - "template_ext": [], - "template_fields_renderers": {}, - "operator_extra_links": [], - "ui_color": "#fff", - "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", - "start_trigger_args": None, - "start_from_trigger": False, } op = BaseSerialization.deserialize(serialized) @@ -2681,7 +2649,6 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "_is_mapped": True, "_task_module": "tests_common.test_utils.mock_operators", "task_type": "MockOperator", - "downstream_task_ids": [], "expand_input": { "type": "list-of-dicts", "value": { @@ -2689,18 +2656,13 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "__var": {"task_id": "op1", "key": "return_value"}, }, }, - "partial_kwargs": {}, + "partial_kwargs": { + "retry_delay": {"__type": "timedelta", "__var": 300.0}, + }, "task_id": "task_2", "template_fields": ["arg1", "arg2"], - "template_ext": [], - "template_fields_renderers": {}, - "operator_extra_links": [], - "ui_color": "#fff", - "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", - "start_trigger_args": None, - "start_from_trigger": False, } op = BaseSerialization.deserialize(serialized) @@ -2737,22 +2699,8 @@ def test_task_resources_serde(): } -@pytest.fixture(params=[None, timedelta(hours=1)]) -def default_task_execution_timeout(request): - """ - Mock setting core.default_task_execution_timeout in airflow.cfg. - """ - from airflow.serialization.serialized_objects import SerializedBaseOperator - - DEFAULT_TASK_EXECUTION_TIMEOUT = request.param - with mock.patch.dict( - SerializedBaseOperator._CONSTRUCTOR_PARAMS, {"execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT} - ): - yield DEFAULT_TASK_EXECUTION_TIMEOUT - - @pytest.mark.parametrize("execution_timeout", [None, timedelta(hours=1)]) -def test_task_execution_timeout_serde(execution_timeout, default_task_execution_timeout): +def test_task_execution_timeout_serde(execution_timeout): """ Test task execution_timeout serialization/deserialization. """ @@ -2762,7 +2710,7 @@ def test_task_execution_timeout_serde(execution_timeout, default_task_execution_ task = EmptyOperator(task_id="task1", execution_timeout=execution_timeout) serialized = BaseSerialization.serialize(task) - if execution_timeout != default_task_execution_timeout: + if execution_timeout: assert "execution_timeout" in serialized["__var"] deserialized = BaseSerialization.deserialize(serialized) @@ -2792,11 +2740,7 @@ def x(arg1, arg2, arg3): "_task_module": "airflow.providers.standard.decorators.python", "task_type": "_PythonDecoratedOperator", "_operator_name": "@task", - "downstream_task_ids": [], "partial_kwargs": { - "is_setup": False, - "is_teardown": False, - "on_failure_fail_dagrun": False, "op_args": [], "op_kwargs": { "__type": "dict", @@ -2817,11 +2761,8 @@ def x(arg1, arg2, arg3): }, }, }, - "operator_extra_links": [], "ui_color": "#ffefeb", - "ui_fgcolor": "#000", "task_id": "x", - "template_ext": [], "template_fields": ["templates_dict", "op_args", "op_kwargs"], "template_fields_renderers": { "templates_dict": "json", @@ -2831,8 +2772,6 @@ def x(arg1, arg2, arg3): "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", "python_callable_name": qualname(x), - "start_trigger_args": None, - "start_from_trigger": False, } deserialized = BaseSerialization.deserialize(serialized) @@ -2848,9 +2787,6 @@ def x(arg1, arg2, arg3): }, ) assert deserialized.partial_kwargs == { - "is_setup": False, - "is_teardown": False, - "on_failure_fail_dagrun": False, "op_args": [], "op_kwargs": {"arg1": [1, 2, {"a": "b"}]}, "retry_delay": timedelta(seconds=30), @@ -2873,9 +2809,6 @@ def x(arg1, arg2, arg3): }, ) assert pickled.partial_kwargs == { - "is_setup": False, - "is_teardown": False, - "on_failure_fail_dagrun": False, "op_args": [], "op_kwargs": {"arg1": [1, 2, {"a": "b"}]}, "retry_delay": timedelta(seconds=30), @@ -2906,13 +2839,7 @@ def x(arg1, arg2, arg3): "task_type": "_PythonDecoratedOperator", "_operator_name": "@task", "python_callable_name": qualname(x), - "start_trigger_args": None, - "start_from_trigger": False, - "downstream_task_ids": [], "partial_kwargs": { - "is_setup": False, - "is_teardown": False, - "on_failure_fail_dagrun": False, "op_args": [], "op_kwargs": { "__type": "dict", @@ -2927,11 +2854,8 @@ def x(arg1, arg2, arg3): "__var": {"task_id": "op1", "key": "return_value"}, }, }, - "operator_extra_links": [], "ui_color": "#ffefeb", - "ui_fgcolor": "#000", "task_id": "x", - "template_ext": [], "template_fields": ["templates_dict", "op_args", "op_kwargs"], "template_fields_renderers": { "templates_dict": "json", @@ -2953,9 +2877,6 @@ def x(arg1, arg2, arg3): value=_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}), ) assert deserialized.partial_kwargs == { - "is_setup": False, - "is_teardown": False, - "on_failure_fail_dagrun": False, "op_args": [], "op_kwargs": {"arg1": [1, 2, {"a": "b"}]}, "retry_delay": timedelta(seconds=30), @@ -2975,9 +2896,6 @@ def x(arg1, arg2, arg3): _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}), ) assert pickled.partial_kwargs == { - "is_setup": False, - "is_teardown": False, - "on_failure_fail_dagrun": False, "op_args": [], "op_kwargs": {"arg1": [1, 2, {"a": "b"}]}, "retry_delay": timedelta(seconds=30), @@ -3052,21 +2970,16 @@ def operator_extra_links(self): "type": "dict-of-lists", "value": {"__type": "dict", "__var": {"inputs": [1, 2, 3]}}, }, - "partial_kwargs": {}, + "partial_kwargs": { + "retry_delay": {"__type": "timedelta", "__var": 300.0}, + }, "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", - "downstream_task_ids": [], "_operator_extra_links": {"airflow": "_link_AirflowLink2"}, - "ui_color": "#fff", - "ui_fgcolor": "#000", - "template_ext": [], "template_fields": [], - "template_fields_renderers": {}, "task_type": "_DummyOperator", "_task_module": "unit.serialization.test_dag_serialization", "_is_mapped": True, - "start_trigger_args": None, - "start_from_trigger": False, } deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR]) # operator defined links have to be instances of XComOperatorLink @@ -3139,8 +3052,9 @@ def test_handle_v1_serdag(): "_task_type": "BashOperator", # Slightly difference from v2-10-stable here, we manually changed this path "_task_module": "airflow.providers.standard.operators.bash", - "owner": "airflow", - "pool": "default_pool", + "owner": "airflow1", + "pool": "pool1", + "task_display_name": "my_bash_task", "is_setup": False, "is_teardown": False, "on_failure_fail_dagrun": False, @@ -3366,4 +3280,583 @@ def test_handle_v1_serdag(): expected["dag"]["dag_dependencies"] = expected_dag_dependencies del expected["dag"]["tasks"][1]["__var"]["_operator_extra_links"] + del expected["client_defaults"] assert v1 == expected + + +def test_email_optimization_removes_email_attrs_when_email_empty(): + """Test that email_on_failure and email_on_retry are removed when email is empty.""" + with DAG(dag_id="test_email_optimization") as dag: + BashOperator( + task_id="test_task", + bash_command="echo test", + email=None, # Empty email + email_on_failure=True, # This should be removed during serialization + email_on_retry=True, # This should be removed during serialization + ) + + serialized_dag = SerializedDAG.to_dict(dag) + task_serialized = serialized_dag["dag"]["tasks"][0]["__var"] + assert task_serialized is not None + + assert "email_on_failure" not in task_serialized + assert "email_on_retry" not in task_serialized + + # But they should be present when email is not empty + with DAG(dag_id="test_email_with_attrs") as dag_with_email: + BashOperator( + task_id="test_task_with_email", + bash_command="echo test", + email="test@example.com", # Non-empty email + email_on_failure=True, + email_on_retry=True, + ) + + serialized_dag_with_email = SerializedDAG.to_dict(dag_with_email) + task_with_email_serialized = serialized_dag_with_email["dag"]["tasks"][0]["__var"] + + assert task_with_email_serialized is not None + + # email_on_failure and email_on_retry SHOULD be in the serialized task + # since email is not empty + assert "email" in task_with_email_serialized + assert task_with_email_serialized["email"] == "test@example.com" + + +def dummy_callback(): + pass + + +@pytest.mark.parametrize( + "callback_config,expected_flags,is_mapped", + [ + # Regular operator tests + ( + { + "on_failure_callback": dummy_callback, + "on_retry_callback": [dummy_callback, dummy_callback], + "on_success_callback": dummy_callback, + }, + {"has_on_failure_callback": True, "has_on_retry_callback": True, "has_on_success_callback": True}, + False, + ), + ( + {}, # No callbacks + { + "has_on_failure_callback": False, + "has_on_retry_callback": False, + "has_on_success_callback": False, + }, + False, + ), + ( + {"on_failure_callback": [], "on_success_callback": None}, # Empty callbacks + {"has_on_failure_callback": False, "has_on_success_callback": False}, + False, + ), + # Mapped operator tests + ( + {"on_failure_callback": dummy_callback, "on_success_callback": [dummy_callback, dummy_callback]}, + {"has_on_failure_callback": True, "has_on_success_callback": True}, + True, + ), + ( + {}, # Mapped operator without callbacks + {"has_on_failure_callback": False, "has_on_success_callback": False}, + True, + ), + ], +) +def test_task_callback_boolean_optimization(callback_config, expected_flags, is_mapped): + """Test that task callbacks are optimized using has_on_*_callback boolean flags.""" + dag = DAG(dag_id="test_callback_dag") + + if is_mapped: + # Create mapped operator + task = BashOperator.partial(task_id="test_task", dag=dag, **callback_config).expand( + bash_command=["echo 1", "echo 2"] + ) + + serialized = BaseSerialization.serialize(task) + deserialized = BaseSerialization.deserialize(serialized) + + # For mapped operators, check partial_kwargs + serialized_data = serialized.get("__var", {}).get("partial_kwargs", {}) + + # Test serialization + for flag, expected in expected_flags.items(): + if expected: + assert flag in serialized_data + assert serialized_data[flag] is True + else: + assert serialized_data.get(flag, False) is False + + # Test deserialized properties + for flag, expected in expected_flags.items(): + assert getattr(deserialized, flag) is expected + + else: + # Create regular operator + task = BashOperator(task_id="test_task", bash_command="echo test", dag=dag, **callback_config) + + serialized = BaseSerialization.serialize(task) + deserialized = BaseSerialization.deserialize(serialized) + + # For regular operators, check top-level + serialized_data = serialized.get("__var", {}) + + # Test serialization (only True values are stored) + for flag, expected in expected_flags.items(): + if expected: + assert serialized_data.get(flag, False) is True + else: + assert serialized_data.get(flag, False) is False + + # Test deserialized properties + for flag, expected in expected_flags.items(): + assert getattr(deserialized, flag) is expected + + +def test_task_callback_properties_exist(): + """Test that all callback boolean properties exist on both regular and mapped operators.""" + dag = DAG(dag_id="test_dag") + + regular_task = BashOperator(task_id="regular", bash_command="echo test", dag=dag) + mapped_task = BashOperator.partial(task_id="mapped", dag=dag).expand(bash_command=["echo 1"]) + + callback_properties = [ + "has_on_execute_callback", + "has_on_failure_callback", + "has_on_success_callback", + "has_on_retry_callback", + "has_on_skipped_callback", + ] + + for prop in callback_properties: + assert hasattr(regular_task, prop), f"Regular operator missing {prop}" + assert hasattr(mapped_task, prop), f"Mapped operator missing {prop}" + + serialized_regular = BaseSerialization.deserialize(BaseSerialization.serialize(regular_task)) + serialized_mapped = BaseSerialization.deserialize(BaseSerialization.serialize(mapped_task)) + + assert hasattr(serialized_regular, prop), f"Deserialized regular operator missing {prop}" + assert hasattr(serialized_mapped, prop), f"Deserialized mapped operator missing {prop}" + + +@pytest.mark.parametrize( + "old_callback_name,new_callback_name", + [ + ("on_execute_callback", "has_on_execute_callback"), + ("on_failure_callback", "has_on_failure_callback"), + ("on_success_callback", "has_on_success_callback"), + ("on_retry_callback", "has_on_retry_callback"), + ("on_skipped_callback", "has_on_skipped_callback"), + ], +) +def test_task_callback_backward_compatibility(old_callback_name, new_callback_name): + """Test that old serialized DAGs with on_*_callback keys are correctly converted to has_on_*_callback.""" + + old_serialized_task = { + "is_setup": False, + old_callback_name: [ + " def dumm_callback(*args, **kwargs):\n # hello\n pass\n" + ], + "is_teardown": False, + "task_type": "BaseOperator", + "pool": "default_pool", + "task_id": "simple_task", + "template_fields": [], + "on_failure_fail_dagrun": False, + "downstream_task_ids": [], + "template_ext": [], + "ui_fgcolor": "#000", + "weight_rule": "downstream", + "ui_color": "#fff", + "template_fields_renderers": {}, + "_needs_expansion": False, + "start_from_trigger": False, + "_task_module": "airflow.sdk.bases.operator", + "start_trigger_args": None, + } + + # Test deserialization converts old format to new format + deserialized_task = SerializedBaseOperator.deserialize_operator(old_serialized_task) + + # Verify the new format is present and correct + assert hasattr(deserialized_task, new_callback_name) + assert getattr(deserialized_task, new_callback_name) is True + assert not hasattr(deserialized_task, old_callback_name) + + # Test with empty/None callback (should convert to False) + old_serialized_task[old_callback_name] = None + + deserialized_task_empty = SerializedBaseOperator.deserialize_operator(old_serialized_task) + assert getattr(deserialized_task_empty, new_callback_name) is False + + +class TestClientDefaultsGeneration: + """Test client defaults generation functionality.""" + + def test_generate_client_defaults_basic(self): + """Test basic client defaults generation.""" + client_defaults = SerializedBaseOperator.generate_client_defaults() + + assert isinstance(client_defaults, dict) + + # Should only include serializable fields + serialized_fields = SerializedBaseOperator.get_serialized_fields() + for field in client_defaults: + assert field in serialized_fields, f"Field {field} not in serialized fields" + + def test_generate_client_defaults_excludes_schema_defaults(self): + """Test that client defaults excludes values that match schema defaults.""" + client_defaults = SerializedBaseOperator.generate_client_defaults() + schema_defaults = SerializedBaseOperator.get_schema_defaults("operator") + + # Check that values matching schema defaults are excluded + for field, value in client_defaults.items(): + if field in schema_defaults: + assert value != schema_defaults[field], ( + f"Field {field} has value {value!r} which matches schema default {schema_defaults[field]!r}" + ) + + def test_generate_client_defaults_excludes_none_and_empty(self): + """Test that client defaults excludes None and empty collection values.""" + client_defaults = SerializedBaseOperator.generate_client_defaults() + + for field, value in client_defaults.items(): + assert value is not None, f"Field {field} has None value" + assert value not in [[], (), set(), {}], f"Field {field} has empty collection value: {value!r}" + + def test_generate_client_defaults_caching(self): + """Test that client defaults generation is cached.""" + # Clear cache first + SerializedBaseOperator.generate_client_defaults.cache_clear() + + # First call + client_defaults_1 = SerializedBaseOperator.generate_client_defaults() + + # Second call should return same object (cached) + client_defaults_2 = SerializedBaseOperator.generate_client_defaults() + + assert client_defaults_1 is client_defaults_2, "Client defaults should be cached" + + # Check cache info + cache_info = SerializedBaseOperator.generate_client_defaults.cache_info() + assert cache_info.hits >= 1, "Cache should have at least one hit" + + def test_generate_client_defaults_only_operator_defaults_fields(self): + """Test that only fields from OPERATOR_DEFAULTS are considered.""" + client_defaults = SerializedBaseOperator.generate_client_defaults() + + # All fields in client_defaults should originate from OPERATOR_DEFAULTS + for field in client_defaults: + assert field in OPERATOR_DEFAULTS, f"Field {field} not in OPERATOR_DEFAULTS" + + +class TestSchemaDefaults: + """Test schema defaults functionality.""" + + def test_get_schema_defaults_operator(self): + """Test getting schema defaults for operator type.""" + schema_defaults = SerializedBaseOperator.get_schema_defaults("operator") + + assert isinstance(schema_defaults, dict) + + # Should contain expected operator defaults + expected_fields = [ + "owner", + "trigger_rule", + "depends_on_past", + "retries", + "queue", + "pool", + "pool_slots", + "priority_weight", + "weight_rule", + "do_xcom_push", + ] + + for field in expected_fields: + assert field in schema_defaults, f"Expected field {field} not in schema defaults" + + def test_get_schema_defaults_nonexistent_type(self): + """Test getting schema defaults for nonexistent type.""" + schema_defaults = SerializedBaseOperator.get_schema_defaults("nonexistent") + assert schema_defaults == {} + + def test_get_operator_optional_fields_from_schema(self): + """Test getting optional fields from schema.""" + optional_fields = SerializedBaseOperator.get_operator_optional_fields_from_schema() + + assert isinstance(optional_fields, set) + + # Should not contain required fields + required_fields = { + "task_type", + "_task_module", + "task_id", + "ui_color", + "ui_fgcolor", + "template_fields", + } + overlap = optional_fields & required_fields + assert not overlap, f"Optional fields should not overlap with required fields: {overlap}" + + +class TestDeserializationDefaultsResolution: + """Test defaults resolution during deserialization.""" + + def test_apply_defaults_to_encoded_op(self): + encoded_op = {"task_id": "test_task", "task_type": "BashOperator", "retries": 10} + client_defaults = {"tasks": {"retry_delay": 300.0, "retries": 2}} # Fix: wrap in "tasks" + + result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, client_defaults) + + # Should merge in order: client_defaults, encoded_op + assert result["retry_delay"] == 300.0 # From client_defaults + assert result["task_id"] == "test_task" # From encoded_op (highest priority) + assert result["retries"] == 10 + + def test_apply_defaults_to_encoded_op_none_inputs(self): + """Test defaults application with None inputs.""" + encoded_op = {"task_id": "test_task"} + + # With None client_defaults + result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None) + assert result == encoded_op + + def test_multiple_tasks_share_client_defaults(self): + """Test that multiple tasks can share the same client_defaults.""" + with DAG(dag_id="test_dag") as dag: + BashOperator(task_id="task1", bash_command="echo 1") + BashOperator(task_id="task2", bash_command="echo 2") + + serialized = SerializedDAG.to_dict(dag) + + # Should have one client_defaults section for all tasks + assert "client_defaults" in serialized + assert "tasks" in serialized["client_defaults"] + + # All tasks should benefit from the same client_defaults + client_defaults = serialized["client_defaults"]["tasks"] + + # Deserialize and check both tasks get the defaults + deserialized_dag = SerializedDAG.from_dict(serialized) + deserialized_task1 = deserialized_dag.get_task("task1") + deserialized_task2 = deserialized_dag.get_task("task2") + + # Both tasks should have the same default values from client_defaults + for field in client_defaults: + if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field): + value1 = getattr(deserialized_task1, field) + value2 = getattr(deserialized_task2, field) + assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}" + + +class TestMappedOperatorSerializationAndClientDefaults: + """Test MappedOperator serialization with client defaults and callback properties.""" + + def test_mapped_operator_client_defaults_application(self): + """Test that client_defaults are correctly applied to MappedOperator during deserialization.""" + with DAG(dag_id="test_mapped_dag") as dag: + # Create a mapped operator + BashOperator.partial( + task_id="mapped_task", + retries=5, # Override default + ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) + + # Serialize the DAG + serialized_dag = SerializedDAG.to_dict(dag) + + # Should have client_defaults section + assert "client_defaults" in serialized_dag + assert "tasks" in serialized_dag["client_defaults"] + + # Deserialize and check that client_defaults are applied + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.get_task("mapped_task") + + # Verify it's still a MappedOperator + from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator + + assert isinstance(deserialized_task, SchedulerMappedOperator) + + # Check that client_defaults values are applied (e.g., retry_delay from client_defaults) + client_defaults = serialized_dag["client_defaults"]["tasks"] + if "retry_delay" in client_defaults: + # If retry_delay wasn't explicitly set, it should come from client_defaults + # Since we can't easily convert timedelta back, check the serialized format + assert hasattr(deserialized_task, "retry_delay") + + # Explicit values should override client_defaults + assert deserialized_task.retries == 5 # Explicitly set value + + @pytest.mark.parametrize( + ["task_config", "dag_id", "task_id", "non_default_fields"], + [ + # Test case 1: Size optimization with non-default values + pytest.param( + {"retries": 3}, # Only set non-default values + "test_mapped_size", + "mapped_size_test", + {"retries"}, + id="non_default_fields", + ), + # Test case 2: No duplication with default values + pytest.param( + {"retries": 0}, # This should match client_defaults and be optimized out + "test_no_duplication", + "mapped_task", + set(), # No fields should be non-default (all optimized out) + id="duplicate_fields", + ), + # Test case 3: Mixed default/non-default values + pytest.param( + {"retries": 2, "max_active_tis_per_dag": 16}, # Mix of default and non-default + "test_mixed_optimization", + "mixed_task", + {"retries", "max_active_tis_per_dag"}, # Both should be preserved as they're non-default + id="test_mixed_optimization", + ), + ], + ) + def test_mapped_operator_client_defaults_optimization( + self, task_config, dag_id, task_id, non_default_fields + ): + """Test that MappedOperator serialization optimizes using client defaults.""" + with DAG(dag_id=dag_id) as dag: + # Create mapped operator with specified configuration + BashOperator.partial( + task_id=task_id, + **task_config, + ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) + + serialized_dag = SerializedDAG.to_dict(dag) + mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"] + + assert mapped_task_serialized is not None + assert mapped_task_serialized.get("_is_mapped") is True + + # Check optimization behavior + client_defaults = serialized_dag["client_defaults"]["tasks"] + partial_kwargs = mapped_task_serialized["partial_kwargs"] + + # Check that all fields are optimized correctly + for field, default_value in client_defaults.items(): + if field in non_default_fields: + # Non-default fields should be present in partial_kwargs + assert field in partial_kwargs, ( + f"Field '{field}' should be in partial_kwargs as it's non-default" + ) + # And have different values than defaults + assert partial_kwargs[field] != default_value, ( + f"Field '{field}' should have non-default value" + ) + else: + # Default fields should either not be present or have different values if present + if field in partial_kwargs: + assert partial_kwargs[field] != default_value, ( + f"Field '{field}' with default value should be optimized out" + ) + + def test_mapped_operator_expand_input_preservation(self): + """Test that expand_input is correctly preserved during serialization.""" + with DAG(dag_id="test_expand_input"): + mapped_task = BashOperator.partial(task_id="test_expand").expand( + bash_command=["echo 1", "echo 2", "echo 3"], env={"VAR1": "value1", "VAR2": "value2"} + ) + + # Serialize and deserialize + serialized = BaseSerialization.serialize(mapped_task) + deserialized = BaseSerialization.deserialize(serialized) + + # Check expand_input structure + assert hasattr(deserialized, "expand_input") + expand_input = deserialized.expand_input + + # Verify the expand_input contains the expected data + assert hasattr(expand_input, "value") + expand_value = expand_input.value + + assert "bash_command" in expand_value + assert "env" in expand_value + assert expand_value["bash_command"] == ["echo 1", "echo 2", "echo 3"] + assert expand_value["env"] == {"VAR1": "value1", "VAR2": "value2"} + + @pytest.mark.parametrize( + ["partial_kwargs_data", "expected_results"], + [ + # Test case 1: Encoded format with client defaults + pytest.param( + { + "retry_delay": {"__type": "timedelta", "__var": 600.0}, + "execution_timeout": {"__type": "timedelta", "__var": 1800.0}, + "owner": "test_user", + }, + { + "retry_delay": timedelta(seconds=600), + "execution_timeout": timedelta(seconds=1800), + "owner": "test_user", + }, + id="encoded_with_client_defaults", + ), + # Test case 2: Non-encoded format (optimized) + pytest.param( + { + "retry_delay": 600.0, + "execution_timeout": 1800.0, + }, + { + "retry_delay": timedelta(seconds=600), + "execution_timeout": timedelta(seconds=1800), + }, + id="non_encoded_optimized", + ), + # Test case 3: Mixed format (some encoded, some not) + pytest.param( + { + "retry_delay": {"__type": "timedelta", "__var": 600.0}, # Encoded + "execution_timeout": 1800.0, # Non-encoded + }, + { + "retry_delay": timedelta(seconds=600), + "execution_timeout": timedelta(seconds=1800), + }, + id="mixed_encoded_non_encoded", + ), + ], + ) + def test_partial_kwargs_deserialization_formats(self, partial_kwargs_data, expected_results): + """Test deserialization of partial_kwargs in various formats (encoded, non-encoded, mixed).""" + result = SerializedBaseOperator._deserialize_partial_kwargs(partial_kwargs_data) + + # Verify all expected results + for key, expected_value in expected_results.items(): + assert key in result, f"Missing key '{key}' in result" + assert result[key] == expected_value, f"key '{key}': expected {expected_value}, got {result[key]}" + + def test_partial_kwargs_end_to_end_deserialization(self): + """Test end-to-end partial_kwargs deserialization with real MappedOperator.""" + with DAG(dag_id="test_e2e_partial_kwargs") as dag: + BashOperator.partial( + task_id="mapped_task", + retry_delay=timedelta(seconds=600), # Non-default value + owner="custom_owner", # Non-default value + # retries not specified, should potentially get from client_defaults + ).expand(bash_command=["echo 1", "echo 2"]) + + # Serialize and deserialize the DAG + serialized_dag = SerializedDAG.to_dict(dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.get_task("mapped_task") + + # Verify the task has correct values after round-trip + assert deserialized_task.retry_delay == timedelta(seconds=600) + assert deserialized_task.owner == "custom_owner" + + # Verify partial_kwargs were deserialized correctly + assert "retry_delay" in deserialized_task.partial_kwargs + assert "owner" in deserialized_task.partial_kwargs + assert deserialized_task.partial_kwargs["retry_delay"] == timedelta(seconds=600) + assert deserialized_task.partial_kwargs["owner"] == "custom_owner" diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 93da49186e065..6d184173aadbb 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -316,7 +316,7 @@ def get_user_provided_run_facets(ti: TaskInstance, ti_state: TaskInstanceState) def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> str: if isinstance(operator, (MappedOperator, SerializedBaseOperator)): # as in airflow.api_connexion.schemas.common_schema.ClassReferenceSchema - return operator._task_module + "." + operator._task_type + return operator._task_module + "." + operator.task_type op_class = get_operator_class(operator) return op_class.__module__ + "." + op_class.__name__ diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 913e124c29bf8..d99889aedd584 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -274,7 +274,7 @@ def test_get_fully_qualified_class_name_serialized_operator(): op_path_after_deserialization = get_fully_qualified_class_name(deserialized) assert op_path_after_deserialization == f"{op_module_path}.{op_name}" assert deserialized._task_module == op_module_path - assert deserialized._task_type == op_name + assert deserialized.task_type == op_name def test_get_fully_qualified_class_name_mapped_operator(): diff --git a/scripts/ci/prek/check_schema_defaults.py b/scripts/ci/prek/check_schema_defaults.py new file mode 100755 index 0000000000000..ebe102136e1fc --- /dev/null +++ b/scripts/ci/prek/check_schema_defaults.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "packaging>=25", +# ] +# /// +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.resolve())) +from common_prek_utils import ( + initialize_breeze_prek, + run_command_via_breeze_shell, + validate_cmd_result, +) + +initialize_breeze_prek(__name__, __file__) + +cmd_result = run_command_via_breeze_shell( + ["python3", "/opt/airflow/scripts/in_container/run_schema_defaults_check.py"], + backend="sqlite", + warn_image_upgrade_needed=True, +) + +validate_cmd_result(cmd_result, include_ci_env_check=True) diff --git a/scripts/in_container/run_schema_defaults_check.py b/scripts/in_container/run_schema_defaults_check.py new file mode 100755 index 0000000000000..ef672c297ca0c --- /dev/null +++ b/scripts/in_container/run_schema_defaults_check.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Check that defaults in Schema JSON match the server-side SerializedBaseOperator defaults. + +This ensures that the schema accurately reflects the actual default values +used by the server-side serialization layer. +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Any + + +def load_schema_defaults() -> dict[str, Any]: + """Load default values from the JSON schema.""" + schema_path = Path("airflow-core/src/airflow/serialization/schema.json") + + if not schema_path.exists(): + print(f"Error: Schema file not found at {schema_path}") + sys.exit(1) + + with open(schema_path) as f: + schema = json.load(f) + + # Extract defaults from the operator definition + operator_def = schema.get("definitions", {}).get("operator", {}) + properties = operator_def.get("properties", {}) + + defaults = {} + for field_name, field_def in properties.items(): + if "default" in field_def: + defaults[field_name] = field_def["default"] + + return defaults + + +def get_server_side_defaults() -> dict[str, Any]: + """Get default values from server-side SerializedBaseOperator class.""" + try: + from airflow.serialization.serialized_objects import SerializedBaseOperator + + # Get all serializable fields + serialized_fields = SerializedBaseOperator.get_serialized_fields() + + # Field name mappings from external API names to internal class attribute names + field_mappings = { + "weight_rule": "_weight_rule", + } + + server_defaults = {} + for field_name in serialized_fields: + # Use the mapped internal name if it exists, otherwise use the field name + attr_name = field_mappings.get(field_name, field_name) + + if hasattr(SerializedBaseOperator, attr_name): + default_value = getattr(SerializedBaseOperator, attr_name) + # Only include actual default values, not methods/properties/descriptors + if not callable(default_value) and not isinstance(default_value, (property, type)): + if isinstance(default_value, (set, tuple)): + # Convert to list since schema.json is pure JSON + default_value = list(default_value) + server_defaults[field_name] = default_value + + return server_defaults + + except ImportError as e: + print(f"Error importing SerializedBaseOperator: {e}") + sys.exit(1) + except Exception as e: + print(f"Error getting server-side defaults: {e}") + sys.exit(1) + + +def compare_defaults() -> list[str]: + """Compare schema defaults with server-side defaults and return discrepancies.""" + schema_defaults = load_schema_defaults() + server_defaults = get_server_side_defaults() + errors = [] + + print(f"Found {len(schema_defaults)} schema defaults") + print(f"Found {len(server_defaults)} server-side defaults") + + # Check each server default against schema + for field_name, server_value in server_defaults.items(): + schema_value = schema_defaults.get(field_name) + + # Check if field exists in schema + if field_name not in schema_defaults: + # Some server fields might not need defaults in schema (like None values) + if server_value is not None and server_value not in [[], {}, (), set()]: + errors.append( + f"Server field '{field_name}' has default {server_value!r} but no schema default" + ) + continue + + # Direct comparison - no complex normalization needed + if schema_value != server_value: + errors.append( + f"Field '{field_name}': schema default is {schema_value!r}, " + f"server default is {server_value!r}" + ) + + # Check for schema defaults that don't have corresponding server defaults + for field_name, schema_value in schema_defaults.items(): + if field_name not in server_defaults: + # Some schema fields are structural and don't need server defaults + schema_only_fields = { + "task_type", + "_task_module", + "task_id", + "_task_display_name", + "_is_mapped", + "_is_sensor", + } + if field_name not in schema_only_fields: + errors.append( + f"Schema has default for '{field_name}' = {schema_value!r} but no corresponding server default" + ) + + return errors + + +def main(): + """Main function to run the schema defaults check.""" + print("Checking schema defaults against server-side SerializedBaseOperator...") + + errors = compare_defaults() + + if errors: + print("❌ Found discrepancies between schema and server defaults:") + for error in errors: + print(f" • {error}") + print() + print("To fix these issues:") + print("1. Update airflow-core/src/airflow/serialization/schema.json to match server defaults, OR") + print( + "2. Update airflow-core/src/airflow/serialization/serialized_objects.py class defaults to match schema" + ) + sys.exit(1) + else: + print("✅ All schema defaults match server-side defaults!") + + +if __name__ == "__main__": + main() diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 3ece4b9937288..581b2b7c20ca7 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -916,11 +916,11 @@ def say_hello_world(**context): "wait_for_downstream", "priority_weight", "execution_timeout", - "on_execute_callback", - "on_failure_callback", - "on_success_callback", - "on_retry_callback", - "on_skipped_callback", + "has_on_execute_callback", + "has_on_failure_callback", + "has_on_success_callback", + "has_on_retry_callback", + "has_on_skipped_callback", "do_xcom_push", "multiple_outputs", "allow_nested_operators", @@ -1479,6 +1479,12 @@ def get_serialized_fields(cls): "on_failure_fail_dagrun", "task_group", "_task_type", + "operator_extra_links", + "on_execute_callback", + "on_failure_callback", + "on_success_callback", + "on_retry_callback", + "on_skipped_callback", } | { # Class level defaults, or `@property` need to be added to this list "start_date", @@ -1498,6 +1504,11 @@ def get_serialized_fields(cls): "_needs_expansion", "start_from_trigger", "max_retry_delay", + "has_on_execute_callback", + "has_on_failure_callback", + "has_on_success_callback", + "has_on_retry_callback", + "has_on_skipped_callback", } ) DagContext.pop() @@ -1634,6 +1645,31 @@ def dry_run(self) -> None: self.log.info("Rendering template for %s", f) self.log.info(content) + @property + def has_on_execute_callback(self) -> bool: + """Return True if the task has execute callbacks.""" + return bool(self.on_execute_callback) + + @property + def has_on_failure_callback(self) -> bool: + """Return True if the task has failure callbacks.""" + return bool(self.on_failure_callback) + + @property + def has_on_success_callback(self) -> bool: + """Return True if the task has success callbacks.""" + return bool(self.on_success_callback) + + @property + def has_on_retry_callback(self) -> bool: + """Return True if the task has retry callbacks.""" + return bool(self.on_retry_callback) + + @property + def has_on_skipped_callback(self) -> bool: + """Return True if the task has skipped callbacks.""" + return bool(self.on_skipped_callback) + def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: r""" diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index febc922ba265c..8cdd879eece91 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -300,7 +300,7 @@ class MappedOperator(AbstractOperator): _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream") _is_sensor: bool = attrs.field(alias="is_sensor", default=False) _task_module: str - _task_type: str + task_type: str _operator_name: str start_trigger_args: StartTriggerArgs | None start_from_trigger: bool @@ -334,7 +334,7 @@ def __hash__(self): return id(self) def __repr__(self): - return f"" + return f"" def __attrs_post_init__(self): from airflow.sdk.definitions.xcom_arg import XComArg @@ -355,10 +355,9 @@ def __attrs_post_init__(self): @classmethod def get_serialized_fields(cls): # Not using 'cls' here since we only want to serialize base fields. - return (frozenset(attrs.fields_dict(MappedOperator)) | {"task_type"}) - { + return (frozenset(attrs.fields_dict(MappedOperator))) - { "_is_empty", "_can_skip_downstream", - "_task_type", "dag", "deps", "expand_input", # This is needed to be able to accept XComArg. @@ -367,13 +366,12 @@ def get_serialized_fields(cls): "_is_setup", "_is_teardown", "_on_failure_fail_dagrun", + "operator_class", + "_needs_expansion", + "partial_kwargs", + "operator_extra_links", } - @property - def task_type(self) -> str: - """Implementing Operator.""" - return self._task_type - @property def operator_name(self) -> str: return self._operator_name @@ -614,6 +612,26 @@ def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: self.partial_kwargs["on_skipped_callback"] = value or [] + @property + def has_on_execute_callback(self) -> bool: + return bool(self.on_execute_callback) + + @property + def has_on_failure_callback(self) -> bool: + return bool(self.on_failure_callback) + + @property + def has_on_retry_callback(self) -> bool: + return bool(self.on_retry_callback) + + @property + def has_on_success_callback(self) -> bool: + return bool(self.on_success_callback) + + @property + def has_on_skipped_callback(self) -> bool: + return bool(self.on_skipped_callback) + @property def run_as_user(self) -> str | None: return self.partial_kwargs.get("run_as_user")