diff --git a/providers/standard/src/airflow/providers/standard/operators/hitl.py b/providers/standard/src/airflow/providers/standard/operators/hitl.py index d469e580a09ed..b0d73ff02cdd6 100644 --- a/providers/standard/src/airflow/providers/standard/operators/hitl.py +++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py @@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any from airflow.providers.standard.exceptions import HITLTimeoutError, HITLTriggerEventError +from airflow.providers.standard.operators.branch import BranchMixIn from airflow.providers.standard.triggers.hitl import HITLTrigger, HITLTriggerEventSuccessPayload from airflow.providers.standard.utils.skipmixin import SkipMixin from airflow.providers.standard.version_compat import BaseOperator @@ -218,14 +219,15 @@ def get_tasks_to_skip(): return ret -class HITLBranchOperator(HITLOperator): +class HITLBranchOperator(HITLOperator, BranchMixIn): """BranchOperator based on Human-in-the-loop Response.""" - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) + inherits_from_skipmixin = True - def execute_complete(self, context: Context, event: dict[str, Any]) -> None: - raise NotImplementedError + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + ret = super().execute_complete(context=context, event=event) + chosen_options = ret["chosen_options"] + return self.do_branch(context=context, branches_to_execute=chosen_options) class HITLEntryOperator(HITLOperator): diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py b/providers/standard/tests/unit/standard/operators/test_hitl.py index 70e6e7e8fc8dd..fb83740a79468 100644 --- a/providers/standard/tests/unit/standard/operators/test_hitl.py +++ b/providers/standard/tests/unit/standard/operators/test_hitl.py @@ -21,10 +21,12 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS if not AIRFLOW_V_3_1_PLUS: - pytest.skip("Human in the loop public API compatible with Airflow >= 3.1.0", allow_module_level=True) + pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0", allow_module_level=True) +import datetime from typing import TYPE_CHECKING, Any +import pytest from sqlalchemy import select from airflow.exceptions import DownstreamTasksSkipped @@ -33,10 +35,11 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.hitl import ( ApprovalOperator, + HITLBranchOperator, HITLEntryOperator, HITLOperator, ) -from airflow.sdk import Param +from airflow.sdk import Param, timezone from airflow.sdk.definitions.param import ParamsDict if TYPE_CHECKING: @@ -46,6 +49,9 @@ pytestmark = pytest.mark.db_test +DEFAULT_DATE = timezone.datetime(2016, 1, 1) +INTERVAL = datetime.timedelta(hours=12) + class TestHITLOperator: def test_validate_defaults(self) -> None: @@ -291,3 +297,50 @@ def test_init_without_default(self) -> None: assert op.options == ["OK", "NOT OK"] assert op.defaults is None + + +class TestHITLBranchOperator: + def test_execute_complete(self, dag_maker) -> None: + with dag_maker("hitl_test_dag", serialized=True): + branch_op = HITLBranchOperator( + task_id="make_choice", + subject="This is subject", + options=[f"branch_{i}" for i in range(1, 6)], + ) + + branch_op >> [EmptyOperator(task_id=f"branch_{i}") for i in range(1, 6)] + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("make_choice") + with pytest.raises(DownstreamTasksSkipped) as exc_info: + branch_op.execute_complete( + context={"ti": ti, "task": ti.task}, + event={ + "chosen_options": ["branch_1"], + "params_input": {}, + }, + ) + assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in range(2, 6)) + + def test_execute_complete_with_multiple_branches(self, dag_maker) -> None: + with dag_maker("hitl_test_dag", serialized=True): + branch_op = HITLBranchOperator( + task_id="make_choice", + subject="This is subject", + multiple=True, + options=[f"branch_{i}" for i in range(1, 6)], + ) + + branch_op >> [EmptyOperator(task_id=f"branch_{i}") for i in range(1, 6)] + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("make_choice") + with pytest.raises(DownstreamTasksSkipped) as exc_info: + branch_op.execute_complete( + context={"ti": ti, "task": ti.task}, + event={ + "chosen_options": [f"branch_{i}" for i in range(1, 4)], + "params_input": {}, + }, + ) + assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in range(4, 6))