From 370fb681e3bee14cb9c232a2e647d2b5605e841a Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 14 May 2025 16:56:01 +0800 Subject: [PATCH] [v3-0-test] Fix default_args application in operator partial (#50525) (cherry picked from commit 7d6a18a129d9f11d94f8117062a2f1cc546f2f66) Co-authored-by: Tzu-ping Chung --- task-sdk/src/airflow/sdk/bases/operator.py | 105 ++++++++---------- ...{test_baseoperator.py => test_operator.py} | 33 ++++++ 2 files changed, 80 insertions(+), 58 deletions(-) rename task-sdk/tests/task_sdk/bases/{test_baseoperator.py => test_operator.py} (96%) diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 09ca8c6c03e5c..1444c89e49605 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -55,7 +55,7 @@ ) from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack from airflow.sdk.definitions._internal.node import validate_key -from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, validate_instance_args +from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.sdk.definitions.param import ParamsDict @@ -232,63 +232,48 @@ def partial( task_id: str, dag: DAG | None = None, task_group: TaskGroup | None = None, - start_date: datetime | ArgNotSet = NOTSET, - end_date: datetime | ArgNotSet = NOTSET, - owner: str | ArgNotSet = NOTSET, - email: None | str | Iterable[str] | ArgNotSet = NOTSET, + start_date: datetime = ..., + end_date: datetime = ..., + owner: str = ..., + email: None | str | Iterable[str] = ..., params: collections.abc.MutableMapping | None = None, - resources: dict[str, Any] | None | ArgNotSet = NOTSET, - trigger_rule: str | ArgNotSet = NOTSET, - depends_on_past: bool | ArgNotSet = NOTSET, - ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, - wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, - wait_for_downstream: bool | ArgNotSet = NOTSET, - retries: int | None | ArgNotSet = NOTSET, - queue: str | ArgNotSet = NOTSET, - pool: str | ArgNotSet = NOTSET, - pool_slots: int | ArgNotSet = NOTSET, - execution_timeout: timedelta | None | ArgNotSet = NOTSET, - max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, - retry_delay: timedelta | float | ArgNotSet = NOTSET, - retry_exponential_backoff: bool | ArgNotSet = NOTSET, - priority_weight: int | ArgNotSet = NOTSET, - weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, - sla: timedelta | None | ArgNotSet = NOTSET, - map_index_template: str | None | ArgNotSet = NOTSET, - max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, - max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, - on_execute_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_failure_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_success_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_retry_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_skipped_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - run_as_user: str | None | ArgNotSet = NOTSET, - executor: str | None | ArgNotSet = NOTSET, - executor_config: dict | None | ArgNotSet = NOTSET, - inlets: Any | None | ArgNotSet = NOTSET, - outlets: Any | None | ArgNotSet = NOTSET, - doc: str | None | ArgNotSet = NOTSET, - doc_md: str | None | ArgNotSet = NOTSET, - doc_json: str | None | ArgNotSet = NOTSET, - doc_yaml: str | None | ArgNotSet = NOTSET, - doc_rst: str | None | ArgNotSet = NOTSET, - task_display_name: str | None | ArgNotSet = NOTSET, - logger_name: str | None | ArgNotSet = NOTSET, + resources: dict[str, Any] | None = ..., + trigger_rule: str = ..., + depends_on_past: bool = ..., + ignore_first_depends_on_past: bool = ..., + wait_for_past_depends_before_skipping: bool = ..., + wait_for_downstream: bool = ..., + retries: int | None = ..., + queue: str = ..., + pool: str = ..., + pool_slots: int = ..., + execution_timeout: timedelta | None = ..., + max_retry_delay: None | timedelta | float = ..., + retry_delay: timedelta | float = ..., + retry_exponential_backoff: bool = ..., + priority_weight: int = ..., + weight_rule: str | PriorityWeightStrategy = ..., + sla: timedelta | None = ..., + map_index_template: str | None = ..., + max_active_tis_per_dag: int | None = ..., + max_active_tis_per_dagrun: int | None = ..., + on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., + on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., + on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., + on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., + on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., + run_as_user: str | None = ..., + executor: str | None = ..., + executor_config: dict | None = ..., + inlets: Any | None = ..., + outlets: Any | None = ..., + doc: str | None = ..., + doc_md: str | None = ..., + doc_json: str | None = ..., + doc_yaml: str | None = ..., + doc_rst: str | None = ..., + task_display_name: str | None = ..., + logger_name: str | None = ..., allow_nested_operators: bool = True, **kwargs, ) -> OperatorPartial: ... @@ -330,8 +315,12 @@ def partial( } # Inject DAG-level default args into args provided to this function. + # Most of the default args will be retrieved during unmapping; here we + # only ensure base properties are correctly set for the scheduler. partial_kwargs.update( - (k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k, NOTSET) is NOTSET + (k, v) + for k, v in dag_default_args.items() + if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names ) # Fill fields not provided by the user with default values. diff --git a/task-sdk/tests/task_sdk/bases/test_baseoperator.py b/task-sdk/tests/task_sdk/bases/test_operator.py similarity index 96% rename from task-sdk/tests/task_sdk/bases/test_baseoperator.py rename to task-sdk/tests/task_sdk/bases/test_operator.py index b2c1440aeb824..dafce2400c8e0 100644 --- a/task-sdk/tests/task_sdk/bases/test_baseoperator.py +++ b/task-sdk/tests/task_sdk/bases/test_operator.py @@ -898,3 +898,36 @@ def say_hello(**context): "timestamp": mock.ANY, }, ] + + +def test_partial_default_args(): + class MockOperator(BaseOperator): + def __init__(self, arg1, arg2, arg3, **kwargs): + self.arg1 = arg1 + self.arg2 = arg2 + self.arg3 = arg3 + self.kwargs = kwargs + super().__init__(**kwargs) + + with DAG( + dag_id="test_partial_default_args", + default_args={"queue": "THIS", "arg1": 1, "arg2": 2, "arg3": 3, "arg4": 4}, + ): + t1 = BaseOperator(task_id="t1") + t2 = MockOperator.partial(task_id="t2", arg2="b").expand(arg1=t1.output) + + # Only default_args recognized by BaseOperator are applied. + assert t2.partial_kwargs["queue"] == "THIS" + assert "arg1" not in t2.partial_kwargs + assert t2.partial_kwargs["arg2"] == "b" + assert "arg3" not in t2.partial_kwargs + assert "arg4" not in t2.partial_kwargs + + # Simulate resolving mapped operator. This should apply all default_args. + op = t2.unmap({"arg1": "a"}) + assert isinstance(op, MockOperator) + assert "arg4" not in op.kwargs # Not recognized by any class; never passed. + assert op.arg1 == "a" + assert op.arg2 == "b" + assert op.arg3 == 3 + assert op.queue == "THIS"