diff --git a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py index 432b287f3a190..f00bbfcb19d24 100644 --- a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py +++ b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py @@ -129,14 +129,17 @@ def skip_all_except( if isinstance(branch_task_ids, str): branch_task_id_set = {branch_task_ids} elif isinstance(branch_task_ids, Iterable): + # Handle the case where invalid values are passed as elements of an Iterable + # Non-string values are considered invalid elements branch_task_id_set = set(branch_task_ids) invalid_task_ids_type = { (bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str) } if invalid_task_ids_type: raise AirflowException( - f"'branch_task_ids' expected all task IDs are strings. " - f"Invalid tasks found: {invalid_task_ids_type}." + f"Unable to branch to the specified tasks. " + f"The branching function returned invalid 'branch_task_ids': {invalid_task_ids_type}. " + f"Please check that your function returns an Iterable of valid task IDs that exist in your DAG." ) elif branch_task_ids is None: branch_task_id_set = set() diff --git a/providers/standard/tests/unit/standard/utils/test_skipmixin.py b/providers/standard/tests/unit/standard/utils/test_skipmixin.py index db805f5f2a598..ac66c89d7e255 100644 --- a/providers/standard/tests/unit/standard/utils/test_skipmixin.py +++ b/providers/standard/tests/unit/standard/utils/test_skipmixin.py @@ -327,10 +327,21 @@ def test_raise_exception_on_not_accepted_iterable_branch_task_ids_type(self, dag ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id) else: ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID) - error_message = ( - r"'branch_task_ids' expected all task IDs are strings. " - r"Invalid tasks found: \{\(42, 'int'\)\}\." - ) + + if AIRFLOW_V_3_0_PLUS: + # Improved error message for Airflow 3.0+ + error_message = ( + r"Unable to branch to the specified tasks\. " + r"The branching function returned invalid 'branch_task_ids': \{\(42, 'int'\)\}\. " + r"Please check that your function returns an Iterable of valid task IDs that exist in your DAG\." + ) + else: + # Old error message for Airflow 2.x + error_message = ( + r"'branch_task_ids' expected all task IDs are strings\. " + r"Invalid tasks found: \{\(42, 'int'\)\}\." + ) + with pytest.raises(AirflowException, match=error_message): SkipMixin().skip_all_except(ti=ti1, branch_task_ids=["task", 42])