Skip to content

Commit d49231e

Browse files
Pouyanpitgasser-nv
authored andcommitted
fix(llm)!: extract reasoning traces to separate field instead of prepending (#1468)
* feat(llm)!: add reasoning_content field to GenerationResponse BREAKING CHANGE: Reasoning traces are no longer prepended directly to response content. When using GenerationOptions, reasoning is now available in the separate reasoning_content field. Without GenerationOptions, reasoning is wrapped in <thinking> tags. Refactor reasoning trace handling to expose reasoning content as a separate field in GenerationResponse instead of prepending it to response content. - Add reasoning_content field to GenerationResponse for structured access - Populate reasoning_content in generate_async when using GenerationOptions - Wrap reasoning traces in <thinking> tags for dict/string responses - Update test to reflect new behavior (no longer prepending reasoning) - Add comprehensive tests for new field and behavior changes This improves API usability by separating reasoning content from the actual response, allowing clients to handle thinking traces independently. Resolves the TODO from PR #1432 about not prepending reasoning traces to final generation content.
1 parent 26c8264 commit d49231e

File tree

7 files changed

+189
-73
lines changed

7 files changed

+189
-73
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,7 @@ async def generate_async(
10841084

10851085
tool_calls = extract_tool_calls_from_events(new_events)
10861086
llm_metadata = get_and_clear_response_metadata_contextvar()
1087-
1087+
reasoning_content = extract_bot_thinking_from_events(new_events)
10881088
# If we have generation options, we prepare a GenerationResponse instance.
10891089
if gen_options:
10901090
# If a prompt was used, we only need to return the content of the message.
@@ -1093,17 +1093,8 @@ async def generate_async(
10931093
else:
10941094
res = GenerationResponse(response=[new_message])
10951095

1096-
if reasoning_trace := extract_bot_thinking_from_events(events):
1097-
if prompt:
1098-
# For prompt mode, response should be a string
1099-
if isinstance(res.response, str):
1100-
res.response = reasoning_trace + res.response
1101-
else:
1102-
# For message mode, response should be a list
1103-
if isinstance(res.response, list) and len(res.response) > 0:
1104-
res.response[0]["content"] = (
1105-
reasoning_trace + res.response[0]["content"]
1106-
)
1096+
if reasoning_content:
1097+
res.reasoning_content = reasoning_content
11071098

11081099
if tool_calls:
11091100
res.tool_calls = tool_calls
@@ -1238,8 +1229,9 @@ async def generate_async(
12381229
else:
12391230
# If a prompt is used, we only return the content of the message.
12401231

1241-
if reasoning_trace := extract_bot_thinking_from_events(events):
1242-
new_message["content"] = reasoning_trace + new_message["content"]
1232+
if reasoning_content:
1233+
thinking_trace = f"<think>{reasoning_content}</think>\n"
1234+
new_message["content"] = thinking_trace + new_message["content"]
12431235

12441236
if prompt:
12451237
return new_message["content"]

nemoguardrails/rails/llm/options.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,10 @@ class GenerationResponse(BaseModel):
428428
default=None,
429429
description="Tool calls extracted from the LLM response, if any.",
430430
)
431+
reasoning_content: Optional[str] = Field(
432+
default=None,
433+
description="The reasoning content extracted from the LLM response, if any.",
434+
)
431435
llm_metadata: Optional[dict] = Field(
432436
default=None,
433437
description="Metadata from the LLM response (additional_kwargs, response_metadata, usage_metadata, etc.)",

tests/conftest.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,10 @@
1717

1818
import pytest
1919

20-
from nemoguardrails.context import reasoning_trace_var
20+
REASONING_TRACE_MOCK_PATH = (
21+
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
22+
)
2123

2224

2325
def pytest_configure(config):
2426
patch("prompt_toolkit.PromptSession", autospec=True).start()
25-
26-
27-
@pytest.fixture(autouse=True)
28-
def reset_reasoning_trace():
29-
"""Reset the reasoning_trace_var before each test.
30-
31-
This fixture runs automatically for every test (autouse=True) to ensure
32-
a clean state for the reasoning trace context variable.
33-
34-
current Issues with ContextVar approach, not only specific to this case:
35-
global State: ContextVar creates global state that's hard to track and manage
36-
implicit Flow: The reasoning trace flows through the system in a non-obvious way
37-
testing Complexity: It causes test isolation problems that we are trying to avoid using this fixture
38-
"""
39-
# reset the variable before the test
40-
reasoning_trace_var.set(None)
41-
yield
42-
# reset the variable after the test as well (in case the test fails)
43-
reasoning_trace_var.set(None)

tests/rails/llm/test_options.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,58 @@ def test_generation_response_model_validation():
193193
assert isinstance(response.tool_calls, list)
194194
assert len(response.tool_calls) == 2
195195
assert response.llm_output["token_usage"]["total_tokens"] == 50
196+
197+
198+
def test_generation_response_with_reasoning_content():
199+
test_reasoning = "Step 1: Analyze\nStep 2: Respond"
200+
201+
response = GenerationResponse(
202+
response="Final answer", reasoning_content=test_reasoning
203+
)
204+
205+
assert response.reasoning_content == test_reasoning
206+
assert response.response == "Final answer"
207+
208+
209+
def test_generation_response_reasoning_content_defaults_to_none():
210+
response = GenerationResponse(response="Answer")
211+
212+
assert response.reasoning_content is None
213+
214+
215+
def test_generation_response_reasoning_content_can_be_empty_string():
216+
response = GenerationResponse(response="Answer", reasoning_content="")
217+
218+
assert response.reasoning_content == ""
219+
220+
221+
def test_generation_response_serialization_with_reasoning_content():
222+
test_reasoning = "Thinking process here"
223+
224+
response = GenerationResponse(response="Response", reasoning_content=test_reasoning)
225+
226+
response_dict = response.dict()
227+
assert "reasoning_content" in response_dict
228+
assert response_dict["reasoning_content"] == test_reasoning
229+
230+
response_json = response.json()
231+
assert "reasoning_content" in response_json
232+
assert test_reasoning in response_json
233+
234+
235+
def test_generation_response_with_all_fields():
236+
test_tool_calls = [
237+
{"name": "test_func", "args": {}, "id": "call_123", "type": "tool_call"}
238+
]
239+
test_reasoning = "Detailed reasoning"
240+
241+
response = GenerationResponse(
242+
response=[{"role": "assistant", "content": "Response"}],
243+
tool_calls=test_tool_calls,
244+
reasoning_content=test_reasoning,
245+
llm_output={"token_usage": {"total_tokens": 100}},
246+
)
247+
248+
assert response.tool_calls == test_tool_calls
249+
assert response.reasoning_content == test_reasoning
250+
assert response.llm_output["token_usage"]["total_tokens"] == 100

tests/test_bot_thinking_events.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@
1818
import pytest
1919

2020
from nemoguardrails import RailsConfig
21+
from tests.conftest import REASONING_TRACE_MOCK_PATH
2122
from tests.utils import TestChat
2223

2324

2425
@pytest.mark.asyncio
2526
async def test_bot_thinking_event_creation_passthrough():
2627
test_reasoning_trace = "Let me think about this step by step..."
2728

28-
with patch(
29-
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
30-
) as mock_get_reasoning:
29+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
3130
mock_get_reasoning.return_value = test_reasoning_trace
3231

3332
config = RailsConfig.from_content(config={"models": [], "passthrough": True})
@@ -46,9 +45,7 @@ async def test_bot_thinking_event_creation_passthrough():
4645
async def test_bot_thinking_event_creation_non_passthrough():
4746
test_reasoning_trace = "Analyzing the user's request..."
4847

49-
with patch(
50-
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
51-
) as mock_get_reasoning:
48+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
5249
mock_get_reasoning.return_value = test_reasoning_trace
5350

5451
config = RailsConfig.from_content(
@@ -84,9 +81,7 @@ async def test_bot_thinking_event_creation_non_passthrough():
8481

8582
@pytest.mark.asyncio
8683
async def test_no_bot_thinking_event_when_no_reasoning_trace():
87-
with patch(
88-
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
89-
) as mock_get_reasoning:
84+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
9085
mock_get_reasoning.return_value = None
9186

9287
config = RailsConfig.from_content(config={"models": [], "passthrough": True})
@@ -104,9 +99,7 @@ async def test_no_bot_thinking_event_when_no_reasoning_trace():
10499
async def test_bot_thinking_before_bot_message():
105100
test_reasoning_trace = "Step 1: Understand the question\nStep 2: Formulate answer"
106101

107-
with patch(
108-
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
109-
) as mock_get_reasoning:
102+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
110103
mock_get_reasoning.return_value = test_reasoning_trace
111104

112105
config = RailsConfig.from_content(config={"models": [], "passthrough": True})
@@ -134,9 +127,7 @@ async def test_bot_thinking_before_bot_message():
134127
async def test_bot_thinking_accessible_in_output_rails():
135128
test_reasoning_trace = "Thinking: This requires careful consideration"
136129

137-
with patch(
138-
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
139-
) as mock_get_reasoning:
130+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
140131
mock_get_reasoning.return_value = test_reasoning_trace
141132

142133
config = RailsConfig.from_content(
@@ -171,9 +162,7 @@ async def test_bot_thinking_accessible_in_output_rails():
171162
async def test_bot_thinking_matches_in_output_rails():
172163
test_reasoning_trace = "Let me analyze: step 1, step 2, step 3"
173164

174-
with patch(
175-
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
176-
) as mock_get_reasoning:
165+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
177166
mock_get_reasoning.return_value = test_reasoning_trace
178167

179168
config = RailsConfig.from_content(
@@ -203,9 +192,7 @@ async def test_bot_thinking_matches_in_output_rails():
203192

204193
@pytest.mark.asyncio
205194
async def test_bot_thinking_none_when_no_reasoning():
206-
with patch(
207-
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
208-
) as mock_get_reasoning:
195+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
209196
mock_get_reasoning.return_value = None
210197

211198
config = RailsConfig.from_content(
@@ -240,9 +227,7 @@ async def test_bot_thinking_none_when_no_reasoning():
240227
async def test_bot_thinking_usable_in_output_rail_logic():
241228
test_reasoning_trace = "This contains sensitive information"
242229

243-
with patch(
244-
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
245-
) as mock_get_reasoning:
230+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
246231
mock_get_reasoning.return_value = test_reasoning_trace
247232

248233
config = RailsConfig.from_content(
@@ -270,12 +255,9 @@ async def test_bot_thinking_usable_in_output_rail_logic():
270255
)
271256

272257
assert isinstance(result.response, list)
273-
# TODO(@Pouyanpi): in llmrails.py appending reasoning traces to the final generation might not be desired anymore
274-
# should be fixed in a subsequent PR for 0.18.0 release
275-
assert (
276-
result.response[0]["content"]
277-
== test_reasoning_trace + "I'm sorry, I can't respond to that."
278-
)
258+
assert result.reasoning_content == test_reasoning_trace
259+
assert result.response[0]["content"] == "I'm sorry, I can't respond to that."
260+
assert test_reasoning_trace not in result.response[0]["content"]
279261

280262

281263
@pytest.mark.asyncio

tests/test_llmrails.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from nemoguardrails.logging.explain import ExplainInfo
2525
from nemoguardrails.rails.llm.config import Model
2626
from nemoguardrails.rails.llm.llmrails import get_action_details_from_flow_id
27+
from tests.conftest import REASONING_TRACE_MOCK_PATH
2728
from tests.utils import FakeLLM, clean_events, event_sequence_conforms
2829

2930

@@ -1372,3 +1373,107 @@ def test_cache_initialization_with_multiple_models(mock_init_llm_model):
13721373
assert "jailbreak_detection" in model_caches
13731374
assert model_caches["content_safety"].maxsize == 1000
13741375
assert model_caches["jailbreak_detection"].maxsize == 2000
1376+
1377+
1378+
@pytest.mark.asyncio
1379+
async def test_generate_async_reasoning_content_field_passthrough():
1380+
from nemoguardrails.rails.llm.options import GenerationOptions
1381+
1382+
test_reasoning_trace = "Let me think about this step by step..."
1383+
1384+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
1385+
mock_get_reasoning.return_value = test_reasoning_trace
1386+
1387+
config = RailsConfig.from_content(config={"models": []})
1388+
llm = FakeLLM(responses=["The answer is 42"])
1389+
llm_rails = LLMRails(config=config, llm=llm)
1390+
1391+
result = await llm_rails.generate_async(
1392+
messages=[{"role": "user", "content": "What is the answer?"}],
1393+
options=GenerationOptions(),
1394+
)
1395+
1396+
assert result.reasoning_content == test_reasoning_trace
1397+
assert isinstance(result.response, list)
1398+
assert result.response[0]["content"] == "The answer is 42"
1399+
1400+
1401+
@pytest.mark.asyncio
1402+
async def test_generate_async_reasoning_content_none():
1403+
from nemoguardrails.rails.llm.options import GenerationOptions
1404+
1405+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
1406+
mock_get_reasoning.return_value = None
1407+
1408+
config = RailsConfig.from_content(config={"models": []})
1409+
llm = FakeLLM(responses=["Regular response"])
1410+
llm_rails = LLMRails(config=config, llm=llm)
1411+
1412+
result = await llm_rails.generate_async(
1413+
messages=[{"role": "user", "content": "Hello"}],
1414+
options=GenerationOptions(),
1415+
)
1416+
1417+
assert result.reasoning_content is None
1418+
assert isinstance(result.response, list)
1419+
assert result.response[0]["content"] == "Regular response"
1420+
1421+
1422+
@pytest.mark.asyncio
1423+
async def test_generate_async_reasoning_not_in_response_content():
1424+
from nemoguardrails.rails.llm.options import GenerationOptions
1425+
1426+
test_reasoning_trace = "Let me analyze this carefully..."
1427+
1428+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
1429+
mock_get_reasoning.return_value = test_reasoning_trace
1430+
1431+
config = RailsConfig.from_content(config={"models": []})
1432+
llm = FakeLLM(responses=["The answer is 42"])
1433+
llm_rails = LLMRails(config=config, llm=llm)
1434+
1435+
result = await llm_rails.generate_async(
1436+
messages=[{"role": "user", "content": "What is the answer?"}],
1437+
options=GenerationOptions(),
1438+
)
1439+
1440+
assert result.reasoning_content == test_reasoning_trace
1441+
assert test_reasoning_trace not in result.response[0]["content"]
1442+
assert result.response[0]["content"] == "The answer is 42"
1443+
1444+
1445+
@pytest.mark.asyncio
1446+
async def test_generate_async_reasoning_with_thinking_tags():
1447+
test_reasoning_trace = "Step 1: Analyze\nStep 2: Respond"
1448+
1449+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
1450+
mock_get_reasoning.return_value = test_reasoning_trace
1451+
1452+
config = RailsConfig.from_content(config={"models": [], "passthrough": True})
1453+
llm = FakeLLM(responses=["The answer is 42"])
1454+
llm_rails = LLMRails(config=config, llm=llm)
1455+
1456+
result = await llm_rails.generate_async(
1457+
messages=[{"role": "user", "content": "What is the answer?"}]
1458+
)
1459+
1460+
expected_prefix = f"<think>{test_reasoning_trace}</think>\n"
1461+
assert result["content"].startswith(expected_prefix)
1462+
assert "The answer is 42" in result["content"]
1463+
1464+
1465+
@pytest.mark.asyncio
1466+
async def test_generate_async_no_thinking_tags_when_no_reasoning():
1467+
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
1468+
mock_get_reasoning.return_value = None
1469+
1470+
config = RailsConfig.from_content(config={"models": []})
1471+
llm = FakeLLM(responses=["Regular response"])
1472+
llm_rails = LLMRails(config=config, llm=llm)
1473+
1474+
result = await llm_rails.generate_async(
1475+
messages=[{"role": "user", "content": "Hello"}]
1476+
)
1477+
1478+
assert not result["content"].startswith("<think>")
1479+
assert result["content"] == "Regular response"

0 commit comments

Comments
 (0)