Skip to content

Commit 80e3466

Browse files
committed
test(reasoning-trace): add comprehensive tests for additional_kwargs extraction
1 parent 8271959 commit 80e3466

File tree

1 file changed

+365
-0
lines changed

1 file changed

+365
-0
lines changed
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from unittest.mock import AsyncMock, patch
17+
18+
import pytest
19+
from langchain_core.messages import AIMessage
20+
21+
from nemoguardrails import RailsConfig
22+
from nemoguardrails.actions.llm.utils import _store_reasoning_traces
23+
from nemoguardrails.context import reasoning_trace_var
24+
from tests.utils import TestChat
25+
26+
27+
class TestStoreReasoningTracesUnit:
28+
def test_store_reasoning_traces_with_valid_reasoning_content(self):
29+
test_reasoning = "Step 1: Analyze the question\nStep 2: Formulate response"
30+
31+
response = AIMessage(
32+
content="The answer is 42",
33+
additional_kwargs={"reasoning_content": test_reasoning},
34+
)
35+
36+
_store_reasoning_traces(response)
37+
38+
stored_trace = reasoning_trace_var.get()
39+
assert stored_trace == test_reasoning
40+
41+
reasoning_trace_var.set(None)
42+
43+
def test_store_reasoning_traces_with_empty_reasoning_content(self):
44+
response = AIMessage(
45+
content="Response", additional_kwargs={"reasoning_content": ""}
46+
)
47+
48+
reasoning_trace_var.set(None)
49+
_store_reasoning_traces(response)
50+
51+
stored_trace = reasoning_trace_var.get()
52+
assert stored_trace is None
53+
54+
reasoning_trace_var.set(None)
55+
56+
def test_store_reasoning_traces_with_none_reasoning_content(self):
57+
response = AIMessage(
58+
content="Response", additional_kwargs={"reasoning_content": None}
59+
)
60+
61+
reasoning_trace_var.set(None)
62+
_store_reasoning_traces(response)
63+
64+
stored_trace = reasoning_trace_var.get()
65+
assert stored_trace is None
66+
67+
reasoning_trace_var.set(None)
68+
69+
def test_store_reasoning_traces_without_reasoning_content_key(self):
70+
response = AIMessage(
71+
content="Response", additional_kwargs={"other_key": "other_value"}
72+
)
73+
74+
reasoning_trace_var.set(None)
75+
_store_reasoning_traces(response)
76+
77+
stored_trace = reasoning_trace_var.get()
78+
assert stored_trace is None
79+
80+
reasoning_trace_var.set(None)
81+
82+
def test_store_reasoning_traces_with_empty_additional_kwargs(self):
83+
response = AIMessage(content="Response", additional_kwargs={})
84+
85+
reasoning_trace_var.set(None)
86+
_store_reasoning_traces(response)
87+
88+
stored_trace = reasoning_trace_var.get()
89+
assert stored_trace is None
90+
91+
reasoning_trace_var.set(None)
92+
93+
def test_store_reasoning_traces_without_additional_kwargs_attribute(self):
94+
class SimpleResponse:
95+
def __init__(self, content):
96+
self.content = content
97+
98+
response = SimpleResponse("Response")
99+
100+
reasoning_trace_var.set(None)
101+
_store_reasoning_traces(response)
102+
103+
stored_trace = reasoning_trace_var.get()
104+
assert stored_trace is None
105+
106+
reasoning_trace_var.set(None)
107+
108+
def test_store_reasoning_traces_with_non_dict_additional_kwargs(self):
109+
class ResponseWithInvalidKwargs:
110+
def __init__(self):
111+
self.content = "Response"
112+
self.additional_kwargs = "not_a_dict"
113+
114+
response = ResponseWithInvalidKwargs()
115+
116+
reasoning_trace_var.set(None)
117+
_store_reasoning_traces(response)
118+
119+
stored_trace = reasoning_trace_var.get()
120+
assert stored_trace is None
121+
122+
reasoning_trace_var.set(None)
123+
124+
def test_store_reasoning_traces_overwrites_previous_trace(self):
125+
initial_trace = "Initial reasoning"
126+
new_trace = "New reasoning"
127+
128+
reasoning_trace_var.set(initial_trace)
129+
130+
response = AIMessage(
131+
content="Response", additional_kwargs={"reasoning_content": new_trace}
132+
)
133+
134+
_store_reasoning_traces(response)
135+
136+
stored_trace = reasoning_trace_var.get()
137+
assert stored_trace == new_trace
138+
assert stored_trace != initial_trace
139+
140+
reasoning_trace_var.set(None)
141+
142+
def test_store_reasoning_traces_with_multiline_content(self):
143+
multiline_reasoning = """Thought process:
144+
1. First, understand the user's intent
145+
2. Second, check available data
146+
3. Third, formulate a response
147+
4. Finally, validate the response"""
148+
149+
response = AIMessage(
150+
content="Response",
151+
additional_kwargs={"reasoning_content": multiline_reasoning},
152+
)
153+
154+
_store_reasoning_traces(response)
155+
156+
stored_trace = reasoning_trace_var.get()
157+
assert stored_trace == multiline_reasoning
158+
159+
reasoning_trace_var.set(None)
160+
161+
def test_store_reasoning_traces_with_special_characters(self):
162+
special_reasoning = "Thinking: Let's analyze this <step> with \"quotes\" and 'apostrophes' & symbols!"
163+
164+
response = AIMessage(
165+
content="Response",
166+
additional_kwargs={"reasoning_content": special_reasoning},
167+
)
168+
169+
_store_reasoning_traces(response)
170+
171+
stored_trace = reasoning_trace_var.get()
172+
assert stored_trace == special_reasoning
173+
174+
reasoning_trace_var.set(None)
175+
176+
177+
class TestReasoningTraceIntegration:
178+
@pytest.mark.asyncio
179+
async def test_llm_call_extracts_reasoning_from_additional_kwargs(self):
180+
test_reasoning = "Let me think about this carefully..."
181+
182+
mock_llm = AsyncMock()
183+
mock_response = AIMessage(
184+
content="The answer is 42",
185+
additional_kwargs={"reasoning_content": test_reasoning},
186+
)
187+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
188+
189+
from nemoguardrails.actions.llm.utils import llm_call
190+
191+
reasoning_trace_var.set(None)
192+
result = await llm_call(mock_llm, "What is the answer?")
193+
194+
assert result == "The answer is 42"
195+
stored_trace = reasoning_trace_var.get()
196+
assert stored_trace == test_reasoning
197+
198+
reasoning_trace_var.set(None)
199+
200+
@pytest.mark.asyncio
201+
async def test_llm_call_handles_missing_reasoning_content(self):
202+
mock_llm = AsyncMock()
203+
mock_response = AIMessage(content="Regular response", additional_kwargs={})
204+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
205+
206+
from nemoguardrails.actions.llm.utils import llm_call
207+
208+
reasoning_trace_var.set(None)
209+
result = await llm_call(mock_llm, "Hello")
210+
211+
assert result == "Regular response"
212+
stored_trace = reasoning_trace_var.get()
213+
assert stored_trace is None
214+
215+
reasoning_trace_var.set(None)
216+
217+
@pytest.mark.asyncio
218+
async def test_llm_call_with_message_list_extracts_reasoning(self):
219+
test_reasoning = "Analyzing the conversation context..."
220+
221+
mock_llm = AsyncMock()
222+
mock_response = AIMessage(
223+
content="Here's my response",
224+
additional_kwargs={"reasoning_content": test_reasoning},
225+
)
226+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
227+
228+
from nemoguardrails.actions.llm.utils import llm_call
229+
230+
messages = [
231+
{"role": "user", "content": "Hello"},
232+
{"role": "assistant", "content": "Hi there"},
233+
]
234+
235+
reasoning_trace_var.set(None)
236+
result = await llm_call(mock_llm, messages)
237+
238+
assert result == "Here's my response"
239+
stored_trace = reasoning_trace_var.get()
240+
assert stored_trace == test_reasoning
241+
242+
reasoning_trace_var.set(None)
243+
244+
@pytest.mark.asyncio
245+
async def test_bot_thinking_event_creation_from_real_llm_response(self):
246+
test_reasoning = "Carefully considering the implications..."
247+
248+
with patch("langchain_core.language_models.llms.LLM.ainvoke") as mock_ainvoke:
249+
mock_ainvoke.return_value = AIMessage(
250+
content="The answer is 42",
251+
additional_kwargs={"reasoning_content": test_reasoning},
252+
)
253+
254+
config = RailsConfig.from_content(
255+
config={"models": [], "passthrough": True}
256+
)
257+
chat = TestChat(config, llm_completions=["The answer is 42"])
258+
259+
events = await chat.app.generate_events_async(
260+
[{"type": "UserMessage", "text": "What is the answer?"}]
261+
)
262+
263+
bot_thinking_events = [e for e in events if e["type"] == "BotThinking"]
264+
assert len(bot_thinking_events) == 1
265+
assert bot_thinking_events[0]["content"] == test_reasoning
266+
267+
@pytest.mark.asyncio
268+
async def test_bot_thinking_preserved_across_rails_with_real_response(self):
269+
test_reasoning = "Step-by-step analysis of the request"
270+
271+
with patch("langchain_core.language_models.llms.LLM.ainvoke") as mock_ainvoke:
272+
mock_ainvoke.return_value = AIMessage(
273+
content="Response text",
274+
additional_kwargs={"reasoning_content": test_reasoning},
275+
)
276+
277+
config = RailsConfig.from_content(
278+
colang_content="""
279+
define flow capture_reasoning
280+
$captured = $bot_thinking
281+
""",
282+
yaml_content="""
283+
models: []
284+
passthrough: true
285+
rails:
286+
output:
287+
flows:
288+
- capture_reasoning
289+
""",
290+
)
291+
292+
chat = TestChat(config, llm_completions=["Response text"])
293+
294+
result = await chat.app.generate_async(
295+
messages=[{"role": "user", "content": "test"}],
296+
options={"output_vars": True},
297+
)
298+
299+
assert result.output_data["captured"] == test_reasoning
300+
301+
@pytest.mark.asyncio
302+
async def test_multiple_llm_calls_preserve_separate_reasoning_traces(self):
303+
first_reasoning = "First analysis"
304+
second_reasoning = "Second analysis"
305+
306+
mock_llm = AsyncMock()
307+
call_count = 0
308+
309+
async def mock_ainvoke(*args, **kwargs):
310+
nonlocal call_count
311+
call_count += 1
312+
if call_count == 1:
313+
return AIMessage(
314+
content="First response",
315+
additional_kwargs={"reasoning_content": first_reasoning},
316+
)
317+
else:
318+
return AIMessage(
319+
content="Second response",
320+
additional_kwargs={"reasoning_content": second_reasoning},
321+
)
322+
323+
mock_llm.ainvoke = mock_ainvoke
324+
325+
from nemoguardrails.actions.llm.utils import llm_call
326+
327+
reasoning_trace_var.set(None)
328+
result1 = await llm_call(mock_llm, "First query")
329+
trace1 = reasoning_trace_var.get()
330+
331+
reasoning_trace_var.set(None)
332+
result2 = await llm_call(mock_llm, "Second query")
333+
trace2 = reasoning_trace_var.get()
334+
335+
assert trace1 == first_reasoning
336+
assert trace2 == second_reasoning
337+
338+
reasoning_trace_var.set(None)
339+
340+
@pytest.mark.asyncio
341+
async def test_reasoning_content_with_other_additional_kwargs(self):
342+
test_reasoning = "Complex reasoning process"
343+
344+
mock_llm = AsyncMock()
345+
mock_response = AIMessage(
346+
content="Response",
347+
additional_kwargs={
348+
"reasoning_content": test_reasoning,
349+
"model": "test-model",
350+
"finish_reason": "stop",
351+
"other_metadata": {"key": "value"},
352+
},
353+
)
354+
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
355+
356+
from nemoguardrails.actions.llm.utils import llm_call
357+
358+
reasoning_trace_var.set(None)
359+
result = await llm_call(mock_llm, "Query")
360+
361+
assert result == "Response"
362+
stored_trace = reasoning_trace_var.get()
363+
assert stored_trace == test_reasoning
364+
365+
reasoning_trace_var.set(None)

0 commit comments

Comments
 (0)