diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py
index ed3fff9bb..2a57e1c26 100644
--- a/nemoguardrails/actions/llm/generation.py
+++ b/nemoguardrails/actions/llm/generation.py
@@ -50,13 +50,14 @@
generation_options_var,
llm_call_info_var,
raw_llm_request,
+ reasoning_trace_var,
streaming_handler_var,
)
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
from nemoguardrails.kb.kb import KnowledgeBase
from nemoguardrails.llm.params import llm_params
from nemoguardrails.llm.prompts import get_prompt
-from nemoguardrails.llm.taskmanager import LLMTaskManager
+from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
from nemoguardrails.llm.types import Task
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
@@ -442,6 +443,7 @@ async def generate_user_intent(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_USER_INTENT, output=result
)
+ result = result.text
user_intent = get_first_nonempty_line(result)
if user_intent is None:
@@ -530,6 +532,11 @@ async def generate_user_intent(
text = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=text
)
+
+ text = _process_parsed_output(
+ text, self._include_reasoning_traces()
+ )
+
else:
# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
@@ -565,6 +572,8 @@ async def generate_user_intent(
text = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=result
)
+
+ text = _process_parsed_output(text, self._include_reasoning_traces())
text = text.strip()
if text.startswith('"'):
text = text[1:-1]
@@ -646,6 +655,7 @@ async def generate_next_step(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_NEXT_STEPS, output=result
)
+ result = result.text
# If we don't have multi-step generation enabled, we only look at the first line.
if not self.config.enable_multi_step_generation:
@@ -900,6 +910,10 @@ async def generate_bot_message(
Task.GENERAL, output=result
)
+ result = _process_parsed_output(
+ result, self._include_reasoning_traces()
+ )
+
log.info(
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
time() - t0,
@@ -963,6 +977,10 @@ async def generate_bot_message(
Task.GENERATE_BOT_MESSAGE, output=result
)
+ result = _process_parsed_output(
+ result, self._include_reasoning_traces()
+ )
+
# TODO: catch openai.error.InvalidRequestError from exceeding max token length
result = get_multiline_response(result)
@@ -1055,6 +1073,7 @@ async def generate_value(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_VALUE, output=result
)
+ result = result.text
# We only use the first line for now
# TODO: support multi-line values?
@@ -1266,6 +1285,7 @@ async def generate_intent_steps_message(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_INTENT_STEPS_MESSAGE, output=result
)
+ result = result.text
# TODO: Implement logic for generating more complex Colang next steps (multi-step),
# not just a single bot intent.
@@ -1348,6 +1368,7 @@ async def generate_intent_steps_message(
result = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=result
)
+ result = _process_parsed_output(result, self._include_reasoning_traces())
text = result.strip()
if text.startswith('"'):
text = text[1:-1]
@@ -1360,6 +1381,10 @@ async def generate_intent_steps_message(
events=[new_event_dict("BotMessage", text=text)],
)
+ def _include_reasoning_traces(self) -> bool:
+ """Get the configuration value for whether to include reasoning traces in output."""
+ return _get_apply_to_reasoning_traces(self.config)
+
def clean_utterance_content(utterance: str) -> str:
"""
@@ -1377,3 +1402,27 @@ def clean_utterance_content(utterance: str) -> str:
# It should be translated to an actual \n character.
utterance = utterance.replace("\\n", "\n")
return utterance
+
+
+def _record_reasoning_trace(trace: str) -> None:
+ """Store the reasoning trace in context for later retrieval."""
+ reasoning_trace_var.set(trace)
+
+
+def _assemble_response(text: str, trace: Optional[str], include_reasoning: bool) -> str:
+ """Combine trace and text if requested, otherwise just return text."""
+ return (trace + text) if (trace and include_reasoning) else text
+
+
+def _process_parsed_output(
+ output: ParsedTaskOutput, include_reasoning_trace: bool
+) -> str:
+ """Record trace, then assemble the final LLM response."""
+ if reasoning_trace := output.reasoning_trace:
+ _record_reasoning_trace(reasoning_trace)
+ return _assemble_response(output.text, reasoning_trace, include_reasoning_trace)
+
+
+def _get_apply_to_reasoning_traces(config: RailsConfig) -> bool:
+ """Get the configuration value for whether to include reasoning traces in output."""
+ return config.rails.output.apply_to_reasoning_traces
diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py
index 012043c04..e58a1aba5 100644
--- a/nemoguardrails/actions/llm/utils.py
+++ b/nemoguardrails/actions/llm/utils.py
@@ -24,7 +24,7 @@
from nemoguardrails.colang.v2_x.lang.colang_ast import Flow
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
-from nemoguardrails.context import llm_call_info_var
+from nemoguardrails.context import llm_call_info_var, reasoning_trace_var
from nemoguardrails.logging.callbacks import logging_callbacks
from nemoguardrails.logging.explain import LLMCallInfo
@@ -192,7 +192,7 @@ def get_colang_history(
and event["action_name"] == "retrieve_relevant_chunks"
):
continue
- history += f'execute {event["action_name"]}\n'
+ history += f"execute {event['action_name']}\n"
elif event["type"] == "InternalSystemActionFinished" and not event.get(
"is_system_action"
):
@@ -577,3 +577,15 @@ def escape_flow_name(name: str) -> str:
# removes non-word chars and leading digits in a word
result = re.sub(r"\b\d+|[^\w\s]", "", result)
return result
+
+
+def get_and_clear_reasoning_trace_contextvar() -> Optional[str]:
+ """Get the current reasoning trace and clear it from the context.
+
+ Returns:
+ Optional[str]: The reasoning trace if one exists, None otherwise.
+ """
+ if reasoning_trace := reasoning_trace_var.get():
+ reasoning_trace_var.set(None)
+ return reasoning_trace
+ return None
diff --git a/nemoguardrails/actions/v2_x/generation.py b/nemoguardrails/actions/v2_x/generation.py
index 7dd5df8c4..e011379b8 100644
--- a/nemoguardrails/actions/v2_x/generation.py
+++ b/nemoguardrails/actions/v2_x/generation.py
@@ -197,7 +197,7 @@ async def _collect_user_intent_and_examples(
# We add these in reverse order so the most relevant is towards the end.
for result in reversed(results):
- examples += f"user action: user said \"{result.text}\"\nuser intent: {result.meta['intent']}\n\n"
+ examples += f'user action: user said "{result.text}"\nuser intent: {result.meta["intent"]}\n\n'
if result.meta["intent"] not in potential_user_intents:
potential_user_intents.append(result.meta["intent"])
@@ -302,6 +302,8 @@ async def generate_user_intent(
Task.GENERATE_USER_INTENT_FROM_USER_ACTION, output=result
)
+ result = result.text
+
user_intent = get_first_nonempty_line(result)
# GTP-4o often adds 'user intent: ' in front
if user_intent and ":" in user_intent:
@@ -378,6 +380,8 @@ async def generate_user_intent_and_bot_action(
Task.GENERATE_USER_INTENT_AND_BOT_ACTION_FROM_USER_ACTION, output=result
)
+ result = result.text
+
user_intent = get_first_nonempty_line(result)
if user_intent and ":" in user_intent:
@@ -458,6 +462,8 @@ async def passthrough_llm_action(
text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text)
+ text = result.text
+
return text
@action(name="CheckValidFlowExistsAction", is_system_action=True)
@@ -541,6 +547,8 @@ async def generate_flow_from_instructions(
task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result
)
+ result = result.text
+
# TODO: why this is not part of a filter or output_parser?
#
lines = _remove_leading_empty_lines(result).split("\n")
@@ -613,6 +621,8 @@ async def generate_flow_from_name(
task=Task.GENERATE_FLOW_FROM_NAME, output=result
)
+ result = result.text
+
lines = _remove_leading_empty_lines(result).split("\n")
if lines[0].startswith("flow"):
@@ -680,6 +690,8 @@ async def generate_flow_continuation(
task=Task.GENERATE_FLOW_CONTINUATION, output=result
)
+ result = result.text
+
lines = _remove_leading_empty_lines(result).split("\n")
if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""):
@@ -806,6 +818,8 @@ async def generate_value(
Task.GENERATE_VALUE_FROM_INSTRUCTION, output=result
)
+ result = result.text
+
# We only use the first line for now
# TODO: support multi-line values?
value = result.strip().split("\n")[0]
@@ -913,6 +927,8 @@ async def generate_flow(
Task.GENERATE_FLOW_CONTINUATION_FROM_NLD, output=result
)
+ result = result.text
+
result = _remove_leading_empty_lines(result)
lines = result.split("\n")
if "codeblock" in lines[0]:
diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py
index 9f1d2fb6a..e66f1a0d5 100644
--- a/nemoguardrails/context.py
+++ b/nemoguardrails/context.py
@@ -14,6 +14,7 @@
# limitations under the License.
import contextvars
+from typing import Optional
streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)
@@ -32,3 +33,7 @@
# The raw LLM request that comes from the user.
# This is used in passthrough mode.
raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None)
+
+reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
+ "reasoning_trace", default=None
+)
diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py
index 7c1cbf406..90022cbc4 100644
--- a/nemoguardrails/library/content_safety/actions.py
+++ b/nemoguardrails/library/content_safety/actions.py
@@ -80,6 +80,7 @@ async def content_safety_check_input(
result = await llm_call(llm, check_input_prompt, stop=stop)
result = llm_task_manager.parse_task_output(task, output=result)
+ result = result.text
try:
is_safe, violated_policies = result
@@ -162,6 +163,8 @@ async def content_safety_check_output(
result = llm_task_manager.parse_task_output(task, output=result)
+ result = result.text
+
try:
is_safe, violated_policies = result
except TypeError:
diff --git a/nemoguardrails/library/self_check/facts/actions.py b/nemoguardrails/library/self_check/facts/actions.py
index 0f9c6d7ec..fb75ef72d 100644
--- a/nemoguardrails/library/self_check/facts/actions.py
+++ b/nemoguardrails/library/self_check/facts/actions.py
@@ -82,6 +82,7 @@ async def self_check_facts(
task, output=response, forced_output_parser="is_content_safe"
)
+ result = result.text
is_not_safe, _ = result
result = float(not is_not_safe)
diff --git a/nemoguardrails/library/self_check/input_check/actions.py b/nemoguardrails/library/self_check/input_check/actions.py
index 1bf778129..8005f0724 100644
--- a/nemoguardrails/library/self_check/input_check/actions.py
+++ b/nemoguardrails/library/self_check/input_check/actions.py
@@ -83,6 +83,7 @@ async def self_check_input(
task, output=response, forced_output_parser="is_content_safe"
)
+ result = result.text
is_safe, _ = result
if not is_safe:
diff --git a/nemoguardrails/library/self_check/output_check/actions.py b/nemoguardrails/library/self_check/output_check/actions.py
index c8ed9ca6c..8bbcdf42e 100644
--- a/nemoguardrails/library/self_check/output_check/actions.py
+++ b/nemoguardrails/library/self_check/output_check/actions.py
@@ -87,6 +87,7 @@ async def self_check_output(
task, output=response, forced_output_parser="is_content_safe"
)
+ result = result.text
is_safe, _ = result
return is_safe
diff --git a/nemoguardrails/llm/filters.py b/nemoguardrails/llm/filters.py
index abf988c77..a0d80bb5d 100644
--- a/nemoguardrails/llm/filters.py
+++ b/nemoguardrails/llm/filters.py
@@ -15,7 +15,8 @@
import re
import textwrap
-from typing import List
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
from nemoguardrails.actions.llm.utils import (
get_colang_history,
@@ -23,6 +24,16 @@
)
+@dataclass
+class ReasoningExtractionResult:
+ """
+ Holds cleaned response text and optional chain-of-thought reasoning trace extracted from LLM output.
+ """
+
+ text: str
+ reasoning_trace: Optional[str] = None
+
+
def colang(events: List[dict]) -> str:
"""Filter that turns an array of events into a colang history."""
return get_colang_history(events)
@@ -439,24 +450,98 @@ def conversation_to_events(conversation: List) -> List[dict]:
return events
-def remove_reasoning_traces(response: str, start_token: str, end_token: str) -> str:
- """Removes the text between the first occurrence of the start token and the
- last occurrence of the last token, if these tokens exist in the response.
+def _find_token_positions_for_removal(
+ response: str, start_token: Optional[str], end_token: Optional[str]
+) -> Tuple[int, int]:
+ """Helper function to find token positions specifically for text removal.
+
+ This is useful, for example, to remove reasoning traces from a reasoning LLM response.
+
+ This is optimized for the removal use case:
+ 1. Uses find() for first start token
+ 2. Uses rfind() for last end token
+ 3. Sets start_index to 0 if start token is missing
- This utility function is useful to strip reasoning traces from reasoning LLMs
- that encode the reasoning traces between specific tokens.
+ Args:
+ response(str): The text to search in
+ start_token(str): The token marking the start of text to remove
+ end_token(str): The token marking the end of text to remove
+
+ Returns:
+ A tuple of (start_index, end_index) marking the span to remove;
+ both indices are -1 if start_token and end_token are not provided.
"""
- if start_token and end_token:
- start_index = response.find(start_token)
- # If the start index is missing, this is probably a continuation of a bot message
- # started in the prompt.
- if start_index == -1:
- start_index = 0
- end_index = response.rfind(end_token)
- if end_index == -1:
- return response
-
- if start_index != -1 and end_index != -1 and start_index < end_index:
- return response[:start_index] + response[end_index + len(end_token) :]
-
- return response
+ if not start_token or not end_token:
+ return -1, -1
+
+ start_index = response.find(start_token)
+ # if the start index is missing, this is probably a continuation of a bot message
+ # started in the prompt.
+ if start_index == -1:
+ start_index = 0
+
+ end_index = response.rfind(end_token)
+
+ return start_index, end_index
+
+
+def find_reasoning_tokens_position(
+ response: str, start_token: Optional[str], end_token: Optional[str]
+) -> Tuple[int, int]:
+ """Finds the positions of the first start token and the last end token.
+
+ This is intended to find the outermost boundaries of potential
+ reasoning sections, typically for removal.
+
+ Args:
+ response(str): The text to search in.
+ start_token(Optional[str]): The token marking the start of reasoning.
+ end_token(Optional[str]): The token marking the end of reasoning.
+
+ Returns:
+ A tuple (start_index, end_index).
+ - start_index: Position of the first `start_token`, or 0 if not found.
+ - end_index: Position of the last `end_token`, or -1 if not found.
+ """
+
+ return _find_token_positions_for_removal(response, start_token, end_token)
+
+
+def extract_and_strip_trace(
+ response: str, start_token: str, end_token: str
+) -> ReasoningExtractionResult:
+ """Extracts and removes reasoning traces from the given text.
+
+ This function identifies reasoning traces in the text that are marked
+ by specific start and end tokens. It extracts these traces, removes
+ them from the original text, and returns both the cleaned text and
+ the extracted reasoning trace.
+
+ Args:
+ response (str): The text to process.
+ start_token (str): The token marking the start of a reasoning trace.
+ end_token (str): The token marking the end of a reasoning trace.
+
+ Returns:
+ ReasoningExtractionResult: An object containing the cleaned text
+ without reasoning traces and the extracted reasoning trace, if any.
+ """
+
+ start_index, end_index = find_reasoning_tokens_position(
+ response, start_token, end_token
+ )
+ # handles invalid/empty tokens returned as (-1, -1)
+ if start_index == -1 and end_index == -1:
+ return ReasoningExtractionResult(text=response, reasoning_trace=None)
+ # end token is missing
+ if end_index == -1:
+ return ReasoningExtractionResult(text=response, reasoning_trace=None)
+ # extrace if tokens are present and start < end
+ if start_index < end_index:
+ reasoning_trace = response[start_index : end_index + len(end_token)]
+ cleaned_text = response[:start_index] + response[end_index + len(end_token) :]
+ return ReasoningExtractionResult(
+ text=cleaned_text, reasoning_trace=reasoning_trace
+ )
+
+ return ReasoningExtractionResult(text=response, reasoning_trace=None)
diff --git a/nemoguardrails/llm/taskmanager.py b/nemoguardrails/llm/taskmanager.py
index a9242cded..bb3a1cde2 100644
--- a/nemoguardrails/llm/taskmanager.py
+++ b/nemoguardrails/llm/taskmanager.py
@@ -16,7 +16,8 @@
import logging
import re
from ast import literal_eval
-from typing import Any, Callable, List, Optional, Union
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
@@ -25,10 +26,10 @@
co_v2,
colang,
colang_without_identifiers,
+ extract_and_strip_trace,
first_turns,
indent,
last_turns,
- remove_reasoning_traces,
remove_text_messages,
to_chat_messages,
to_intent_messages,
@@ -52,13 +53,62 @@
from nemoguardrails.rails.llm.config import MessageTemplate, RailsConfig
+def output_has_reasoning_traces(output: str, start_token: str, end_token: str) -> bool:
+ """Checks if the output string contains both start and end reasoning tokens."""
+ return start_token in output and end_token in output
+
+
+@dataclass
+class ParsedTaskOutput:
+ """
+ Encapsulates the result of running and parsing an LLM task.
+
+ Attributes:
+ text (str): The cleaned and parsed output string, representing
+ the main result of the task.
+ reasoning_trace (Optional[str]): An optional chain-of-thought
+ reasoning trace, providing insights into the reasoning
+ process behind the task output, if available.
+ """
+
+ text: str
+ reasoning_trace: Optional[str] = None
+
+
+def should_remove_reasoning_traces_from_output(config, task):
+ model = get_task_model(config, task)
+
+ model_config = (
+ model
+ and model.reasoning_config
+ and model.reasoning_config.remove_thinking_traces
+ )
+
+ if config.rails.output.apply_to_reasoning_traces:
+ return False
+ else:
+ return model_config
+
+
+def get_reasoning_token_tags(config, task):
+ model = get_task_model(config, task)
+
+ if model and model.reasoning_config:
+ start_token = model.reasoning_config.start_token
+ end_token = model.reasoning_config.end_token
+ else:
+ start_token = None
+ end_token = None
+
+ return start_token, end_token
+
+
class LLMTaskManager:
"""Interface for interacting with an LLM in a task-oriented way."""
def __init__(self, config: RailsConfig):
# Save the config as we need access to instructions and sample conversations.
self.config = config
-
# Initialize the environment for rendering templates.
self.env = SandboxedEnvironment()
@@ -78,7 +128,7 @@ def __init__(self, config: RailsConfig):
self.env.filters["to_chat_messages"] = to_chat_messages
self.env.filters["verbose_v1"] = verbose_v1
- self.output_parsers = {
+ self.output_parsers: Dict[Optional[str], Callable] = {
"user_intent": user_intent_parser,
"bot_intent": bot_intent_parser,
"bot_message": bot_message_parser,
@@ -308,36 +358,50 @@ def render_task_prompt(
def parse_task_output(
self, task: Task, output: str, forced_output_parser: Optional[str] = None
- ):
- """Parses the output for the provided tasks.
-
- If an output parser is associated with the prompt, it will be used.
- Otherwise, the output is returned as is.
+ ) -> ParsedTaskOutput:
+ """Parses the output of a task, optionally extracting reasoning traces.
+
+ Args:
+ task (Task): The task for which the output is being parsed.
+ output (str): The output string to be parsed.
+ forced_output_parser (Optional[str]): An optional parser name to force
+
+ Returns:
+ ParsedTaskOutput: An object containing the parsed text (which may
+ include or exclude reasoning traces based on configuration) and
+ any reasoning trace.
"""
- prompt = get_prompt(self.config, task)
+ reasoning_trace: Optional[str] = None
- output_parser = None
- if forced_output_parser:
- output_parser = self.output_parsers.get(forced_output_parser)
- elif prompt.output_parser:
- output_parser = self.output_parsers.get(prompt.output_parser)
- if not output_parser:
- logging.info("No output parser found for %s", prompt.output_parser)
+ # Get the tokens first to check for their presence
+ start_token, end_token = get_reasoning_token_tags(self.config, task)
- model = get_task_model(self.config, task)
+ # 1. strip and capture reasoning traces if configured and present
if (
- model
- and model.reasoning_config
- and model.reasoning_config.remove_thinking_traces
+ start_token
+ and end_token
+ and output_has_reasoning_traces(output, start_token, end_token)
):
- start_token = model.reasoning_config.start_token
- end_token = model.reasoning_config.end_token
- output = remove_reasoning_traces(output, start_token, end_token)
+ reasoning_trace_result = extract_and_strip_trace(
+ output, start_token, end_token
+ )
+ reasoning_trace = reasoning_trace_result.reasoning_trace
+
+ if should_remove_reasoning_traces_from_output(self.config, task):
+ output = reasoning_trace_result.text
+
+ # 2. delegate to existing parser
+ prompt = get_prompt(self.config, task)
+ parser_name = forced_output_parser or prompt.output_parser
+ parser_fn = self.output_parsers.get(parser_name)
- if output_parser:
- return output_parser(output)
+ if parser_fn:
+ parsed_text = parser_fn(output)
else:
- return output
+ logging.info("No output parser found for %s", prompt.output_parser)
+ parsed_text = output
+
+ return ParsedTaskOutput(text=parsed_text, reasoning_trace=reasoning_trace)
def has_output_parser(self, task: Task):
prompt = get_prompt(self.config, task)
diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py
index 8d2502eb0..e776e6653 100644
--- a/nemoguardrails/rails/llm/config.py
+++ b/nemoguardrails/rails/llm/config.py
@@ -453,6 +453,15 @@ class OutputRails(BaseModel):
description="Configuration for streaming output rails.",
)
+ apply_to_reasoning_traces: bool = Field(
+ default=False,
+ description=(
+ "If True, output rails will apply guardrails to both reasoning traces and output response. "
+ "If False, output rails will only apply guardrails to the output response excluding the reasoning traces, "
+ "thus keeping reasoning traces unaltered."
+ ),
+ )
+
class RetrievalRails(BaseModel):
"""Configuration of retrieval rails."""
diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py
index 3b667fee1..11876bdba 100644
--- a/nemoguardrails/rails/llm/llmrails.py
+++ b/nemoguardrails/rails/llm/llmrails.py
@@ -31,7 +31,10 @@
from langchain_core.language_models.llms import BaseLLM
from nemoguardrails.actions.llm.generation import LLMGenerationActions
-from nemoguardrails.actions.llm.utils import get_colang_history
+from nemoguardrails.actions.llm.utils import (
+ get_and_clear_reasoning_trace_contextvar,
+ get_colang_history,
+)
from nemoguardrails.actions.output_mapping import is_output_blocked
from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx
from nemoguardrails.colang import parse_colang_file
@@ -48,6 +51,7 @@
generation_options_var,
llm_stats_var,
raw_llm_request,
+ reasoning_trace_var,
streaming_handler_var,
)
from nemoguardrails.embeddings.index import EmbeddingsIndex
@@ -838,6 +842,14 @@ async def generate_async(
else:
res = GenerationResponse(response=[new_message])
+ if reasoning_trace := get_and_clear_reasoning_trace_contextvar():
+ if prompt:
+ res.response = reasoning_trace + res.response
+ else:
+ res.response[0]["content"] = (
+ reasoning_trace + res.response[0]["content"]
+ )
+
if self.config.colang_version == "1.0":
# If output variables are specified, we extract their values
if options.output_vars:
@@ -926,9 +938,14 @@ async def generate_async(
input=messages, response=res, adapters=self._log_adapters
)
await tracer.export_async()
+
return res
else:
# If a prompt is used, we only return the content of the message.
+
+ if reasoning_trace := get_and_clear_reasoning_trace_contextvar():
+ new_message["content"] = reasoning_trace + new_message["content"]
+
if prompt:
return new_message["content"]
else:
diff --git a/tests/conftest.py b/tests/conftest.py
index 12fb06e89..34d62ba8d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,6 +17,27 @@
import pytest
+from nemoguardrails.context import reasoning_trace_var
+
def pytest_configure(config):
patch("prompt_toolkit.PromptSession", autospec=True).start()
+
+
+@pytest.fixture(autouse=True)
+def reset_reasoning_trace():
+ """Reset the reasoning_trace_var before each test.
+
+ This fixture runs automatically for every test (autouse=True) to ensure
+ a clean state for the reasoning trace context variable.
+
+ current Issues with ContextVar approach, not only specific to this case:
+ global State: ContextVar creates global state that's hard to track and manage
+ implicit Flow: The reasoning trace flows through the system in a non-obvious way
+ testing Complexity: It causes test isolation problems that we are trying to avoid using this fixture
+ """
+ # reset the variable before the test
+ reasoning_trace_var.set(None)
+ yield
+ # reset the variable after the test as well (in case the test fails)
+ reasoning_trace_var.set(None)
diff --git a/tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py b/tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py
index 1ed061a4b..4ba550383 100644
--- a/tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py
+++ b/tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py
@@ -50,6 +50,7 @@ async def custom_llm_request(
result = await llm_call(llm, prompt, stop=stop)
result = llm_task_manager.parse_task_output(prompt_template_name, output=result)
+ result = result.text
# Any additional parsing of the output
value = result.strip().split("\n")[0]
diff --git a/tests/test_filters.py b/tests/test_filters.py
index ae8bba489..f97b288d2 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -14,13 +14,16 @@
# limitations under the License.
import textwrap
+from typing import List, Tuple, Union
import pytest
from nemoguardrails.llm.filters import (
+ ReasoningExtractionResult,
+ extract_and_strip_trace,
+ find_reasoning_tokens_position,
first_turns,
last_turns,
- remove_reasoning_traces,
to_chat_messages,
user_assistant_sequence,
)
@@ -92,51 +95,228 @@ def test_last_turns():
assert last_turns(colang_history, 2) == colang_history
+def _build_test_string(parts: List[Union[str, Tuple[str, int]]]) -> str:
+ """Builds a test string from a list of parts.
+
+ Each part can be a literal string or a (character, count) tuple.
+ Example: [("a", 3), "[START]", ("b", 5)] -> "aaa[START]bbbbb"
+ """
+ result = []
+ for part in parts:
+ if isinstance(part, str):
+ result.append(part)
+ elif isinstance(part, tuple) and len(part) == 2:
+ char, count = part
+ result.append(char * count)
+ else:
+ raise TypeError(f"Invalid part type in _build_test_string: {part}")
+ return "".join(result)
+
+
@pytest.mark.parametrize(
"response, start_token, end_token, expected",
[
(
- "This is an example [START]hidden reasoning[END] of a response.",
+ _build_test_string(
+ [
+ ("a", 5),
+ "[START]",
+ ("b", 10),
+ "[END]",
+ ("c", 5),
+ ]
+ ),
"[START]",
"[END]",
- "This is an example of a response.",
+ (5, 22), # 5 a's + 7 START + 10 b = 22
+ ),
+ # multiple reasoning sections
+ (
+ _build_test_string(
+ [
+ ("a", 3),
+ "[START]",
+ ("b", 4),
+ "[END]",
+ ("c", 3),
+ "[START]",
+ ("d", 4),
+ "[END]",
+ ("e", 3),
+ ]
+ ),
+ "[START]",
+ "[END]",
+ (
+ 3,
+ 33,
+ ),
+ ),
+ (
+ _build_test_string(
+ [
+ ("a", 2),
+ "[START]",
+ ("b", 2),
+ "[START]",
+ ("c", 2),
+ "[END]",
+ ("d", 2),
+ "[END]",
+ ("e", 2),
+ ]
+ ),
+ "[START]",
+ "[END]",
+ (
+ 2,
+ 27,
+ ),
+ ),
+ (
+ _build_test_string([("a", 10)]),
+ "[START]",
+ "[END]",
+ (0, -1), # no tokens found, start_index is 0
),
(
- "This is an example without an end token.",
+ _build_test_string(
+ [
+ ("a", 5),
+ "[START]",
+ ("b", 5),
+ ]
+ ),
"[START]",
"[END]",
- "This is an example without an end token.",
+ (5, -1), # [START] at pos 5, no end token
),
(
- "This is an example [START] with a start token but no end token.",
+ _build_test_string(
+ [
+ ("a", 5),
+ "[END]",
+ ("b", 5),
+ ]
+ ),
+ "[START]",
+ "[END]",
+ (0, 5), # no start token so 0, end at pos 5
+ ),
+ (
+ "",
+ "[START]",
+ "[END]",
+ (0, -1), # empty string, start_index is 0
+ ),
+ ],
+)
+def test_find_token_positions_for_removal(response, start_token, end_token, expected):
+ """Test finding token positions for removal.
+
+ Test cases use _build_test_string for clarity and mathematical obviousness.
+ """
+ assert find_reasoning_tokens_position(response, start_token, end_token) == expected
+
+
+@pytest.mark.parametrize(
+ "response, start_token, end_token, expected_text, expected_trace",
+ [
+ (
+ "This is an example [START]hidden reasoning[END] of a response.",
"[START]",
"[END]",
- "This is an example [START] with a start token but no end token.",
+ "This is an example of a response.",
+ "[START]hidden reasoning[END]",
),
(
- "Before [START]hidden[END] middle [START]extra hidden[END] after.",
+ "Before [START]first[END] middle [START]second[END] after.",
"[START]",
"[END]",
"Before after.",
+ "[START]first[END] middle [START]second[END]",
),
(
"Text [START] first [START] nested [END] second [END] more text.",
"[START]",
"[END]",
"Text more text.",
+ "[START] first [START] nested [END] second [END]",
+ ),
+ (
+ "No tokens here",
+ "[START]",
+ "[END]",
+ "No tokens here",
+ None,
+ ),
+ (
+ "Only [START] start token",
+ "[START]",
+ "[END]",
+ "Only [START] start token",
+ None,
+ ),
+ (
+ "Only end token [END]",
+ "[START]",
+ "[END]",
+ "",
+ "Only end token [END]",
+ ),
+ (
+ "",
+ "[START]",
+ "[END]",
+ "",
+ None,
+ ),
+ # End token before start token (tests the final return path)
+ (
+ "some [END] text [START]",
+ "[START]",
+ "[END]",
+ "some [END] text [START]",
+ None,
+ ),
+ # Original test cases adapted
+ (
+ "[END] Out of order [START] tokens [END] example.",
+ "[START]",
+ "[END]",
+ "[END] Out of order example.",
+ "[START] tokens [END]",
+ ),
+ (
+ "[START] nested [START] tokens [END] out of [END] order.",
+ "[START]",
+ "[END]",
+ " order.",
+ "[START] nested [START] tokens [END] out of [END]",
+ ),
+ (
+ "[END] [START] [START] example [END] text.",
+ "[START]",
+ "[END]",
+ "[END] text.",
+ "[START] [START] example [END]",
),
(
- "[START]Remove this[END] but keep this.",
+ "example text.",
"[START]",
"[END]",
- " but keep this.",
+ "example text.",
+ None,
),
- ("", "[START]", "[END]", ""),
],
)
-def test_remove_reasoning_traces(response, start_token, end_token, expected):
- """Test removal of text between start and end tokens with multiple cases."""
- assert remove_reasoning_traces(response, start_token, end_token) == expected
+def test_extract_and_strip_trace(
+ response, start_token, end_token, expected_text, expected_trace
+):
+ """Tests the extraction and stripping of reasoning traces."""
+ result = extract_and_strip_trace(response, start_token, end_token)
+ assert result.text == expected_text
+ assert result.reasoning_trace == expected_trace
class TestToChatMessages:
diff --git a/tests/test_llmrails_reasoning.py b/tests/test_llmrails_reasoning.py
index a5575f423..a016ed77c 100644
--- a/tests/test_llmrails_reasoning.py
+++ b/tests/test_llmrails_reasoning.py
@@ -69,10 +69,10 @@ def rails_config():
async def test_1(rails_config):
llm = FakeLLM(
responses=[
- "some text\n express greeting",
- "some text\n ask math question",
- 'some text\n "The answer is 5"',
- 'some text\n "Are you happy with the result?"',
+ "some redundant CoT text 1\n express greeting",
+ "some redundant CoT text 2\n ask math question",
+ 'some redundant CoT text 3\n "The answer is 5"',
+ 'some important COT text\n "Are you happy with the result?"',
]
)
@@ -92,5 +92,5 @@ async def compute(what: Optional[str] = "2 + 3"):
bot_message = await llm_rails.generate_async(messages=messages)
assert bot_message == {
"role": "assistant",
- "content": "The answer is 5\nAre you happy with the result?",
+ "content": "some important COT textThe answer is 5\nAre you happy with the result?",
}
diff --git a/tests/test_llmrails_reasoning_output_rails.py b/tests/test_llmrails_reasoning_output_rails.py
new file mode 100644
index 000000000..1b6f7ae13
--- /dev/null
+++ b/tests/test_llmrails_reasoning_output_rails.py
@@ -0,0 +1,312 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for LLM Rails reasoning output configuration and behavior.
+
+This module contains tests that verify the behavior of LLM Rails when handling
+reasoning traces in the output, including configuration options and guardrail
+behavior.
+"""
+
+from typing import Any, Dict, NamedTuple
+
+import pytest
+
+from nemoguardrails import RailsConfig
+from tests.utils import TestChat
+
+
+class ReasoningTraceTestCase(NamedTuple):
+ """Test case for reasoning trace configuration.
+
+ Attributes:
+ description: description of the test case
+ remove_thinking_traces: Whether to remove thinking traces in the model config
+ apply_to_reasoning_traces: Whether to apply output rails to reasoning traces
+ expected_think_tag: Whether the think tag should be present in the response
+ expected_error_message: Whether the error message should be present in the response
+ """
+
+ description: str
+ remove_thinking_traces: bool
+ apply_to_reasoning_traces: bool
+ expected_think_tag: bool
+ expected_error_message: bool
+
+
+async def check_sensitive_info(context: Dict[str, Any]) -> bool:
+ """Check if the response contains sensitive information."""
+ response = context.get("bot_message", "")
+ prompt = context.get("user_message", "")
+ input_text = response or prompt
+ return "credit card" in input_text.lower() or any(
+ c.isdigit() for c in input_text if c.isdigit() or c == "-"
+ )
+
+
+async def check_think_tag_present(context: Dict[str, Any]) -> bool:
+ """Check if the think tag is present in the bot's response."""
+ response = context.get("bot_message", "")
+ return "" in response
+
+
+@pytest.fixture
+def base_config() -> RailsConfig:
+ """Creates a base RailsConfig with common test configuration."""
+ return RailsConfig.from_content(
+ colang_content="""
+ define flow check think tag
+ $not_allowed = execute check_think_tag_present
+ if $not_allowed
+ bot informs tag not allowed
+ stop
+
+ define bot informs tag not allowed
+ "think tag is not allowed it must be removed"
+ """,
+ yaml_content="""
+ models:
+ - type: main
+ engine: fake
+ model: fake
+ colang_version: "1.0"
+ rails:
+ output:
+ flows:
+ - check think tag
+ """,
+ )
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ ReasoningTraceTestCase(
+ description="Remove thinking traces and show error when guardrail is enabled",
+ remove_thinking_traces=True,
+ apply_to_reasoning_traces=True,
+ expected_think_tag=True,
+ expected_error_message=True,
+ ),
+ ReasoningTraceTestCase(
+ description="Preserve thinking traces and hide error when guardrail is disabled",
+ remove_thinking_traces=True,
+ apply_to_reasoning_traces=False,
+ expected_think_tag=True,
+ expected_error_message=False,
+ ),
+ ReasoningTraceTestCase(
+ description="Preserve thinking traces and show error when guardrail is enabled",
+ remove_thinking_traces=False,
+ apply_to_reasoning_traces=True,
+ expected_think_tag=True,
+ expected_error_message=True,
+ ),
+ ReasoningTraceTestCase(
+ description="Remove thinking traces and show error when both flags are disabled",
+ remove_thinking_traces=False,
+ apply_to_reasoning_traces=False,
+ expected_think_tag=True,
+ expected_error_message=True,
+ ),
+ ],
+ ids=lambda tc: tc.description,
+)
+async def test_output_rails_reasoning_traces_configuration(
+ base_config: RailsConfig,
+ test_case: ReasoningTraceTestCase,
+) -> None:
+ """Test output rails with different reasoning traces configurations.
+
+ The test verifies the following behaviors based on configuration:
+
+ 1. When remove_thinking_traces=True:
+ - The model is configured to remove thinking traces
+ - However, the actual removal depends on apply_to_reasoning_traces
+
+ 2. When apply_to_reasoning_traces=True:
+ - The output rail will check for and report think tags
+ - Because we expect the think tag to be present as output rails explicitly requires it
+
+ 3. When apply_to_reasoning_traces=False:
+ - The output rails will check for think tags
+ - No error message will be shown because it is not there to get blocked
+
+ """
+ base_config.models[
+ 0
+ ].reasoning_config.remove_thinking_traces = test_case.remove_thinking_traces
+ base_config.rails.output.apply_to_reasoning_traces = (
+ test_case.apply_to_reasoning_traces
+ )
+
+ chat = TestChat(
+ base_config,
+ llm_completions=[
+ " I should think more Your kindness is appreciated"
+ ],
+ )
+
+ chat.app.runtime.register_action(check_think_tag_present)
+
+ messages = [{"role": "user", "content": "you are nice"}]
+ response = await chat.app.generate_async(messages=messages)
+
+ if test_case.expected_think_tag:
+ assert (
+ "" in response["content"]
+ ), "Think tag should be present in response"
+ else:
+ assert (
+ "" not in response["content"]
+ ), "Think tag should not be present in response"
+
+ if test_case.expected_error_message:
+ assert (
+ "think tag is not allowed" in response["content"]
+ ), "Error message should be present"
+ else:
+ assert (
+ "think tag is not allowed" not in response["content"]
+ ), "Error message should not be present"
+
+
+@pytest.mark.asyncio
+async def test_output_rails_preserves_reasoning_traces() -> None:
+ """Test that output rails preserve reasoning traces when configured to do so."""
+ config = RailsConfig.from_content(
+ colang_content="""
+ define flow check sensitive info
+ $not_allowed = execute check_sensitive_info
+ if $not_allowed
+ bot provide sanitized response
+ stop
+ define bot provide sanitized response
+ "I cannot share sensitive information."
+ """,
+ yaml_content="""
+ models:
+ - type: main
+ engine: fake
+ model: fake
+ reasoning_config:
+ remove_thinking_traces: True
+ colang_version: "1.0"
+ rails:
+ output:
+ flows:
+ - check sensitive info
+ apply_to_reasoning_traces: True
+ """,
+ )
+
+ chat = TestChat(
+ config,
+ llm_completions=[
+ ' I should not share sensitive info \n "Here is my credit card: 1234-5678-9012-3456"',
+ ],
+ )
+
+ chat.app.runtime.register_action(check_sensitive_info)
+
+ messages = [{"role": "user", "content": "What's your credit card number?"}]
+ response = await chat.app.generate_async(messages=messages)
+
+ assert "" in response["content"], "Reasoning traces should be preserved"
+ assert (
+ "I should not share sensitive info" in response["content"]
+ ), "Reasoning content should be preserved"
+ assert (
+ "credit card" not in response["content"].lower()
+ ), "Sensitive information should be removed"
+
+
+@pytest.mark.asyncio
+async def test_output_rails_without_reasoning_traces() -> None:
+ """Test that output rails properly handle responses when reasoning traces are disabled."""
+ config = RailsConfig.from_content(
+ colang_content="""
+ define flow check sensitive info
+ $not_allowed = execute check_sensitive_info
+ if $not_allowed
+ bot provide sanitized response
+ stop
+ define flow check think tag
+ $not_allowed = execute check_think_tag_present
+ if $not_allowed
+ bot says tag not allowed
+ stop
+
+ define bot says tag not allowed
+ " tag is not allowed it must be removed"
+
+ define bot provide sanitized response
+ "I cannot share sensitive information."
+ """,
+ yaml_content="""
+ models:
+ - type: main
+ engine: fake
+ model: fake
+ reasoning_config:
+ remove_thinking_traces: True
+ colang_version: "1.0"
+ rails:
+ input:
+ flows:
+ - check sensitive info
+ output:
+ flows:
+ - check sensitive info
+ - check think tag
+ apply_to_reasoning_traces: false
+ """,
+ )
+
+ chat = TestChat(
+ config,
+ llm_completions=[
+ " I should think more Your credit card number is 1234-5678-9012-3456",
+ ],
+ )
+
+ chat.app.runtime.register_action(check_sensitive_info)
+ chat.app.runtime.register_action(check_think_tag_present)
+
+ # case 1: Sensitive information is blocked by input rail
+ messages = [{"role": "user", "content": "What's your credit card number?"}]
+ response = await chat.app.generate_async(messages=messages)
+
+ assert "" not in response["content"], "Think tag should not be present"
+ assert (
+ "I should not share sensitive info" not in response["content"]
+ ), "Reasoning content should not be present"
+ assert (
+ response["content"] == "I cannot share sensitive information."
+ ), "Should return sanitized response"
+
+ # case 2: Think tag is preserved but content is sanitized
+ messages = [{"role": "user", "content": "Tell me some numbers"}]
+ response = await chat.app.generate_async(messages=messages)
+
+ assert "" in response["content"], "Think tag should be present"
+ assert (
+ "I should not share sensitive info" not in response["content"]
+ ), "Reasoning content should not be present"
+ assert (
+ response["content"]
+ == " I should think more I cannot share sensitive information."
+ ), "Should preserve think tag but sanitize content"
diff --git a/tests/test_reasoning_trace_context.py b/tests/test_reasoning_trace_context.py
new file mode 100644
index 000000000..d1c0c6db3
--- /dev/null
+++ b/tests/test_reasoning_trace_context.py
@@ -0,0 +1,227 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from nemoguardrails import RailsConfig
+from nemoguardrails.actions.llm.utils import get_and_clear_reasoning_trace_contextvar
+from nemoguardrails.context import reasoning_trace_var
+from nemoguardrails.rails.llm.llmrails import GenerationOptions, GenerationResponse
+from tests.utils import TestChat
+
+
+def test_get_and_clear_reasoning_trace_contextvar():
+ """Test that it correctly gets and clears the trace."""
+ reasoning_trace_var.set(" oh COT again ")
+
+ result = get_and_clear_reasoning_trace_contextvar()
+
+ assert result == " oh COT again "
+ assert reasoning_trace_var.get() is None
+
+
+def test_get_and_clear_reasoning_trace_contextvar_empty():
+ """Test that it returns None when no trace exists."""
+ reasoning_trace_var.set(None)
+
+ result = get_and_clear_reasoning_trace_contextvar()
+
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_generate_async_trace_with_messages_and_options():
+ """Test generate_async prepends reasoning trace when using generation options and messages."""
+ config = RailsConfig.from_content(
+ colang_content="""
+ define user express greeting
+ "hi"
+ "hello"
+
+ define bot express greeting
+ "Hello! How can I assist you today?"
+
+ define flow
+ user express greeting
+ bot express greeting
+ """,
+ yaml_content="""
+ models: []
+ rails:
+ output:
+ apply_to_reasoning_traces: true
+ """,
+ )
+
+ chat = TestChat(
+ config,
+ llm_completions=[
+ "user express greeting",
+ "bot express greeting",
+ "Hello! How can I assist you today?",
+ ],
+ )
+
+ reasoning_trace_var.set(" yet another COT ")
+
+ options = GenerationOptions()
+ result = await chat.app.generate_async(
+ messages=[{"role": "user", "content": "hi"}], options=options
+ )
+
+ assert isinstance(result, GenerationResponse)
+ assert isinstance(result.response, list)
+ assert len(result.response) == 1
+ assert (
+ result.response[0]["content"]
+ == " yet another COT Hello! How can I assist you today?"
+ )
+ assert reasoning_trace_var.get() is None
+
+
+@pytest.mark.asyncio
+async def test_generate_async_trace_with_prompt_and_options():
+ """Test generate_async prepends reasoning trace using prompt and options"""
+ config = RailsConfig.from_content(
+ colang_content="""
+ define user express greeting
+ "hi"
+ "hello"
+
+ define bot express greeting
+ "Hello! How can I assist you today?"
+
+ define flow
+ user express greeting
+ bot express greeting
+ """,
+ yaml_content="""
+ models: []
+ rails:
+ output:
+ apply_to_reasoning_traces: true
+ """,
+ )
+
+ chat = TestChat(
+ config,
+ llm_completions=[
+ "user express greeting",
+ "bot express greeting",
+ "Hello! How can I assist you today?",
+ ],
+ )
+
+ reasoning_trace_var.set(" yet another COT ")
+
+ options = GenerationOptions()
+ result = await chat.app.generate_async(options=options, prompt="test prompt")
+
+ assert isinstance(result, GenerationResponse)
+ assert isinstance(result.response, str)
+ assert (
+ result.response
+ == " yet another COT Hello! How can I assist you today?"
+ )
+ assert reasoning_trace_var.get() is None
+
+
+@pytest.mark.asyncio
+async def test_generate_async_trace_messages_only():
+ """Test generate_async prepends reasoning trace when using only messages."""
+ config = RailsConfig.from_content(
+ colang_content="""
+ define user express greeting
+ "hi"
+ "hello"
+
+ define bot express greeting
+ "Hello! How can I assist you today?"
+
+ define flow
+ user express greeting
+ bot express greeting
+ """,
+ yaml_content="""
+ models: []
+ rails:
+ output:
+ apply_to_reasoning_traces: true
+ """,
+ )
+
+ chat = TestChat(
+ config,
+ llm_completions=[
+ "user express greeting",
+ "bot express greeting",
+ "Hello! How can I assist you today?",
+ ],
+ )
+
+ reasoning_trace_var.set(" yet another COT ")
+
+ result = await chat.app.generate_async(messages=[{"role": "user", "content": "hi"}])
+
+ assert isinstance(result, dict)
+ assert result.get("role") == "assistant"
+ assert (
+ result.get("content")
+ == " yet another COT Hello! How can I assist you today?"
+ )
+ assert reasoning_trace_var.get() is None
+
+
+@pytest.mark.asyncio
+async def test_generate_async_trace_with_prompt_only():
+ """Test generate_async prepends reasoning trace when using prompt."""
+ config = RailsConfig.from_content(
+ colang_content="""
+ define user express greeting
+ "hi"
+ "hello"
+
+ define bot express greeting
+ "Hello! How can I assist you today?"
+
+ define flow
+ user express greeting
+ bot express greeting
+ """,
+ yaml_content="""
+ models: []
+ rails:
+ output:
+ apply_to_reasoning_traces: true
+ """,
+ )
+
+ chat = TestChat(
+ config,
+ llm_completions=[
+ "user express greeting",
+ "bot express greeting",
+ "Hello! How can I assist you today?",
+ ],
+ )
+
+ reasoning_trace_var.set(" yet another COT ")
+
+ result = await chat.app.generate_async(prompt="hi")
+
+ assert (
+ result == " yet another COT Hello! How can I assist you today?"
+ )
+ assert reasoning_trace_var.get() is None
diff --git a/tests/test_reasoning_traces.py b/tests/test_reasoning_traces.py
index ce4edec1b..ad1d78e5a 100644
--- a/tests/test_reasoning_traces.py
+++ b/tests/test_reasoning_traces.py
@@ -17,20 +17,32 @@
import pytest
-from nemoguardrails.actions.llm.generation import LLMGenerationActions
+from nemoguardrails.actions.llm.generation import (
+ LLMGenerationActions,
+ _get_apply_to_reasoning_traces,
+ _process_parsed_output,
+)
from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx
from nemoguardrails.context import (
generation_options_var,
llm_call_info_var,
streaming_handler_var,
)
-from nemoguardrails.llm.filters import remove_reasoning_traces
-from nemoguardrails.llm.taskmanager import LLMTaskManager
+from nemoguardrails.llm.filters import extract_and_strip_trace
+from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
from nemoguardrails.llm.types import Task
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.rails.llm.config import Model, RailsConfig, ReasoningModelConfig
+def create_mock_config():
+ config = MagicMock(spec=RailsConfig)
+ config.rails = MagicMock()
+ config.rails.output = MagicMock()
+ config.rails.output.apply_to_reasoning_traces = False
+ return config
+
+
class TestReasoningTraces:
"""Test the reasoning traces functionality."""
@@ -38,8 +50,8 @@ def test_remove_reasoning_traces_basic(self):
"""Test basic removal of reasoning traces."""
input_text = "This is a \nSome reasoning here\nMore reasoning\n response."
expected = "This is a response."
- result = remove_reasoning_traces(input_text, "", "")
- assert result == expected
+ result = extract_and_strip_trace(input_text, "", "")
+ assert result.text == expected
def test_remove_reasoning_traces_multiline(self):
"""Test removal of multiline reasoning traces."""
@@ -52,8 +64,8 @@ def test_remove_reasoning_traces_multiline(self):
response after thinking.
"""
expected = "\n Here is my response after thinking.\n "
- result = remove_reasoning_traces(input_text, "", "")
- assert result == expected
+ result = extract_and_strip_trace(input_text, "", "")
+ assert result.text == expected
def test_remove_reasoning_traces_multiple_sections(self):
"""Test removal of multiple reasoning trace sections."""
@@ -61,8 +73,8 @@ def test_remove_reasoning_traces_multiple_sections(self):
# Note: The current implementation removes all content between the first start and last end token
# So the expected result is "Start end." not "Start middle end."
expected = "Start end."
- result = remove_reasoning_traces(input_text, "", "")
- assert result == expected
+ result = extract_and_strip_trace(input_text, "", "")
+ assert result.text == expected
def test_remove_reasoning_traces_nested(self):
"""Test handling of nested reasoning trace markers (should be handled correctly)."""
@@ -70,22 +82,21 @@ def test_remove_reasoning_traces_nested(self):
"Begin Outer Inner Outer End."
)
expected = "Begin End."
- result = remove_reasoning_traces(input_text, "", "")
- assert result == expected
+ result = extract_and_strip_trace(input_text, "", "")
+ assert result.text == expected
def test_remove_reasoning_traces_unmatched(self):
"""Test handling of unmatched reasoning trace markers."""
input_text = "Begin Unmatched end."
- result = remove_reasoning_traces(input_text, "", "")
+ result = extract_and_strip_trace(input_text, "", "")
# We ~hould keep the unmatched tag since it's not a complete section
- assert result == "Begin Unmatched end."
+ assert result.text == "Begin Unmatched end."
@pytest.mark.asyncio
async def test_task_manager_parse_task_output(self):
"""Test that the task manager correctly removes reasoning traces."""
# mock config
- config = MagicMock(spec=RailsConfig)
-
+ config = create_mock_config()
# Create a ReasoningModelConfig
reasoning_config = ReasoningModelConfig(
remove_thinking_traces=True,
@@ -121,12 +132,13 @@ async def test_task_manager_parse_task_output(self):
expected = "This is a final answer."
result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
- assert result == expected
+ assert result.text == expected
@pytest.mark.asyncio
async def test_parse_task_output_without_reasoning_config(self):
"""Test that parse_task_output works without a reasoning config."""
- config = MagicMock(spec=RailsConfig)
+
+ config = create_mock_config()
# a Model without reasoning_config
model_config = Model(type="main", engine="test", model="test-model")
@@ -147,18 +159,22 @@ async def test_parse_task_output_without_reasoning_config(self):
input_text = (
"This is a Some reasoning here final answer."
)
-
- # Without a reasoning config, the text should remain unchanged
result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
- assert result == input_text
+ assert result.text == input_text
@pytest.mark.asyncio
async def test_parse_task_output_with_default_reasoning_traces(self):
- """Test that parse_task_output works without a reasoning config."""
- config = MagicMock(spec=RailsConfig)
+ """Test that parse_task_output works with default reasoning traces."""
- # a Model without reasoning_config
- model_config = Model(type="main", engine="test", model="test-model")
+ config = create_mock_config()
+
+ # Create a Model with default reasoning_config
+ model_config = Model(
+ type="main",
+ engine="test",
+ model="test-model",
+ reasoning_config=ReasoningModelConfig(),
+ )
# Mock the get_prompt and get_task_model functions
with (
@@ -172,42 +188,51 @@ async def test_parse_task_output_with_default_reasoning_traces(self):
llm_task_manager = LLMTaskManager(config)
- # test parsing without a reasoning config
+ # test parsing with default reasoning traces
input_text = "This is a Some reasoning here final answer."
- expected = "This is a final answer."
-
- # without a reasoning config, the default start_token and stop_token are used thus the text should change
result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
- assert result == expected
+ assert result.text == "This is a final answer."
@pytest.mark.asyncio
async def test_parse_task_output_with_output_parser(self):
- """Test that parse_task_output correctly applies output parsers before returning."""
- config = MagicMock(spec=RailsConfig)
+ """Test that parse_task_output works with an output parser."""
- # mock output parser function
- def mock_parser(text):
- return text.upper()
+ config = create_mock_config()
- llm_task_manager = LLMTaskManager(config)
- llm_task_manager.output_parsers["test_parser"] = mock_parser
+ # Create a Model with reasoning_config
+ model_config = Model(
+ type="main",
+ engine="test",
+ model="test-model",
+ reasoning_config=ReasoningModelConfig(
+ remove_thinking_traces=True,
+ start_token="",
+ end_token="",
+ ),
+ )
- # mock the get_prompt and get_task_model functions
+ def mock_parser(text):
+ return f"PARSED: {text}"
+
+ # Mock the get_prompt and get_task_model functions
with (
patch("nemoguardrails.llm.taskmanager.get_prompt") as mock_get_prompt,
patch(
"nemoguardrails.llm.taskmanager.get_task_model"
) as mock_get_task_model,
):
- mock_get_prompt.return_value = MagicMock(output_parser="test_parser")
- mock_get_task_model.return_value = None
+ mock_get_prompt.return_value = MagicMock(output_parser="mock_parser")
+ mock_get_task_model.return_value = model_config
- # Test with output parser
- input_text = "this should be uppercase"
- expected = "THIS SHOULD BE UPPERCASE"
+ llm_task_manager = LLMTaskManager(config)
+ llm_task_manager.output_parsers["mock_parser"] = mock_parser
+ # test parsing with an output parser
+ input_text = (
+ "This is a Some reasoning here final answer."
+ )
result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
- assert result == expected
+ assert result.text == "PARSED: This is a final answer."
@pytest.mark.asyncio
async def test_passthrough_llm_action_removes_reasoning(self):
@@ -344,3 +369,47 @@ def __init__(self, events):
)
assert mock_result.events[0]["text"] == "This is a final answer."
+
+
+class TestProcessParsedOutput:
+ """Test the _process_parsed_output function."""
+
+ def test_process_parsed_output_with_reasoning_trace(self):
+ """Test processing output with reasoning trace when guardrail is enabled."""
+ result = ParsedTaskOutput(
+ text="final answer",
+ reasoning_trace="some reasoning",
+ )
+ output = _process_parsed_output(result, include_reasoning_trace=True)
+ assert output == "some reasoningfinal answer"
+
+ def test_process_parsed_output_with_reasoning_trace_disabled(self):
+ """Test processing output with reasoning trace when guardrail is disabled."""
+ result = ParsedTaskOutput(
+ text="final answer",
+ reasoning_trace="some reasoning",
+ )
+ output = _process_parsed_output(result, include_reasoning_trace=False)
+ assert output == "final answer"
+
+ def test_process_parsed_output_without_reasoning_trace(self):
+ """Test processing output without reasoning trace."""
+ result = ParsedTaskOutput(text="final answer", reasoning_trace=None)
+ output = _process_parsed_output(result, include_reasoning_trace=True)
+ assert output == "final answer"
+
+
+class TestGuardrailReasoningTraces:
+ """Test the guardrail reasoning traces configuration."""
+
+ def test_get_apply_to_reasoning_traces_enabled(self):
+ """Test getting guardrail reasoning traces when enabled."""
+ config = create_mock_config()
+ config.rails.output.apply_to_reasoning_traces = True
+ assert _get_apply_to_reasoning_traces(config) is True
+
+ def test_get_apply_to_reasoning_traces_disabled(self):
+ """Test getting guardrail reasoning traces when disabled."""
+ config = create_mock_config()
+ config.rails.output.apply_to_reasoning_traces = False
+ assert _get_apply_to_reasoning_traces(config) is False