Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,17 @@ def skip_all_except(
if isinstance(branch_task_ids, str):
branch_task_id_set = {branch_task_ids}
elif isinstance(branch_task_ids, Iterable):
# Handle the case where invalid values are passed as elements of an Iterable
# Non-string values are considered invalid elements
branch_task_id_set = set(branch_task_ids)
invalid_task_ids_type = {
(bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str)
}
if invalid_task_ids_type:
raise AirflowException(
f"'branch_task_ids' expected all task IDs are strings. "
f"Invalid tasks found: {invalid_task_ids_type}."
f"Unable to branch to the specified tasks. "
f"The branching function returned invalid 'branch_task_ids': {invalid_task_ids_type}. "
f"Please check that your function returns an Iterable of valid task IDs that exist in your DAG."
)
elif branch_task_ids is None:
branch_task_id_set = set()
Expand Down
19 changes: 15 additions & 4 deletions providers/standard/tests/unit/standard/utils/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,21 @@ def test_raise_exception_on_not_accepted_iterable_branch_task_ids_type(self, dag
ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id)
else:
ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID)
error_message = (
r"'branch_task_ids' expected all task IDs are strings. "
r"Invalid tasks found: \{\(42, 'int'\)\}\."
)

if AIRFLOW_V_3_0_PLUS:
# Improved error message for Airflow 3.0+
error_message = (
r"Unable to branch to the specified tasks\. "
r"The branching function returned invalid 'branch_task_ids': \{\(42, 'int'\)\}\. "
r"Please check that your function returns an Iterable of valid task IDs that exist in your DAG\."
)
else:
# Old error message for Airflow 2.x
error_message = (
r"'branch_task_ids' expected all task IDs are strings\. "
r"Invalid tasks found: \{\(42, 'int'\)\}\."
)

with pytest.raises(AirflowException, match=error_message):
SkipMixin().skip_all_except(ti=ti1, branch_task_ids=["task", 42])

Expand Down
Loading