Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,57 @@ class HITLBranchOperator(HITLOperator, BranchMixIn):

inherits_from_skipmixin = True

def __init__(self, *, options_mapping: dict[str, str] | None = None, **kwargs) -> None:
"""
Initialize HITLBranchOperator.

Args:
options_mapping:
A dictionary mapping option labels (must match entries in `self.options`)
to string values (e.g., task IDs). Defaults to an empty dict if not provided.

Raises:
ValueError:
- If `options_mapping` contains keys not present in `self.options`.
- If any value in `options_mapping` is not a string.
"""
super().__init__(**kwargs)
self.options_mapping = options_mapping or {}
self.validate_options_mapping()

def validate_options_mapping(self) -> None:
"""
Validate that `options_mapping` keys match `self.options` and all values are strings.

Raises:
ValueError: If any key is not in `self.options` or any value is not a string.
"""
if not self.options_mapping:
return

# Validate that the choice options are keys in the mapping are the same
invalid_keys = set(self.options_mapping.keys()) - set(self.options)
if invalid_keys:
raise ValueError(
f"`options_mapping` contains keys that are not in `options`: {sorted(invalid_keys)}"
)

# validate that all values are strings
invalid_entries = {
k: (v, type(v).__name__) for k, v in self.options_mapping.items() if not isinstance(v, str)
}
if invalid_entries:
raise ValueError(
f"`options_mapping` values must be strings (task_ids).\nInvalid entries: {invalid_entries}"
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""Execute the operator and branch based on chosen options."""
ret = super().execute_complete(context=context, event=event)
chosen_options = ret["chosen_options"]

# Map options to task IDs using the mapping, fallback to original option
chosen_options = [self.options_mapping.get(option, option) for option in chosen_options]
return self.do_branch(context=context, branches_to_execute=chosen_options)


Expand Down
110 changes: 109 additions & 1 deletion providers/standard/tests/unit/standard/operators/test_hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import pytest
from sqlalchemy import select

from airflow.exceptions import DownstreamTasksSkipped
from airflow.exceptions import AirflowException, DownstreamTasksSkipped
from airflow.models import TaskInstance, Trigger
from airflow.models.hitl import HITLDetail
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -505,3 +505,111 @@ def test_execute_complete_with_multiple_branches(self, dag_maker) -> None:
},
)
assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in range(4, 6))

def test_mapping_applies_for_single_choice(self, dag_maker):
# ["Approve"]; map -> "publish"
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
task_id="choose",
subject="S",
options=["Approve", "Reject"],
options_mapping={"Approve": "publish"},
)
op >> [EmptyOperator(task_id="publish"), EmptyOperator(task_id="archive")]

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("choose")

with pytest.raises(DownstreamTasksSkipped) as exc:
op.execute_complete(
context={"ti": ti, "task": ti.task},
event={"chosen_options": ["Approve"], "params_input": {}},
)
# checks to see that the "archive" task was skipped
assert set(exc.value.tasks) == {("archive", -1)}

def test_mapping_with_multiple_choices(self, dag_maker):
# multiple=True; mapping applied per option; no dedup implied
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
task_id="choose",
subject="S",
multiple=True,
options=["Approve", "KeepAsIs"],
options_mapping={"Approve": "publish", "KeepAsIs": "keep"},
)
op >> [
EmptyOperator(task_id="publish"),
EmptyOperator(task_id="keep"),
EmptyOperator(task_id="other"),
]

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("choose")

with pytest.raises(DownstreamTasksSkipped) as exc:
op.execute_complete(
context={"ti": ti, "task": ti.task},
event={"chosen_options": ["Approve", "KeepAsIs"], "params_input": {}},
)
# publish + keep chosen → only "other" skipped
assert set(exc.value.tasks) == {("other", -1)}

def test_fallback_to_option_when_not_mapped(self, dag_maker):
# No mapping: option must match downstream task_id
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
task_id="choose",
subject="S",
options=["branch_1", "branch_2"], # no mapping for branch_2
)
op >> [EmptyOperator(task_id="branch_1"), EmptyOperator(task_id="branch_2")]

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("choose")

with pytest.raises(DownstreamTasksSkipped) as exc:
op.execute_complete(
context={"ti": ti, "task": ti.task},
event={"chosen_options": ["branch_2"], "params_input": {}},
)
assert set(exc.value.tasks) == {("branch_1", -1)}

def test_error_if_mapped_branch_not_direct_downstream(self, dag_maker):
# Don't add the mapped task downstream → expect a clean error
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
task_id="choose",
subject="S",
options=["Approve"],
options_mapping={"Approve": "not_a_downstream"},
)
# Intentionally no downstream "not_a_downstream"

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("choose")

with pytest.raises(AirflowException, match="downstream|not found"):
op.execute_complete(
context={"ti": ti, "task": ti.task},
event={"chosen_options": ["Approve"], "params_input": {}},
)

@pytest.mark.parametrize("bad", [123, ["publish"], {"x": "y"}, b"publish"])
def test_options_mapping_non_string_value_raises(self, bad):
with pytest.raises(ValueError, match=r"values must be strings \(task_ids\)"):
HITLBranchOperator(
task_id="choose",
subject="S",
options=["Approve"],
options_mapping={"Approve": bad},
)

def test_options_mapping_key_not_in_options_raises(self):
with pytest.raises(ValueError, match="contains keys that are not in `options`"):
HITLBranchOperator(
task_id="choose",
subject="S",
options=["Approve", "Reject"],
options_mapping={"NotAnOption": "publish"},
)