Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ async def generate_async(

tool_calls = extract_tool_calls_from_events(new_events)
llm_metadata = get_and_clear_response_metadata_contextvar()

reasoning_content = extract_bot_thinking_from_events(new_events)
# If we have generation options, we prepare a GenerationResponse instance.
if gen_options:
# If a prompt was used, we only need to return the content of the message.
Expand All @@ -1093,17 +1093,8 @@ async def generate_async(
else:
res = GenerationResponse(response=[new_message])

if reasoning_trace := extract_bot_thinking_from_events(events):
if prompt:
# For prompt mode, response should be a string
if isinstance(res.response, str):
res.response = reasoning_trace + res.response
else:
# For message mode, response should be a list
if isinstance(res.response, list) and len(res.response) > 0:
res.response[0]["content"] = (
reasoning_trace + res.response[0]["content"]
)
if reasoning_content:
res.reasoning_content = reasoning_content

if tool_calls:
res.tool_calls = tool_calls
Expand Down Expand Up @@ -1238,8 +1229,9 @@ async def generate_async(
else:
# If a prompt is used, we only return the content of the message.

if reasoning_trace := extract_bot_thinking_from_events(events):
new_message["content"] = reasoning_trace + new_message["content"]
if reasoning_content:
thinking_trace = f"<think>{reasoning_content}</think>\n"
new_message["content"] = thinking_trace + new_message["content"]

if prompt:
return new_message["content"]
Expand Down
4 changes: 4 additions & 0 deletions nemoguardrails/rails/llm/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ class GenerationResponse(BaseModel):
default=None,
description="Tool calls extracted from the LLM response, if any.",
)
reasoning_content: Optional[str] = Field(
default=None,
description="The reasoning content extracted from the LLM response, if any.",
)
llm_metadata: Optional[dict] = Field(
default=None,
description="Metadata from the LLM response (additional_kwargs, response_metadata, usage_metadata, etc.)",
Expand Down
23 changes: 3 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,10 @@

import pytest

from nemoguardrails.context import reasoning_trace_var
REASONING_TRACE_MOCK_PATH = (
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
)


def pytest_configure(config):
patch("prompt_toolkit.PromptSession", autospec=True).start()


@pytest.fixture(autouse=True)
def reset_reasoning_trace():
"""Reset the reasoning_trace_var before each test.

This fixture runs automatically for every test (autouse=True) to ensure
a clean state for the reasoning trace context variable.

current Issues with ContextVar approach, not only specific to this case:
global State: ContextVar creates global state that's hard to track and manage
implicit Flow: The reasoning trace flows through the system in a non-obvious way
testing Complexity: It causes test isolation problems that we are trying to avoid using this fixture
"""
# reset the variable before the test
reasoning_trace_var.set(None)
yield
# reset the variable after the test as well (in case the test fails)
reasoning_trace_var.set(None)
55 changes: 55 additions & 0 deletions tests/rails/llm/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,58 @@ def test_generation_response_model_validation():
assert isinstance(response.tool_calls, list)
assert len(response.tool_calls) == 2
assert response.llm_output["token_usage"]["total_tokens"] == 50


def test_generation_response_with_reasoning_content():
test_reasoning = "Step 1: Analyze\nStep 2: Respond"

response = GenerationResponse(
response="Final answer", reasoning_content=test_reasoning
)

assert response.reasoning_content == test_reasoning
assert response.response == "Final answer"


def test_generation_response_reasoning_content_defaults_to_none():
response = GenerationResponse(response="Answer")

assert response.reasoning_content is None


def test_generation_response_reasoning_content_can_be_empty_string():
response = GenerationResponse(response="Answer", reasoning_content="")

assert response.reasoning_content == ""


def test_generation_response_serialization_with_reasoning_content():
test_reasoning = "Thinking process here"

response = GenerationResponse(response="Response", reasoning_content=test_reasoning)

response_dict = response.dict()
assert "reasoning_content" in response_dict
assert response_dict["reasoning_content"] == test_reasoning

response_json = response.json()
assert "reasoning_content" in response_json
assert test_reasoning in response_json


def test_generation_response_with_all_fields():
test_tool_calls = [
{"name": "test_func", "args": {}, "id": "call_123", "type": "tool_call"}
]
test_reasoning = "Detailed reasoning"

response = GenerationResponse(
response=[{"role": "assistant", "content": "Response"}],
tool_calls=test_tool_calls,
reasoning_content=test_reasoning,
llm_output={"token_usage": {"total_tokens": 100}},
)

assert response.tool_calls == test_tool_calls
assert response.reasoning_content == test_reasoning
assert response.llm_output["token_usage"]["total_tokens"] == 100
42 changes: 12 additions & 30 deletions tests/test_bot_thinking_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
import pytest

from nemoguardrails import RailsConfig
from tests.conftest import REASONING_TRACE_MOCK_PATH
from tests.utils import TestChat


@pytest.mark.asyncio
async def test_bot_thinking_event_creation_passthrough():
test_reasoning_trace = "Let me think about this step by step..."

with patch(
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
) as mock_get_reasoning:
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(config={"models": [], "passthrough": True})
Expand All @@ -46,9 +45,7 @@ async def test_bot_thinking_event_creation_passthrough():
async def test_bot_thinking_event_creation_non_passthrough():
test_reasoning_trace = "Analyzing the user's request..."

with patch(
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
) as mock_get_reasoning:
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(
Expand Down Expand Up @@ -84,9 +81,7 @@ async def test_bot_thinking_event_creation_non_passthrough():

@pytest.mark.asyncio
async def test_no_bot_thinking_event_when_no_reasoning_trace():
with patch(
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
) as mock_get_reasoning:
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = None

config = RailsConfig.from_content(config={"models": [], "passthrough": True})
Expand All @@ -104,9 +99,7 @@ async def test_no_bot_thinking_event_when_no_reasoning_trace():
async def test_bot_thinking_before_bot_message():
test_reasoning_trace = "Step 1: Understand the question\nStep 2: Formulate answer"

with patch(
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
) as mock_get_reasoning:
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(config={"models": [], "passthrough": True})
Expand Down Expand Up @@ -134,9 +127,7 @@ async def test_bot_thinking_before_bot_message():
async def test_bot_thinking_accessible_in_output_rails():
test_reasoning_trace = "Thinking: This requires careful consideration"

with patch(
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
) as mock_get_reasoning:
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(
Expand Down Expand Up @@ -171,9 +162,7 @@ async def test_bot_thinking_accessible_in_output_rails():
async def test_bot_thinking_matches_in_output_rails():
test_reasoning_trace = "Let me analyze: step 1, step 2, step 3"

with patch(
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
) as mock_get_reasoning:
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(
Expand Down Expand Up @@ -203,9 +192,7 @@ async def test_bot_thinking_matches_in_output_rails():

@pytest.mark.asyncio
async def test_bot_thinking_none_when_no_reasoning():
with patch(
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
) as mock_get_reasoning:
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = None

config = RailsConfig.from_content(
Expand Down Expand Up @@ -240,9 +227,7 @@ async def test_bot_thinking_none_when_no_reasoning():
async def test_bot_thinking_usable_in_output_rail_logic():
test_reasoning_trace = "This contains sensitive information"

with patch(
"nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar"
) as mock_get_reasoning:
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(
Expand Down Expand Up @@ -270,12 +255,9 @@ async def test_bot_thinking_usable_in_output_rail_logic():
)

assert isinstance(result.response, list)
# TODO(@Pouyanpi): in llmrails.py appending reasoning traces to the final generation might not be desired anymore
# should be fixed in a subsequent PR for 0.18.0 release
assert (
result.response[0]["content"]
== test_reasoning_trace + "I'm sorry, I can't respond to that."
)
assert result.reasoning_content == test_reasoning_trace
assert result.response[0]["content"] == "I'm sorry, I can't respond to that."
assert test_reasoning_trace not in result.response[0]["content"]


@pytest.mark.asyncio
Expand Down
105 changes: 105 additions & 0 deletions tests/test_llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nemoguardrails.logging.explain import ExplainInfo
from nemoguardrails.rails.llm.config import Model
from nemoguardrails.rails.llm.llmrails import get_action_details_from_flow_id
from tests.conftest import REASONING_TRACE_MOCK_PATH
from tests.utils import FakeLLM, clean_events, event_sequence_conforms


Expand Down Expand Up @@ -1372,3 +1373,107 @@ def test_cache_initialization_with_multiple_models(mock_init_llm_model):
assert "jailbreak_detection" in model_caches
assert model_caches["content_safety"].maxsize == 1000
assert model_caches["jailbreak_detection"].maxsize == 2000


@pytest.mark.asyncio
async def test_generate_async_reasoning_content_field_passthrough():
from nemoguardrails.rails.llm.options import GenerationOptions

test_reasoning_trace = "Let me think about this step by step..."

with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(config={"models": []})
llm = FakeLLM(responses=["The answer is 42"])
llm_rails = LLMRails(config=config, llm=llm)

result = await llm_rails.generate_async(
messages=[{"role": "user", "content": "What is the answer?"}],
options=GenerationOptions(),
)

assert result.reasoning_content == test_reasoning_trace
assert isinstance(result.response, list)
assert result.response[0]["content"] == "The answer is 42"


@pytest.mark.asyncio
async def test_generate_async_reasoning_content_none():
from nemoguardrails.rails.llm.options import GenerationOptions

with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = None

config = RailsConfig.from_content(config={"models": []})
llm = FakeLLM(responses=["Regular response"])
llm_rails = LLMRails(config=config, llm=llm)

result = await llm_rails.generate_async(
messages=[{"role": "user", "content": "Hello"}],
options=GenerationOptions(),
)

assert result.reasoning_content is None
assert isinstance(result.response, list)
assert result.response[0]["content"] == "Regular response"


@pytest.mark.asyncio
async def test_generate_async_reasoning_not_in_response_content():
from nemoguardrails.rails.llm.options import GenerationOptions

test_reasoning_trace = "Let me analyze this carefully..."

with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(config={"models": []})
llm = FakeLLM(responses=["The answer is 42"])
llm_rails = LLMRails(config=config, llm=llm)

result = await llm_rails.generate_async(
messages=[{"role": "user", "content": "What is the answer?"}],
options=GenerationOptions(),
)

assert result.reasoning_content == test_reasoning_trace
assert test_reasoning_trace not in result.response[0]["content"]
assert result.response[0]["content"] == "The answer is 42"


@pytest.mark.asyncio
async def test_generate_async_reasoning_with_thinking_tags():
test_reasoning_trace = "Step 1: Analyze\nStep 2: Respond"

with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = test_reasoning_trace

config = RailsConfig.from_content(config={"models": [], "passthrough": True})
llm = FakeLLM(responses=["The answer is 42"])
llm_rails = LLMRails(config=config, llm=llm)

result = await llm_rails.generate_async(
messages=[{"role": "user", "content": "What is the answer?"}]
)

expected_prefix = f"<think>{test_reasoning_trace}</think>\n"
assert result["content"].startswith(expected_prefix)
assert "The answer is 42" in result["content"]


@pytest.mark.asyncio
async def test_generate_async_no_thinking_tags_when_no_reasoning():
with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning:
mock_get_reasoning.return_value = None

config = RailsConfig.from_content(config={"models": []})
llm = FakeLLM(responses=["Regular response"])
llm_rails = LLMRails(config=config, llm=llm)

result = await llm_rails.generate_async(
messages=[{"role": "user", "content": "Hello"}]
)

assert not result["content"].startswith("<think>")
assert result["content"] == "Regular response"
Loading