Skip to content

Commit 195886d

Browse files
author
Emil Milow
committed
feat: Adds telemetry and fixes usage metadata for live mode
1 parent d620bcb commit 195886d

File tree

6 files changed

+610
-17
lines changed

6 files changed

+610
-17
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,12 @@ def get_author_for_event(llm_response):
251251
invocation_id=invocation_context.invocation_id,
252252
author=get_author_for_event(llm_response),
253253
)
254+
trace_call_llm(
255+
invocation_context,
256+
model_response_event.id,
257+
llm_request,
258+
llm_response,
259+
)
254260
async for event in self._postprocess_live(
255261
invocation_context,
256262
llm_request,

src/google/adk/models/gemini_llm_connection.py

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
148148
content = message.server_content.model_turn
149149
if content and content.parts:
150150
llm_response = LlmResponse(
151-
content=content, interrupted=message.server_content.interrupted
151+
content=content,
152+
interrupted=message.server_content.interrupted,
153+
usage_metadata=self._fix_usage_metadata(
154+
getattr(message, 'usage_metadata', None)
155+
),
152156
)
153157
if content.parts[0].text:
154158
text += content.parts[0].text
@@ -169,7 +173,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
169173
)
170174
]
171175
llm_response = LlmResponse(
172-
content=types.Content(role='user', parts=parts)
176+
content=types.Content(role='user', parts=parts),
177+
usage_metadata=self._fix_usage_metadata(
178+
getattr(message, 'usage_metadata', None)
179+
),
173180
)
174181
yield llm_response
175182
if (
@@ -190,7 +197,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
190197
)
191198
]
192199
llm_response = LlmResponse(
193-
content=types.Content(role='model', parts=parts), partial=True
200+
content=types.Content(role='model', parts=parts),
201+
partial=True,
202+
usage_metadata=self._fix_usage_metadata(
203+
getattr(message, 'usage_metadata', None)
204+
),
194205
)
195206
yield llm_response
196207

@@ -199,7 +210,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
199210
yield self.__build_full_text_response(text)
200211
text = ''
201212
yield LlmResponse(
202-
turn_complete=True, interrupted=message.server_content.interrupted
213+
turn_complete=True,
214+
interrupted=message.server_content.interrupted,
215+
usage_metadata=self._fix_usage_metadata(
216+
getattr(message, 'usage_metadata', None)
217+
),
203218
)
204219
break
205220
# in case of empty content or parts, we sill surface it
@@ -209,7 +224,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
209224
if message.server_content.interrupted and text:
210225
yield self.__build_full_text_response(text)
211226
text = ''
212-
yield LlmResponse(interrupted=message.server_content.interrupted)
227+
yield LlmResponse(
228+
interrupted=message.server_content.interrupted,
229+
usage_metadata=self._fix_usage_metadata(
230+
getattr(message, 'usage_metadata', None)
231+
),
232+
)
213233
if message.tool_call:
214234
if text:
215235
yield self.__build_full_text_response(text)
@@ -218,7 +238,84 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
218238
types.Part(function_call=function_call)
219239
for function_call in message.tool_call.function_calls
220240
]
221-
yield LlmResponse(content=types.Content(role='model', parts=parts))
241+
yield LlmResponse(
242+
content=types.Content(role='model', parts=parts),
243+
usage_metadata=self._fix_usage_metadata(
244+
getattr(message, 'usage_metadata', None)
245+
),
246+
)
247+
248+
def _fix_usage_metadata(self, usage_metadata):
249+
"""
250+
Fix missing candidates_token_count in Gemini Live API responses.
251+
252+
The Gemini Live API inconsistently returns usage metadata. While it typically
253+
provides total_token_count and prompt_token_count, it often leaves
254+
candidates_token_count as None. This creates incomplete telemetry data which
255+
affects billing reporting and token usage monitoring.
256+
257+
This method calculates the missing candidates_token_count using the formula:
258+
candidates_token_count = total_token_count - prompt_token_count
259+
260+
Args:
261+
usage_metadata: The usage metadata from the Live API response, which may
262+
have missing candidates_token_count.
263+
264+
Returns:
265+
Fixed usage metadata with calculated candidates_token_count, or the
266+
original metadata if no fix is needed/possible.
267+
"""
268+
if not usage_metadata:
269+
return usage_metadata
270+
271+
# Safely get token counts using getattr with defaults
272+
total_tokens = getattr(usage_metadata, 'total_token_count', None)
273+
prompt_tokens = getattr(usage_metadata, 'prompt_token_count', None)
274+
candidates_tokens = getattr(usage_metadata, 'candidates_token_count', None)
275+
276+
# Only fix if we have total and prompt but missing candidates
277+
if (
278+
total_tokens is not None
279+
and prompt_tokens is not None
280+
and candidates_tokens is None
281+
):
282+
# Calculate candidates tokens as: total - prompt
283+
calculated_candidates = total_tokens - prompt_tokens
284+
285+
if calculated_candidates > 0:
286+
# Create a new usage metadata object with the calculated value
287+
from google.genai import types
288+
289+
return types.GenerateContentResponseUsageMetadata(
290+
total_token_count=total_tokens,
291+
prompt_token_count=prompt_tokens,
292+
candidates_token_count=calculated_candidates,
293+
# Copy other fields if they exist
294+
cache_tokens_details=getattr(
295+
usage_metadata, 'cache_tokens_details', None
296+
),
297+
cached_content_token_count=getattr(
298+
usage_metadata, 'cached_content_token_count', None
299+
),
300+
candidates_tokens_details=getattr(
301+
usage_metadata, 'candidates_tokens_details', None
302+
),
303+
prompt_tokens_details=getattr(
304+
usage_metadata, 'prompt_tokens_details', None
305+
),
306+
thoughts_token_count=getattr(
307+
usage_metadata, 'thoughts_token_count', None
308+
),
309+
tool_use_prompt_token_count=getattr(
310+
usage_metadata, 'tool_use_prompt_token_count', None
311+
),
312+
tool_use_prompt_tokens_details=getattr(
313+
usage_metadata, 'tool_use_prompt_tokens_details', None
314+
),
315+
traffic_type=getattr(usage_metadata, 'traffic_type', None),
316+
)
317+
318+
return usage_metadata
222319

223320
async def close(self):
224321
"""Closes the llm server connection."""

src/google/adk/telemetry.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,31 @@ def trace_call_llm(
166166
llm_request: The LLM request object.
167167
llm_response: The LLM response object.
168168
"""
169-
span = trace.get_current_span()
170-
# Special standard Open Telemetry GenaI attributes that indicate
171-
# that this is a span related to a Generative AI system.
169+
# For live events with usage metadata, create a new span for each event
170+
# For regular events or live events without usage data, use the current span
171+
if (
172+
hasattr(invocation_context, 'live_request_queue')
173+
and invocation_context.live_request_queue
174+
and llm_response.usage_metadata is not None
175+
):
176+
# Live mode with usage data: create new span for each event
177+
span_name = f'llm_call_live_event [{event_id[:8]}]'
178+
with tracer.start_as_current_span(span_name) as span:
179+
_set_llm_span_attributes(
180+
span, invocation_context, event_id, llm_request, llm_response
181+
)
182+
else:
183+
# Regular mode or live mode without usage data: use current span
184+
span = trace.get_current_span()
185+
_set_llm_span_attributes(
186+
span, invocation_context, event_id, llm_request, llm_response
187+
)
188+
189+
190+
def _set_llm_span_attributes(
191+
span, invocation_context, event_id, llm_request, llm_response
192+
):
193+
"""Set LLM span attributes."""
172194
span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
173195
span.set_attribute('gen_ai.request.model', llm_request.model)
174196
span.set_attribute(
@@ -196,14 +218,16 @@ def trace_call_llm(
196218
)
197219

198220
if llm_response.usage_metadata is not None:
199-
span.set_attribute(
200-
'gen_ai.usage.input_tokens',
201-
llm_response.usage_metadata.prompt_token_count,
202-
)
203-
span.set_attribute(
204-
'gen_ai.usage.output_tokens',
205-
llm_response.usage_metadata.candidates_token_count,
206-
)
221+
if llm_response.usage_metadata.prompt_token_count is not None:
222+
span.set_attribute(
223+
'gen_ai.usage.input_tokens',
224+
llm_response.usage_metadata.prompt_token_count,
225+
)
226+
if llm_response.usage_metadata.candidates_token_count is not None:
227+
span.set_attribute(
228+
'gen_ai.usage.output_tokens',
229+
llm_response.usage_metadata.candidates_token_count,
230+
)
207231

208232

209233
def trace_send_data(

tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,146 @@ async def test_send_to_model_with_text_content(mock_llm_connection):
199199
# Verify send_content was called instead of send_realtime
200200
mock_llm_connection.send_content.assert_called_once_with(content)
201201
mock_llm_connection.send_realtime.assert_not_called()
202+
203+
204+
@pytest.mark.asyncio
205+
async def test_receive_from_model_calls_telemetry_trace(monkeypatch):
206+
"""Test that _receive_from_model calls trace_call_llm for telemetry."""
207+
# Mock the trace_call_llm function
208+
mock_trace_call_llm = mock.AsyncMock()
209+
monkeypatch.setattr(
210+
'google.adk.flows.llm_flows.base_llm_flow.trace_call_llm',
211+
mock_trace_call_llm,
212+
)
213+
214+
# Create mock LLM connection that yields responses
215+
mock_llm_connection = mock.AsyncMock()
216+
217+
# Create test LLM response with usage metadata
218+
from google.adk.models.llm_response import LlmResponse
219+
220+
test_llm_response = LlmResponse(
221+
content=types.Content(
222+
role='model', parts=[types.Part.from_text(text='Test response')]
223+
),
224+
usage_metadata=types.GenerateContentResponseUsageMetadata(
225+
total_token_count=100,
226+
prompt_token_count=50,
227+
candidates_token_count=50,
228+
),
229+
)
230+
231+
# Mock the receive method to yield our test response
232+
async def mock_receive():
233+
yield test_llm_response
234+
235+
mock_llm_connection.receive = mock_receive
236+
237+
# Create agent and invocation context
238+
agent = Agent(name='test_agent', model='mock')
239+
invocation_context = await testing_utils.create_invocation_context(
240+
agent=agent, user_content='test message'
241+
)
242+
invocation_context.live_request_queue = LiveRequestQueue()
243+
244+
# Create flow and test data
245+
flow = TestBaseLlmFlow()
246+
event_id = 'test_event_123'
247+
llm_request = LlmRequest()
248+
249+
# Call _receive_from_model and consume the generator
250+
events = []
251+
async for event in flow._receive_from_model(
252+
mock_llm_connection, event_id, invocation_context, llm_request
253+
):
254+
events.append(event)
255+
break # Exit after first event to avoid infinite loop
256+
257+
# Verify trace_call_llm was called
258+
mock_trace_call_llm.assert_called()
259+
260+
# Verify the call arguments
261+
call_args = mock_trace_call_llm.call_args
262+
assert call_args[0][0] == invocation_context # First arg: invocation_context
263+
assert call_args[0][2] == llm_request # Third arg: llm_request
264+
assert call_args[0][3] == test_llm_response # Fourth arg: llm_response
265+
266+
# Second arg should be the event ID from the generated event
267+
assert len(call_args[0][1]) > 0 # Event ID should be non-empty string
268+
269+
270+
@pytest.mark.asyncio
271+
async def test_receive_from_model_telemetry_integration_with_live_queue(
272+
monkeypatch,
273+
):
274+
"""Test telemetry integration in live mode with actual live request queue."""
275+
# Mock the telemetry tracer to capture span creation
276+
mock_tracer = mock.MagicMock()
277+
mock_span = mock.MagicMock()
278+
mock_tracer.start_as_current_span.return_value.__enter__.return_value = (
279+
mock_span
280+
)
281+
282+
monkeypatch.setattr('google.adk.telemetry.tracer', mock_tracer)
283+
284+
# Create mock LLM connection
285+
mock_llm_connection = mock.AsyncMock()
286+
287+
# Create test responses - one with usage metadata, one without
288+
from google.adk.models.llm_response import LlmResponse
289+
290+
response_with_usage = LlmResponse(
291+
content=types.Content(
292+
role='model', parts=[types.Part.from_text(text='Response 1')]
293+
),
294+
usage_metadata=types.GenerateContentResponseUsageMetadata(
295+
total_token_count=100,
296+
prompt_token_count=50,
297+
candidates_token_count=50,
298+
),
299+
)
300+
301+
response_without_usage = LlmResponse(
302+
content=types.Content(
303+
role='model', parts=[types.Part.from_text(text='Response 2')]
304+
),
305+
usage_metadata=None,
306+
)
307+
308+
# Mock receive to yield both responses
309+
async def mock_receive():
310+
yield response_with_usage
311+
yield response_without_usage
312+
313+
mock_llm_connection.receive = mock_receive
314+
315+
# Create agent and invocation context with live request queue
316+
agent = Agent(name='test_agent', model='mock')
317+
invocation_context = await testing_utils.create_invocation_context(
318+
agent=agent, user_content='test message'
319+
)
320+
invocation_context.live_request_queue = LiveRequestQueue()
321+
322+
# Create flow
323+
flow = TestBaseLlmFlow()
324+
event_id = 'test_event_integration'
325+
llm_request = LlmRequest()
326+
327+
# Process events from _receive_from_model
328+
events = []
329+
async for event in flow._receive_from_model(
330+
mock_llm_connection, event_id, invocation_context, llm_request
331+
):
332+
events.append(event)
333+
if len(events) >= 2: # Stop after processing both responses
334+
break
335+
336+
# Verify new spans were created for live events with usage metadata
337+
assert mock_tracer.start_as_current_span.call_count >= 1
338+
339+
# Check that at least one span was created with live event naming
340+
span_calls = mock_tracer.start_as_current_span.call_args_list
341+
live_event_spans = [
342+
call for call in span_calls if 'llm_call_live_event' in call[0][0]
343+
]
344+
assert len(live_event_spans) >= 1, 'Should create live event spans'

0 commit comments

Comments
 (0)