Skip to content

Commit

Permalink
Refactor TriggerRule & WeightRule classes to inherit from Enum (#21264)
Browse files Browse the repository at this point in the history
closes: #19905
related: #5302,#18627

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
chenglongyan and uranusjr authored Feb 20, 2022
1 parent d8ae7df commit 9ad4de8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 25 deletions.
21 changes: 8 additions & 13 deletions airflow/utils/trigger_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from enum import Enum
from typing import Set


class TriggerRule:
class TriggerRule(str, Enum):
"""Class with task's trigger rules."""

ALL_SUCCESS = 'all_success'
Expand All @@ -34,20 +34,15 @@ class TriggerRule:
ALWAYS = 'always'
NONE_FAILED_MIN_ONE_SUCCESS = "none_failed_min_one_success"

_ALL_TRIGGER_RULES: Set[str] = set()

@classmethod
def is_valid(cls, trigger_rule):
def is_valid(cls, trigger_rule: str) -> bool:
"""Validates a trigger rule."""
return trigger_rule in cls.all_triggers()

@classmethod
def all_triggers(cls):
def all_triggers(cls) -> Set[str]:
"""Returns all trigger rules."""
if not cls._ALL_TRIGGER_RULES:
cls._ALL_TRIGGER_RULES = {
getattr(cls, attr)
for attr in dir(cls)
if not attr.startswith("_") and not callable(getattr(cls, attr))
}
return cls._ALL_TRIGGER_RULES
return set(cls.__members__.values())

def __str__(self) -> str:
return self.value
22 changes: 10 additions & 12 deletions airflow/utils/weight_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from enum import Enum
from typing import Set

from airflow.compat.functools import cache


class WeightRule:
class WeightRule(str, Enum):
"""Weight rules."""

DOWNSTREAM = 'downstream'
UPSTREAM = 'upstream'
ABSOLUTE = 'absolute'

_ALL_WEIGHT_RULES: Set[str] = set()

@classmethod
def is_valid(cls, weight_rule):
def is_valid(cls, weight_rule: str) -> bool:
"""Check if weight rule is valid."""
return weight_rule in cls.all_weight_rules()

@classmethod
@cache
def all_weight_rules(cls) -> Set[str]:
"""Returns all weight rules"""
if not cls._ALL_WEIGHT_RULES:
cls._ALL_WEIGHT_RULES = {
getattr(cls, attr)
for attr in dir(cls)
if not attr.startswith("_") and not callable(getattr(cls, attr))
}
return cls._ALL_WEIGHT_RULES
return set(cls.__members__.values())

def __str__(self) -> str:
return self.value
5 changes: 5 additions & 0 deletions tests/utils/test_trigger_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import unittest

import pytest

from airflow.utils.trigger_rule import TriggerRule


Expand All @@ -35,3 +37,6 @@ def test_valid_trigger_rules(self):
assert TriggerRule.is_valid(TriggerRule.ALWAYS)
assert TriggerRule.is_valid(TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
assert len(TriggerRule.all_triggers()) == 11

with pytest.raises(ValueError):
TriggerRule("NOT_EXIST_TRIGGER_RULE")
5 changes: 5 additions & 0 deletions tests/utils/test_weight_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import unittest

import pytest

from airflow.utils.weight_rule import WeightRule


Expand All @@ -27,3 +29,6 @@ def test_valid_weight_rules(self):
assert WeightRule.is_valid(WeightRule.UPSTREAM)
assert WeightRule.is_valid(WeightRule.ABSOLUTE)
assert len(WeightRule.all_weight_rules()) == 3

with pytest.raises(ValueError):
WeightRule("NOT_EXIST_WEIGHT_RULE")

0 comments on commit 9ad4de8

Please sign in to comment.