Skip to content

Commit 2c7b365

Browse files
authored
Merge pull request #489 from raptorsun/lcore-411
LCORE-411: add token usage metrics
2 parents 3c0c98b + 0c5e610 commit 2c7b365

File tree

6 files changed

+128
-4
lines changed

6 files changed

+128
-4
lines changed

src/app/endpoints/query.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from configuration import configuration
2525
from app.database import get_session
2626
import metrics
27+
from metrics.utils import update_llm_token_count_from_turn
2728
import constants
2829
from authorization.middleware import authorize
2930
from models.config import Action
@@ -220,6 +221,7 @@ async def query_endpoint_handler(
220221
query_request,
221222
token,
222223
mcp_headers=mcp_headers,
224+
provider_id=provider_id,
223225
)
224226
# Update metrics for the LLM call
225227
metrics.llm_calls_total.labels(provider_id, model_id).inc()
@@ -389,12 +391,14 @@ def is_input_shield(shield: Shield) -> bool:
389391
return _is_inout_shield(shield) or not is_output_shield(shield)
390392

391393

392-
async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches
394+
async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments
393395
client: AsyncLlamaStackClient,
394396
model_id: str,
395397
query_request: QueryRequest,
396398
token: str,
397399
mcp_headers: dict[str, dict[str, str]] | None = None,
400+
*,
401+
provider_id: str = "",
398402
) -> tuple[TurnSummary, str]:
399403
"""
400404
Retrieve response from LLMs and agents.
@@ -413,6 +417,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
413417
414418
Parameters:
415419
model_id (str): The identifier of the LLM model to use.
420+
provider_id (str): The identifier of the LLM provider to use.
416421
query_request (QueryRequest): The user's query and associated metadata.
417422
token (str): The authentication token for authorization.
418423
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.
@@ -512,6 +517,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
512517
tool_calls=[],
513518
)
514519

520+
# Update token count metrics for the LLM call
521+
model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id
522+
update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt)
523+
515524
# Check for validation errors in the response
516525
steps = response.steps or []
517526
for step in steps:

src/app/endpoints/streaming_query.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from client import AsyncLlamaStackClientHolder
2727
from configuration import configuration
2828
import metrics
29+
from metrics.utils import update_llm_token_count_from_turn
2930
from models.config import Action
3031
from models.requests import QueryRequest
3132
from models.database.conversations import UserConversation
@@ -623,6 +624,13 @@ async def response_generator(
623624
summary.llm_response = interleaved_content_as_str(
624625
p.turn.output_message.content
625626
)
627+
system_prompt = get_system_prompt(query_request, configuration)
628+
try:
629+
update_llm_token_count_from_turn(
630+
p.turn, model_id, provider_id, system_prompt
631+
)
632+
except Exception: # pylint: disable=broad-except
633+
logger.exception("Failed to update token usage metrics")
626634
elif p.event_type == "step_complete":
627635
if p.step_details.step_type == "tool_execution":
628636
summary.append_tool_calls_from_llama(p.step_details)

src/metrics/utils.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
"""Utility functions for metrics handling."""
22

3-
from configuration import configuration
3+
from typing import cast
4+
5+
from llama_stack.models.llama.datatypes import RawMessage
6+
from llama_stack.models.llama.llama3.chat_format import ChatFormat
7+
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
8+
from llama_stack_client.types.agents.turn import Turn
9+
10+
import metrics
411
from client import AsyncLlamaStackClientHolder
12+
from configuration import configuration
513
from log import get_logger
6-
import metrics
714
from utils.common import run_once_async
815

916
logger = get_logger(__name__)
@@ -48,3 +55,23 @@ async def setup_model_metrics() -> None:
4855
default_model_value,
4956
)
5057
logger.info("Model metrics setup complete")
58+
59+
60+
def update_llm_token_count_from_turn(
61+
turn: Turn, model: str, provider: str, system_prompt: str = ""
62+
) -> None:
63+
"""Update the LLM calls metrics from a turn."""
64+
tokenizer = Tokenizer.get_instance()
65+
formatter = ChatFormat(tokenizer)
66+
67+
raw_message = cast(RawMessage, turn.output_message)
68+
encoded_output = formatter.encode_dialog_prompt([raw_message])
69+
token_count = len(encoded_output.tokens) if encoded_output.tokens else 0
70+
metrics.llm_token_received_total.labels(provider, model).inc(token_count)
71+
72+
input_messages = [RawMessage(role="user", content=system_prompt)] + cast(
73+
list[RawMessage], turn.input_messages
74+
)
75+
encoded_input = formatter.encode_dialog_prompt(input_messages)
76+
token_count = len(encoded_input.tokens) if encoded_input.tokens else 0
77+
metrics.llm_token_sent_total.labels(provider, model).inc(token_count)

tests/unit/app/endpoints/test_query.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ def dummy_request() -> Request:
4747
return req
4848

4949

50+
def mock_metrics(mocker):
51+
"""Helper function to mock metrics operations for query endpoints."""
52+
mocker.patch(
53+
"app.endpoints.query.update_llm_token_count_from_turn",
54+
return_value=None,
55+
)
56+
57+
5058
def mock_database_operations(mocker):
5159
"""Helper function to mock database operations for query endpoints."""
5260
mocker.patch(
@@ -444,6 +452,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker
444452
"app.endpoints.query.get_agent",
445453
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
446454
)
455+
mock_metrics(mocker)
447456

448457
query_request = QueryRequest(query="What is OpenStack?")
449458
model_id = "fake_model_id"
@@ -475,6 +484,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo
475484
"app.endpoints.query.get_agent",
476485
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
477486
)
487+
mock_metrics(mocker)
478488

479489
query_request = QueryRequest(query="What is OpenStack?")
480490
model_id = "fake_model_id"
@@ -507,6 +517,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker
507517
"app.endpoints.query.get_agent",
508518
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
509519
)
520+
mock_metrics(mocker)
510521

511522
query_request = QueryRequest(query="What is OpenStack?")
512523
model_id = "fake_model_id"
@@ -545,6 +556,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke
545556
"app.endpoints.query.get_agent",
546557
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
547558
)
559+
mock_metrics(mocker)
548560

549561
query_request = QueryRequest(query="What is OpenStack?")
550562
model_id = "fake_model_id"
@@ -594,6 +606,7 @@ def __repr__(self):
594606
"app.endpoints.query.get_agent",
595607
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
596608
)
609+
mock_metrics(mocker)
597610

598611
query_request = QueryRequest(query="What is OpenStack?")
599612
model_id = "fake_model_id"
@@ -646,6 +659,7 @@ def __repr__(self):
646659
"app.endpoints.query.get_agent",
647660
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
648661
)
662+
mock_metrics(mocker)
649663

650664
query_request = QueryRequest(query="What is OpenStack?")
651665
model_id = "fake_model_id"
@@ -700,6 +714,7 @@ def __repr__(self):
700714
"app.endpoints.query.get_agent",
701715
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
702716
)
717+
mock_metrics(mocker)
703718

704719
query_request = QueryRequest(query="What is OpenStack?")
705720
model_id = "fake_model_id"
@@ -756,6 +771,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker
756771
"app.endpoints.query.get_agent",
757772
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
758773
)
774+
mock_metrics(mocker)
759775

760776
query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
761777
model_id = "fake_model_id"
@@ -810,6 +826,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke
810826
"app.endpoints.query.get_agent",
811827
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
812828
)
829+
mock_metrics(mocker)
813830

814831
query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
815832
model_id = "fake_model_id"
@@ -865,6 +882,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker):
865882
"app.endpoints.query.get_agent",
866883
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
867884
)
885+
mock_metrics(mocker)
868886

869887
query_request = QueryRequest(query="What is OpenStack?")
870888
model_id = "fake_model_id"
@@ -934,6 +952,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token(
934952
"app.endpoints.query.get_agent",
935953
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
936954
)
955+
mock_metrics(mocker)
937956

938957
query_request = QueryRequest(query="What is OpenStack?")
939958
model_id = "fake_model_id"
@@ -995,6 +1014,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(
9951014
"app.endpoints.query.get_agent",
9961015
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
9971016
)
1017+
mock_metrics(mocker)
9981018

9991019
query_request = QueryRequest(query="What is OpenStack?")
10001020
model_id = "fake_model_id"
@@ -1091,6 +1111,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker):
10911111
"app.endpoints.query.get_agent",
10921112
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
10931113
)
1114+
mock_metrics(mocker)
10941115

10951116
query_request = QueryRequest(query="What is OpenStack?")
10961117

@@ -1327,6 +1348,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag(
13271348
"app.endpoints.query.get_agent",
13281349
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
13291350
)
1351+
mock_metrics(mocker)
13301352

13311353
query_request = QueryRequest(query="What is OpenStack?", no_tools=True)
13321354
model_id = "fake_model_id"
@@ -1377,6 +1399,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
13771399
"app.endpoints.query.get_agent",
13781400
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
13791401
)
1402+
mock_metrics(mocker)
13801403

13811404
query_request = QueryRequest(query="What is OpenStack?", no_tools=False)
13821405
model_id = "fake_model_id"

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def mock_database_operations(mocker):
5959
mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details")
6060

6161

62+
def mock_metrics(mocker):
63+
"""Helper function to mock metrics operations for streaming query endpoints."""
64+
mocker.patch(
65+
"app.endpoints.streaming_query.update_llm_token_count_from_turn",
66+
return_value=None,
67+
)
68+
69+
6270
SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [
6371
"""knowledge_search tool found 2 chunks:
6472
BEGIN of knowledge_search tool results.
@@ -347,12 +355,14 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
347355
@pytest.mark.asyncio
348356
async def test_streaming_query_endpoint_handler(mocker):
349357
"""Test the streaming query endpoint handler with transcript storage disabled."""
358+
mock_metrics(mocker)
350359
await _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
351360

352361

353362
@pytest.mark.asyncio
354363
async def test_streaming_query_endpoint_handler_store_transcript(mocker):
355364
"""Test the streaming query endpoint handler with transcript storage enabled."""
365+
mock_metrics(mocker)
356366
await _test_streaming_query_endpoint_handler(mocker, store_transcript=True)
357367

358368

tests/unit/metrics/test_utis.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Unit tests for functions defined in metrics/utils.py"""
22

3-
from metrics.utils import setup_model_metrics
3+
from metrics.utils import setup_model_metrics, update_llm_token_count_from_turn
44

55

66
async def test_setup_model_metrics(mocker):
@@ -74,3 +74,50 @@ async def test_setup_model_metrics(mocker):
7474
],
7575
any_order=False, # Order matters here
7676
)
77+
78+
79+
def test_update_llm_token_count_from_turn(mocker):
80+
"""Test the update_llm_token_count_from_turn function."""
81+
mocker.patch("metrics.utils.Tokenizer.get_instance")
82+
mock_formatter_class = mocker.patch("metrics.utils.ChatFormat")
83+
mock_formatter = mocker.Mock()
84+
mock_formatter_class.return_value = mock_formatter
85+
86+
mock_received_metric = mocker.patch(
87+
"metrics.utils.metrics.llm_token_received_total"
88+
)
89+
mock_sent_metric = mocker.patch("metrics.utils.metrics.llm_token_sent_total")
90+
91+
mock_turn = mocker.Mock()
92+
# turn.output_message should satisfy the type RawMessage
93+
mock_turn.output_message = {"role": "assistant", "content": "test response"}
94+
# turn.input_messages should satisfy the type list[RawMessage]
95+
mock_turn.input_messages = [{"role": "user", "content": "test input"}]
96+
97+
# Mock the encoded results with tokens
98+
mock_encoded_output = mocker.Mock()
99+
mock_encoded_output.tokens = ["token1", "token2", "token3"] # 3 tokens
100+
mock_encoded_input = mocker.Mock()
101+
mock_encoded_input.tokens = ["token1", "token2"] # 2 tokens
102+
mock_formatter.encode_dialog_prompt.side_effect = [
103+
mock_encoded_output,
104+
mock_encoded_input,
105+
]
106+
107+
test_model = "test_model"
108+
test_provider = "test_provider"
109+
test_system_prompt = "test system prompt"
110+
111+
update_llm_token_count_from_turn(
112+
mock_turn, test_model, test_provider, test_system_prompt
113+
)
114+
115+
# Verify that llm_token_received_total.labels() was called with correct metrics
116+
mock_received_metric.labels.assert_called_once_with(test_provider, test_model)
117+
mock_received_metric.labels().inc.assert_called_once_with(
118+
3
119+
) # token count from output
120+
121+
# Verify that llm_token_sent_total.labels() was called with correct metrics
122+
mock_sent_metric.labels.assert_called_once_with(test_provider, test_model)
123+
mock_sent_metric.labels().inc.assert_called_once_with(2) # token count from input

0 commit comments

Comments
 (0)