diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 6a96ea161cb99..5e6d6d613b47f 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -49,6 +49,8 @@ CommsDecoder, ConnectionResult, DagRunStateResult, + DeleteVariable, + DeleteXCom, DRCount, ErrorResponse, GetConnection, @@ -59,6 +61,9 @@ GetTICount, GetVariable, GetXCom, + OKResponse, + PutVariable, + SetXCom, TaskStatesResult, TICount, UpdateHITLDetail, @@ -240,7 +245,8 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | TICount | TaskStatesResult | HITLDetailResponseResult - | ErrorResponse, + | ErrorResponse + | OKResponse, Field(discriminator="type"), ] """ @@ -252,8 +258,12 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe ToTriggerSupervisor = Annotated[ messages.TriggerStateChanges | GetConnection + | DeleteVariable | GetVariable + | PutVariable + | DeleteXCom | GetXCom + | SetXCom | GetTICount | GetTaskStates | GetDagRunState @@ -419,6 +429,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True, "by_alias": True} else: resp = conn + elif isinstance(msg, DeleteVariable): + resp = self.client.variables.delete(msg.key) elif isinstance(msg, GetVariable): var = self.client.variables.get(msg.key) if isinstance(var, VariableResponse): @@ -427,6 +439,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True} else: resp = var + elif isinstance(msg, PutVariable): + self.client.variables.set(msg.key, msg.value, msg.description) + elif isinstance(msg, DeleteXCom): + self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) if isinstance(xcom, XComResponse): @@ -435,6 +451,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r dump_opts = {"exclude_unset": True} else: resp = xcom + elif isinstance(msg, SetXCom): + self.client.xcoms.set( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length + ) elif isinstance(msg, GetDRCount): dr_count = self.client.dag_runs.get_count( dag_id=msg.dag_id, diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 384bf6a76fc60..ef065cf2c8d16 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -648,19 +648,64 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: conn = await sync_to_async(BaseHook.get_connection)("test_connection") self.log.info("Loaded conn %s", conn.conn_id) - variable = await sync_to_async(Variable.get)("test_variable") - self.log.info("Loaded variable %s", variable) + get_variable_value = await sync_to_async(Variable.get)("test_get_variable") + self.log.info("Loaded variable %s", get_variable_value) - xcom = await sync_to_async(XCom.get_one)( - key="test_xcom", + get_xcom_value = await sync_to_async(XCom.get_one)( + key="test_get_xcom", dag_id=self.dag_id, run_id=self.run_id, task_id=self.task_id, map_index=self.map_index, ) - self.log.info("Loaded XCom %s", xcom) + self.log.info("Loaded XCom %s", get_xcom_value) - yield TriggerEvent({"connection": attrs.asdict(conn), "variable": variable, "xcom": xcom}) + set_variable_key = "test_set_variable" + set_variable_value = "set_value" + await sync_to_async(Variable.set)(key=set_variable_key, value=set_variable_value) + self.log.info("Set variable with key %s and value %s", set_variable_key, set_variable_value) + + set_xcom_key = "test_set_xcom" + set_xcom_value = "set_xcom" + await sync_to_async(XCom.set)( + key=set_xcom_key, + dag_id=self.dag_id, + run_id=self.run_id, + task_id=self.task_id, + map_index=self.map_index, + value=set_xcom_value, + ) + self.log.info("Set xcom with key %s and value %s", set_xcom_key, set_xcom_value) + + delete_variable_key = "test_delete_variable" + await sync_to_async(Variable.delete)(delete_variable_key) + self.log.info("Deleted variable with key %s", delete_variable_key) + + delete_xcom_key = "test_delete_xcom" + await sync_to_async(XCom.delete)( + key=delete_xcom_key, + dag_id=self.dag_id, + run_id=self.run_id, + task_id=self.task_id, + map_index=self.map_index, + ) + self.log.info("Delete xcom with key %s", delete_xcom_key) + + yield TriggerEvent( + { + "connection": attrs.asdict(conn), + "variable": { + "get_variable": get_variable_value, + "set_variable": set_variable_value, + "delete_variable": delete_variable_key, + }, + "xcom": { + "get_xcom": get_xcom_value, + "set_xcom": set_xcom_value, + "delete_xcom": delete_xcom_key, + }, + } + ) def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -687,8 +732,8 @@ def handle_events(self): @pytest.mark.asyncio @pytest.mark.execution_timeout(20) -async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_maker): - """Checks that the trigger will successfully access Variables, Connections and XComs.""" +async def test_trigger_can_call_variables_connections_and_xcoms_methods(session, dag_maker): + """Checks that the trigger will successfully call Variables, Connections and XComs methods.""" # Create the test DAG and task with dag_maker(dag_id="trigger_accessing_variable_connection_and_xcom", session=session): EmptyOperator(task_id="dummy1") @@ -704,7 +749,7 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m kwargs={"dag_id": dr.dag_id, "run_id": dr.run_id, "task_id": task_instance.task_id, "map_index": -1}, ) session.add(trigger_orm) - session.commit() + session.flush() task_instance.trigger_id = trigger_orm.id # Create the appropriate Connection, Variable and XCom @@ -718,9 +763,25 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m port=443, host="example.com", ) - variable = Variable(key="test_variable", val="some_variable_value") + get_variable = Variable(key="test_get_variable", val="some_variable_value") + delete_variable = Variable(key="test_delete_variable", val="delete_value") + + session.add(connection) + session.add(get_variable) + session.add(delete_variable) + XComModel.set( - key="test_xcom", + key="test_get_xcom", + value="some_xcom_value", + task_id=task_instance.task_id, + dag_id=dr.dag_id, + run_id=dr.run_id, + map_index=-1, + session=session, + ) + + XComModel.set( + key="test_delete_xcom", value="some_xcom_value", task_id=task_instance.task_id, dag_id=dr.dag_id, @@ -728,8 +789,6 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m map_index=-1, session=session, ) - session.add(connection) - session.add(variable) job = Job() session.add(job) @@ -741,7 +800,7 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m task_instance.refresh_from_db() assert task_instance.state == TaskInstanceState.SCHEDULED assert task_instance.next_method != "__fail__" - assert task_instance.next_kwargs == { + expected_event = { "event": { "connection": { "conn_id": "test_connection", @@ -754,10 +813,19 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m "port": 443, "extra": '{"key": "value"}', }, - "variable": "some_variable_value", - "xcom": '"some_xcom_value"', + "variable": { + "get_variable": "some_variable_value", + "set_variable": "set_value", + "delete_variable": "test_delete_variable", + }, + "xcom": { + "get_xcom": '"some_xcom_value"', + "set_xcom": "set_xcom", + "delete_xcom": "test_delete_xcom", + }, } } + assert task_instance.next_kwargs == expected_event class CustomTriggerDagRun(BaseTrigger):