Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import TYPE_CHECKING, Any

from airflow.providers.standard.exceptions import HITLTimeoutError, HITLTriggerEventError
from airflow.providers.standard.operators.branch import BranchMixIn
from airflow.providers.standard.triggers.hitl import HITLTrigger, HITLTriggerEventSuccessPayload
from airflow.providers.standard.utils.skipmixin import SkipMixin
from airflow.providers.standard.version_compat import BaseOperator
Expand Down Expand Up @@ -218,14 +219,15 @@ def get_tasks_to_skip():
return ret


class HITLBranchOperator(HITLOperator):
class HITLBranchOperator(HITLOperator, BranchMixIn):
"""BranchOperator based on Human-in-the-loop Response."""

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
inherits_from_skipmixin = True

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
raise NotImplementedError
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
ret = super().execute_complete(context=context, event=event)
chosen_options = ret["chosen_options"]
return self.do_branch(context=context, branches_to_execute=chosen_options)


class HITLEntryOperator(HITLOperator):
Expand Down
57 changes: 55 additions & 2 deletions providers/standard/tests/unit/standard/operators/test_hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS

if not AIRFLOW_V_3_1_PLUS:
pytest.skip("Human in the loop public API compatible with Airflow >= 3.1.0", allow_module_level=True)
pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0", allow_module_level=True)

import datetime
from typing import TYPE_CHECKING, Any

import pytest
from sqlalchemy import select

from airflow.exceptions import DownstreamTasksSkipped
Expand All @@ -33,10 +35,11 @@
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.hitl import (
ApprovalOperator,
HITLBranchOperator,
HITLEntryOperator,
HITLOperator,
)
from airflow.sdk import Param
from airflow.sdk import Param, timezone
from airflow.sdk.definitions.param import ParamsDict

if TYPE_CHECKING:
Expand All @@ -46,6 +49,9 @@

pytestmark = pytest.mark.db_test

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
INTERVAL = datetime.timedelta(hours=12)


class TestHITLOperator:
def test_validate_defaults(self) -> None:
Expand Down Expand Up @@ -291,3 +297,50 @@ def test_init_without_default(self) -> None:

assert op.options == ["OK", "NOT OK"]
assert op.defaults is None


class TestHITLBranchOperator:
def test_execute_complete(self, dag_maker) -> None:
with dag_maker("hitl_test_dag", serialized=True):
branch_op = HITLBranchOperator(
task_id="make_choice",
subject="This is subject",
options=[f"branch_{i}" for i in range(1, 6)],
)

branch_op >> [EmptyOperator(task_id=f"branch_{i}") for i in range(1, 6)]

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("make_choice")
with pytest.raises(DownstreamTasksSkipped) as exc_info:
branch_op.execute_complete(
context={"ti": ti, "task": ti.task},
event={
"chosen_options": ["branch_1"],
"params_input": {},
},
)
assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in range(2, 6))

def test_execute_complete_with_multiple_branches(self, dag_maker) -> None:
with dag_maker("hitl_test_dag", serialized=True):
branch_op = HITLBranchOperator(
task_id="make_choice",
subject="This is subject",
multiple=True,
options=[f"branch_{i}" for i in range(1, 6)],
)

branch_op >> [EmptyOperator(task_id=f"branch_{i}") for i in range(1, 6)]

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("make_choice")
with pytest.raises(DownstreamTasksSkipped) as exc_info:
branch_op.execute_complete(
context={"ti": ti, "task": ti.task},
event={
"chosen_options": [f"branch_{i}" for i in range(1, 4)],
"params_input": {},
},
)
assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in range(4, 6))