From 195886d4483dc9b91e2093c322ef105da0e0f642 Mon Sep 17 00:00:00 2001 From: Emil Milow Date: Mon, 4 Aug 2025 15:10:35 +0200 Subject: [PATCH] feat: Adds telemetry and fixes usage metadata for live mode --- .../adk/flows/llm_flows/base_llm_flow.py | 6 + .../adk/models/gemini_llm_connection.py | 109 +++++++++- src/google/adk/telemetry.py | 46 +++-- .../llm_flows/test_base_llm_flow_realtime.py | 143 ++++++++++++++ .../models/test_gemini_llm_connection.py | 137 +++++++++++++ tests/unittests/test_telemetry.py | 186 ++++++++++++++++++ 6 files changed, 610 insertions(+), 17 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b388667100..b7e45fe879 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -251,6 +251,12 @@ def get_author_for_event(llm_response): invocation_id=invocation_context.invocation_id, author=get_author_for_event(llm_response), ) + trace_call_llm( + invocation_context, + model_response_event.id, + llm_request, + llm_response, + ) async for event in self._postprocess_live( invocation_context, llm_request, diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 3a902c562d..a16a93e2f7 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -148,7 +148,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: content = message.server_content.model_turn if content and content.parts: llm_response = LlmResponse( - content=content, interrupted=message.server_content.interrupted + content=content, + interrupted=message.server_content.interrupted, + usage_metadata=self._fix_usage_metadata( + getattr(message, 'usage_metadata', None) + ), ) if content.parts[0].text: text += content.parts[0].text @@ -169,7 +173,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) ] llm_response = LlmResponse( - content=types.Content(role='user', parts=parts) + content=types.Content(role='user', parts=parts), + usage_metadata=self._fix_usage_metadata( + getattr(message, 'usage_metadata', None) + ), ) yield llm_response if ( @@ -190,7 +197,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) ] llm_response = LlmResponse( - content=types.Content(role='model', parts=parts), partial=True + content=types.Content(role='model', parts=parts), + partial=True, + usage_metadata=self._fix_usage_metadata( + getattr(message, 'usage_metadata', None) + ), ) yield llm_response @@ -199,7 +210,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: yield self.__build_full_text_response(text) text = '' yield LlmResponse( - turn_complete=True, interrupted=message.server_content.interrupted + turn_complete=True, + interrupted=message.server_content.interrupted, + usage_metadata=self._fix_usage_metadata( + getattr(message, 'usage_metadata', None) + ), ) break # in case of empty content or parts, we sill surface it @@ -209,7 +224,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: if message.server_content.interrupted and text: yield self.__build_full_text_response(text) text = '' - yield LlmResponse(interrupted=message.server_content.interrupted) + yield LlmResponse( + interrupted=message.server_content.interrupted, + usage_metadata=self._fix_usage_metadata( + getattr(message, 'usage_metadata', None) + ), + ) if message.tool_call: if text: yield self.__build_full_text_response(text) @@ -218,7 +238,84 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: types.Part(function_call=function_call) for function_call in message.tool_call.function_calls ] - yield LlmResponse(content=types.Content(role='model', parts=parts)) + yield LlmResponse( + content=types.Content(role='model', parts=parts), + usage_metadata=self._fix_usage_metadata( + getattr(message, 'usage_metadata', None) + ), + ) + + def _fix_usage_metadata(self, usage_metadata): + """ + Fix missing candidates_token_count in Gemini Live API responses. + + The Gemini Live API inconsistently returns usage metadata. While it typically + provides total_token_count and prompt_token_count, it often leaves + candidates_token_count as None. This creates incomplete telemetry data which + affects billing reporting and token usage monitoring. + + This method calculates the missing candidates_token_count using the formula: + candidates_token_count = total_token_count - prompt_token_count + + Args: + usage_metadata: The usage metadata from the Live API response, which may + have missing candidates_token_count. + + Returns: + Fixed usage metadata with calculated candidates_token_count, or the + original metadata if no fix is needed/possible. + """ + if not usage_metadata: + return usage_metadata + + # Safely get token counts using getattr with defaults + total_tokens = getattr(usage_metadata, 'total_token_count', None) + prompt_tokens = getattr(usage_metadata, 'prompt_token_count', None) + candidates_tokens = getattr(usage_metadata, 'candidates_token_count', None) + + # Only fix if we have total and prompt but missing candidates + if ( + total_tokens is not None + and prompt_tokens is not None + and candidates_tokens is None + ): + # Calculate candidates tokens as: total - prompt + calculated_candidates = total_tokens - prompt_tokens + + if calculated_candidates > 0: + # Create a new usage metadata object with the calculated value + from google.genai import types + + return types.GenerateContentResponseUsageMetadata( + total_token_count=total_tokens, + prompt_token_count=prompt_tokens, + candidates_token_count=calculated_candidates, + # Copy other fields if they exist + cache_tokens_details=getattr( + usage_metadata, 'cache_tokens_details', None + ), + cached_content_token_count=getattr( + usage_metadata, 'cached_content_token_count', None + ), + candidates_tokens_details=getattr( + usage_metadata, 'candidates_tokens_details', None + ), + prompt_tokens_details=getattr( + usage_metadata, 'prompt_tokens_details', None + ), + thoughts_token_count=getattr( + usage_metadata, 'thoughts_token_count', None + ), + tool_use_prompt_token_count=getattr( + usage_metadata, 'tool_use_prompt_token_count', None + ), + tool_use_prompt_tokens_details=getattr( + usage_metadata, 'tool_use_prompt_tokens_details', None + ), + traffic_type=getattr(usage_metadata, 'traffic_type', None), + ) + + return usage_metadata async def close(self): """Closes the llm server connection.""" diff --git a/src/google/adk/telemetry.py b/src/google/adk/telemetry.py index 10ac583990..9bf41e1c1c 100644 --- a/src/google/adk/telemetry.py +++ b/src/google/adk/telemetry.py @@ -166,9 +166,31 @@ def trace_call_llm( llm_request: The LLM request object. llm_response: The LLM response object. """ - span = trace.get_current_span() - # Special standard Open Telemetry GenaI attributes that indicate - # that this is a span related to a Generative AI system. + # For live events with usage metadata, create a new span for each event + # For regular events or live events without usage data, use the current span + if ( + hasattr(invocation_context, 'live_request_queue') + and invocation_context.live_request_queue + and llm_response.usage_metadata is not None + ): + # Live mode with usage data: create new span for each event + span_name = f'llm_call_live_event [{event_id[:8]}]' + with tracer.start_as_current_span(span_name) as span: + _set_llm_span_attributes( + span, invocation_context, event_id, llm_request, llm_response + ) + else: + # Regular mode or live mode without usage data: use current span + span = trace.get_current_span() + _set_llm_span_attributes( + span, invocation_context, event_id, llm_request, llm_response + ) + + +def _set_llm_span_attributes( + span, invocation_context, event_id, llm_request, llm_response +): + """Set LLM span attributes.""" span.set_attribute('gen_ai.system', 'gcp.vertex.agent') span.set_attribute('gen_ai.request.model', llm_request.model) span.set_attribute( @@ -196,14 +218,16 @@ def trace_call_llm( ) if llm_response.usage_metadata is not None: - span.set_attribute( - 'gen_ai.usage.input_tokens', - llm_response.usage_metadata.prompt_token_count, - ) - span.set_attribute( - 'gen_ai.usage.output_tokens', - llm_response.usage_metadata.candidates_token_count, - ) + if llm_response.usage_metadata.prompt_token_count is not None: + span.set_attribute( + 'gen_ai.usage.input_tokens', + llm_response.usage_metadata.prompt_token_count, + ) + if llm_response.usage_metadata.candidates_token_count is not None: + span.set_attribute( + 'gen_ai.usage.output_tokens', + llm_response.usage_metadata.candidates_token_count, + ) def trace_send_data( diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py b/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py index d6033450c2..a7857d413c 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py @@ -199,3 +199,146 @@ async def test_send_to_model_with_text_content(mock_llm_connection): # Verify send_content was called instead of send_realtime mock_llm_connection.send_content.assert_called_once_with(content) mock_llm_connection.send_realtime.assert_not_called() + + +@pytest.mark.asyncio +async def test_receive_from_model_calls_telemetry_trace(monkeypatch): + """Test that _receive_from_model calls trace_call_llm for telemetry.""" + # Mock the trace_call_llm function + mock_trace_call_llm = mock.AsyncMock() + monkeypatch.setattr( + 'google.adk.flows.llm_flows.base_llm_flow.trace_call_llm', + mock_trace_call_llm, + ) + + # Create mock LLM connection that yields responses + mock_llm_connection = mock.AsyncMock() + + # Create test LLM response with usage metadata + from google.adk.models.llm_response import LlmResponse + + test_llm_response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Test response')] + ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=100, + prompt_token_count=50, + candidates_token_count=50, + ), + ) + + # Mock the receive method to yield our test response + async def mock_receive(): + yield test_llm_response + + mock_llm_connection.receive = mock_receive + + # Create agent and invocation context + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and test data + flow = TestBaseLlmFlow() + event_id = 'test_event_123' + llm_request = LlmRequest() + + # Call _receive_from_model and consume the generator + events = [] + async for event in flow._receive_from_model( + mock_llm_connection, event_id, invocation_context, llm_request + ): + events.append(event) + break # Exit after first event to avoid infinite loop + + # Verify trace_call_llm was called + mock_trace_call_llm.assert_called() + + # Verify the call arguments + call_args = mock_trace_call_llm.call_args + assert call_args[0][0] == invocation_context # First arg: invocation_context + assert call_args[0][2] == llm_request # Third arg: llm_request + assert call_args[0][3] == test_llm_response # Fourth arg: llm_response + + # Second arg should be the event ID from the generated event + assert len(call_args[0][1]) > 0 # Event ID should be non-empty string + + +@pytest.mark.asyncio +async def test_receive_from_model_telemetry_integration_with_live_queue( + monkeypatch, +): + """Test telemetry integration in live mode with actual live request queue.""" + # Mock the telemetry tracer to capture span creation + mock_tracer = mock.MagicMock() + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = ( + mock_span + ) + + monkeypatch.setattr('google.adk.telemetry.tracer', mock_tracer) + + # Create mock LLM connection + mock_llm_connection = mock.AsyncMock() + + # Create test responses - one with usage metadata, one without + from google.adk.models.llm_response import LlmResponse + + response_with_usage = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Response 1')] + ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=100, + prompt_token_count=50, + candidates_token_count=50, + ), + ) + + response_without_usage = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Response 2')] + ), + usage_metadata=None, + ) + + # Mock receive to yield both responses + async def mock_receive(): + yield response_with_usage + yield response_without_usage + + mock_llm_connection.receive = mock_receive + + # Create agent and invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow + flow = TestBaseLlmFlow() + event_id = 'test_event_integration' + llm_request = LlmRequest() + + # Process events from _receive_from_model + events = [] + async for event in flow._receive_from_model( + mock_llm_connection, event_id, invocation_context, llm_request + ): + events.append(event) + if len(events) >= 2: # Stop after processing both responses + break + + # Verify new spans were created for live events with usage metadata + assert mock_tracer.start_as_current_span.call_count >= 1 + + # Check that at least one span was created with live event naming + span_calls = mock_tracer.start_as_current_span.call_args_list + live_event_spans = [ + call for call in span_calls if 'llm_call_live_event' in call[0][0] + ] + assert len(live_event_spans) >= 1, 'Should create live event spans' diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 2327115033..224336fd70 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -109,3 +109,140 @@ async def test_close(gemini_connection, mock_gemini_session): await gemini_connection.close() mock_gemini_session.close.assert_called_once() + + +def test_fix_usage_metadata_with_missing_candidates_token_count( + gemini_connection, +): + """Test _fix_usage_metadata with missing candidates_token_count.""" + # Create usage metadata with missing candidates_token_count + usage_metadata = types.GenerateContentResponseUsageMetadata( + total_token_count=100, + prompt_token_count=60, + candidates_token_count=None, + ) + + result = gemini_connection._fix_usage_metadata(usage_metadata) + + # Should calculate candidates_token_count as total - prompt = 100 - 60 = 40 + assert result.total_token_count == 100 + assert result.prompt_token_count == 60 + assert result.candidates_token_count == 40 + + +def test_fix_usage_metadata_with_none_input(gemini_connection): + """Test _fix_usage_metadata with None input.""" + result = gemini_connection._fix_usage_metadata(None) + assert result is None + + +def test_fix_usage_metadata_with_existing_candidates_token_count( + gemini_connection, +): + """Test _fix_usage_metadata when candidates_token_count already exists.""" + usage_metadata = types.GenerateContentResponseUsageMetadata( + total_token_count=100, + prompt_token_count=60, + candidates_token_count=40, + ) + + result = gemini_connection._fix_usage_metadata(usage_metadata) + + # Should return the original metadata unchanged + assert result is usage_metadata + assert result.total_token_count == 100 + assert result.prompt_token_count == 60 + assert result.candidates_token_count == 40 + + +def test_fix_usage_metadata_with_missing_total_token_count(gemini_connection): + """Test _fix_usage_metadata with missing total_token_count.""" + usage_metadata = types.GenerateContentResponseUsageMetadata( + total_token_count=None, + prompt_token_count=60, + candidates_token_count=None, + ) + + result = gemini_connection._fix_usage_metadata(usage_metadata) + + # Should return the original metadata unchanged + assert result is usage_metadata + assert result.total_token_count is None + assert result.prompt_token_count == 60 + assert result.candidates_token_count is None + + +def test_fix_usage_metadata_with_missing_prompt_token_count(gemini_connection): + """Test _fix_usage_metadata with missing prompt_token_count.""" + usage_metadata = types.GenerateContentResponseUsageMetadata( + total_token_count=100, + prompt_token_count=None, + candidates_token_count=None, + ) + + result = gemini_connection._fix_usage_metadata(usage_metadata) + + # Should return the original metadata unchanged + assert result is usage_metadata + assert result.total_token_count == 100 + assert result.prompt_token_count is None + assert result.candidates_token_count is None + + +def test_fix_usage_metadata_with_zero_calculated_candidates(gemini_connection): + """Test _fix_usage_metadata when calculated candidates would be zero.""" + usage_metadata = types.GenerateContentResponseUsageMetadata( + total_token_count=60, + prompt_token_count=60, + candidates_token_count=None, + ) + + result = gemini_connection._fix_usage_metadata(usage_metadata) + + # Should not create new metadata when calculated candidates is 0 or negative + assert result is usage_metadata + assert result.total_token_count == 60 + assert result.prompt_token_count == 60 + assert result.candidates_token_count is None + + +def test_fix_usage_metadata_with_negative_calculated_candidates( + gemini_connection, +): + """Test _fix_usage_metadata when calculated candidates would be negative.""" + usage_metadata = types.GenerateContentResponseUsageMetadata( + total_token_count=50, + prompt_token_count=60, + candidates_token_count=None, + ) + + result = gemini_connection._fix_usage_metadata(usage_metadata) + + # Should not create new metadata when calculated candidates is negative + assert result is usage_metadata + assert result.total_token_count == 50 + assert result.prompt_token_count == 60 + assert result.candidates_token_count is None + + +def test_fix_usage_metadata_preserves_other_fields(gemini_connection): + """Test _fix_usage_metadata preserves other optional fields.""" + from google.genai import types + + usage_metadata = types.GenerateContentResponseUsageMetadata( + total_token_count=100, + prompt_token_count=60, + candidates_token_count=None, + cached_content_token_count=10, + thoughts_token_count=5, + ) + + result = gemini_connection._fix_usage_metadata(usage_metadata) + + # Should create new metadata with calculated candidates and preserved fields + assert result is not usage_metadata # New object created + assert result.total_token_count == 100 + assert result.prompt_token_count == 60 + assert result.candidates_token_count == 40 + assert result.cached_content_token_count == 10 + assert result.thoughts_token_count == 5 diff --git a/tests/unittests/test_telemetry.py b/tests/unittests/test_telemetry.py index 8a3964b219..867e787bd0 100644 --- a/tests/unittests/test_telemetry.py +++ b/tests/unittests/test_telemetry.py @@ -326,3 +326,189 @@ def test_trace_merged_tool_calls_sets_correct_attributes( expected_calls, any_order=True ) mock_event_fixture.model_dumps_json.assert_called_once_with(exclude_none=True) + + +@pytest.mark.asyncio +async def test_trace_call_llm_live_mode_with_usage_metadata_creates_new_span( + monkeypatch, +): + """Test that live mode with usage metadata creates new spans.""" + mock_tracer = mock.MagicMock() + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = ( + mock_span + ) + + monkeypatch.setattr('google.adk.telemetry.tracer', mock_tracer) + + agent = LlmAgent(name='test_agent') + invocation_context = await _create_invocation_context(agent) + + # Set up live request queue to indicate live mode + from google.adk.agents.live_request_queue import LiveRequestQueue + + invocation_context.live_request_queue = LiveRequestQueue() + + llm_request = LlmRequest( + config=types.GenerateContentConfig(system_instruction=''), + ) + llm_response = LlmResponse( + turn_complete=True, + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=100, + prompt_token_count=50, + candidates_token_count=50, + ), + ) + + event_id = 'test_event_id_123' + trace_call_llm(invocation_context, event_id, llm_request, llm_response) + + # Should create new span with live event naming + expected_span_name = f'llm_call_live_event [{event_id[:8]}]' + mock_tracer.start_as_current_span.assert_called_once_with(expected_span_name) + + # Verify span attributes were set on the new span + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.request.model', llm_request.model), + mock.call('gen_ai.usage.input_tokens', 50), + mock.call('gen_ai.usage.output_tokens', 50), + ] + mock_span.set_attribute.assert_has_calls(expected_calls, any_order=True) + + +@pytest.mark.asyncio +async def test_trace_call_llm_live_mode_without_usage_metadata_uses_current_span( + monkeypatch, mock_span_fixture +): + """Test that live mode without usage metadata uses current span.""" + mock_tracer = mock.MagicMock() + monkeypatch.setattr('google.adk.telemetry.tracer', mock_tracer) + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + agent = LlmAgent(name='test_agent') + invocation_context = await _create_invocation_context(agent) + + # Set up live request queue to indicate live mode + from google.adk.agents.live_request_queue import LiveRequestQueue + + invocation_context.live_request_queue = LiveRequestQueue() + + llm_request = LlmRequest( + config=types.GenerateContentConfig(system_instruction=''), + ) + llm_response = LlmResponse( + turn_complete=True, + usage_metadata=None, # No usage metadata + ) + + event_id = 'test_event_id_456' + trace_call_llm(invocation_context, event_id, llm_request, llm_response) + + # Should NOT create new span + mock_tracer.start_as_current_span.assert_not_called() + + # Should use current span + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.request.model', llm_request.model), + ] + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + + +@pytest.mark.asyncio +async def test_trace_call_llm_regular_mode_uses_current_span( + monkeypatch, mock_span_fixture +): + """Test that regular mode (no live_request_queue) uses current span.""" + mock_tracer = mock.MagicMock() + monkeypatch.setattr('google.adk.telemetry.tracer', mock_tracer) + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + agent = LlmAgent(name='test_agent') + invocation_context = await _create_invocation_context(agent) + + # No live_request_queue - regular mode + + llm_request = LlmRequest( + config=types.GenerateContentConfig(system_instruction=''), + ) + llm_response = LlmResponse( + turn_complete=True, + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=100, + prompt_token_count=50, + candidates_token_count=50, + ), + ) + + event_id = 'test_event_id_789' + trace_call_llm(invocation_context, event_id, llm_request, llm_response) + + # Should NOT create new span even with usage metadata + mock_tracer.start_as_current_span.assert_not_called() + + # Should use current span + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.request.model', llm_request.model), + mock.call('gen_ai.usage.input_tokens', 50), + mock.call('gen_ai.usage.output_tokens', 50), + ] + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + + +@pytest.mark.asyncio +async def test_trace_call_llm_live_mode_span_name_formatting(monkeypatch): + """Test that live mode span names are formatted correctly.""" + mock_tracer = mock.MagicMock() + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = ( + mock_span + ) + + monkeypatch.setattr('google.adk.telemetry.tracer', mock_tracer) + + agent = LlmAgent(name='test_agent') + invocation_context = await _create_invocation_context(agent) + + # Set up live request queue + from google.adk.agents.live_request_queue import LiveRequestQueue + + invocation_context.live_request_queue = LiveRequestQueue() + + llm_request = LlmRequest( + config=types.GenerateContentConfig(system_instruction=''), + ) + llm_response = LlmResponse( + turn_complete=True, + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=100, + ), + ) + + # Test with different event ID lengths + test_cases = [ + ('12345678', 'llm_call_live_event [12345678]'), + ( + '123456789012', + 'llm_call_live_event [12345678]', + ), # Should truncate to 8 chars + ('1234', 'llm_call_live_event [1234]'), # Shorter than 8 chars + ] + + for event_id, expected_span_name in test_cases: + mock_tracer.reset_mock() + trace_call_llm(invocation_context, event_id, llm_request, llm_response) + mock_tracer.start_as_current_span.assert_called_once_with( + expected_span_name + )