Skip to content

Commit d85a301

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add support for passing request metadata in RemoteA2AAgent
This change updates `RemoteA2AAgent` to extract and forward custom metadata from session events to the `a2a-sdk`'s `send_message` method. The metadata is looked for under the key `A2A_METADATA_PREFIX + "metadata"` within the `custom_metadata` of the relevant session events. The `a2a-sdk` dependency is also updated to a version that supports this feature. This feature was added in v0.3.11 of the a2a-sdk library: https://github.com/a2aproject/a2a-python/releases/tag/v0.3.11 PiperOrigin-RevId: 831120978
1 parent 249216e commit d85a301

File tree

3 files changed

+84
-57
lines changed

3 files changed

+84
-57
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ dev = [
9292

9393
a2a = [
9494
# go/keep-sorted start
95-
"a2a-sdk>=0.3.4,<0.4.0;python_version>='3.10'",
95+
"a2a-sdk>=0.3.11,<0.4.0;python_version>='3.10'",
9696
# go/keep-sorted end
9797
]
9898

src/google/adk/agents/remote_a2a_agent.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ async def _ensure_resolved(self) -> None:
318318

319319
def _create_a2a_request_for_user_function_response(
320320
self, ctx: InvocationContext
321-
) -> Optional[A2AMessage]:
321+
) -> tuple[Optional[A2AMessage], Optional[dict[str, Any]]]:
322322
"""Create A2A request for user function response if applicable.
323323
324324
Args:
@@ -328,34 +328,38 @@ def _create_a2a_request_for_user_function_response(
328328
SendMessageRequest if function response found, None otherwise
329329
"""
330330
if not ctx.session.events or ctx.session.events[-1].author != "user":
331-
return None
331+
return None, None
332332
function_call_event = find_matching_function_call(ctx.session.events)
333333
if not function_call_event:
334-
return None
334+
return None, None
335335

336336
a2a_message = convert_event_to_a2a_message(
337337
ctx.session.events[-1], ctx, Role.user, self._genai_part_converter
338338
)
339+
message_metadata = None
339340
if function_call_event.custom_metadata:
340341
metadata = function_call_event.custom_metadata
341342
a2a_message.task_id = metadata.get(A2A_METADATA_PREFIX + "task_id")
342343
a2a_message.context_id = metadata.get(A2A_METADATA_PREFIX + "context_id")
344+
message_metadata = metadata.get(A2A_METADATA_PREFIX + "metadata")
343345

344-
return a2a_message
346+
return a2a_message, message_metadata
345347

346348
def _construct_message_parts_from_session(
347349
self, ctx: InvocationContext
348-
) -> tuple[list[A2APart], Optional[str]]:
350+
) -> tuple[list[A2APart], Optional[str], Optional[dict[str, Any]]]:
349351
"""Construct A2A message parts from session events.
350352
351353
Args:
352354
ctx: The invocation context
353355
354356
Returns:
355-
List of A2A parts extracted from session events, context ID
357+
List of A2A parts extracted from session events, context ID,
358+
request metadata
356359
"""
357360
message_parts: list[A2APart] = []
358361
context_id = None
362+
request_metadata = None
359363

360364
events_to_process = []
361365
for event in reversed(ctx.session.events):
@@ -365,6 +369,7 @@ def _construct_message_parts_from_session(
365369
if event.custom_metadata:
366370
metadata = event.custom_metadata
367371
context_id = metadata.get(A2A_METADATA_PREFIX + "context_id")
372+
request_metadata = metadata.get(A2A_METADATA_PREFIX + "metadata")
368373
break
369374
events_to_process.append(event)
370375

@@ -385,7 +390,7 @@ def _construct_message_parts_from_session(
385390
else:
386391
logger.warning("Failed to convert part to A2A format: %s", part)
387392

388-
return message_parts, context_id
393+
return message_parts, context_id, request_metadata
389394

390395
async def _handle_a2a_response(
391396
self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext
@@ -493,10 +498,12 @@ async def _run_async_impl(
493498
return
494499

495500
# Create A2A request for function response or regular message
496-
a2a_request = self._create_a2a_request_for_user_function_response(ctx)
501+
a2a_request, request_metadata = (
502+
self._create_a2a_request_for_user_function_response(ctx)
503+
)
497504
if not a2a_request:
498-
message_parts, context_id = self._construct_message_parts_from_session(
499-
ctx
505+
message_parts, context_id, request_metadata = (
506+
self._construct_message_parts_from_session(ctx)
500507
)
501508

502509
if not message_parts:
@@ -522,7 +529,8 @@ async def _run_async_impl(
522529

523530
try:
524531
async for a2a_response in self._a2a_client.send_message(
525-
request=a2a_request
532+
request=a2a_request,
533+
request_metadata=request_metadata,
526534
):
527535
logger.debug(build_a2a_response_log(a2a_response))
528536

tests/unittests/agents/test_remote_a2a_agent.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def test_create_a2a_request_for_user_function_response_no_function_call(self):
582582
) as mock_find:
583583
mock_find.return_value = None
584584

585-
result = self.agent._create_a2a_request_for_user_function_response(
585+
result, _ = self.agent._create_a2a_request_for_user_function_response(
586586
self.mock_context
587587
)
588588

@@ -593,7 +593,8 @@ def test_create_a2a_request_for_user_function_response_success(self):
593593
# Mock function call event
594594
mock_function_event = Mock()
595595
mock_function_event.custom_metadata = {
596-
A2A_METADATA_PREFIX + "task_id": "task-123"
596+
A2A_METADATA_PREFIX + "task_id": "task-123",
597+
A2A_METADATA_PREFIX + "metadata": {"foo": "bar"},
597598
}
598599

599600
# Mock latest event with function response - set proper author
@@ -614,13 +615,17 @@ def test_create_a2a_request_for_user_function_response_success(self):
614615
mock_a2a_message.task_id = None # Will be set by the method
615616
mock_convert.return_value = mock_a2a_message
616617

617-
result = self.agent._create_a2a_request_for_user_function_response(
618-
self.mock_context
618+
result, metadata = (
619+
self.agent._create_a2a_request_for_user_function_response(
620+
self.mock_context
621+
)
619622
)
620623

621624
assert result is not None
622625
assert result == mock_a2a_message
623626
assert mock_a2a_message.task_id == "task-123"
627+
assert metadata is not None
628+
assert metadata == {"foo": "bar"}
624629

625630
def test_construct_message_parts_from_session_success(self):
626631
"""Test successful message parts construction from session."""
@@ -644,14 +649,14 @@ def test_construct_message_parts_from_session_success(self):
644649
mock_a2a_part = Mock()
645650
self.mock_genai_part_converter.return_value = mock_a2a_part
646651

647-
result = self.agent._construct_message_parts_from_session(
648-
self.mock_context
652+
parts, context_id, metadata = (
653+
self.agent._construct_message_parts_from_session(self.mock_context)
649654
)
650655

651-
assert len(result) == 2 # Returns tuple of (parts, context_id)
652-
assert len(result[0]) == 1 # parts list
653-
assert result[0][0] == mock_a2a_part
654-
assert result[1] is None # context_id
656+
assert len(parts) == 1
657+
assert parts[0] == mock_a2a_part
658+
assert context_id is None
659+
assert metadata is None
655660

656661
def test_construct_message_parts_from_session_success_multiple_parts(self):
657662
"""Test successful message parts construction from session."""
@@ -679,24 +684,25 @@ def test_construct_message_parts_from_session_success_multiple_parts(self):
679684
mock_a2a_part2,
680685
]
681686

682-
result = self.agent._construct_message_parts_from_session(
683-
self.mock_context
687+
parts, context_id, metadata = (
688+
self.agent._construct_message_parts_from_session(self.mock_context)
684689
)
685690

686-
assert len(result) == 2 # Returns tuple of (parts, context_id)
687-
assert len(result[0]) == 2 # parts list
688-
assert result[0] == [mock_a2a_part1, mock_a2a_part2]
689-
assert result[1] is None # context_id
691+
assert parts == [mock_a2a_part1, mock_a2a_part2]
692+
assert context_id is None
693+
assert metadata is None
690694

691695
def test_construct_message_parts_from_session_empty_events(self):
692696
"""Test message parts construction with empty events."""
693697
self.mock_session.events = []
694698

695-
result = self.agent._construct_message_parts_from_session(self.mock_context)
699+
parts, context_id, metadata = (
700+
self.agent._construct_message_parts_from_session(self.mock_context)
701+
)
696702

697-
assert len(result) == 2 # Returns tuple of (parts, context_id)
698-
assert result[0] == [] # empty parts list
699-
assert result[1] is None # context_id
703+
assert parts == []
704+
assert context_id is None
705+
assert metadata is None
700706

701707
@pytest.mark.asyncio
702708
async def test_handle_a2a_response_success_with_message(self):
@@ -818,14 +824,14 @@ def mock_converter(part):
818824

819825
self.mock_genai_part_converter.side_effect = mock_converter
820826

821-
result = self.agent._construct_message_parts_from_session(
822-
self.mock_context
827+
parts, context_id, metadata = (
828+
self.agent._construct_message_parts_from_session(self.mock_context)
823829
)
824830

825831
# Verify the parts are in correct order
826-
assert len(result) == 2 # Returns tuple of (parts, context_id)
827-
assert len(result[0]) == 3 # 1 user part + 2 other agent parts
828-
assert result[1] is None # context_id
832+
assert len(parts) == 3 # 1 user part + 2 other agent parts
833+
assert context_id is None
834+
assert metadata is None
829835

830836
# Verify order: user part, then "For context:", then agent message
831837
assert converted_parts[0].original_text == "User question"
@@ -1068,7 +1074,7 @@ def test_create_a2a_request_for_user_function_response_no_function_call(self):
10681074
) as mock_find:
10691075
mock_find.return_value = None
10701076

1071-
result = self.agent._create_a2a_request_for_user_function_response(
1077+
result, _ = self.agent._create_a2a_request_for_user_function_response(
10721078
self.mock_context
10731079
)
10741080

@@ -1079,7 +1085,8 @@ def test_create_a2a_request_for_user_function_response_success(self):
10791085
# Mock function call event
10801086
mock_function_event = Mock()
10811087
mock_function_event.custom_metadata = {
1082-
A2A_METADATA_PREFIX + "task_id": "task-123"
1088+
A2A_METADATA_PREFIX + "task_id": "task-123",
1089+
A2A_METADATA_PREFIX + "metadata": {"foo": "bar"},
10831090
}
10841091

10851092
# Mock latest event with function response - set proper author
@@ -1100,13 +1107,17 @@ def test_create_a2a_request_for_user_function_response_success(self):
11001107
mock_a2a_message.task_id = None # Will be set by the method
11011108
mock_convert.return_value = mock_a2a_message
11021109

1103-
result = self.agent._create_a2a_request_for_user_function_response(
1104-
self.mock_context
1110+
result, metadata = (
1111+
self.agent._create_a2a_request_for_user_function_response(
1112+
self.mock_context
1113+
)
11051114
)
11061115

11071116
assert result is not None
11081117
assert result == mock_a2a_message
11091118
assert mock_a2a_message.task_id == "task-123"
1119+
assert metadata is not None
1120+
assert metadata == {"foo": "bar"}
11101121

11111122
def test_construct_message_parts_from_session_success(self):
11121123
"""Test successful message parts construction from session."""
@@ -1133,24 +1144,26 @@ def test_construct_message_parts_from_session_success(self):
11331144
mock_a2a_part = Mock()
11341145
mock_convert_part.return_value = mock_a2a_part
11351146

1136-
result = self.agent._construct_message_parts_from_session(
1137-
self.mock_context
1147+
parts, context_id, metadata = (
1148+
self.agent._construct_message_parts_from_session(self.mock_context)
11381149
)
11391150

1140-
assert len(result) == 2 # Returns tuple of (parts, context_id)
1141-
assert len(result[0]) == 1 # parts list
1142-
assert result[0][0] == mock_a2a_part
1143-
assert result[1] is None # context_id
1151+
assert len(parts) == 1
1152+
assert parts[0] == mock_a2a_part
1153+
assert context_id is None
1154+
assert metadata is None
11441155

11451156
def test_construct_message_parts_from_session_empty_events(self):
11461157
"""Test message parts construction with empty events."""
11471158
self.mock_session.events = []
11481159

1149-
result = self.agent._construct_message_parts_from_session(self.mock_context)
1160+
parts, context_id, metadata = (
1161+
self.agent._construct_message_parts_from_session(self.mock_context)
1162+
)
11501163

1151-
assert len(result) == 2 # Returns tuple of (parts, context_id)
1152-
assert result[0] == [] # empty parts list
1153-
assert result[1] is None # context_id
1164+
assert parts == []
1165+
assert context_id is None
1166+
assert metadata is None
11541167

11551168
@pytest.mark.asyncio
11561169
async def test_handle_a2a_response_success_with_message(self):
@@ -1469,14 +1482,15 @@ async def test_run_async_impl_no_message_parts(self):
14691482
with patch.object(
14701483
self.agent, "_create_a2a_request_for_user_function_response"
14711484
) as mock_create_func:
1472-
mock_create_func.return_value = None
1485+
mock_create_func.return_value = (None, None)
14731486

14741487
with patch.object(
14751488
self.agent, "_construct_message_parts_from_session"
14761489
) as mock_construct:
14771490
mock_construct.return_value = (
14781491
[],
14791492
None,
1493+
None,
14801494
) # Tuple with empty parts and no context_id
14811495

14821496
events = []
@@ -1494,7 +1508,7 @@ async def test_run_async_impl_successful_request(self):
14941508
with patch.object(
14951509
self.agent, "_create_a2a_request_for_user_function_response"
14961510
) as mock_create_func:
1497-
mock_create_func.return_value = None
1511+
mock_create_func.return_value = (None, None)
14981512

14991513
with patch.object(
15001514
self.agent, "_construct_message_parts_from_session"
@@ -1507,6 +1521,7 @@ async def test_run_async_impl_successful_request(self):
15071521
mock_construct.return_value = (
15081522
[mock_a2a_part],
15091523
"context-123",
1524+
{"foo": "bar"},
15101525
) # Tuple with parts and context_id
15111526

15121527
# Mock A2A client
@@ -1567,7 +1582,7 @@ async def test_run_async_impl_a2a_client_error(self):
15671582
with patch.object(
15681583
self.agent, "_create_a2a_request_for_user_function_response"
15691584
) as mock_create_func:
1570-
mock_create_func.return_value = None
1585+
mock_create_func.return_value = None, None
15711586

15721587
with patch.object(
15731588
self.agent, "_construct_message_parts_from_session"
@@ -1579,6 +1594,7 @@ async def test_run_async_impl_a2a_client_error(self):
15791594
mock_construct.return_value = (
15801595
[mock_a2a_part],
15811596
"context-123",
1597+
{"foo": "bar"},
15821598
) # Tuple with parts and context_id
15831599

15841600
# Mock A2A client that throws an exception
@@ -1660,14 +1676,15 @@ async def test_run_async_impl_no_message_parts(self):
16601676
with patch.object(
16611677
self.agent, "_create_a2a_request_for_user_function_response"
16621678
) as mock_create_func:
1663-
mock_create_func.return_value = None
1679+
mock_create_func.return_value = None, None
16641680

16651681
with patch.object(
16661682
self.agent, "_construct_message_parts_from_session"
16671683
) as mock_construct:
16681684
mock_construct.return_value = (
16691685
[],
16701686
None,
1687+
None,
16711688
) # Tuple with empty parts and no context_id
16721689

16731690
events = []
@@ -1685,7 +1702,7 @@ async def test_run_async_impl_successful_request(self):
16851702
with patch.object(
16861703
self.agent, "_create_a2a_request_for_user_function_response"
16871704
) as mock_create_func:
1688-
mock_create_func.return_value = None
1705+
mock_create_func.return_value = None, None
16891706

16901707
with patch.object(
16911708
self.agent, "_construct_message_parts_from_session"
@@ -1698,6 +1715,7 @@ async def test_run_async_impl_successful_request(self):
16981715
mock_construct.return_value = (
16991716
[mock_a2a_part],
17001717
"context-123",
1718+
None,
17011719
) # Tuple with parts and context_id
17021720

17031721
# Mock A2A client
@@ -1760,7 +1778,7 @@ async def test_run_async_impl_a2a_client_error(self):
17601778
with patch.object(
17611779
self.agent, "_create_a2a_request_for_user_function_response"
17621780
) as mock_create_func:
1763-
mock_create_func.return_value = None
1781+
mock_create_func.return_value = None, None
17641782

17651783
with patch.object(
17661784
self.agent, "_construct_message_parts_from_session"
@@ -1772,6 +1790,7 @@ async def test_run_async_impl_a2a_client_error(self):
17721790
mock_construct.return_value = (
17731791
[mock_a2a_part],
17741792
"context-123",
1793+
None,
17751794
) # Tuple with parts and context_id
17761795

17771796
# Mock A2A client that throws an exception

0 commit comments

Comments
 (0)