diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index ed3fff9bb..2a57e1c26 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -50,13 +50,14 @@ generation_options_var, llm_call_info_var, raw_llm_request, + reasoning_trace_var, streaming_handler_var, ) from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem from nemoguardrails.kb.kb import KnowledgeBase from nemoguardrails.llm.params import llm_params from nemoguardrails.llm.prompts import get_prompt -from nemoguardrails.llm.taskmanager import LLMTaskManager +from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput from nemoguardrails.llm.types import Task from nemoguardrails.logging.explain import LLMCallInfo from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop @@ -442,6 +443,7 @@ async def generate_user_intent( result = self.llm_task_manager.parse_task_output( Task.GENERATE_USER_INTENT, output=result ) + result = result.text user_intent = get_first_nonempty_line(result) if user_intent is None: @@ -530,6 +532,11 @@ async def generate_user_intent( text = self.llm_task_manager.parse_task_output( Task.GENERAL, output=text ) + + text = _process_parsed_output( + text, self._include_reasoning_traces() + ) + else: # Initialize the LLMCallInfo object llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value)) @@ -565,6 +572,8 @@ async def generate_user_intent( text = self.llm_task_manager.parse_task_output( Task.GENERAL, output=result ) + + text = _process_parsed_output(text, self._include_reasoning_traces()) text = text.strip() if text.startswith('"'): text = text[1:-1] @@ -646,6 +655,7 @@ async def generate_next_step( result = self.llm_task_manager.parse_task_output( Task.GENERATE_NEXT_STEPS, output=result ) + result = result.text # If we don't have multi-step generation enabled, we only look at the first line. if not self.config.enable_multi_step_generation: @@ -900,6 +910,10 @@ async def generate_bot_message( Task.GENERAL, output=result ) + result = _process_parsed_output( + result, self._include_reasoning_traces() + ) + log.info( "--- :: LLM Bot Message Generation passthrough call took %.2f seconds", time() - t0, @@ -963,6 +977,10 @@ async def generate_bot_message( Task.GENERATE_BOT_MESSAGE, output=result ) + result = _process_parsed_output( + result, self._include_reasoning_traces() + ) + # TODO: catch openai.error.InvalidRequestError from exceeding max token length result = get_multiline_response(result) @@ -1055,6 +1073,7 @@ async def generate_value( result = self.llm_task_manager.parse_task_output( Task.GENERATE_VALUE, output=result ) + result = result.text # We only use the first line for now # TODO: support multi-line values? @@ -1266,6 +1285,7 @@ async def generate_intent_steps_message( result = self.llm_task_manager.parse_task_output( Task.GENERATE_INTENT_STEPS_MESSAGE, output=result ) + result = result.text # TODO: Implement logic for generating more complex Colang next steps (multi-step), # not just a single bot intent. @@ -1348,6 +1368,7 @@ async def generate_intent_steps_message( result = self.llm_task_manager.parse_task_output( Task.GENERAL, output=result ) + result = _process_parsed_output(result, self._include_reasoning_traces()) text = result.strip() if text.startswith('"'): text = text[1:-1] @@ -1360,6 +1381,10 @@ async def generate_intent_steps_message( events=[new_event_dict("BotMessage", text=text)], ) + def _include_reasoning_traces(self) -> bool: + """Get the configuration value for whether to include reasoning traces in output.""" + return _get_apply_to_reasoning_traces(self.config) + def clean_utterance_content(utterance: str) -> str: """ @@ -1377,3 +1402,27 @@ def clean_utterance_content(utterance: str) -> str: # It should be translated to an actual \n character. utterance = utterance.replace("\\n", "\n") return utterance + + +def _record_reasoning_trace(trace: str) -> None: + """Store the reasoning trace in context for later retrieval.""" + reasoning_trace_var.set(trace) + + +def _assemble_response(text: str, trace: Optional[str], include_reasoning: bool) -> str: + """Combine trace and text if requested, otherwise just return text.""" + return (trace + text) if (trace and include_reasoning) else text + + +def _process_parsed_output( + output: ParsedTaskOutput, include_reasoning_trace: bool +) -> str: + """Record trace, then assemble the final LLM response.""" + if reasoning_trace := output.reasoning_trace: + _record_reasoning_trace(reasoning_trace) + return _assemble_response(output.text, reasoning_trace, include_reasoning_trace) + + +def _get_apply_to_reasoning_traces(config: RailsConfig) -> bool: + """Get the configuration value for whether to include reasoning traces in output.""" + return config.rails.output.apply_to_reasoning_traces diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index 012043c04..e58a1aba5 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -24,7 +24,7 @@ from nemoguardrails.colang.v2_x.lang.colang_ast import Flow from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents -from nemoguardrails.context import llm_call_info_var +from nemoguardrails.context import llm_call_info_var, reasoning_trace_var from nemoguardrails.logging.callbacks import logging_callbacks from nemoguardrails.logging.explain import LLMCallInfo @@ -192,7 +192,7 @@ def get_colang_history( and event["action_name"] == "retrieve_relevant_chunks" ): continue - history += f'execute {event["action_name"]}\n' + history += f"execute {event['action_name']}\n" elif event["type"] == "InternalSystemActionFinished" and not event.get( "is_system_action" ): @@ -577,3 +577,15 @@ def escape_flow_name(name: str) -> str: # removes non-word chars and leading digits in a word result = re.sub(r"\b\d+|[^\w\s]", "", result) return result + + +def get_and_clear_reasoning_trace_contextvar() -> Optional[str]: + """Get the current reasoning trace and clear it from the context. + + Returns: + Optional[str]: The reasoning trace if one exists, None otherwise. + """ + if reasoning_trace := reasoning_trace_var.get(): + reasoning_trace_var.set(None) + return reasoning_trace + return None diff --git a/nemoguardrails/actions/v2_x/generation.py b/nemoguardrails/actions/v2_x/generation.py index 7dd5df8c4..e011379b8 100644 --- a/nemoguardrails/actions/v2_x/generation.py +++ b/nemoguardrails/actions/v2_x/generation.py @@ -197,7 +197,7 @@ async def _collect_user_intent_and_examples( # We add these in reverse order so the most relevant is towards the end. for result in reversed(results): - examples += f"user action: user said \"{result.text}\"\nuser intent: {result.meta['intent']}\n\n" + examples += f'user action: user said "{result.text}"\nuser intent: {result.meta["intent"]}\n\n' if result.meta["intent"] not in potential_user_intents: potential_user_intents.append(result.meta["intent"]) @@ -302,6 +302,8 @@ async def generate_user_intent( Task.GENERATE_USER_INTENT_FROM_USER_ACTION, output=result ) + result = result.text + user_intent = get_first_nonempty_line(result) # GTP-4o often adds 'user intent: ' in front if user_intent and ":" in user_intent: @@ -378,6 +380,8 @@ async def generate_user_intent_and_bot_action( Task.GENERATE_USER_INTENT_AND_BOT_ACTION_FROM_USER_ACTION, output=result ) + result = result.text + user_intent = get_first_nonempty_line(result) if user_intent and ":" in user_intent: @@ -458,6 +462,8 @@ async def passthrough_llm_action( text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text) + text = result.text + return text @action(name="CheckValidFlowExistsAction", is_system_action=True) @@ -541,6 +547,8 @@ async def generate_flow_from_instructions( task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result ) + result = result.text + # TODO: why this is not part of a filter or output_parser? # lines = _remove_leading_empty_lines(result).split("\n") @@ -613,6 +621,8 @@ async def generate_flow_from_name( task=Task.GENERATE_FLOW_FROM_NAME, output=result ) + result = result.text + lines = _remove_leading_empty_lines(result).split("\n") if lines[0].startswith("flow"): @@ -680,6 +690,8 @@ async def generate_flow_continuation( task=Task.GENERATE_FLOW_CONTINUATION, output=result ) + result = result.text + lines = _remove_leading_empty_lines(result).split("\n") if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""): @@ -806,6 +818,8 @@ async def generate_value( Task.GENERATE_VALUE_FROM_INSTRUCTION, output=result ) + result = result.text + # We only use the first line for now # TODO: support multi-line values? value = result.strip().split("\n")[0] @@ -913,6 +927,8 @@ async def generate_flow( Task.GENERATE_FLOW_CONTINUATION_FROM_NLD, output=result ) + result = result.text + result = _remove_leading_empty_lines(result) lines = result.split("\n") if "codeblock" in lines[0]: diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py index 9f1d2fb6a..e66f1a0d5 100644 --- a/nemoguardrails/context.py +++ b/nemoguardrails/context.py @@ -14,6 +14,7 @@ # limitations under the License. import contextvars +from typing import Optional streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None) @@ -32,3 +33,7 @@ # The raw LLM request that comes from the user. # This is used in passthrough mode. raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None) + +reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + "reasoning_trace", default=None +) diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index 7c1cbf406..90022cbc4 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -80,6 +80,7 @@ async def content_safety_check_input( result = await llm_call(llm, check_input_prompt, stop=stop) result = llm_task_manager.parse_task_output(task, output=result) + result = result.text try: is_safe, violated_policies = result @@ -162,6 +163,8 @@ async def content_safety_check_output( result = llm_task_manager.parse_task_output(task, output=result) + result = result.text + try: is_safe, violated_policies = result except TypeError: diff --git a/nemoguardrails/library/self_check/facts/actions.py b/nemoguardrails/library/self_check/facts/actions.py index 0f9c6d7ec..fb75ef72d 100644 --- a/nemoguardrails/library/self_check/facts/actions.py +++ b/nemoguardrails/library/self_check/facts/actions.py @@ -82,6 +82,7 @@ async def self_check_facts( task, output=response, forced_output_parser="is_content_safe" ) + result = result.text is_not_safe, _ = result result = float(not is_not_safe) diff --git a/nemoguardrails/library/self_check/input_check/actions.py b/nemoguardrails/library/self_check/input_check/actions.py index 1bf778129..8005f0724 100644 --- a/nemoguardrails/library/self_check/input_check/actions.py +++ b/nemoguardrails/library/self_check/input_check/actions.py @@ -83,6 +83,7 @@ async def self_check_input( task, output=response, forced_output_parser="is_content_safe" ) + result = result.text is_safe, _ = result if not is_safe: diff --git a/nemoguardrails/library/self_check/output_check/actions.py b/nemoguardrails/library/self_check/output_check/actions.py index c8ed9ca6c..8bbcdf42e 100644 --- a/nemoguardrails/library/self_check/output_check/actions.py +++ b/nemoguardrails/library/self_check/output_check/actions.py @@ -87,6 +87,7 @@ async def self_check_output( task, output=response, forced_output_parser="is_content_safe" ) + result = result.text is_safe, _ = result return is_safe diff --git a/nemoguardrails/llm/filters.py b/nemoguardrails/llm/filters.py index abf988c77..a0d80bb5d 100644 --- a/nemoguardrails/llm/filters.py +++ b/nemoguardrails/llm/filters.py @@ -15,7 +15,8 @@ import re import textwrap -from typing import List +from dataclasses import dataclass +from typing import List, Optional, Tuple from nemoguardrails.actions.llm.utils import ( get_colang_history, @@ -23,6 +24,16 @@ ) +@dataclass +class ReasoningExtractionResult: + """ + Holds cleaned response text and optional chain-of-thought reasoning trace extracted from LLM output. + """ + + text: str + reasoning_trace: Optional[str] = None + + def colang(events: List[dict]) -> str: """Filter that turns an array of events into a colang history.""" return get_colang_history(events) @@ -439,24 +450,98 @@ def conversation_to_events(conversation: List) -> List[dict]: return events -def remove_reasoning_traces(response: str, start_token: str, end_token: str) -> str: - """Removes the text between the first occurrence of the start token and the - last occurrence of the last token, if these tokens exist in the response. +def _find_token_positions_for_removal( + response: str, start_token: Optional[str], end_token: Optional[str] +) -> Tuple[int, int]: + """Helper function to find token positions specifically for text removal. + + This is useful, for example, to remove reasoning traces from a reasoning LLM response. + + This is optimized for the removal use case: + 1. Uses find() for first start token + 2. Uses rfind() for last end token + 3. Sets start_index to 0 if start token is missing - This utility function is useful to strip reasoning traces from reasoning LLMs - that encode the reasoning traces between specific tokens. + Args: + response(str): The text to search in + start_token(str): The token marking the start of text to remove + end_token(str): The token marking the end of text to remove + + Returns: + A tuple of (start_index, end_index) marking the span to remove; + both indices are -1 if start_token and end_token are not provided. """ - if start_token and end_token: - start_index = response.find(start_token) - # If the start index is missing, this is probably a continuation of a bot message - # started in the prompt. - if start_index == -1: - start_index = 0 - end_index = response.rfind(end_token) - if end_index == -1: - return response - - if start_index != -1 and end_index != -1 and start_index < end_index: - return response[:start_index] + response[end_index + len(end_token) :] - - return response + if not start_token or not end_token: + return -1, -1 + + start_index = response.find(start_token) + # if the start index is missing, this is probably a continuation of a bot message + # started in the prompt. + if start_index == -1: + start_index = 0 + + end_index = response.rfind(end_token) + + return start_index, end_index + + +def find_reasoning_tokens_position( + response: str, start_token: Optional[str], end_token: Optional[str] +) -> Tuple[int, int]: + """Finds the positions of the first start token and the last end token. + + This is intended to find the outermost boundaries of potential + reasoning sections, typically for removal. + + Args: + response(str): The text to search in. + start_token(Optional[str]): The token marking the start of reasoning. + end_token(Optional[str]): The token marking the end of reasoning. + + Returns: + A tuple (start_index, end_index). + - start_index: Position of the first `start_token`, or 0 if not found. + - end_index: Position of the last `end_token`, or -1 if not found. + """ + + return _find_token_positions_for_removal(response, start_token, end_token) + + +def extract_and_strip_trace( + response: str, start_token: str, end_token: str +) -> ReasoningExtractionResult: + """Extracts and removes reasoning traces from the given text. + + This function identifies reasoning traces in the text that are marked + by specific start and end tokens. It extracts these traces, removes + them from the original text, and returns both the cleaned text and + the extracted reasoning trace. + + Args: + response (str): The text to process. + start_token (str): The token marking the start of a reasoning trace. + end_token (str): The token marking the end of a reasoning trace. + + Returns: + ReasoningExtractionResult: An object containing the cleaned text + without reasoning traces and the extracted reasoning trace, if any. + """ + + start_index, end_index = find_reasoning_tokens_position( + response, start_token, end_token + ) + # handles invalid/empty tokens returned as (-1, -1) + if start_index == -1 and end_index == -1: + return ReasoningExtractionResult(text=response, reasoning_trace=None) + # end token is missing + if end_index == -1: + return ReasoningExtractionResult(text=response, reasoning_trace=None) + # extrace if tokens are present and start < end + if start_index < end_index: + reasoning_trace = response[start_index : end_index + len(end_token)] + cleaned_text = response[:start_index] + response[end_index + len(end_token) :] + return ReasoningExtractionResult( + text=cleaned_text, reasoning_trace=reasoning_trace + ) + + return ReasoningExtractionResult(text=response, reasoning_trace=None) diff --git a/nemoguardrails/llm/taskmanager.py b/nemoguardrails/llm/taskmanager.py index a9242cded..bb3a1cde2 100644 --- a/nemoguardrails/llm/taskmanager.py +++ b/nemoguardrails/llm/taskmanager.py @@ -16,7 +16,8 @@ import logging import re from ast import literal_eval -from typing import Any, Callable, List, Optional, Union +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union from jinja2 import meta from jinja2.sandbox import SandboxedEnvironment @@ -25,10 +26,10 @@ co_v2, colang, colang_without_identifiers, + extract_and_strip_trace, first_turns, indent, last_turns, - remove_reasoning_traces, remove_text_messages, to_chat_messages, to_intent_messages, @@ -52,13 +53,62 @@ from nemoguardrails.rails.llm.config import MessageTemplate, RailsConfig +def output_has_reasoning_traces(output: str, start_token: str, end_token: str) -> bool: + """Checks if the output string contains both start and end reasoning tokens.""" + return start_token in output and end_token in output + + +@dataclass +class ParsedTaskOutput: + """ + Encapsulates the result of running and parsing an LLM task. + + Attributes: + text (str): The cleaned and parsed output string, representing + the main result of the task. + reasoning_trace (Optional[str]): An optional chain-of-thought + reasoning trace, providing insights into the reasoning + process behind the task output, if available. + """ + + text: str + reasoning_trace: Optional[str] = None + + +def should_remove_reasoning_traces_from_output(config, task): + model = get_task_model(config, task) + + model_config = ( + model + and model.reasoning_config + and model.reasoning_config.remove_thinking_traces + ) + + if config.rails.output.apply_to_reasoning_traces: + return False + else: + return model_config + + +def get_reasoning_token_tags(config, task): + model = get_task_model(config, task) + + if model and model.reasoning_config: + start_token = model.reasoning_config.start_token + end_token = model.reasoning_config.end_token + else: + start_token = None + end_token = None + + return start_token, end_token + + class LLMTaskManager: """Interface for interacting with an LLM in a task-oriented way.""" def __init__(self, config: RailsConfig): # Save the config as we need access to instructions and sample conversations. self.config = config - # Initialize the environment for rendering templates. self.env = SandboxedEnvironment() @@ -78,7 +128,7 @@ def __init__(self, config: RailsConfig): self.env.filters["to_chat_messages"] = to_chat_messages self.env.filters["verbose_v1"] = verbose_v1 - self.output_parsers = { + self.output_parsers: Dict[Optional[str], Callable] = { "user_intent": user_intent_parser, "bot_intent": bot_intent_parser, "bot_message": bot_message_parser, @@ -308,36 +358,50 @@ def render_task_prompt( def parse_task_output( self, task: Task, output: str, forced_output_parser: Optional[str] = None - ): - """Parses the output for the provided tasks. - - If an output parser is associated with the prompt, it will be used. - Otherwise, the output is returned as is. + ) -> ParsedTaskOutput: + """Parses the output of a task, optionally extracting reasoning traces. + + Args: + task (Task): The task for which the output is being parsed. + output (str): The output string to be parsed. + forced_output_parser (Optional[str]): An optional parser name to force + + Returns: + ParsedTaskOutput: An object containing the parsed text (which may + include or exclude reasoning traces based on configuration) and + any reasoning trace. """ - prompt = get_prompt(self.config, task) + reasoning_trace: Optional[str] = None - output_parser = None - if forced_output_parser: - output_parser = self.output_parsers.get(forced_output_parser) - elif prompt.output_parser: - output_parser = self.output_parsers.get(prompt.output_parser) - if not output_parser: - logging.info("No output parser found for %s", prompt.output_parser) + # Get the tokens first to check for their presence + start_token, end_token = get_reasoning_token_tags(self.config, task) - model = get_task_model(self.config, task) + # 1. strip and capture reasoning traces if configured and present if ( - model - and model.reasoning_config - and model.reasoning_config.remove_thinking_traces + start_token + and end_token + and output_has_reasoning_traces(output, start_token, end_token) ): - start_token = model.reasoning_config.start_token - end_token = model.reasoning_config.end_token - output = remove_reasoning_traces(output, start_token, end_token) + reasoning_trace_result = extract_and_strip_trace( + output, start_token, end_token + ) + reasoning_trace = reasoning_trace_result.reasoning_trace + + if should_remove_reasoning_traces_from_output(self.config, task): + output = reasoning_trace_result.text + + # 2. delegate to existing parser + prompt = get_prompt(self.config, task) + parser_name = forced_output_parser or prompt.output_parser + parser_fn = self.output_parsers.get(parser_name) - if output_parser: - return output_parser(output) + if parser_fn: + parsed_text = parser_fn(output) else: - return output + logging.info("No output parser found for %s", prompt.output_parser) + parsed_text = output + + return ParsedTaskOutput(text=parsed_text, reasoning_trace=reasoning_trace) def has_output_parser(self, task: Task): prompt = get_prompt(self.config, task) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 8d2502eb0..e776e6653 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -453,6 +453,15 @@ class OutputRails(BaseModel): description="Configuration for streaming output rails.", ) + apply_to_reasoning_traces: bool = Field( + default=False, + description=( + "If True, output rails will apply guardrails to both reasoning traces and output response. " + "If False, output rails will only apply guardrails to the output response excluding the reasoning traces, " + "thus keeping reasoning traces unaltered." + ), + ) + class RetrievalRails(BaseModel): """Configuration of retrieval rails.""" diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 3b667fee1..11876bdba 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -31,7 +31,10 @@ from langchain_core.language_models.llms import BaseLLM from nemoguardrails.actions.llm.generation import LLMGenerationActions -from nemoguardrails.actions.llm.utils import get_colang_history +from nemoguardrails.actions.llm.utils import ( + get_and_clear_reasoning_trace_contextvar, + get_colang_history, +) from nemoguardrails.actions.output_mapping import is_output_blocked from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx from nemoguardrails.colang import parse_colang_file @@ -48,6 +51,7 @@ generation_options_var, llm_stats_var, raw_llm_request, + reasoning_trace_var, streaming_handler_var, ) from nemoguardrails.embeddings.index import EmbeddingsIndex @@ -838,6 +842,14 @@ async def generate_async( else: res = GenerationResponse(response=[new_message]) + if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): + if prompt: + res.response = reasoning_trace + res.response + else: + res.response[0]["content"] = ( + reasoning_trace + res.response[0]["content"] + ) + if self.config.colang_version == "1.0": # If output variables are specified, we extract their values if options.output_vars: @@ -926,9 +938,14 @@ async def generate_async( input=messages, response=res, adapters=self._log_adapters ) await tracer.export_async() + return res else: # If a prompt is used, we only return the content of the message. + + if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): + new_message["content"] = reasoning_trace + new_message["content"] + if prompt: return new_message["content"] else: diff --git a/tests/conftest.py b/tests/conftest.py index 12fb06e89..34d62ba8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,27 @@ import pytest +from nemoguardrails.context import reasoning_trace_var + def pytest_configure(config): patch("prompt_toolkit.PromptSession", autospec=True).start() + + +@pytest.fixture(autouse=True) +def reset_reasoning_trace(): + """Reset the reasoning_trace_var before each test. + + This fixture runs automatically for every test (autouse=True) to ensure + a clean state for the reasoning trace context variable. + + current Issues with ContextVar approach, not only specific to this case: + global State: ContextVar creates global state that's hard to track and manage + implicit Flow: The reasoning trace flows through the system in a non-obvious way + testing Complexity: It causes test isolation problems that we are trying to avoid using this fixture + """ + # reset the variable before the test + reasoning_trace_var.set(None) + yield + # reset the variable after the test as well (in case the test fails) + reasoning_trace_var.set(None) diff --git a/tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py b/tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py index 1ed061a4b..4ba550383 100644 --- a/tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py +++ b/tests/test_configs/with_custom_llm_prompt_action_v2_x/actions.py @@ -50,6 +50,7 @@ async def custom_llm_request( result = await llm_call(llm, prompt, stop=stop) result = llm_task_manager.parse_task_output(prompt_template_name, output=result) + result = result.text # Any additional parsing of the output value = result.strip().split("\n")[0] diff --git a/tests/test_filters.py b/tests/test_filters.py index ae8bba489..f97b288d2 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -14,13 +14,16 @@ # limitations under the License. import textwrap +from typing import List, Tuple, Union import pytest from nemoguardrails.llm.filters import ( + ReasoningExtractionResult, + extract_and_strip_trace, + find_reasoning_tokens_position, first_turns, last_turns, - remove_reasoning_traces, to_chat_messages, user_assistant_sequence, ) @@ -92,51 +95,228 @@ def test_last_turns(): assert last_turns(colang_history, 2) == colang_history +def _build_test_string(parts: List[Union[str, Tuple[str, int]]]) -> str: + """Builds a test string from a list of parts. + + Each part can be a literal string or a (character, count) tuple. + Example: [("a", 3), "[START]", ("b", 5)] -> "aaa[START]bbbbb" + """ + result = [] + for part in parts: + if isinstance(part, str): + result.append(part) + elif isinstance(part, tuple) and len(part) == 2: + char, count = part + result.append(char * count) + else: + raise TypeError(f"Invalid part type in _build_test_string: {part}") + return "".join(result) + + @pytest.mark.parametrize( "response, start_token, end_token, expected", [ ( - "This is an example [START]hidden reasoning[END] of a response.", + _build_test_string( + [ + ("a", 5), + "[START]", + ("b", 10), + "[END]", + ("c", 5), + ] + ), "[START]", "[END]", - "This is an example of a response.", + (5, 22), # 5 a's + 7 START + 10 b = 22 + ), + # multiple reasoning sections + ( + _build_test_string( + [ + ("a", 3), + "[START]", + ("b", 4), + "[END]", + ("c", 3), + "[START]", + ("d", 4), + "[END]", + ("e", 3), + ] + ), + "[START]", + "[END]", + ( + 3, + 33, + ), + ), + ( + _build_test_string( + [ + ("a", 2), + "[START]", + ("b", 2), + "[START]", + ("c", 2), + "[END]", + ("d", 2), + "[END]", + ("e", 2), + ] + ), + "[START]", + "[END]", + ( + 2, + 27, + ), + ), + ( + _build_test_string([("a", 10)]), + "[START]", + "[END]", + (0, -1), # no tokens found, start_index is 0 ), ( - "This is an example without an end token.", + _build_test_string( + [ + ("a", 5), + "[START]", + ("b", 5), + ] + ), "[START]", "[END]", - "This is an example without an end token.", + (5, -1), # [START] at pos 5, no end token ), ( - "This is an example [START] with a start token but no end token.", + _build_test_string( + [ + ("a", 5), + "[END]", + ("b", 5), + ] + ), + "[START]", + "[END]", + (0, 5), # no start token so 0, end at pos 5 + ), + ( + "", + "[START]", + "[END]", + (0, -1), # empty string, start_index is 0 + ), + ], +) +def test_find_token_positions_for_removal(response, start_token, end_token, expected): + """Test finding token positions for removal. + + Test cases use _build_test_string for clarity and mathematical obviousness. + """ + assert find_reasoning_tokens_position(response, start_token, end_token) == expected + + +@pytest.mark.parametrize( + "response, start_token, end_token, expected_text, expected_trace", + [ + ( + "This is an example [START]hidden reasoning[END] of a response.", "[START]", "[END]", - "This is an example [START] with a start token but no end token.", + "This is an example of a response.", + "[START]hidden reasoning[END]", ), ( - "Before [START]hidden[END] middle [START]extra hidden[END] after.", + "Before [START]first[END] middle [START]second[END] after.", "[START]", "[END]", "Before after.", + "[START]first[END] middle [START]second[END]", ), ( "Text [START] first [START] nested [END] second [END] more text.", "[START]", "[END]", "Text more text.", + "[START] first [START] nested [END] second [END]", + ), + ( + "No tokens here", + "[START]", + "[END]", + "No tokens here", + None, + ), + ( + "Only [START] start token", + "[START]", + "[END]", + "Only [START] start token", + None, + ), + ( + "Only end token [END]", + "[START]", + "[END]", + "", + "Only end token [END]", + ), + ( + "", + "[START]", + "[END]", + "", + None, + ), + # End token before start token (tests the final return path) + ( + "some [END] text [START]", + "[START]", + "[END]", + "some [END] text [START]", + None, + ), + # Original test cases adapted + ( + "[END] Out of order [START] tokens [END] example.", + "[START]", + "[END]", + "[END] Out of order example.", + "[START] tokens [END]", + ), + ( + "[START] nested [START] tokens [END] out of [END] order.", + "[START]", + "[END]", + " order.", + "[START] nested [START] tokens [END] out of [END]", + ), + ( + "[END] [START] [START] example [END] text.", + "[START]", + "[END]", + "[END] text.", + "[START] [START] example [END]", ), ( - "[START]Remove this[END] but keep this.", + "example text.", "[START]", "[END]", - " but keep this.", + "example text.", + None, ), - ("", "[START]", "[END]", ""), ], ) -def test_remove_reasoning_traces(response, start_token, end_token, expected): - """Test removal of text between start and end tokens with multiple cases.""" - assert remove_reasoning_traces(response, start_token, end_token) == expected +def test_extract_and_strip_trace( + response, start_token, end_token, expected_text, expected_trace +): + """Tests the extraction and stripping of reasoning traces.""" + result = extract_and_strip_trace(response, start_token, end_token) + assert result.text == expected_text + assert result.reasoning_trace == expected_trace class TestToChatMessages: diff --git a/tests/test_llmrails_reasoning.py b/tests/test_llmrails_reasoning.py index a5575f423..a016ed77c 100644 --- a/tests/test_llmrails_reasoning.py +++ b/tests/test_llmrails_reasoning.py @@ -69,10 +69,10 @@ def rails_config(): async def test_1(rails_config): llm = FakeLLM( responses=[ - "some text\n express greeting", - "some text\n ask math question", - 'some text\n "The answer is 5"', - 'some text\n "Are you happy with the result?"', + "some redundant CoT text 1\n express greeting", + "some redundant CoT text 2\n ask math question", + 'some redundant CoT text 3\n "The answer is 5"', + 'some important COT text\n "Are you happy with the result?"', ] ) @@ -92,5 +92,5 @@ async def compute(what: Optional[str] = "2 + 3"): bot_message = await llm_rails.generate_async(messages=messages) assert bot_message == { "role": "assistant", - "content": "The answer is 5\nAre you happy with the result?", + "content": "some important COT textThe answer is 5\nAre you happy with the result?", } diff --git a/tests/test_llmrails_reasoning_output_rails.py b/tests/test_llmrails_reasoning_output_rails.py new file mode 100644 index 000000000..1b6f7ae13 --- /dev/null +++ b/tests/test_llmrails_reasoning_output_rails.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LLM Rails reasoning output configuration and behavior. + +This module contains tests that verify the behavior of LLM Rails when handling +reasoning traces in the output, including configuration options and guardrail +behavior. +""" + +from typing import Any, Dict, NamedTuple + +import pytest + +from nemoguardrails import RailsConfig +from tests.utils import TestChat + + +class ReasoningTraceTestCase(NamedTuple): + """Test case for reasoning trace configuration. + + Attributes: + description: description of the test case + remove_thinking_traces: Whether to remove thinking traces in the model config + apply_to_reasoning_traces: Whether to apply output rails to reasoning traces + expected_think_tag: Whether the think tag should be present in the response + expected_error_message: Whether the error message should be present in the response + """ + + description: str + remove_thinking_traces: bool + apply_to_reasoning_traces: bool + expected_think_tag: bool + expected_error_message: bool + + +async def check_sensitive_info(context: Dict[str, Any]) -> bool: + """Check if the response contains sensitive information.""" + response = context.get("bot_message", "") + prompt = context.get("user_message", "") + input_text = response or prompt + return "credit card" in input_text.lower() or any( + c.isdigit() for c in input_text if c.isdigit() or c == "-" + ) + + +async def check_think_tag_present(context: Dict[str, Any]) -> bool: + """Check if the think tag is present in the bot's response.""" + response = context.get("bot_message", "") + return "" in response + + +@pytest.fixture +def base_config() -> RailsConfig: + """Creates a base RailsConfig with common test configuration.""" + return RailsConfig.from_content( + colang_content=""" + define flow check think tag + $not_allowed = execute check_think_tag_present + if $not_allowed + bot informs tag not allowed + stop + + define bot informs tag not allowed + "think tag is not allowed it must be removed" + """, + yaml_content=""" + models: + - type: main + engine: fake + model: fake + colang_version: "1.0" + rails: + output: + flows: + - check think tag + """, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_case", + [ + ReasoningTraceTestCase( + description="Remove thinking traces and show error when guardrail is enabled", + remove_thinking_traces=True, + apply_to_reasoning_traces=True, + expected_think_tag=True, + expected_error_message=True, + ), + ReasoningTraceTestCase( + description="Preserve thinking traces and hide error when guardrail is disabled", + remove_thinking_traces=True, + apply_to_reasoning_traces=False, + expected_think_tag=True, + expected_error_message=False, + ), + ReasoningTraceTestCase( + description="Preserve thinking traces and show error when guardrail is enabled", + remove_thinking_traces=False, + apply_to_reasoning_traces=True, + expected_think_tag=True, + expected_error_message=True, + ), + ReasoningTraceTestCase( + description="Remove thinking traces and show error when both flags are disabled", + remove_thinking_traces=False, + apply_to_reasoning_traces=False, + expected_think_tag=True, + expected_error_message=True, + ), + ], + ids=lambda tc: tc.description, +) +async def test_output_rails_reasoning_traces_configuration( + base_config: RailsConfig, + test_case: ReasoningTraceTestCase, +) -> None: + """Test output rails with different reasoning traces configurations. + + The test verifies the following behaviors based on configuration: + + 1. When remove_thinking_traces=True: + - The model is configured to remove thinking traces + - However, the actual removal depends on apply_to_reasoning_traces + + 2. When apply_to_reasoning_traces=True: + - The output rail will check for and report think tags + - Because we expect the think tag to be present as output rails explicitly requires it + + 3. When apply_to_reasoning_traces=False: + - The output rails will check for think tags + - No error message will be shown because it is not there to get blocked + + """ + base_config.models[ + 0 + ].reasoning_config.remove_thinking_traces = test_case.remove_thinking_traces + base_config.rails.output.apply_to_reasoning_traces = ( + test_case.apply_to_reasoning_traces + ) + + chat = TestChat( + base_config, + llm_completions=[ + " I should think more Your kindness is appreciated" + ], + ) + + chat.app.runtime.register_action(check_think_tag_present) + + messages = [{"role": "user", "content": "you are nice"}] + response = await chat.app.generate_async(messages=messages) + + if test_case.expected_think_tag: + assert ( + "" in response["content"] + ), "Think tag should be present in response" + else: + assert ( + "" not in response["content"] + ), "Think tag should not be present in response" + + if test_case.expected_error_message: + assert ( + "think tag is not allowed" in response["content"] + ), "Error message should be present" + else: + assert ( + "think tag is not allowed" not in response["content"] + ), "Error message should not be present" + + +@pytest.mark.asyncio +async def test_output_rails_preserves_reasoning_traces() -> None: + """Test that output rails preserve reasoning traces when configured to do so.""" + config = RailsConfig.from_content( + colang_content=""" + define flow check sensitive info + $not_allowed = execute check_sensitive_info + if $not_allowed + bot provide sanitized response + stop + define bot provide sanitized response + "I cannot share sensitive information." + """, + yaml_content=""" + models: + - type: main + engine: fake + model: fake + reasoning_config: + remove_thinking_traces: True + colang_version: "1.0" + rails: + output: + flows: + - check sensitive info + apply_to_reasoning_traces: True + """, + ) + + chat = TestChat( + config, + llm_completions=[ + ' I should not share sensitive info \n "Here is my credit card: 1234-5678-9012-3456"', + ], + ) + + chat.app.runtime.register_action(check_sensitive_info) + + messages = [{"role": "user", "content": "What's your credit card number?"}] + response = await chat.app.generate_async(messages=messages) + + assert "" in response["content"], "Reasoning traces should be preserved" + assert ( + "I should not share sensitive info" in response["content"] + ), "Reasoning content should be preserved" + assert ( + "credit card" not in response["content"].lower() + ), "Sensitive information should be removed" + + +@pytest.mark.asyncio +async def test_output_rails_without_reasoning_traces() -> None: + """Test that output rails properly handle responses when reasoning traces are disabled.""" + config = RailsConfig.from_content( + colang_content=""" + define flow check sensitive info + $not_allowed = execute check_sensitive_info + if $not_allowed + bot provide sanitized response + stop + define flow check think tag + $not_allowed = execute check_think_tag_present + if $not_allowed + bot says tag not allowed + stop + + define bot says tag not allowed + " tag is not allowed it must be removed" + + define bot provide sanitized response + "I cannot share sensitive information." + """, + yaml_content=""" + models: + - type: main + engine: fake + model: fake + reasoning_config: + remove_thinking_traces: True + colang_version: "1.0" + rails: + input: + flows: + - check sensitive info + output: + flows: + - check sensitive info + - check think tag + apply_to_reasoning_traces: false + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " I should think more Your credit card number is 1234-5678-9012-3456", + ], + ) + + chat.app.runtime.register_action(check_sensitive_info) + chat.app.runtime.register_action(check_think_tag_present) + + # case 1: Sensitive information is blocked by input rail + messages = [{"role": "user", "content": "What's your credit card number?"}] + response = await chat.app.generate_async(messages=messages) + + assert "" not in response["content"], "Think tag should not be present" + assert ( + "I should not share sensitive info" not in response["content"] + ), "Reasoning content should not be present" + assert ( + response["content"] == "I cannot share sensitive information." + ), "Should return sanitized response" + + # case 2: Think tag is preserved but content is sanitized + messages = [{"role": "user", "content": "Tell me some numbers"}] + response = await chat.app.generate_async(messages=messages) + + assert "" in response["content"], "Think tag should be present" + assert ( + "I should not share sensitive info" not in response["content"] + ), "Reasoning content should not be present" + assert ( + response["content"] + == " I should think more I cannot share sensitive information." + ), "Should preserve think tag but sanitize content" diff --git a/tests/test_reasoning_trace_context.py b/tests/test_reasoning_trace_context.py new file mode 100644 index 000000000..d1c0c6db3 --- /dev/null +++ b/tests/test_reasoning_trace_context.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails import RailsConfig +from nemoguardrails.actions.llm.utils import get_and_clear_reasoning_trace_contextvar +from nemoguardrails.context import reasoning_trace_var +from nemoguardrails.rails.llm.llmrails import GenerationOptions, GenerationResponse +from tests.utils import TestChat + + +def test_get_and_clear_reasoning_trace_contextvar(): + """Test that it correctly gets and clears the trace.""" + reasoning_trace_var.set(" oh COT again ") + + result = get_and_clear_reasoning_trace_contextvar() + + assert result == " oh COT again " + assert reasoning_trace_var.get() is None + + +def test_get_and_clear_reasoning_trace_contextvar_empty(): + """Test that it returns None when no trace exists.""" + reasoning_trace_var.set(None) + + result = get_and_clear_reasoning_trace_contextvar() + + assert result is None + + +@pytest.mark.asyncio +async def test_generate_async_trace_with_messages_and_options(): + """Test generate_async prepends reasoning trace when using generation options and messages.""" + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hi" + "hello" + + define bot express greeting + "Hello! How can I assist you today?" + + define flow + user express greeting + bot express greeting + """, + yaml_content=""" + models: [] + rails: + output: + apply_to_reasoning_traces: true + """, + ) + + chat = TestChat( + config, + llm_completions=[ + "user express greeting", + "bot express greeting", + "Hello! How can I assist you today?", + ], + ) + + reasoning_trace_var.set(" yet another COT ") + + options = GenerationOptions() + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "hi"}], options=options + ) + + assert isinstance(result, GenerationResponse) + assert isinstance(result.response, list) + assert len(result.response) == 1 + assert ( + result.response[0]["content"] + == " yet another COT Hello! How can I assist you today?" + ) + assert reasoning_trace_var.get() is None + + +@pytest.mark.asyncio +async def test_generate_async_trace_with_prompt_and_options(): + """Test generate_async prepends reasoning trace using prompt and options""" + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hi" + "hello" + + define bot express greeting + "Hello! How can I assist you today?" + + define flow + user express greeting + bot express greeting + """, + yaml_content=""" + models: [] + rails: + output: + apply_to_reasoning_traces: true + """, + ) + + chat = TestChat( + config, + llm_completions=[ + "user express greeting", + "bot express greeting", + "Hello! How can I assist you today?", + ], + ) + + reasoning_trace_var.set(" yet another COT ") + + options = GenerationOptions() + result = await chat.app.generate_async(options=options, prompt="test prompt") + + assert isinstance(result, GenerationResponse) + assert isinstance(result.response, str) + assert ( + result.response + == " yet another COT Hello! How can I assist you today?" + ) + assert reasoning_trace_var.get() is None + + +@pytest.mark.asyncio +async def test_generate_async_trace_messages_only(): + """Test generate_async prepends reasoning trace when using only messages.""" + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hi" + "hello" + + define bot express greeting + "Hello! How can I assist you today?" + + define flow + user express greeting + bot express greeting + """, + yaml_content=""" + models: [] + rails: + output: + apply_to_reasoning_traces: true + """, + ) + + chat = TestChat( + config, + llm_completions=[ + "user express greeting", + "bot express greeting", + "Hello! How can I assist you today?", + ], + ) + + reasoning_trace_var.set(" yet another COT ") + + result = await chat.app.generate_async(messages=[{"role": "user", "content": "hi"}]) + + assert isinstance(result, dict) + assert result.get("role") == "assistant" + assert ( + result.get("content") + == " yet another COT Hello! How can I assist you today?" + ) + assert reasoning_trace_var.get() is None + + +@pytest.mark.asyncio +async def test_generate_async_trace_with_prompt_only(): + """Test generate_async prepends reasoning trace when using prompt.""" + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hi" + "hello" + + define bot express greeting + "Hello! How can I assist you today?" + + define flow + user express greeting + bot express greeting + """, + yaml_content=""" + models: [] + rails: + output: + apply_to_reasoning_traces: true + """, + ) + + chat = TestChat( + config, + llm_completions=[ + "user express greeting", + "bot express greeting", + "Hello! How can I assist you today?", + ], + ) + + reasoning_trace_var.set(" yet another COT ") + + result = await chat.app.generate_async(prompt="hi") + + assert ( + result == " yet another COT Hello! How can I assist you today?" + ) + assert reasoning_trace_var.get() is None diff --git a/tests/test_reasoning_traces.py b/tests/test_reasoning_traces.py index ce4edec1b..ad1d78e5a 100644 --- a/tests/test_reasoning_traces.py +++ b/tests/test_reasoning_traces.py @@ -17,20 +17,32 @@ import pytest -from nemoguardrails.actions.llm.generation import LLMGenerationActions +from nemoguardrails.actions.llm.generation import ( + LLMGenerationActions, + _get_apply_to_reasoning_traces, + _process_parsed_output, +) from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx from nemoguardrails.context import ( generation_options_var, llm_call_info_var, streaming_handler_var, ) -from nemoguardrails.llm.filters import remove_reasoning_traces -from nemoguardrails.llm.taskmanager import LLMTaskManager +from nemoguardrails.llm.filters import extract_and_strip_trace +from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput from nemoguardrails.llm.types import Task from nemoguardrails.logging.explain import LLMCallInfo from nemoguardrails.rails.llm.config import Model, RailsConfig, ReasoningModelConfig +def create_mock_config(): + config = MagicMock(spec=RailsConfig) + config.rails = MagicMock() + config.rails.output = MagicMock() + config.rails.output.apply_to_reasoning_traces = False + return config + + class TestReasoningTraces: """Test the reasoning traces functionality.""" @@ -38,8 +50,8 @@ def test_remove_reasoning_traces_basic(self): """Test basic removal of reasoning traces.""" input_text = "This is a \nSome reasoning here\nMore reasoning\n response." expected = "This is a response." - result = remove_reasoning_traces(input_text, "", "") - assert result == expected + result = extract_and_strip_trace(input_text, "", "") + assert result.text == expected def test_remove_reasoning_traces_multiline(self): """Test removal of multiline reasoning traces.""" @@ -52,8 +64,8 @@ def test_remove_reasoning_traces_multiline(self): response after thinking. """ expected = "\n Here is my response after thinking.\n " - result = remove_reasoning_traces(input_text, "", "") - assert result == expected + result = extract_and_strip_trace(input_text, "", "") + assert result.text == expected def test_remove_reasoning_traces_multiple_sections(self): """Test removal of multiple reasoning trace sections.""" @@ -61,8 +73,8 @@ def test_remove_reasoning_traces_multiple_sections(self): # Note: The current implementation removes all content between the first start and last end token # So the expected result is "Start end." not "Start middle end." expected = "Start end." - result = remove_reasoning_traces(input_text, "", "") - assert result == expected + result = extract_and_strip_trace(input_text, "", "") + assert result.text == expected def test_remove_reasoning_traces_nested(self): """Test handling of nested reasoning trace markers (should be handled correctly).""" @@ -70,22 +82,21 @@ def test_remove_reasoning_traces_nested(self): "Begin Outer Inner Outer End." ) expected = "Begin End." - result = remove_reasoning_traces(input_text, "", "") - assert result == expected + result = extract_and_strip_trace(input_text, "", "") + assert result.text == expected def test_remove_reasoning_traces_unmatched(self): """Test handling of unmatched reasoning trace markers.""" input_text = "Begin Unmatched end." - result = remove_reasoning_traces(input_text, "", "") + result = extract_and_strip_trace(input_text, "", "") # We ~hould keep the unmatched tag since it's not a complete section - assert result == "Begin Unmatched end." + assert result.text == "Begin Unmatched end." @pytest.mark.asyncio async def test_task_manager_parse_task_output(self): """Test that the task manager correctly removes reasoning traces.""" # mock config - config = MagicMock(spec=RailsConfig) - + config = create_mock_config() # Create a ReasoningModelConfig reasoning_config = ReasoningModelConfig( remove_thinking_traces=True, @@ -121,12 +132,13 @@ async def test_task_manager_parse_task_output(self): expected = "This is a final answer." result = llm_task_manager.parse_task_output(Task.GENERAL, input_text) - assert result == expected + assert result.text == expected @pytest.mark.asyncio async def test_parse_task_output_without_reasoning_config(self): """Test that parse_task_output works without a reasoning config.""" - config = MagicMock(spec=RailsConfig) + + config = create_mock_config() # a Model without reasoning_config model_config = Model(type="main", engine="test", model="test-model") @@ -147,18 +159,22 @@ async def test_parse_task_output_without_reasoning_config(self): input_text = ( "This is a Some reasoning here final answer." ) - - # Without a reasoning config, the text should remain unchanged result = llm_task_manager.parse_task_output(Task.GENERAL, input_text) - assert result == input_text + assert result.text == input_text @pytest.mark.asyncio async def test_parse_task_output_with_default_reasoning_traces(self): - """Test that parse_task_output works without a reasoning config.""" - config = MagicMock(spec=RailsConfig) + """Test that parse_task_output works with default reasoning traces.""" - # a Model without reasoning_config - model_config = Model(type="main", engine="test", model="test-model") + config = create_mock_config() + + # Create a Model with default reasoning_config + model_config = Model( + type="main", + engine="test", + model="test-model", + reasoning_config=ReasoningModelConfig(), + ) # Mock the get_prompt and get_task_model functions with ( @@ -172,42 +188,51 @@ async def test_parse_task_output_with_default_reasoning_traces(self): llm_task_manager = LLMTaskManager(config) - # test parsing without a reasoning config + # test parsing with default reasoning traces input_text = "This is a Some reasoning here final answer." - expected = "This is a final answer." - - # without a reasoning config, the default start_token and stop_token are used thus the text should change result = llm_task_manager.parse_task_output(Task.GENERAL, input_text) - assert result == expected + assert result.text == "This is a final answer." @pytest.mark.asyncio async def test_parse_task_output_with_output_parser(self): - """Test that parse_task_output correctly applies output parsers before returning.""" - config = MagicMock(spec=RailsConfig) + """Test that parse_task_output works with an output parser.""" - # mock output parser function - def mock_parser(text): - return text.upper() + config = create_mock_config() - llm_task_manager = LLMTaskManager(config) - llm_task_manager.output_parsers["test_parser"] = mock_parser + # Create a Model with reasoning_config + model_config = Model( + type="main", + engine="test", + model="test-model", + reasoning_config=ReasoningModelConfig( + remove_thinking_traces=True, + start_token="", + end_token="", + ), + ) - # mock the get_prompt and get_task_model functions + def mock_parser(text): + return f"PARSED: {text}" + + # Mock the get_prompt and get_task_model functions with ( patch("nemoguardrails.llm.taskmanager.get_prompt") as mock_get_prompt, patch( "nemoguardrails.llm.taskmanager.get_task_model" ) as mock_get_task_model, ): - mock_get_prompt.return_value = MagicMock(output_parser="test_parser") - mock_get_task_model.return_value = None + mock_get_prompt.return_value = MagicMock(output_parser="mock_parser") + mock_get_task_model.return_value = model_config - # Test with output parser - input_text = "this should be uppercase" - expected = "THIS SHOULD BE UPPERCASE" + llm_task_manager = LLMTaskManager(config) + llm_task_manager.output_parsers["mock_parser"] = mock_parser + # test parsing with an output parser + input_text = ( + "This is a Some reasoning here final answer." + ) result = llm_task_manager.parse_task_output(Task.GENERAL, input_text) - assert result == expected + assert result.text == "PARSED: This is a final answer." @pytest.mark.asyncio async def test_passthrough_llm_action_removes_reasoning(self): @@ -344,3 +369,47 @@ def __init__(self, events): ) assert mock_result.events[0]["text"] == "This is a final answer." + + +class TestProcessParsedOutput: + """Test the _process_parsed_output function.""" + + def test_process_parsed_output_with_reasoning_trace(self): + """Test processing output with reasoning trace when guardrail is enabled.""" + result = ParsedTaskOutput( + text="final answer", + reasoning_trace="some reasoning", + ) + output = _process_parsed_output(result, include_reasoning_trace=True) + assert output == "some reasoningfinal answer" + + def test_process_parsed_output_with_reasoning_trace_disabled(self): + """Test processing output with reasoning trace when guardrail is disabled.""" + result = ParsedTaskOutput( + text="final answer", + reasoning_trace="some reasoning", + ) + output = _process_parsed_output(result, include_reasoning_trace=False) + assert output == "final answer" + + def test_process_parsed_output_without_reasoning_trace(self): + """Test processing output without reasoning trace.""" + result = ParsedTaskOutput(text="final answer", reasoning_trace=None) + output = _process_parsed_output(result, include_reasoning_trace=True) + assert output == "final answer" + + +class TestGuardrailReasoningTraces: + """Test the guardrail reasoning traces configuration.""" + + def test_get_apply_to_reasoning_traces_enabled(self): + """Test getting guardrail reasoning traces when enabled.""" + config = create_mock_config() + config.rails.output.apply_to_reasoning_traces = True + assert _get_apply_to_reasoning_traces(config) is True + + def test_get_apply_to_reasoning_traces_disabled(self): + """Test getting guardrail reasoning traces when disabled.""" + config = create_mock_config() + config.rails.output.apply_to_reasoning_traces = False + assert _get_apply_to_reasoning_traces(config) is False