Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
CommsDecoder,
ConnectionResult,
DagRunStateResult,
DeleteVariable,
DeleteXCom,
DRCount,
ErrorResponse,
GetConnection,
Expand All @@ -59,6 +61,9 @@
GetTICount,
GetVariable,
GetXCom,
OKResponse,
PutVariable,
SetXCom,
TaskStatesResult,
TICount,
UpdateHITLDetail,
Expand Down Expand Up @@ -240,7 +245,8 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
| TICount
| TaskStatesResult
| HITLDetailResponseResult
| ErrorResponse,
| ErrorResponse
| OKResponse,
Field(discriminator="type"),
]
"""
Expand All @@ -252,8 +258,12 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
ToTriggerSupervisor = Annotated[
messages.TriggerStateChanges
| GetConnection
| DeleteVariable
| GetVariable
| PutVariable
| DeleteXCom
| GetXCom
| SetXCom
| GetTICount
| GetTaskStates
| GetDagRunState
Expand Down Expand Up @@ -419,6 +429,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
dump_opts = {"exclude_unset": True, "by_alias": True}
else:
resp = conn
elif isinstance(msg, DeleteVariable):
resp = self.client.variables.delete(msg.key)
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
if isinstance(var, VariableResponse):
Expand All @@ -427,6 +439,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
dump_opts = {"exclude_unset": True}
else:
resp = var
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
elif isinstance(msg, DeleteXCom):
self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
if isinstance(xcom, XComResponse):
Expand All @@ -435,6 +451,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
dump_opts = {"exclude_unset": True}
else:
resp = xcom
elif isinstance(msg, SetXCom):
self.client.xcoms.set(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length
)
elif isinstance(msg, GetDRCount):
dr_count = self.client.dag_runs.get_count(
dag_id=msg.dag_id,
Expand Down
100 changes: 84 additions & 16 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,19 +648,64 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]:
conn = await sync_to_async(BaseHook.get_connection)("test_connection")
self.log.info("Loaded conn %s", conn.conn_id)

variable = await sync_to_async(Variable.get)("test_variable")
self.log.info("Loaded variable %s", variable)
get_variable_value = await sync_to_async(Variable.get)("test_get_variable")
self.log.info("Loaded variable %s", get_variable_value)

xcom = await sync_to_async(XCom.get_one)(
key="test_xcom",
get_xcom_value = await sync_to_async(XCom.get_one)(
key="test_get_xcom",
dag_id=self.dag_id,
run_id=self.run_id,
task_id=self.task_id,
map_index=self.map_index,
)
self.log.info("Loaded XCom %s", xcom)
self.log.info("Loaded XCom %s", get_xcom_value)

yield TriggerEvent({"connection": attrs.asdict(conn), "variable": variable, "xcom": xcom})
set_variable_key = "test_set_variable"
set_variable_value = "set_value"
await sync_to_async(Variable.set)(key=set_variable_key, value=set_variable_value)
self.log.info("Set variable with key %s and value %s", set_variable_key, set_variable_value)

set_xcom_key = "test_set_xcom"
set_xcom_value = "set_xcom"
await sync_to_async(XCom.set)(
key=set_xcom_key,
dag_id=self.dag_id,
run_id=self.run_id,
task_id=self.task_id,
map_index=self.map_index,
value=set_xcom_value,
)
self.log.info("Set xcom with key %s and value %s", set_xcom_key, set_xcom_value)

delete_variable_key = "test_delete_variable"
await sync_to_async(Variable.delete)(delete_variable_key)
self.log.info("Deleted variable with key %s", delete_variable_key)

delete_xcom_key = "test_delete_xcom"
await sync_to_async(XCom.delete)(
key=delete_xcom_key,
dag_id=self.dag_id,
run_id=self.run_id,
task_id=self.task_id,
map_index=self.map_index,
)
self.log.info("Delete xcom with key %s", delete_xcom_key)

yield TriggerEvent(
{
"connection": attrs.asdict(conn),
"variable": {
"get_variable": get_variable_value,
"set_variable": set_variable_value,
"delete_variable": delete_variable_key,
},
"xcom": {
"get_xcom": get_xcom_value,
"set_xcom": set_xcom_value,
"delete_xcom": delete_xcom_key,
},
}
)

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
Expand All @@ -687,8 +732,8 @@ def handle_events(self):

@pytest.mark.asyncio
@pytest.mark.execution_timeout(20)
async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_maker):
"""Checks that the trigger will successfully access Variables, Connections and XComs."""
async def test_trigger_can_call_variables_connections_and_xcoms_methods(session, dag_maker):
"""Checks that the trigger will successfully call Variables, Connections and XComs methods."""
# Create the test DAG and task
with dag_maker(dag_id="trigger_accessing_variable_connection_and_xcom", session=session):
EmptyOperator(task_id="dummy1")
Expand All @@ -704,7 +749,7 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m
kwargs={"dag_id": dr.dag_id, "run_id": dr.run_id, "task_id": task_instance.task_id, "map_index": -1},
)
session.add(trigger_orm)
session.commit()
session.flush()
task_instance.trigger_id = trigger_orm.id

# Create the appropriate Connection, Variable and XCom
Expand All @@ -718,18 +763,32 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m
port=443,
host="example.com",
)
variable = Variable(key="test_variable", val="some_variable_value")
get_variable = Variable(key="test_get_variable", val="some_variable_value")
delete_variable = Variable(key="test_delete_variable", val="delete_value")

session.add(connection)
session.add(get_variable)
session.add(delete_variable)

XComModel.set(
key="test_xcom",
key="test_get_xcom",
value="some_xcom_value",
task_id=task_instance.task_id,
dag_id=dr.dag_id,
run_id=dr.run_id,
map_index=-1,
session=session,
)

XComModel.set(
key="test_delete_xcom",
value="some_xcom_value",
task_id=task_instance.task_id,
dag_id=dr.dag_id,
run_id=dr.run_id,
map_index=-1,
session=session,
)
session.add(connection)
session.add(variable)

job = Job()
session.add(job)
Expand All @@ -741,7 +800,7 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m
task_instance.refresh_from_db()
assert task_instance.state == TaskInstanceState.SCHEDULED
assert task_instance.next_method != "__fail__"
assert task_instance.next_kwargs == {
expected_event = {
"event": {
"connection": {
"conn_id": "test_connection",
Expand All @@ -754,10 +813,19 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m
"port": 443,
"extra": '{"key": "value"}',
},
"variable": "some_variable_value",
"xcom": '"some_xcom_value"',
"variable": {
"get_variable": "some_variable_value",
"set_variable": "set_value",
"delete_variable": "test_delete_variable",
},
"xcom": {
"get_xcom": '"some_xcom_value"',
"set_xcom": "set_xcom",
"delete_xcom": "test_delete_xcom",
},
}
}
assert task_instance.next_kwargs == expected_event


class CustomTriggerDagRun(BaseTrigger):
Expand Down
Loading