2424 llm_call_info_var ,
2525 streaming_handler_var ,
2626)
27- from nemoguardrails .llm .filters import remove_reasoning_traces
27+ from nemoguardrails .llm .filters import extract_and_strip_trace
2828from nemoguardrails .llm .taskmanager import LLMTaskManager
2929from nemoguardrails .llm .types import Task
3030from nemoguardrails .logging .explain import LLMCallInfo
@@ -38,8 +38,8 @@ def test_remove_reasoning_traces_basic(self):
3838 """Test basic removal of reasoning traces."""
3939 input_text = "This is a <thinking>\n Some reasoning here\n More reasoning\n </thinking> response."
4040 expected = "This is a response."
41- result = remove_reasoning_traces (input_text , "<thinking>" , "</thinking>" )
42- assert result == expected
41+ result = extract_and_strip_trace (input_text , "<thinking>" , "</thinking>" )
42+ assert result . text == expected
4343
4444 def test_remove_reasoning_traces_multiline (self ):
4545 """Test removal of multiline reasoning traces."""
@@ -52,40 +52,40 @@ def test_remove_reasoning_traces_multiline(self):
5252 </thinking> response after thinking.
5353 """
5454 expected = "\n Here is my response after thinking.\n "
55- result = remove_reasoning_traces (input_text , "<thinking>" , "</thinking>" )
56- assert result == expected
55+ result = extract_and_strip_trace (input_text , "<thinking>" , "</thinking>" )
56+ assert result . text == expected
5757
5858 def test_remove_reasoning_traces_multiple_sections (self ):
5959 """Test removal of multiple reasoning trace sections."""
6060 input_text = "Start <thinking>Reasoning 1</thinking> middle <thinking>Reasoning 2</thinking> end."
6161 # Note: The current implementation removes all content between the first start and last end token
6262 # So the expected result is "Start end." not "Start middle end."
6363 expected = "Start end."
64- result = remove_reasoning_traces (input_text , "<thinking>" , "</thinking>" )
65- assert result == expected
64+ result = extract_and_strip_trace (input_text , "<thinking>" , "</thinking>" )
65+ assert result . text == expected
6666
6767 def test_remove_reasoning_traces_nested (self ):
6868 """Test handling of nested reasoning trace markers (should be handled correctly)."""
6969 input_text = (
7070 "Begin <thinking>Outer <thinking>Inner</thinking> Outer</thinking> End."
7171 )
7272 expected = "Begin End."
73- result = remove_reasoning_traces (input_text , "<thinking>" , "</thinking>" )
74- assert result == expected
73+ result = extract_and_strip_trace (input_text , "<thinking>" , "</thinking>" )
74+ assert result . text == expected
7575
7676 def test_remove_reasoning_traces_unmatched (self ):
7777 """Test handling of unmatched reasoning trace markers."""
7878 input_text = "Begin <thinking>Unmatched end."
79- result = remove_reasoning_traces (input_text , "<thinking>" , "</thinking>" )
79+ result = extract_and_strip_trace (input_text , "<thinking>" , "</thinking>" )
8080 # We ~hould keep the unmatched tag since it's not a complete section
81- assert result == "Begin <thinking>Unmatched end."
81+ assert result . text == "Begin <thinking>Unmatched end."
8282
8383 @pytest .mark .asyncio
8484 async def test_task_manager_parse_task_output (self ):
8585 """Test that the task manager correctly removes reasoning traces."""
8686 # mock config
8787 config = MagicMock (spec = RailsConfig )
88-
88+ config . guardrail_reasoning_traces = False
8989 # Create a ReasoningModelConfig
9090 reasoning_config = ReasoningModelConfig (
9191 remove_thinking_traces = True ,
@@ -121,12 +121,13 @@ async def test_task_manager_parse_task_output(self):
121121 expected = "This is a final answer."
122122
123123 result = llm_task_manager .parse_task_output (Task .GENERAL , input_text )
124- assert result == expected
124+ assert result . text == expected
125125
126126 @pytest .mark .asyncio
127127 async def test_parse_task_output_without_reasoning_config (self ):
128128 """Test that parse_task_output works without a reasoning config."""
129129 config = MagicMock (spec = RailsConfig )
130+ config .guardrail_reasoning_traces = False
130131
131132 # a Model without reasoning_config
132133 model_config = Model (type = "main" , engine = "test" , model = "test-model" )
@@ -147,18 +148,22 @@ async def test_parse_task_output_without_reasoning_config(self):
147148 input_text = (
148149 "This is a <thinking>Some reasoning here</thinking> final answer."
149150 )
150-
151- # Without a reasoning config, the text should remain unchanged
152151 result = llm_task_manager .parse_task_output (Task .GENERAL , input_text )
153- assert result == input_text
152+ assert result . text == input_text
154153
155154 @pytest .mark .asyncio
156155 async def test_parse_task_output_with_default_reasoning_traces (self ):
157- """Test that parse_task_output works without a reasoning config ."""
156+ """Test that parse_task_output works with default reasoning traces ."""
158157 config = MagicMock (spec = RailsConfig )
158+ config .guardrail_reasoning_traces = False
159159
160- # a Model without reasoning_config
161- model_config = Model (type = "main" , engine = "test" , model = "test-model" )
160+ # Create a Model with default reasoning_config
161+ model_config = Model (
162+ type = "main" ,
163+ engine = "test" ,
164+ model = "test-model" ,
165+ reasoning_config = ReasoningModelConfig (),
166+ )
162167
163168 # Mock the get_prompt and get_task_model functions
164169 with (
@@ -172,42 +177,51 @@ async def test_parse_task_output_with_default_reasoning_traces(self):
172177
173178 llm_task_manager = LLMTaskManager (config )
174179
175- # test parsing without a reasoning config
180+ # test parsing with default reasoning traces
176181 input_text = "This is a <think>Some reasoning here</think> final answer."
177- expected = "This is a final answer."
178-
179- # without a reasoning config, the default start_token and stop_token are used thus the text should change
180182 result = llm_task_manager .parse_task_output (Task .GENERAL , input_text )
181- assert result == expected
183+ assert result . text == "This is a final answer."
182184
183185 @pytest .mark .asyncio
184186 async def test_parse_task_output_with_output_parser (self ):
185- """Test that parse_task_output correctly applies output parsers before returning ."""
187+ """Test that parse_task_output works with an output parser ."""
186188 config = MagicMock (spec = RailsConfig )
189+ config .guardrail_reasoning_traces = False
187190
188- # mock output parser function
189- def mock_parser (text ):
190- return text .upper ()
191+ # Create a Model with reasoning_config
192+ model_config = Model (
193+ type = "main" ,
194+ engine = "test" ,
195+ model = "test-model" ,
196+ reasoning_config = ReasoningModelConfig (
197+ remove_thinking_traces = True ,
198+ start_token = "<thinking>" ,
199+ end_token = "</thinking>" ,
200+ ),
201+ )
191202
192- llm_task_manager = LLMTaskManager ( config )
193- llm_task_manager . output_parsers [ "test_parser" ] = mock_parser
203+ def mock_parser ( text ):
204+ return f"PARSED: { text } "
194205
195- # mock the get_prompt and get_task_model functions
206+ # Mock the get_prompt and get_task_model functions
196207 with (
197208 patch ("nemoguardrails.llm.taskmanager.get_prompt" ) as mock_get_prompt ,
198209 patch (
199210 "nemoguardrails.llm.taskmanager.get_task_model"
200211 ) as mock_get_task_model ,
201212 ):
202- mock_get_prompt .return_value = MagicMock (output_parser = "test_parser " )
203- mock_get_task_model .return_value = None
213+ mock_get_prompt .return_value = MagicMock (output_parser = "mock_parser " )
214+ mock_get_task_model .return_value = model_config
204215
205- # Test with output parser
206- input_text = "this should be uppercase"
207- expected = "THIS SHOULD BE UPPERCASE"
216+ llm_task_manager = LLMTaskManager (config )
217+ llm_task_manager .output_parsers ["mock_parser" ] = mock_parser
208218
219+ # test parsing with an output parser
220+ input_text = (
221+ "This is a <thinking>Some reasoning here</thinking> final answer."
222+ )
209223 result = llm_task_manager .parse_task_output (Task .GENERAL , input_text )
210- assert result == expected
224+ assert result . text == "PARSED: This is a final answer."
211225
212226 @pytest .mark .asyncio
213227 async def test_passthrough_llm_action_removes_reasoning (self ):
0 commit comments