Skip to content

Commit

Permalink
Experimental: Support custom weight_rule implementation to calculate …
Browse files Browse the repository at this point in the history
…the TI priority_weight
  • Loading branch information
hussein-awala committed Mar 17, 2024
1 parent 8839e0a commit 5ab7a40
Show file tree
Hide file tree
Showing 19 changed files with 488 additions and 32 deletions.
13 changes: 9 additions & 4 deletions airflow/api_connexion/schemas/common_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.weight_rule import WeightRule


class CronExpression(typing.NamedTuple):
Expand Down Expand Up @@ -138,9 +137,15 @@ def __init__(self, **metadata):
class WeightRuleField(fields.String):
"""Schema for WeightRule."""

def __init__(self, **metadata):
super().__init__(**metadata)
self.validators = [validate.OneOf(WeightRule.all_weight_rules()), *self.validators]
def _serialize(self, value, attr, obj, **kwargs):
from airflow.serialization.serialized_objects import encode_priority_weight_strategy

return encode_priority_weight_strategy(value)

def _deserialize(self, value, attr, data, **kwargs):
from airflow.serialization.serialized_objects import decode_priority_weight_strategy

return decode_priority_weight_strategy(value)


class TimezoneField(fields.String):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.plugins_manager import AirflowPlugin
from airflow.task.priority_strategy import PriorityWeightStrategy

if TYPE_CHECKING:
from airflow.models import TaskInstance


# [START custom_priority_weight_strategy]
class DecreasingPriorityStrategy(PriorityWeightStrategy):
"""A priority weight strategy that decreases the priority weight with each attempt of the DAG task."""

def get_weight(self, ti: TaskInstance):
return max(3 - ti._try_number + 1, 1)


class DecreasingPriorityWeightStrategyPlugin(AirflowPlugin):
name = "decreasing_priority_weight_strategy_plugin"
priority_weight_strategies = [DecreasingPriorityStrategy]


# [END custom_priority_weight_strategy]
2 changes: 1 addition & 1 deletion airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def queue_task_instance(
self.queue_command(
task_instance,
command_list_to_run,
priority=task_instance.task.priority_weight_total,
priority=task_instance.priority_weight,
queue=task_instance.task.queue,
)

Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def queue_task_instance(
self.queue_command(
task_instance,
[str(task_instance)], # Just for better logging, it's not used anywhere
priority=task_instance.task.priority_weight_total,
priority=task_instance.priority_weight,
queue=task_instance.task.queue,
)
# Save params for TaskInstance._run_raw_task
Expand Down
15 changes: 11 additions & 4 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.utils.task_group import TaskGroup

DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
Expand Down Expand Up @@ -97,7 +98,7 @@ class AbstractOperator(Templater, DAGNode):

operator_class: type[BaseOperator] | dict[str, Any]

weight_rule: str
weight_rule: PriorityWeightStrategy
priority_weight: int

# Defines the operator level extra links.
Expand Down Expand Up @@ -397,11 +398,17 @@ def priority_weight_total(self) -> int:
- WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks
- WeightRule.UPSTREAM - adds priority weight of all upstream tasks
"""
if self.weight_rule == WeightRule.ABSOLUTE:
from airflow.task.priority_strategy import (
_AbsolutePriorityWeightStrategy,
_DownstreamPriorityWeightStrategy,
_UpstreamPriorityWeightStrategy,
)

if type(self.weight_rule) == _AbsolutePriorityWeightStrategy:
return self.priority_weight
elif self.weight_rule == WeightRule.DOWNSTREAM:
elif type(self.weight_rule) == _DownstreamPriorityWeightStrategy:
upstream = False
elif self.weight_rule == WeightRule.UPSTREAM:
elif type(self.weight_rule) == _UpstreamPriorityWeightStrategy:
upstream = True
else:
upstream = False
Expand Down
21 changes: 11 additions & 10 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
Expand All @@ -94,7 +95,6 @@
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
from airflow.utils.weight_rule import WeightRule
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -244,7 +244,7 @@ def partial(
retry_delay: timedelta | float | ArgNotSet = NOTSET,
retry_exponential_backoff: bool | ArgNotSet = NOTSET,
priority_weight: int | ArgNotSet = NOTSET,
weight_rule: str | ArgNotSet = NOTSET,
weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET,
sla: timedelta | None | ArgNotSet = NOTSET,
map_index_template: str | None | ArgNotSet = NOTSET,
max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
Expand Down Expand Up @@ -575,6 +575,13 @@ class derived from this one results in the creation of a task object,
significantly speeding up the task creation process as for very large
DAGs. Options can be set as string or using the constants defined in
the static class ``airflow.utils.WeightRule``
|experimental|
Since 2.9.0, Airflow allows to define custom priority weight strategy,
by creating a subclass of
``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering
in a plugin, then providing the class path or the class instance via
``weight_rule`` parameter. The custom priority weight strategy will be
used to calculate the effective total priority weight of the task instance.
:param queue: which queue to target when running this job. Not
all executors implement queue management, the CeleryExecutor
does support targeting specific queues.
Expand Down Expand Up @@ -767,7 +774,7 @@ def __init__(
params: collections.abc.MutableMapping | None = None,
default_args: dict | None = None,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
weight_rule: str = DEFAULT_WEIGHT_RULE,
weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
queue: str = DEFAULT_QUEUE,
pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
Expand Down Expand Up @@ -918,13 +925,7 @@ def __init__(
f"received '{type(priority_weight)}'."
)
self.priority_weight = priority_weight
if not WeightRule.is_valid(weight_rule):
raise AirflowException(
f"The weight_rule must be one of "
f"{WeightRule.all_weight_rules},'{dag.dag_id if dag else ''}.{task_id}'; "
f"received '{weight_rule}'."
)
self.weight_rule = weight_rule
self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
self.resources = coerce_resources(resources)
if task_concurrency and not max_active_tis_per_dag:
# TODO: Remove in Airflow 3.0
Expand Down
11 changes: 7 additions & 4 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
Expand Down Expand Up @@ -534,12 +535,14 @@ def priority_weight(self, value: int) -> None:
self.partial_kwargs["priority_weight"] = value

@property
def weight_rule(self) -> str: # type: ignore[override]
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override]
return validate_and_load_priority_weight_strategy(
self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
)

@weight_rule.setter
def weight_rule(self, value: str) -> None:
self.partial_kwargs["weight_rule"] = value
def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value)

@property
def sla(self) -> datetime.timedelta | None:
Expand Down
12 changes: 10 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,11 @@ def _refresh_from_task(
task_instance.queue = task.queue
task_instance.pool = pool_override or task.pool
task_instance.pool_slots = task.pool_slots
task_instance.priority_weight = task.priority_weight_total
with contextlib.suppress(Exception):
# This method is called from the different places, and sometimes the TI is not fully initialized
task_instance.priority_weight = task_instance.task.weight_rule.get_weight(
task_instance # type: ignore[arg-type]
)
task_instance.run_as_user = task.run_as_user
# Do not set max_tries to task.retries here because max_tries is a cumulative
# value that needs to be stored in the db.
Expand Down Expand Up @@ -1421,6 +1425,10 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any
:meta private:
"""
priority_weight = task.weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index)
)

return {
"dag_id": task.dag_id,
"task_id": task.task_id,
Expand All @@ -1431,7 +1439,7 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any
"queue": task.queue,
"pool": task.pool,
"pool_slots": task.pool_slots,
"priority_weight": task.priority_weight_total,
"priority_weight": priority_weight,
"run_as_user": task.run_as_user,
"max_tries": task.retries,
"executor_config": task.executor_config,
Expand Down
36 changes: 35 additions & 1 deletion airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from typing import TYPE_CHECKING, Any, Iterable

from airflow import settings
from airflow.task.priority_strategy import (
PriorityWeightStrategy,
airflow_priority_weight_strategies,
)
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.file import find_path_from_directory
from airflow.utils.module_loading import import_string, qualname
Expand Down Expand Up @@ -68,6 +72,7 @@
registered_operator_link_classes: dict[str, type] | None = None
registered_ti_dep_classes: dict[str, type] | None = None
timetable_classes: dict[str, type[Timetable]] | None = None
priority_weight_strategy_classes: dict[str, type[PriorityWeightStrategy]] | None = None
"""
Mapping of class names to class of OperatorLinks registered by plugins.
Expand All @@ -89,6 +94,7 @@
"ti_deps",
"timetables",
"listeners",
"priority_weight_strategies",
}


Expand Down Expand Up @@ -169,6 +175,9 @@ class AirflowPlugin:

listeners: list[ModuleType | object] = []

# A list of priority weight strategy classes that can be used for calculating tasks weight priority.
priority_weight_strategies: list[type[PriorityWeightStrategy]] = []

@classmethod
def validate(cls):
"""Validate if plugin has a name."""
Expand Down Expand Up @@ -556,7 +565,7 @@ def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str
for attr in attrs_to_dump:
if attr in ("global_operator_extra_links", "operator_extra_links"):
info[attr] = [f"<{qualname(d.__class__)} object>" for d in getattr(plugin, attr)]
elif attr in ("macros", "timetables", "hooks", "executors"):
elif attr in ("macros", "timetables", "hooks", "executors", "priority_weight_strategies"):
info[attr] = [qualname(d) for d in getattr(plugin, attr)]
elif attr == "listeners":
# listeners may be modules or class instances
Expand All @@ -577,3 +586,28 @@ def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str
info[attr] = getattr(plugin, attr)
plugins_info.append(info)
return plugins_info


def initialize_priority_weight_strategy_plugins():
"""Collect priority weight strategy classes registered by plugins."""
global priority_weight_strategy_classes

if priority_weight_strategy_classes is not None:
return

ensure_plugins_loaded()

if plugins is None:
raise AirflowPluginException("Can't load plugins.")

log.debug("Initialize extra priority weight strategy plugins")

plugins_priority_weight_strategy_classes = {
qualname(priority_weight_strategy_class): priority_weight_strategy_class
for plugin in plugins
for priority_weight_strategy_class in plugin.priority_weight_strategies
}
priority_weight_strategy_classes = {
**airflow_priority_weight_strategies,
**plugins_priority_weight_strategy_classes,
}
Loading

0 comments on commit 5ab7a40

Please sign in to comment.