diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index b97f88695c422..eb748b66eeedd 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -19,10 +19,12 @@ import datetime import inspect -from functools import cached_property +import warnings +from functools import cached_property, wraps from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence from sqlalchemy import select +from typing_extensions import final from airflow.compat.functools import cache from airflow.configuration import conf @@ -78,6 +80,29 @@ ) +def _priority_weight_limiter(prop: Callable[..., int]) -> Callable[..., int]: + weight_upper_bound = 2**31 - 1 + weight_lower_bound = -(2**31) + + @wraps(prop) + def wrapper(self: AbstractOperator) -> int: + weight_total = prop(self) + if not (weight_upper_bound >= weight_total >= weight_lower_bound): + msg = f"Task {self.task_id!r} total priority {weight_total:,} " + if dag := self.get_dag(): + msg += f"in dag {dag.dag_id!r} " + weight_total = weight_upper_bound if weight_total > weight_upper_bound else weight_lower_bound + msg += ( + f"exceeds allowed priority weight range [{weight_upper_bound:,}..{weight_lower_bound:,}]. " + f"Fallback to {weight_total:,}." + ) + warnings.warn(msg, UserWarning, stacklevel=2) + + return weight_total + + return wrapper + + class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" @@ -386,7 +411,9 @@ def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> Bas """ raise NotImplementedError() + @final @property + @_priority_weight_limiter def priority_weight_total(self) -> int: """ Total priority weight for the task. It might include all upstream or downstream tasks. diff --git a/docs/apache-airflow/administration-and-deployment/priority-weight.rst b/docs/apache-airflow/administration-and-deployment/priority-weight.rst index 3807b3ee5ddd9..5935ad2e4ddc0 100644 --- a/docs/apache-airflow/administration-and-deployment/priority-weight.rst +++ b/docs/apache-airflow/administration-and-deployment/priority-weight.rst @@ -24,6 +24,11 @@ Priority Weights bumped to any integer. Moreover, each task has a true ``priority_weight`` that is calculated based on its ``weight_rule`` which defines the weighting method used for the effective total priority weight of the task. +.. versionadded:: 2.9.0 + + Total priority weight should be in range between **-2,147,483,648** and **2,147,483,647**. + In case of overflow it fallback to the boundaries of allowed range. + Below are the weighting methods. By default, Airflow's weighting method is ``downstream``. diff --git a/tests/core/test_overflow_weighted_priority.py b/tests/core/test_overflow_weighted_priority.py new file mode 100644 index 0000000000000..454201d32c62d --- /dev/null +++ b/tests/core/test_overflow_weighted_priority.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import warnings + +import pytest + +from airflow.models.baseoperator import BaseOperator +from airflow.operators.empty import EmptyOperator +from tests.test_utils.db import clear_db_dags, clear_db_serialized_dags + +INT32_MAX = 2147483647 +INT32_MIN = -2147483648 + + +@contextlib.contextmanager +def _warning_not_expected(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "error", message=".*exceeds allowed priority weight range.*", category=UserWarning + ) + yield + + +@pytest.fixture +def _clear_dags(): + clear_db_dags() + clear_db_serialized_dags() + yield + clear_db_dags() + clear_db_serialized_dags() + + +class TestDagTaskParameterOverflow: + @_warning_not_expected() + def test_priority_weight_default(self): + assert EmptyOperator(task_id="empty").priority_weight_total + + @pytest.mark.parametrize( + "priority_weight", + [ + 42, + pytest.param(INT32_MIN, id="lower-bound"), + pytest.param(INT32_MAX, id="upper-bound"), + ], + ) + @_warning_not_expected() + def test_priority_weight_absolute(self, priority_weight): + EmptyOperator(task_id="empty", priority_weight=priority_weight) + + @pytest.mark.parametrize( + "priority_weight, priority_weight_total", + [ + pytest.param(INT32_MIN - 1, INT32_MIN, id="less-than-lower-bound"), + pytest.param(INT32_MAX + 1, INT32_MAX, id="greater-than-upper-bound"), + ], + ) + def test_priority_weight_absolute_overflow(self, priority_weight, priority_weight_total): + op = EmptyOperator(task_id="empty", priority_weight=priority_weight) + with pytest.warns(UserWarning, match="exceeds allowed priority weight range"): + assert op.priority_weight_total == priority_weight_total + + @pytest.mark.db_test + @pytest.mark.parametrize( + "priority, bound_priority", + [ + pytest.param(-10, INT32_MIN, id="less-than-lower-bound"), + pytest.param(10, INT32_MAX, id="greater-than-upper-bound"), + ], + ) + def test_priority_weight_sum_up_overflow( + self, priority: int, bound_priority: int, dag_maker, _clear_dags + ): + class TestOp(BaseOperator): + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + with dag_maker(dag_id="test_priority_weight_sum_up_overflow"): + op1 = EmptyOperator(task_id="op1", priority_weight=priority) + op2 = TestOp.partial(task_id="op2", priority_weight=bound_priority).expand(value=[1, 2, 3]) + op3 = EmptyOperator(task_id="op3", priority_weight=priority) + op1 >> op2 >> op3 + + with pytest.warns(UserWarning, match="exceeds allowed priority weight range"): + dr = dag_maker.create_dagrun() + + tis_priorities = {ti.task_id: ti.priority_weight for ti in dr.task_instances} + assert tis_priorities["op3"] == priority + assert tis_priorities["op2"] == bound_priority + assert tis_priorities["op1"] == bound_priority