Skip to content

Commit 92ede8b

Browse files
committed
test: update reasoning trace handling in tests
- Replaced `remove_reasoning_traces` with `extract_and_strip_trace` across all test cases. - Adjusted assertions to use `result.text` for compatibility with the updated function. - Added `config.guardrail_reasoning_traces` to relevant tests for better configuration handling. - Improved test descriptions for clarity and consistency. fix fix
1 parent 1aca75d commit 92ede8b

File tree

1 file changed

+51
-37
lines changed

1 file changed

+51
-37
lines changed

tests/test_reasoning_traces.py

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
llm_call_info_var,
2525
streaming_handler_var,
2626
)
27-
from nemoguardrails.llm.filters import remove_reasoning_traces
27+
from nemoguardrails.llm.filters import extract_and_strip_trace
2828
from nemoguardrails.llm.taskmanager import LLMTaskManager
2929
from nemoguardrails.llm.types import Task
3030
from nemoguardrails.logging.explain import LLMCallInfo
@@ -38,8 +38,8 @@ def test_remove_reasoning_traces_basic(self):
3838
"""Test basic removal of reasoning traces."""
3939
input_text = "This is a <thinking>\nSome reasoning here\nMore reasoning\n</thinking> response."
4040
expected = "This is a response."
41-
result = remove_reasoning_traces(input_text, "<thinking>", "</thinking>")
42-
assert result == expected
41+
result = extract_and_strip_trace(input_text, "<thinking>", "</thinking>")
42+
assert result.text == expected
4343

4444
def test_remove_reasoning_traces_multiline(self):
4545
"""Test removal of multiline reasoning traces."""
@@ -52,40 +52,40 @@ def test_remove_reasoning_traces_multiline(self):
5252
</thinking> response after thinking.
5353
"""
5454
expected = "\n Here is my response after thinking.\n "
55-
result = remove_reasoning_traces(input_text, "<thinking>", "</thinking>")
56-
assert result == expected
55+
result = extract_and_strip_trace(input_text, "<thinking>", "</thinking>")
56+
assert result.text == expected
5757

5858
def test_remove_reasoning_traces_multiple_sections(self):
5959
"""Test removal of multiple reasoning trace sections."""
6060
input_text = "Start <thinking>Reasoning 1</thinking> middle <thinking>Reasoning 2</thinking> end."
6161
# Note: The current implementation removes all content between the first start and last end token
6262
# So the expected result is "Start end." not "Start middle end."
6363
expected = "Start end."
64-
result = remove_reasoning_traces(input_text, "<thinking>", "</thinking>")
65-
assert result == expected
64+
result = extract_and_strip_trace(input_text, "<thinking>", "</thinking>")
65+
assert result.text == expected
6666

6767
def test_remove_reasoning_traces_nested(self):
6868
"""Test handling of nested reasoning trace markers (should be handled correctly)."""
6969
input_text = (
7070
"Begin <thinking>Outer <thinking>Inner</thinking> Outer</thinking> End."
7171
)
7272
expected = "Begin End."
73-
result = remove_reasoning_traces(input_text, "<thinking>", "</thinking>")
74-
assert result == expected
73+
result = extract_and_strip_trace(input_text, "<thinking>", "</thinking>")
74+
assert result.text == expected
7575

7676
def test_remove_reasoning_traces_unmatched(self):
7777
"""Test handling of unmatched reasoning trace markers."""
7878
input_text = "Begin <thinking>Unmatched end."
79-
result = remove_reasoning_traces(input_text, "<thinking>", "</thinking>")
79+
result = extract_and_strip_trace(input_text, "<thinking>", "</thinking>")
8080
# We ~hould keep the unmatched tag since it's not a complete section
81-
assert result == "Begin <thinking>Unmatched end."
81+
assert result.text == "Begin <thinking>Unmatched end."
8282

8383
@pytest.mark.asyncio
8484
async def test_task_manager_parse_task_output(self):
8585
"""Test that the task manager correctly removes reasoning traces."""
8686
# mock config
8787
config = MagicMock(spec=RailsConfig)
88-
88+
config.guardrail_reasoning_traces = False
8989
# Create a ReasoningModelConfig
9090
reasoning_config = ReasoningModelConfig(
9191
remove_thinking_traces=True,
@@ -121,12 +121,13 @@ async def test_task_manager_parse_task_output(self):
121121
expected = "This is a final answer."
122122

123123
result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
124-
assert result == expected
124+
assert result.text == expected
125125

126126
@pytest.mark.asyncio
127127
async def test_parse_task_output_without_reasoning_config(self):
128128
"""Test that parse_task_output works without a reasoning config."""
129129
config = MagicMock(spec=RailsConfig)
130+
config.guardrail_reasoning_traces = False
130131

131132
# a Model without reasoning_config
132133
model_config = Model(type="main", engine="test", model="test-model")
@@ -147,18 +148,22 @@ async def test_parse_task_output_without_reasoning_config(self):
147148
input_text = (
148149
"This is a <thinking>Some reasoning here</thinking> final answer."
149150
)
150-
151-
# Without a reasoning config, the text should remain unchanged
152151
result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
153-
assert result == input_text
152+
assert result.text == input_text
154153

155154
@pytest.mark.asyncio
156155
async def test_parse_task_output_with_default_reasoning_traces(self):
157-
"""Test that parse_task_output works without a reasoning config."""
156+
"""Test that parse_task_output works with default reasoning traces."""
158157
config = MagicMock(spec=RailsConfig)
158+
config.guardrail_reasoning_traces = False
159159

160-
# a Model without reasoning_config
161-
model_config = Model(type="main", engine="test", model="test-model")
160+
# Create a Model with default reasoning_config
161+
model_config = Model(
162+
type="main",
163+
engine="test",
164+
model="test-model",
165+
reasoning_config=ReasoningModelConfig(),
166+
)
162167

163168
# Mock the get_prompt and get_task_model functions
164169
with (
@@ -172,42 +177,51 @@ async def test_parse_task_output_with_default_reasoning_traces(self):
172177

173178
llm_task_manager = LLMTaskManager(config)
174179

175-
# test parsing without a reasoning config
180+
# test parsing with default reasoning traces
176181
input_text = "This is a <think>Some reasoning here</think> final answer."
177-
expected = "This is a final answer."
178-
179-
# without a reasoning config, the default start_token and stop_token are used thus the text should change
180182
result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
181-
assert result == expected
183+
assert result.text == "This is a final answer."
182184

183185
@pytest.mark.asyncio
184186
async def test_parse_task_output_with_output_parser(self):
185-
"""Test that parse_task_output correctly applies output parsers before returning."""
187+
"""Test that parse_task_output works with an output parser."""
186188
config = MagicMock(spec=RailsConfig)
189+
config.guardrail_reasoning_traces = False
187190

188-
# mock output parser function
189-
def mock_parser(text):
190-
return text.upper()
191+
# Create a Model with reasoning_config
192+
model_config = Model(
193+
type="main",
194+
engine="test",
195+
model="test-model",
196+
reasoning_config=ReasoningModelConfig(
197+
remove_thinking_traces=True,
198+
start_token="<thinking>",
199+
end_token="</thinking>",
200+
),
201+
)
191202

192-
llm_task_manager = LLMTaskManager(config)
193-
llm_task_manager.output_parsers["test_parser"] = mock_parser
203+
def mock_parser(text):
204+
return f"PARSED: {text}"
194205

195-
# mock the get_prompt and get_task_model functions
206+
# Mock the get_prompt and get_task_model functions
196207
with (
197208
patch("nemoguardrails.llm.taskmanager.get_prompt") as mock_get_prompt,
198209
patch(
199210
"nemoguardrails.llm.taskmanager.get_task_model"
200211
) as mock_get_task_model,
201212
):
202-
mock_get_prompt.return_value = MagicMock(output_parser="test_parser")
203-
mock_get_task_model.return_value = None
213+
mock_get_prompt.return_value = MagicMock(output_parser="mock_parser")
214+
mock_get_task_model.return_value = model_config
204215

205-
# Test with output parser
206-
input_text = "this should be uppercase"
207-
expected = "THIS SHOULD BE UPPERCASE"
216+
llm_task_manager = LLMTaskManager(config)
217+
llm_task_manager.output_parsers["mock_parser"] = mock_parser
208218

219+
# test parsing with an output parser
220+
input_text = (
221+
"This is a <thinking>Some reasoning here</thinking> final answer."
222+
)
209223
result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
210-
assert result == expected
224+
assert result.text == "PARSED: This is a final answer."
211225

212226
@pytest.mark.asyncio
213227
async def test_passthrough_llm_action_removes_reasoning(self):

0 commit comments

Comments
 (0)