Skip to content

Commit 54c06e0

Browse files
Pouyanpitgasser-nv
authored andcommitted
feat(bot-thinking): add reasoning trace extraction from llm calls (#1431)
* feat: add reasoning trace extraction from llm calls * test(reasoning-trace): add comprehensive tests for additional_kwargs extraction
1 parent a0af002 commit 54c06e0

File tree

3 files changed

+328
-12
lines changed

3 files changed

+328
-12
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ async def llm_call(
110110
generation_llm, prompt, all_callbacks
111111
)
112112

113+
_store_reasoning_traces(response)
113114
_store_tool_calls(response)
114115
_store_response_metadata(response)
115116
return _extract_content(response)
@@ -172,6 +173,18 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
172173
return dicts_to_messages(prompt)
173174

174175

176+
def _store_reasoning_traces(response) -> None:
177+
if hasattr(response, "additional_kwargs"):
178+
additional_kwargs = response.additional_kwargs
179+
if (
180+
isinstance(additional_kwargs, dict)
181+
and "reasoning_content" in additional_kwargs
182+
):
183+
reasoning_content = additional_kwargs["reasoning_content"]
184+
if reasoning_content:
185+
reasoning_trace_var.set(reasoning_content)
186+
187+
175188
def _store_tool_calls(response) -> None:
176189
"""Extract and store tool calls from response in context."""
177190
tool_calls = getattr(response, "tool_calls", None)
@@ -192,15 +205,6 @@ def _store_response_metadata(response) -> None:
192205
metadata[field_name] = getattr(response, field_name)
193206
llm_response_metadata_var.set(metadata)
194207

195-
if hasattr(response, "additional_kwargs"):
196-
additional_kwargs = response.additional_kwargs
197-
if (
198-
isinstance(additional_kwargs, dict)
199-
and "reasoning_content" in additional_kwargs
200-
):
201-
reasoning_content = additional_kwargs["reasoning_content"]
202-
if reasoning_content:
203-
reasoning_trace_var.set(reasoning_content)
204208
else:
205209
llm_response_metadata_var.set(None)
206210

@@ -704,6 +708,12 @@ def extract_tool_calls_from_events(events: list) -> Optional[list]:
704708
return None
705709

706710

711+
def extract_bot_thinking_from_events(events: list):
712+
for event in events:
713+
if event.get("type") == "BotThinking":
714+
return event.get("content")
715+
716+
707717
def get_and_clear_response_metadata_contextvar() -> Optional[dict]:
708718
"""Get the current response metadata and clear it from the context.
709719

nemoguardrails/rails/llm/llmrails.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343

4444
from nemoguardrails.actions.llm.generation import LLMGenerationActions
4545
from nemoguardrails.actions.llm.utils import (
46+
extract_bot_thinking_from_events,
4647
extract_tool_calls_from_events,
47-
get_and_clear_reasoning_trace_contextvar,
4848
get_and_clear_response_metadata_contextvar,
4949
get_colang_history,
5050
)
@@ -1037,7 +1037,7 @@ async def generate_async(
10371037
else:
10381038
res = GenerationResponse(response=[new_message])
10391039

1040-
if reasoning_trace := get_and_clear_reasoning_trace_contextvar():
1040+
if reasoning_trace := extract_bot_thinking_from_events(events):
10411041
if prompt:
10421042
# For prompt mode, response should be a string
10431043
if isinstance(res.response, str):
@@ -1182,7 +1182,7 @@ async def generate_async(
11821182
else:
11831183
# If a prompt is used, we only return the content of the message.
11841184

1185-
if reasoning_trace := get_and_clear_reasoning_trace_contextvar():
1185+
if reasoning_trace := extract_bot_thinking_from_events(events):
11861186
new_message["content"] = reasoning_trace + new_message["content"]
11871187

11881188
if prompt:
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from unittest.mock import AsyncMock
17+
18+
import pytest
19+
from langchain_core.messages import AIMessage
20+
21+
from nemoguardrails.actions.llm.utils import _store_reasoning_traces
22+
from nemoguardrails.context import reasoning_trace_var
23+
24+
25+
class TestStoreReasoningTracesUnit:
26+
def test_store_reasoning_traces_with_valid_reasoning_content(self):
27+
test_reasoning = "Step 1: Analyze the question\nStep 2: Formulate response"
28+
29+
response = AIMessage(
30+
content="The answer is 42",
31+
additional_kwargs={"reasoning_content": test_reasoning},
32+
)
33+
34+
_store_reasoning_traces(response)
35+
36+
stored_trace = reasoning_trace_var.get()
37+
assert stored_trace == test_reasoning
38+
39+
reasoning_trace_var.set(None)
40+
41+
def test_store_reasoning_traces_with_empty_reasoning_content(self):
42+
response = AIMessage(
43+
content="Response", additional_kwargs={"reasoning_content": ""}
44+
)
45+
46+
reasoning_trace_var.set(None)
47+
_store_reasoning_traces(response)
48+
49+
stored_trace = reasoning_trace_var.get()
50+
assert stored_trace is None
51+
52+
reasoning_trace_var.set(None)
53+
54+
def test_store_reasoning_traces_with_none_reasoning_content(self):
55+
response = AIMessage(
56+
content="Response", additional_kwargs={"reasoning_content": None}
57+
)
58+
59+
reasoning_trace_var.set(None)
60+
_store_reasoning_traces(response)
61+
62+
stored_trace = reasoning_trace_var.get()
63+
assert stored_trace is None
64+
65+
reasoning_trace_var.set(None)
66+
67+
def test_store_reasoning_traces_without_reasoning_content_key(self):
68+
response = AIMessage(
69+
content="Response", additional_kwargs={"other_key": "other_value"}
70+
)
71+
72+
reasoning_trace_var.set(None)
73+
_store_reasoning_traces(response)
74+
75+
stored_trace = reasoning_trace_var.get()
76+
assert stored_trace is None
77+
78+
reasoning_trace_var.set(None)
79+
80+
def test_store_reasoning_traces_with_empty_additional_kwargs(self):
81+
response = AIMessage(content="Response", additional_kwargs={})
82+
83+
reasoning_trace_var.set(None)
84+
_store_reasoning_traces(response)
85+
86+
stored_trace = reasoning_trace_var.get()
87+
assert stored_trace is None
88+
89+
reasoning_trace_var.set(None)
90+
91+
def test_store_reasoning_traces_without_additional_kwargs_attribute(self):
92+
class SimpleResponse:
93+
def __init__(self, content):
94+
self.content = content
95+
96+
response = SimpleResponse("Response")
97+
98+
reasoning_trace_var.set(None)
99+
_store_reasoning_traces(response)
100+
101+
stored_trace = reasoning_trace_var.get()
102+
assert stored_trace is None
103+
104+
reasoning_trace_var.set(None)
105+
106+
def test_store_reasoning_traces_with_non_dict_additional_kwargs(self):
107+
class ResponseWithInvalidKwargs:
108+
def __init__(self):
109+
self.content = "Response"
110+
self.additional_kwargs = "not_a_dict"
111+
112+
response = ResponseWithInvalidKwargs()
113+
114+
reasoning_trace_var.set(None)
115+
_store_reasoning_traces(response)
116+
117+
stored_trace = reasoning_trace_var.get()
118+
assert stored_trace is None
119+
120+
reasoning_trace_var.set(None)
121+
122+
def test_store_reasoning_traces_overwrites_previous_trace(self):
123+
initial_trace = "Initial reasoning"
124+
new_trace = "New reasoning"
125+
126+
reasoning_trace_var.set(initial_trace)
127+
128+
response = AIMessage(
129+
content="Response", additional_kwargs={"reasoning_content": new_trace}
130+
)
131+
132+
_store_reasoning_traces(response)
133+
134+
stored_trace = reasoning_trace_var.get()
135+
assert stored_trace == new_trace
136+
assert stored_trace != initial_trace
137+
138+
reasoning_trace_var.set(None)
139+
140+
def test_store_reasoning_traces_with_multiline_content(self):
141+
multiline_reasoning = """Thought process:
142+
1. First, understand the user's intent
143+
2. Second, check available data
144+
3. Third, formulate a response
145+
4. Finally, validate the response"""
146+
147+
response = AIMessage(
148+
content="Response",
149+
additional_kwargs={"reasoning_content": multiline_reasoning},
150+
)
151+
152+
_store_reasoning_traces(response)
153+
154+
stored_trace = reasoning_trace_var.get()
155+
assert stored_trace == multiline_reasoning
156+
157+
reasoning_trace_var.set(None)
158+
159+
def test_store_reasoning_traces_with_special_characters(self):
160+
special_reasoning = "Thinking: Let's analyze this <step> with \"quotes\" and 'apostrophes' & symbols!"
161+
162+
response = AIMessage(
163+
content="Response",
164+
additional_kwargs={"reasoning_content": special_reasoning},
165+
)
166+
167+
_store_reasoning_traces(response)
168+
169+
stored_trace = reasoning_trace_var.get()
170+
assert stored_trace == special_reasoning
171+
172+
reasoning_trace_var.set(None)
173+
174+
175+
class TestReasoningTraceIntegration:
176+
@pytest.mark.asyncio
177+
async def test_llm_call_extracts_reasoning_from_additional_kwargs(self):
178+
test_reasoning = "Let me think about this carefully..."
179+
180+
mock_llm = AsyncMock()
181+
mock_response = AIMessage(
182+
content="The answer is 42",
183+
additional_kwargs={"reasoning_content": test_reasoning},
184+
)
185+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
186+
187+
from nemoguardrails.actions.llm.utils import llm_call
188+
189+
reasoning_trace_var.set(None)
190+
result = await llm_call(mock_llm, "What is the answer?")
191+
192+
assert result == "The answer is 42"
193+
stored_trace = reasoning_trace_var.get()
194+
assert stored_trace == test_reasoning
195+
196+
reasoning_trace_var.set(None)
197+
198+
@pytest.mark.asyncio
199+
async def test_llm_call_handles_missing_reasoning_content(self):
200+
mock_llm = AsyncMock()
201+
mock_response = AIMessage(content="Regular response", additional_kwargs={})
202+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
203+
204+
from nemoguardrails.actions.llm.utils import llm_call
205+
206+
reasoning_trace_var.set(None)
207+
result = await llm_call(mock_llm, "Hello")
208+
209+
assert result == "Regular response"
210+
stored_trace = reasoning_trace_var.get()
211+
assert stored_trace is None
212+
213+
reasoning_trace_var.set(None)
214+
215+
@pytest.mark.asyncio
216+
async def test_llm_call_with_message_list_extracts_reasoning(self):
217+
test_reasoning = "Analyzing the conversation context..."
218+
219+
mock_llm = AsyncMock()
220+
mock_response = AIMessage(
221+
content="Here's my response",
222+
additional_kwargs={"reasoning_content": test_reasoning},
223+
)
224+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
225+
226+
from nemoguardrails.actions.llm.utils import llm_call
227+
228+
messages = [
229+
{"role": "user", "content": "Hello"},
230+
{"role": "assistant", "content": "Hi there"},
231+
]
232+
233+
reasoning_trace_var.set(None)
234+
result = await llm_call(mock_llm, messages)
235+
236+
assert result == "Here's my response"
237+
stored_trace = reasoning_trace_var.get()
238+
assert stored_trace == test_reasoning
239+
240+
reasoning_trace_var.set(None)
241+
242+
@pytest.mark.asyncio
243+
async def test_multiple_llm_calls_preserve_separate_reasoning_traces(self):
244+
first_reasoning = "First analysis"
245+
second_reasoning = "Second analysis"
246+
247+
mock_llm = AsyncMock()
248+
call_count = 0
249+
250+
async def mock_ainvoke(*args, **kwargs):
251+
nonlocal call_count
252+
call_count += 1
253+
if call_count == 1:
254+
return AIMessage(
255+
content="First response",
256+
additional_kwargs={"reasoning_content": first_reasoning},
257+
)
258+
else:
259+
return AIMessage(
260+
content="Second response",
261+
additional_kwargs={"reasoning_content": second_reasoning},
262+
)
263+
264+
mock_llm.ainvoke = mock_ainvoke
265+
266+
from nemoguardrails.actions.llm.utils import llm_call
267+
268+
reasoning_trace_var.set(None)
269+
result1 = await llm_call(mock_llm, "First query")
270+
trace1 = reasoning_trace_var.get()
271+
272+
reasoning_trace_var.set(None)
273+
result2 = await llm_call(mock_llm, "Second query")
274+
trace2 = reasoning_trace_var.get()
275+
276+
assert trace1 == first_reasoning
277+
assert trace2 == second_reasoning
278+
279+
reasoning_trace_var.set(None)
280+
281+
@pytest.mark.asyncio
282+
async def test_reasoning_content_with_other_additional_kwargs(self):
283+
test_reasoning = "Complex reasoning process"
284+
285+
mock_llm = AsyncMock()
286+
mock_response = AIMessage(
287+
content="Response",
288+
additional_kwargs={
289+
"reasoning_content": test_reasoning,
290+
"model": "test-model",
291+
"finish_reason": "stop",
292+
"other_metadata": {"key": "value"},
293+
},
294+
)
295+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
296+
297+
from nemoguardrails.actions.llm.utils import llm_call
298+
299+
reasoning_trace_var.set(None)
300+
result = await llm_call(mock_llm, "Query")
301+
302+
assert result == "Response"
303+
stored_trace = reasoning_trace_var.get()
304+
assert stored_trace == test_reasoning
305+
306+
reasoning_trace_var.set(None)

0 commit comments

Comments
 (0)