diff --git a/airflow/example_dags/example_xcomargs.py b/airflow/example_dags/example_xcomargs.py index 42faadd0e5585..24df48ed8f9a6 100644 --- a/airflow/example_dags/example_xcomargs.py +++ b/airflow/example_dags/example_xcomargs.py @@ -20,6 +20,7 @@ import logging from airflow import DAG +from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator, get_current_context, task from airflow.utils.dates import days_ago @@ -43,7 +44,7 @@ def print_value(value): default_args={'owner': 'airflow'}, start_date=days_ago(2), schedule_interval=None, - tags=['example'] + tags=['example'], ) as dag: task1 = PythonOperator( task_id='generate_value', @@ -51,3 +52,18 @@ def print_value(value): ) print_value(task1.output) + + +with DAG( + "example_xcom_args_with_operators", + default_args={'owner': 'airflow'}, + start_date=days_ago(2), + schedule_interval=None, + tags=['example'], +) as dag2: + bash_op1 = BashOperator(task_id="c", bash_command="echo c") + bash_op2 = BashOperator(task_id="d", bash_command="echo c") + xcom_args_a = print_value("first!") # type: ignore + xcom_args_b = print_value("second!") # type: ignore + + bash_op1 >> xcom_args_a >> xcom_args_b >> bash_op2 diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 5036899bbcbfc..b85d638fc14c3 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -70,13 +70,29 @@ def __lshift__(self, other): Implements XComArg << op """ self.set_upstream(other) - return self + return other def __rshift__(self, other): """ Implements XComArg >> op """ self.set_downstream(other) + return other + + def __rrshift__(self, other): + """ + Called for XComArg >> [XComArg] because list don't have + __rshift__ operators. + """ + self.__lshift__(other) + return self + + def __rlshift__(self, other): + """ + Called for XComArg >> [XComArg] because list don't have + __lshift__ operators. + """ + self.__rshift__(other) return self def __getitem__(self, item): diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py index f1a261fe61965..2f38ed72bdafc 100644 --- a/tests/models/test_xcom_arg.py +++ b/tests/models/test_xcom_arg.py @@ -78,29 +78,31 @@ def test_set_downstream(self): with DAG("test_set_downstream", default_args=DEFAULT_ARGS): op_a = BashOperator(task_id="a", bash_command="echo a") op_b = BashOperator(task_id="b", bash_command="echo b") - bash_op = BashOperator(task_id="c", bash_command="echo c") + bash_op1 = BashOperator(task_id="c", bash_command="echo c") + bash_op2 = BashOperator(task_id="d", bash_command="echo c") xcom_args_a = XComArg(op_a) xcom_args_b = XComArg(op_b) - xcom_args_a >> xcom_args_b >> bash_op + bash_op1 >> xcom_args_a >> xcom_args_b >> bash_op2 - assert len(op_a.downstream_list) == 2 + assert op_a in bash_op1.downstream_list assert op_b in op_a.downstream_list - assert bash_op in op_a.downstream_list + assert bash_op2 in op_b.downstream_list def test_set_upstream(self): with DAG("test_set_upstream", default_args=DEFAULT_ARGS): op_a = BashOperator(task_id="a", bash_command="echo a") op_b = BashOperator(task_id="b", bash_command="echo b") - bash_op = BashOperator(task_id="c", bash_command="echo c") + bash_op1 = BashOperator(task_id="c", bash_command="echo c") + bash_op2 = BashOperator(task_id="d", bash_command="echo c") xcom_args_a = XComArg(op_a) xcom_args_b = XComArg(op_b) - xcom_args_a << xcom_args_b << bash_op + bash_op1 << xcom_args_a << xcom_args_b << bash_op2 - assert len(op_a.upstream_list) == 2 + assert op_a in bash_op1.upstream_list assert op_b in op_a.upstream_list - assert bash_op in op_a.upstream_list + assert bash_op2 in op_b.upstream_list def test_xcom_arg_property_of_base_operator(self): with DAG("test_xcom_arg_property_of_base_operator", default_args=DEFAULT_ARGS):