2020from pathlib import Path
2121from typing import Any
2222from typing import AsyncGenerator
23+ from typing import Callable
2324from typing import Optional
2425from typing import Union
2526from urllib .parse import urlparse
@@ -125,12 +126,16 @@ def __init__(
125126 self ,
126127 name : str ,
127128 agent_card : Union [AgentCard , str ],
129+ * ,
128130 description : str = "" ,
129131 httpx_client : Optional [httpx .AsyncClient ] = None ,
130132 timeout : float = DEFAULT_TIMEOUT ,
131133 genai_part_converter : GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part ,
132134 a2a_part_converter : A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part ,
133135 a2a_client_factory : Optional [A2AClientFactory ] = None ,
136+ a2a_request_meta_provider : Optional [
137+ Callable [[InvocationContext , A2AMessage ], dict [str , Any ]]
138+ ] = None ,
134139 ** kwargs : Any ,
135140 ) -> None :
136141 """Initialize RemoteA2aAgent.
@@ -144,6 +149,9 @@ def __init__(
144149 timeout: HTTP timeout in seconds
145150 a2a_client_factory: Optional A2AClientFactory object (will create own if
146151 not provided)
152+ a2a_request_meta_provider: Optional callable that takes InvocationContext
153+ and A2AMessage and returns a metadata object to attach to the A2A
154+ request.
147155 **kwargs: Additional arguments passed to BaseAgent
148156
149157 Raises:
@@ -169,6 +177,7 @@ def __init__(
169177 self ._genai_part_converter = genai_part_converter
170178 self ._a2a_part_converter = a2a_part_converter
171179 self ._a2a_client_factory : Optional [A2AClientFactory ] = a2a_client_factory
180+ self ._a2a_request_meta_provider = a2a_request_meta_provider
172181
173182 # Validate and store agent card reference
174183 if isinstance (agent_card , AgentCard ):
@@ -318,7 +327,7 @@ async def _ensure_resolved(self) -> None:
318327
319328 def _create_a2a_request_for_user_function_response (
320329 self , ctx : InvocationContext
321- ) -> tuple [ Optional [A2AMessage ], Optional [ dict [ str , Any ]] ]:
330+ ) -> Optional [A2AMessage ]:
322331 """Create A2A request for user function response if applicable.
323332
324333 Args:
@@ -328,26 +337,24 @@ def _create_a2a_request_for_user_function_response(
328337 SendMessageRequest if function response found, None otherwise
329338 """
330339 if not ctx .session .events or ctx .session .events [- 1 ].author != "user" :
331- return None , None
340+ return None
332341 function_call_event = find_matching_function_call (ctx .session .events )
333342 if not function_call_event :
334- return None , None
343+ return None
335344
336345 a2a_message = convert_event_to_a2a_message (
337346 ctx .session .events [- 1 ], ctx , Role .user , self ._genai_part_converter
338347 )
339- message_metadata = None
340348 if function_call_event .custom_metadata :
341349 metadata = function_call_event .custom_metadata
342350 a2a_message .task_id = metadata .get (A2A_METADATA_PREFIX + "task_id" )
343351 a2a_message .context_id = metadata .get (A2A_METADATA_PREFIX + "context_id" )
344- message_metadata = metadata .get (A2A_METADATA_PREFIX + "metadata" )
345352
346- return a2a_message , message_metadata
353+ return a2a_message
347354
348355 def _construct_message_parts_from_session (
349356 self , ctx : InvocationContext
350- ) -> tuple [list [A2APart ], Optional [str ], Optional [ dict [ str , Any ]] ]:
357+ ) -> tuple [list [A2APart ], Optional [str ]]:
351358 """Construct A2A message parts from session events.
352359
353360 Args:
@@ -359,7 +366,6 @@ def _construct_message_parts_from_session(
359366 """
360367 message_parts : list [A2APart ] = []
361368 context_id = None
362- request_metadata = None
363369
364370 events_to_process = []
365371 for event in reversed (ctx .session .events ):
@@ -369,7 +375,6 @@ def _construct_message_parts_from_session(
369375 if event .custom_metadata :
370376 metadata = event .custom_metadata
371377 context_id = metadata .get (A2A_METADATA_PREFIX + "context_id" )
372- request_metadata = metadata .get (A2A_METADATA_PREFIX + "metadata" )
373378 break
374379 events_to_process .append (event )
375380
@@ -390,7 +395,7 @@ def _construct_message_parts_from_session(
390395 else :
391396 logger .warning ("Failed to convert part to A2A format: %s" , part )
392397
393- return message_parts , context_id , request_metadata
398+ return message_parts , context_id
394399
395400 async def _handle_a2a_response (
396401 self , a2a_response : A2AClientEvent | A2AMessage , ctx : InvocationContext
@@ -498,12 +503,10 @@ async def _run_async_impl(
498503 return
499504
500505 # Create A2A request for function response or regular message
501- a2a_request , request_metadata = (
502- self ._create_a2a_request_for_user_function_response (ctx )
503- )
506+ a2a_request = self ._create_a2a_request_for_user_function_response (ctx )
504507 if not a2a_request :
505- message_parts , context_id , request_metadata = (
506- self . _construct_message_parts_from_session ( ctx )
508+ message_parts , context_id = self . _construct_message_parts_from_session (
509+ ctx
507510 )
508511
509512 if not message_parts :
@@ -528,6 +531,10 @@ async def _run_async_impl(
528531 logger .debug (build_a2a_request_log (a2a_request ))
529532
530533 try :
534+ request_metadata = None
535+ if self ._a2a_request_meta_provider :
536+ request_metadata = self ._a2a_request_meta_provider (ctx , a2a_request )
537+
531538 async for a2a_response in self ._a2a_client .send_message (
532539 request = a2a_request ,
533540 request_metadata = request_metadata ,
0 commit comments