Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 47 additions & 58 deletions task-sdk/src/airflow/sdk/bases/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
)
from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack
from airflow.sdk.definitions._internal.node import validate_key
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, validate_instance_args
from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args
from airflow.sdk.definitions.edges import EdgeModifier
from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs
from airflow.sdk.definitions.param import ParamsDict
Expand Down Expand Up @@ -232,63 +232,48 @@ def partial(
task_id: str,
dag: DAG | None = None,
task_group: TaskGroup | None = None,
start_date: datetime | ArgNotSet = NOTSET,
end_date: datetime | ArgNotSet = NOTSET,
owner: str | ArgNotSet = NOTSET,
email: None | str | Iterable[str] | ArgNotSet = NOTSET,
start_date: datetime = ...,
end_date: datetime = ...,
owner: str = ...,
email: None | str | Iterable[str] = ...,
params: collections.abc.MutableMapping | None = None,
resources: dict[str, Any] | None | ArgNotSet = NOTSET,
trigger_rule: str | ArgNotSet = NOTSET,
depends_on_past: bool | ArgNotSet = NOTSET,
ignore_first_depends_on_past: bool | ArgNotSet = NOTSET,
wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET,
wait_for_downstream: bool | ArgNotSet = NOTSET,
retries: int | None | ArgNotSet = NOTSET,
queue: str | ArgNotSet = NOTSET,
pool: str | ArgNotSet = NOTSET,
pool_slots: int | ArgNotSet = NOTSET,
execution_timeout: timedelta | None | ArgNotSet = NOTSET,
max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET,
retry_delay: timedelta | float | ArgNotSet = NOTSET,
retry_exponential_backoff: bool | ArgNotSet = NOTSET,
priority_weight: int | ArgNotSet = NOTSET,
weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET,
sla: timedelta | None | ArgNotSet = NOTSET,
map_index_template: str | None | ArgNotSet = NOTSET,
max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET,
on_execute_callback: None
| TaskStateChangeCallback
| list[TaskStateChangeCallback]
| ArgNotSet = NOTSET,
on_failure_callback: None
| TaskStateChangeCallback
| list[TaskStateChangeCallback]
| ArgNotSet = NOTSET,
on_success_callback: None
| TaskStateChangeCallback
| list[TaskStateChangeCallback]
| ArgNotSet = NOTSET,
on_retry_callback: None
| TaskStateChangeCallback
| list[TaskStateChangeCallback]
| ArgNotSet = NOTSET,
on_skipped_callback: None
| TaskStateChangeCallback
| list[TaskStateChangeCallback]
| ArgNotSet = NOTSET,
run_as_user: str | None | ArgNotSet = NOTSET,
executor: str | None | ArgNotSet = NOTSET,
executor_config: dict | None | ArgNotSet = NOTSET,
inlets: Any | None | ArgNotSet = NOTSET,
outlets: Any | None | ArgNotSet = NOTSET,
doc: str | None | ArgNotSet = NOTSET,
doc_md: str | None | ArgNotSet = NOTSET,
doc_json: str | None | ArgNotSet = NOTSET,
doc_yaml: str | None | ArgNotSet = NOTSET,
doc_rst: str | None | ArgNotSet = NOTSET,
task_display_name: str | None | ArgNotSet = NOTSET,
logger_name: str | None | ArgNotSet = NOTSET,
resources: dict[str, Any] | None = ...,
trigger_rule: str = ...,
depends_on_past: bool = ...,
ignore_first_depends_on_past: bool = ...,
wait_for_past_depends_before_skipping: bool = ...,
wait_for_downstream: bool = ...,
retries: int | None = ...,
queue: str = ...,
pool: str = ...,
pool_slots: int = ...,
execution_timeout: timedelta | None = ...,
max_retry_delay: None | timedelta | float = ...,
retry_delay: timedelta | float = ...,
retry_exponential_backoff: bool = ...,
priority_weight: int = ...,
weight_rule: str | PriorityWeightStrategy = ...,
sla: timedelta | None = ...,
map_index_template: str | None = ...,
max_active_tis_per_dag: int | None = ...,
max_active_tis_per_dagrun: int | None = ...,
on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
run_as_user: str | None = ...,
executor: str | None = ...,
executor_config: dict | None = ...,
inlets: Any | None = ...,
outlets: Any | None = ...,
doc: str | None = ...,
doc_md: str | None = ...,
doc_json: str | None = ...,
doc_yaml: str | None = ...,
doc_rst: str | None = ...,
task_display_name: str | None = ...,
logger_name: str | None = ...,
allow_nested_operators: bool = True,
**kwargs,
) -> OperatorPartial: ...
Expand Down Expand Up @@ -330,8 +315,12 @@ def partial(
}

# Inject DAG-level default args into args provided to this function.
# Most of the default args will be retrieved during unmapping; here we
# only ensure base properties are correctly set for the scheduler.
partial_kwargs.update(
(k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k, NOTSET) is NOTSET
(k, v)
for k, v in dag_default_args.items()
if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names
)

# Fill fields not provided by the user with default values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,3 +898,36 @@ def say_hello(**context):
"timestamp": mock.ANY,
},
]


def test_partial_default_args():
class MockOperator(BaseOperator):
def __init__(self, arg1, arg2, arg3, **kwargs):
self.arg1 = arg1
self.arg2 = arg2
self.arg3 = arg3
self.kwargs = kwargs
super().__init__(**kwargs)

with DAG(
dag_id="test_partial_default_args",
default_args={"queue": "THIS", "arg1": 1, "arg2": 2, "arg3": 3, "arg4": 4},
):
t1 = BaseOperator(task_id="t1")
t2 = MockOperator.partial(task_id="t2", arg2="b").expand(arg1=t1.output)

# Only default_args recognized by BaseOperator are applied.
assert t2.partial_kwargs["queue"] == "THIS"
assert "arg1" not in t2.partial_kwargs
assert t2.partial_kwargs["arg2"] == "b"
assert "arg3" not in t2.partial_kwargs
assert "arg4" not in t2.partial_kwargs

# Simulate resolving mapped operator. This should apply all default_args.
op = t2.unmap({"arg1": "a"})
assert isinstance(op, MockOperator)
assert "arg4" not in op.kwargs # Not recognized by any class; never passed.
assert op.arg1 == "a"
assert op.arg2 == "b"
assert op.arg3 == 3
assert op.queue == "THIS"