Skip to content

Commit ef97795

Browse files
authored
fix: enable token usage tracking for streaming LLM calls (#1264)
* fix: enable token usage tracking for streaming LLM calls Set stream_usage=True by default in model kwargs to ensure token usage metadata is included during streaming operations. This allows the LoggingCallbackHandler to properly track and report token statistics for streaming LLM calls. Without this parameter, streaming responses don't include usage_metadata, causing token usage tracking to fail during streaming operations and affecting accurate usage reporting and monitoring. Fixes token usage tracking when using streaming with LangChain chat models. * feat(llmrails): enable stream_usage only for supported engines * test: add tests for stream_usage and token tracking * feat: add constant for stream usage supported llm engines * test: add integration tests for streaming Add integration tests to verify token usage tracking with streaming and non streaming LLMs, including multiple calls and unsupported providers. Update FakeLLM and TestChat to simulate stream_usage and token usage behavior for supported engines. * always pass stream_usage when streaming * chore(deps): bump langchain-openai to >=0.1.0
1 parent 0d3ddfc commit ef97795

File tree

7 files changed

+699
-8
lines changed

7 files changed

+699
-8
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,11 @@ def _prepare_model_kwargs(self, model_config):
367367
if api_key:
368368
kwargs["api_key"] = api_key
369369

370+
# enable streaming token usage when streaming is enabled
371+
# providers that don't support this parameter will simply ignore it
372+
if self.config.streaming:
373+
kwargs["stream_usage"] = True
374+
370375
return kwargs
371376

372377
def _configure_main_llm_streaming(

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ opentelemetry-sdk = { version = ">=1.27.0,<2.0.0", optional = true }
7474
aiofiles = { version = ">=24.1.0", optional = true }
7575

7676
# openai
77-
langchain-openai = { version = ">=0.0.5", optional = true }
77+
langchain-openai = { version = ">=0.1.0", optional = true }
7878

7979
# eval
8080
tqdm = { version = ">=4.65,<5.0", optional = true }

tests/test_callbacks.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from uuid import uuid4
17+
18+
import pytest
19+
from langchain.schema import Generation, LLMResult
20+
from langchain_core.messages import AIMessage
21+
from langchain_core.outputs import ChatGeneration
22+
23+
from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
24+
from nemoguardrails.logging.callbacks import LoggingCallbackHandler
25+
from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo
26+
from nemoguardrails.logging.stats import LLMStats
27+
28+
29+
@pytest.mark.asyncio
30+
async def test_token_usage_tracking_with_usage_metadata():
31+
"""Test that token usage is tracked when usage_metadata is available (stream_usage=True scenario)."""
32+
33+
llm_call_info = LLMCallInfo()
34+
llm_call_info_var.set(llm_call_info)
35+
36+
llm_stats = LLMStats()
37+
llm_stats_var.set(llm_stats)
38+
39+
explain_info = ExplainInfo()
40+
explain_info_var.set(explain_info)
41+
42+
handler = LoggingCallbackHandler()
43+
44+
# simulate the LLM response with usage metadata (as would happen with stream_usage=True)
45+
ai_message = AIMessage(
46+
content="Hello! How can I help you?",
47+
usage_metadata={"input_tokens": 10, "output_tokens": 6, "total_tokens": 16},
48+
)
49+
50+
chat_generation = ChatGeneration(message=ai_message)
51+
llm_result = LLMResult(generations=[[chat_generation]])
52+
53+
# call the on_llm_end method
54+
await handler.on_llm_end(llm_result, run_id=uuid4())
55+
56+
assert llm_call_info.total_tokens == 16
57+
assert llm_call_info.prompt_tokens == 10
58+
assert llm_call_info.completion_tokens == 6
59+
60+
assert llm_stats.get_stat("total_tokens") == 16
61+
assert llm_stats.get_stat("total_prompt_tokens") == 10
62+
assert llm_stats.get_stat("total_completion_tokens") == 6
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_token_usage_tracking_with_llm_output_fallback():
67+
"""Test token usage tracking with legacy llm_output format."""
68+
69+
llm_call_info = LLMCallInfo()
70+
llm_call_info_var.set(llm_call_info)
71+
72+
llm_stats = LLMStats()
73+
llm_stats_var.set(llm_stats)
74+
75+
explain_info = ExplainInfo()
76+
explain_info_var.set(explain_info)
77+
78+
handler = LoggingCallbackHandler()
79+
80+
# simulate LLM response with token usage in llm_output (fallback scenario)
81+
generation = Generation(text="Fallback response")
82+
llm_result = LLMResult(
83+
generations=[[generation]],
84+
llm_output={
85+
"token_usage": {
86+
"total_tokens": 20,
87+
"prompt_tokens": 12,
88+
"completion_tokens": 8,
89+
}
90+
},
91+
)
92+
93+
await handler.on_llm_end(llm_result, run_id=uuid4())
94+
95+
assert llm_call_info.total_tokens == 20
96+
assert llm_call_info.prompt_tokens == 12
97+
assert llm_call_info.completion_tokens == 8
98+
99+
assert llm_stats.get_stat("total_tokens") == 20
100+
assert llm_stats.get_stat("total_prompt_tokens") == 12
101+
assert llm_stats.get_stat("total_completion_tokens") == 8
102+
103+
104+
@pytest.mark.asyncio
105+
async def test_no_token_usage_tracking_without_metadata():
106+
"""Test that no token usage is tracked when metadata is not available."""
107+
108+
llm_call_info = LLMCallInfo()
109+
llm_call_info_var.set(llm_call_info)
110+
111+
llm_stats = LLMStats()
112+
llm_stats_var.set(llm_stats)
113+
114+
explain_info = ExplainInfo()
115+
explain_info_var.set(explain_info)
116+
117+
handler = LoggingCallbackHandler()
118+
119+
# simulate LLM response without usage metadata (stream_usage=False scenario)
120+
ai_message = AIMessage(content="Hello! How can I help you?")
121+
chat_generation = ChatGeneration(message=ai_message)
122+
llm_result = LLMResult(generations=[[chat_generation]])
123+
124+
await handler.on_llm_end(llm_result, run_id=uuid4())
125+
126+
assert llm_call_info.total_tokens is None or llm_call_info.total_tokens == 0
127+
assert llm_call_info.prompt_tokens is None or llm_call_info.prompt_tokens == 0
128+
assert (
129+
llm_call_info.completion_tokens is None or llm_call_info.completion_tokens == 0
130+
)
131+
132+
133+
@pytest.mark.asyncio
134+
async def test_multiple_generations_token_accumulation():
135+
"""Test that token usage accumulates across multiple generations."""
136+
137+
llm_call_info = LLMCallInfo()
138+
llm_call_info_var.set(llm_call_info)
139+
140+
llm_stats = LLMStats()
141+
llm_stats_var.set(llm_stats)
142+
143+
explain_info = ExplainInfo()
144+
explain_info_var.set(explain_info)
145+
146+
handler = LoggingCallbackHandler()
147+
148+
ai_message1 = AIMessage(
149+
content="First response",
150+
usage_metadata={"input_tokens": 5, "output_tokens": 3, "total_tokens": 8},
151+
)
152+
153+
ai_message2 = AIMessage(
154+
content="Second response",
155+
usage_metadata={"input_tokens": 7, "output_tokens": 4, "total_tokens": 11},
156+
)
157+
158+
chat_generation1 = ChatGeneration(message=ai_message1)
159+
chat_generation2 = ChatGeneration(message=ai_message2)
160+
llm_result = LLMResult(generations=[[chat_generation1, chat_generation2]])
161+
162+
await handler.on_llm_end(llm_result, run_id=uuid4())
163+
164+
assert llm_call_info.total_tokens == 19 # 8 + 11
165+
assert llm_call_info.prompt_tokens == 12 # 5 + 7
166+
assert llm_call_info.completion_tokens == 7 # 3 + 4
167+
168+
assert llm_stats.get_stat("total_tokens") == 19
169+
assert llm_stats.get_stat("total_prompt_tokens") == 12
170+
assert llm_stats.get_stat("total_completion_tokens") == 7

tests/test_llmrails.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,3 +1068,90 @@ def __init__(self):
10681068

10691069
assert kwargs["api_key"] == "direct-key"
10701070
assert kwargs["temperature"] == 0.3
1071+
1072+
1073+
@pytest.mark.asyncio
1074+
@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
1075+
async def test_stream_usage_enabled_for_streaming_supported_providers(
1076+
mock_init_llm_model,
1077+
):
1078+
"""Test that stream_usage=True is set when streaming is enabled for supported providers."""
1079+
config = RailsConfig.from_content(
1080+
config={
1081+
"models": [
1082+
{
1083+
"type": "main",
1084+
"engine": "openai",
1085+
"model": "gpt-4",
1086+
}
1087+
],
1088+
"streaming": True,
1089+
}
1090+
)
1091+
1092+
LLMRails(config=config)
1093+
1094+
mock_init_llm_model.assert_called_once()
1095+
call_args = mock_init_llm_model.call_args
1096+
kwargs = call_args.kwargs.get("kwargs", {})
1097+
1098+
assert kwargs.get("stream_usage") is True
1099+
1100+
1101+
@pytest.mark.asyncio
1102+
@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
1103+
async def test_stream_usage_not_set_without_streaming(mock_init_llm_model):
1104+
"""Test that stream_usage is not set when streaming is disabled."""
1105+
config = RailsConfig.from_content(
1106+
config={
1107+
"models": [
1108+
{
1109+
"type": "main",
1110+
"engine": "openai",
1111+
"model": "gpt-4",
1112+
}
1113+
],
1114+
"streaming": False,
1115+
}
1116+
)
1117+
1118+
LLMRails(config=config)
1119+
1120+
mock_init_llm_model.assert_called_once()
1121+
call_args = mock_init_llm_model.call_args
1122+
kwargs = call_args.kwargs.get("kwargs", {})
1123+
1124+
assert "stream_usage" not in kwargs
1125+
1126+
1127+
@pytest.mark.asyncio
1128+
@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
1129+
async def test_stream_usage_enabled_for_all_providers_when_streaming(
1130+
mock_init_llm_model,
1131+
):
1132+
"""Test that stream_usage is passed to ALL providers when streaming is enabled.
1133+
1134+
With the new design, stream_usage=True is passed to ALL providers when
1135+
streaming is enabled. Providers that don't support it will simply ignore it.
1136+
"""
1137+
config = RailsConfig.from_content(
1138+
config={
1139+
"models": [
1140+
{
1141+
"type": "main",
1142+
"engine": "unsupported",
1143+
"model": "whatever",
1144+
}
1145+
],
1146+
"streaming": True,
1147+
}
1148+
)
1149+
1150+
LLMRails(config=config)
1151+
1152+
mock_init_llm_model.assert_called_once()
1153+
call_args = mock_init_llm_model.call_args
1154+
kwargs = call_args.kwargs.get("kwargs", {})
1155+
1156+
# stream_usage should be set for all providers when streaming is enabled
1157+
assert kwargs.get("stream_usage") is True

0 commit comments

Comments
 (0)