Skip to content
Open
6 changes: 6 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def get_author_for_event(llm_response):
author=get_author_for_event(llm_response),
)

trace_call_llm(
invocation_context,
model_response_event.id,
llm_request,
llm_response,
)
async with Aclosing(
self._postprocess_live(
invocation_context,
Expand Down
110 changes: 104 additions & 6 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
content = message.server_content.model_turn
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
content=content,
interrupted=message.server_content.interrupted,
usage_metadata=self._fix_usage_metadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this cause duplicated usage metadata since we're adding it to all the LlmResponses? i.e. in the case that a message contains both content.parts and message.server_content.input_transcription

getattr(message, 'usage_metadata', None)
),
)
if content.parts[0].text:
text += content.parts[0].text
Expand All @@ -167,15 +171,21 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
and message.server_content.input_transcription.text
):
llm_response = LlmResponse(
input_transcription=message.server_content.input_transcription,
input_transcription=message.server_content.input_transcription,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
yield llm_response
if (
message.server_content.output_transcription
and message.server_content.output_transcription.text
):
llm_response = LlmResponse(
output_transcription=message.server_content.output_transcription
output_transcription=message.server_content.output_transcription,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
yield llm_response
if message.server_content.turn_complete:
Expand All @@ -185,6 +195,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
yield LlmResponse(
turn_complete=True,
interrupted=message.server_content.interrupted,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
break
# in case of empty content or parts, we sill surface it
Expand All @@ -194,7 +207,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(interrupted=message.server_content.interrupted)
yield LlmResponse(
interrupted=message.server_content.interrupted,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
Expand All @@ -203,15 +221,95 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
yield LlmResponse(
content=types.Content(role='model', parts=parts),
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
if message.session_resumption_update:
logger.info('Redeived session reassumption message: %s', message)
yield (
LlmResponse(
live_session_resumption_update=message.session_resumption_update
live_session_resumption_update=message.session_resumption_update,
usage_metadata=self._fix_usage_metadata(
getattr(message, 'usage_metadata', None)
),
)
)

def _fix_usage_metadata(self, usage_metadata):
"""
Fix missing candidates_token_count in Gemini Live API responses.

The Gemini Live API inconsistently returns usage metadata. While it typically
provides total_token_count and prompt_token_count, it often leaves
candidates_token_count as None. This creates incomplete telemetry data which
affects billing reporting and token usage monitoring.

This method calculates the missing candidates_token_count using the formula:
candidates_token_count = total_token_count - prompt_token_count

Args:
usage_metadata: The usage metadata from the Live API response, which may
have missing candidates_token_count.

Returns:
Fixed usage metadata with calculated candidates_token_count, or the
original metadata if no fix is needed/possible.
"""
if not usage_metadata:
return usage_metadata

# Safely get token counts using getattr with defaults
total_tokens = getattr(usage_metadata, 'total_token_count', None)
prompt_tokens = getattr(usage_metadata, 'prompt_token_count', None)
candidates_tokens = getattr(usage_metadata, 'candidates_token_count', None)

# Only fix if we have total and prompt but missing candidates
if (
total_tokens is not None
and prompt_tokens is not None
and candidates_tokens is None
):
# Calculate candidates tokens as: total - prompt
calculated_candidates = total_tokens - prompt_tokens

if calculated_candidates > 0:
# Create a new usage metadata object with the calculated value
from google.genai import types

return types.GenerateContentResponseUsageMetadata(
total_token_count=total_tokens,
prompt_token_count=prompt_tokens,
candidates_token_count=calculated_candidates,
# Copy other fields if they exist
cache_tokens_details=getattr(
usage_metadata, 'cache_tokens_details', None
),
cached_content_token_count=getattr(
usage_metadata, 'cached_content_token_count', None
),
candidates_tokens_details=getattr(
usage_metadata, 'candidates_tokens_details', None
),
prompt_tokens_details=getattr(
usage_metadata, 'prompt_tokens_details', None
),
thoughts_token_count=getattr(
usage_metadata, 'thoughts_token_count', None
),
tool_use_prompt_token_count=getattr(
usage_metadata, 'tool_use_prompt_token_count', None
),
tool_use_prompt_tokens_details=getattr(
usage_metadata, 'tool_use_prompt_tokens_details', None
),
traffic_type=getattr(usage_metadata, 'traffic_type', None),
)

return usage_metadata

async def close(self):
"""Closes the llm server connection."""

Expand Down
37 changes: 30 additions & 7 deletions src/google/adk/telemetry/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,31 @@ def trace_call_llm(
llm_request: The LLM request object.
llm_response: The LLM response object.
"""
span = trace.get_current_span()
# Special standard Open Telemetry GenaI attributes that indicate
# that this is a span related to a Generative AI system.
# For live events with usage metadata, create a new span for each event
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm what's the reasoning for creating a new span for every live event? just wondering if it would it cause too much overhead by generating too many spans

# For regular events or live events without usage data, use the current span
if (
hasattr(invocation_context, 'live_request_queue')
and invocation_context.live_request_queue
and llm_response.usage_metadata is not None
):
# Live mode with usage data: create new span for each event
span_name = f'llm_call_live_event [{event_id[:8]}]'
with tracer.start_as_current_span(span_name) as span:
_set_llm_span_attributes(
span, invocation_context, event_id, llm_request, llm_response
)
else:
# Regular mode or live mode without usage data: use current span
span = trace.get_current_span()
_set_llm_span_attributes(
span, invocation_context, event_id, llm_request, llm_response
)


def _set_llm_span_attributes(
span, invocation_context, event_id, llm_request, llm_response
):
"""Set LLM span attributes."""
span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
span.set_attribute('gen_ai.request.model', llm_request.model)
span.set_attribute(
Expand Down Expand Up @@ -271,10 +293,11 @@ def trace_call_llm(
)

if llm_response.usage_metadata is not None:
span.set_attribute(
'gen_ai.usage.input_tokens',
llm_response.usage_metadata.prompt_token_count,
)
if llm_response.usage_metadata.prompt_token_count is not None:
span.set_attribute(
'gen_ai.usage.input_tokens',
llm_response.usage_metadata.prompt_token_count,
)
if llm_response.usage_metadata.candidates_token_count is not None:
span.set_attribute(
'gen_ai.usage.output_tokens',
Expand Down
143 changes: 143 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,146 @@ async def test_send_to_model_with_text_content(mock_llm_connection):
# Verify send_content was called instead of send_realtime
mock_llm_connection.send_content.assert_called_once_with(content)
mock_llm_connection.send_realtime.assert_not_called()


@pytest.mark.asyncio
async def test_receive_from_model_calls_telemetry_trace(monkeypatch):
"""Test that _receive_from_model calls trace_call_llm for telemetry."""
# Mock the trace_call_llm function
mock_trace_call_llm = mock.AsyncMock()
monkeypatch.setattr(
'google.adk.flows.llm_flows.base_llm_flow.trace_call_llm',
mock_trace_call_llm,
)

# Create mock LLM connection that yields responses
mock_llm_connection = mock.AsyncMock()

# Create test LLM response with usage metadata
from google.adk.models.llm_response import LlmResponse

test_llm_response = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Test response')]
),
usage_metadata=types.GenerateContentResponseUsageMetadata(
total_token_count=100,
prompt_token_count=50,
candidates_token_count=50,
),
)

# Mock the receive method to yield our test response
async def mock_receive():
yield test_llm_response

mock_llm_connection.receive = mock_receive

# Create agent and invocation context
agent = Agent(name='test_agent', model='mock')
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)
invocation_context.live_request_queue = LiveRequestQueue()

# Create flow and test data
flow = TestBaseLlmFlow()
event_id = 'test_event_123'
llm_request = LlmRequest()

# Call _receive_from_model and consume the generator
events = []
async for event in flow._receive_from_model(
mock_llm_connection, event_id, invocation_context, llm_request
):
events.append(event)
break # Exit after first event to avoid infinite loop

# Verify trace_call_llm was called
mock_trace_call_llm.assert_called()

# Verify the call arguments
call_args = mock_trace_call_llm.call_args
assert call_args[0][0] == invocation_context # First arg: invocation_context
assert call_args[0][2] == llm_request # Third arg: llm_request
assert call_args[0][3] == test_llm_response # Fourth arg: llm_response

# Second arg should be the event ID from the generated event
assert len(call_args[0][1]) > 0 # Event ID should be non-empty string


@pytest.mark.asyncio
async def test_receive_from_model_telemetry_integration_with_live_queue(
monkeypatch,
):
"""Test telemetry integration in live mode with actual live request queue."""
# Mock the telemetry tracer to capture span creation
mock_tracer = mock.MagicMock()
mock_span = mock.MagicMock()
mock_tracer.start_as_current_span.return_value.__enter__.return_value = (
mock_span
)

monkeypatch.setattr('google.adk.telemetry.tracer', mock_tracer)

# Create mock LLM connection
mock_llm_connection = mock.AsyncMock()

# Create test responses - one with usage metadata, one without
from google.adk.models.llm_response import LlmResponse

response_with_usage = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Response 1')]
),
usage_metadata=types.GenerateContentResponseUsageMetadata(
total_token_count=100,
prompt_token_count=50,
candidates_token_count=50,
),
)

response_without_usage = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Response 2')]
),
usage_metadata=None,
)

# Mock receive to yield both responses
async def mock_receive():
yield response_with_usage
yield response_without_usage

mock_llm_connection.receive = mock_receive

# Create agent and invocation context with live request queue
agent = Agent(name='test_agent', model='mock')
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)
invocation_context.live_request_queue = LiveRequestQueue()

# Create flow
flow = TestBaseLlmFlow()
event_id = 'test_event_integration'
llm_request = LlmRequest()

# Process events from _receive_from_model
events = []
async for event in flow._receive_from_model(
mock_llm_connection, event_id, invocation_context, llm_request
):
events.append(event)
if len(events) >= 2: # Stop after processing both responses
break

# Verify new spans were created for live events with usage metadata
assert mock_tracer.start_as_current_span.call_count >= 1

# Check that at least one span was created with live event naming
span_calls = mock_tracer.start_as_current_span.call_args_list
live_event_spans = [
call for call in span_calls if 'llm_call_live_event' in call[0][0]
]
assert len(live_event_spans) >= 1, 'Should create live event spans'
Loading