diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 3505dad07..a9ed69048 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -1084,7 +1084,7 @@ async def generate_async( tool_calls = extract_tool_calls_from_events(new_events) llm_metadata = get_and_clear_response_metadata_contextvar() - + reasoning_content = extract_bot_thinking_from_events(new_events) # If we have generation options, we prepare a GenerationResponse instance. if gen_options: # If a prompt was used, we only need to return the content of the message. @@ -1093,17 +1093,8 @@ async def generate_async( else: res = GenerationResponse(response=[new_message]) - if reasoning_trace := extract_bot_thinking_from_events(events): - if prompt: - # For prompt mode, response should be a string - if isinstance(res.response, str): - res.response = reasoning_trace + res.response - else: - # For message mode, response should be a list - if isinstance(res.response, list) and len(res.response) > 0: - res.response[0]["content"] = ( - reasoning_trace + res.response[0]["content"] - ) + if reasoning_content: + res.reasoning_content = reasoning_content if tool_calls: res.tool_calls = tool_calls @@ -1238,8 +1229,9 @@ async def generate_async( else: # If a prompt is used, we only return the content of the message. - if reasoning_trace := extract_bot_thinking_from_events(events): - new_message["content"] = reasoning_trace + new_message["content"] + if reasoning_content: + thinking_trace = f"{reasoning_content}\n" + new_message["content"] = thinking_trace + new_message["content"] if prompt: return new_message["content"] diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index 21110daae..ca8a7dfa1 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -428,6 +428,10 @@ class GenerationResponse(BaseModel): default=None, description="Tool calls extracted from the LLM response, if any.", ) + reasoning_content: Optional[str] = Field( + default=None, + description="The reasoning content extracted from the LLM response, if any.", + ) llm_metadata: Optional[dict] = Field( default=None, description="Metadata from the LLM response (additional_kwargs, response_metadata, usage_metadata, etc.)", diff --git a/tests/conftest.py b/tests/conftest.py index ed3220aa6..1dc00134b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,27 +17,10 @@ import pytest -from nemoguardrails.context import reasoning_trace_var +REASONING_TRACE_MOCK_PATH = ( + "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" +) def pytest_configure(config): patch("prompt_toolkit.PromptSession", autospec=True).start() - - -@pytest.fixture(autouse=True) -def reset_reasoning_trace(): - """Reset the reasoning_trace_var before each test. - - This fixture runs automatically for every test (autouse=True) to ensure - a clean state for the reasoning trace context variable. - - current Issues with ContextVar approach, not only specific to this case: - global State: ContextVar creates global state that's hard to track and manage - implicit Flow: The reasoning trace flows through the system in a non-obvious way - testing Complexity: It causes test isolation problems that we are trying to avoid using this fixture - """ - # reset the variable before the test - reasoning_trace_var.set(None) - yield - # reset the variable after the test as well (in case the test fails) - reasoning_trace_var.set(None) diff --git a/tests/rails/llm/test_options.py b/tests/rails/llm/test_options.py index 65ed56ab0..2cef2cec3 100644 --- a/tests/rails/llm/test_options.py +++ b/tests/rails/llm/test_options.py @@ -193,3 +193,58 @@ def test_generation_response_model_validation(): assert isinstance(response.tool_calls, list) assert len(response.tool_calls) == 2 assert response.llm_output["token_usage"]["total_tokens"] == 50 + + +def test_generation_response_with_reasoning_content(): + test_reasoning = "Step 1: Analyze\nStep 2: Respond" + + response = GenerationResponse( + response="Final answer", reasoning_content=test_reasoning + ) + + assert response.reasoning_content == test_reasoning + assert response.response == "Final answer" + + +def test_generation_response_reasoning_content_defaults_to_none(): + response = GenerationResponse(response="Answer") + + assert response.reasoning_content is None + + +def test_generation_response_reasoning_content_can_be_empty_string(): + response = GenerationResponse(response="Answer", reasoning_content="") + + assert response.reasoning_content == "" + + +def test_generation_response_serialization_with_reasoning_content(): + test_reasoning = "Thinking process here" + + response = GenerationResponse(response="Response", reasoning_content=test_reasoning) + + response_dict = response.dict() + assert "reasoning_content" in response_dict + assert response_dict["reasoning_content"] == test_reasoning + + response_json = response.json() + assert "reasoning_content" in response_json + assert test_reasoning in response_json + + +def test_generation_response_with_all_fields(): + test_tool_calls = [ + {"name": "test_func", "args": {}, "id": "call_123", "type": "tool_call"} + ] + test_reasoning = "Detailed reasoning" + + response = GenerationResponse( + response=[{"role": "assistant", "content": "Response"}], + tool_calls=test_tool_calls, + reasoning_content=test_reasoning, + llm_output={"token_usage": {"total_tokens": 100}}, + ) + + assert response.tool_calls == test_tool_calls + assert response.reasoning_content == test_reasoning + assert response.llm_output["token_usage"]["total_tokens"] == 100 diff --git a/tests/test_bot_thinking_events.py b/tests/test_bot_thinking_events.py index 64934c9ab..a57ba4769 100644 --- a/tests/test_bot_thinking_events.py +++ b/tests/test_bot_thinking_events.py @@ -18,6 +18,7 @@ import pytest from nemoguardrails import RailsConfig +from tests.conftest import REASONING_TRACE_MOCK_PATH from tests.utils import TestChat @@ -25,9 +26,7 @@ async def test_bot_thinking_event_creation_passthrough(): test_reasoning_trace = "Let me think about this step by step..." - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content(config={"models": [], "passthrough": True}) @@ -46,9 +45,7 @@ async def test_bot_thinking_event_creation_passthrough(): async def test_bot_thinking_event_creation_non_passthrough(): test_reasoning_trace = "Analyzing the user's request..." - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content( @@ -84,9 +81,7 @@ async def test_bot_thinking_event_creation_non_passthrough(): @pytest.mark.asyncio async def test_no_bot_thinking_event_when_no_reasoning_trace(): - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = None config = RailsConfig.from_content(config={"models": [], "passthrough": True}) @@ -104,9 +99,7 @@ async def test_no_bot_thinking_event_when_no_reasoning_trace(): async def test_bot_thinking_before_bot_message(): test_reasoning_trace = "Step 1: Understand the question\nStep 2: Formulate answer" - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content(config={"models": [], "passthrough": True}) @@ -134,9 +127,7 @@ async def test_bot_thinking_before_bot_message(): async def test_bot_thinking_accessible_in_output_rails(): test_reasoning_trace = "Thinking: This requires careful consideration" - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content( @@ -171,9 +162,7 @@ async def test_bot_thinking_accessible_in_output_rails(): async def test_bot_thinking_matches_in_output_rails(): test_reasoning_trace = "Let me analyze: step 1, step 2, step 3" - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content( @@ -203,9 +192,7 @@ async def test_bot_thinking_matches_in_output_rails(): @pytest.mark.asyncio async def test_bot_thinking_none_when_no_reasoning(): - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = None config = RailsConfig.from_content( @@ -240,9 +227,7 @@ async def test_bot_thinking_none_when_no_reasoning(): async def test_bot_thinking_usable_in_output_rail_logic(): test_reasoning_trace = "This contains sensitive information" - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content( @@ -270,12 +255,9 @@ async def test_bot_thinking_usable_in_output_rail_logic(): ) assert isinstance(result.response, list) - # TODO(@Pouyanpi): in llmrails.py appending reasoning traces to the final generation might not be desired anymore - # should be fixed in a subsequent PR for 0.18.0 release - assert ( - result.response[0]["content"] - == test_reasoning_trace + "I'm sorry, I can't respond to that." - ) + assert result.reasoning_content == test_reasoning_trace + assert result.response[0]["content"] == "I'm sorry, I can't respond to that." + assert test_reasoning_trace not in result.response[0]["content"] @pytest.mark.asyncio diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 15523bb51..89e7e87cf 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -24,6 +24,7 @@ from nemoguardrails.logging.explain import ExplainInfo from nemoguardrails.rails.llm.config import Model from nemoguardrails.rails.llm.llmrails import get_action_details_from_flow_id +from tests.conftest import REASONING_TRACE_MOCK_PATH from tests.utils import FakeLLM, clean_events, event_sequence_conforms @@ -1372,3 +1373,107 @@ def test_cache_initialization_with_multiple_models(mock_init_llm_model): assert "jailbreak_detection" in model_caches assert model_caches["content_safety"].maxsize == 1000 assert model_caches["jailbreak_detection"].maxsize == 2000 + + +@pytest.mark.asyncio +async def test_generate_async_reasoning_content_field_passthrough(): + from nemoguardrails.rails.llm.options import GenerationOptions + + test_reasoning_trace = "Let me think about this step by step..." + + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: + mock_get_reasoning.return_value = test_reasoning_trace + + config = RailsConfig.from_content(config={"models": []}) + llm = FakeLLM(responses=["The answer is 42"]) + llm_rails = LLMRails(config=config, llm=llm) + + result = await llm_rails.generate_async( + messages=[{"role": "user", "content": "What is the answer?"}], + options=GenerationOptions(), + ) + + assert result.reasoning_content == test_reasoning_trace + assert isinstance(result.response, list) + assert result.response[0]["content"] == "The answer is 42" + + +@pytest.mark.asyncio +async def test_generate_async_reasoning_content_none(): + from nemoguardrails.rails.llm.options import GenerationOptions + + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: + mock_get_reasoning.return_value = None + + config = RailsConfig.from_content(config={"models": []}) + llm = FakeLLM(responses=["Regular response"]) + llm_rails = LLMRails(config=config, llm=llm) + + result = await llm_rails.generate_async( + messages=[{"role": "user", "content": "Hello"}], + options=GenerationOptions(), + ) + + assert result.reasoning_content is None + assert isinstance(result.response, list) + assert result.response[0]["content"] == "Regular response" + + +@pytest.mark.asyncio +async def test_generate_async_reasoning_not_in_response_content(): + from nemoguardrails.rails.llm.options import GenerationOptions + + test_reasoning_trace = "Let me analyze this carefully..." + + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: + mock_get_reasoning.return_value = test_reasoning_trace + + config = RailsConfig.from_content(config={"models": []}) + llm = FakeLLM(responses=["The answer is 42"]) + llm_rails = LLMRails(config=config, llm=llm) + + result = await llm_rails.generate_async( + messages=[{"role": "user", "content": "What is the answer?"}], + options=GenerationOptions(), + ) + + assert result.reasoning_content == test_reasoning_trace + assert test_reasoning_trace not in result.response[0]["content"] + assert result.response[0]["content"] == "The answer is 42" + + +@pytest.mark.asyncio +async def test_generate_async_reasoning_with_thinking_tags(): + test_reasoning_trace = "Step 1: Analyze\nStep 2: Respond" + + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: + mock_get_reasoning.return_value = test_reasoning_trace + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + llm = FakeLLM(responses=["The answer is 42"]) + llm_rails = LLMRails(config=config, llm=llm) + + result = await llm_rails.generate_async( + messages=[{"role": "user", "content": "What is the answer?"}] + ) + + expected_prefix = f"{test_reasoning_trace}\n" + assert result["content"].startswith(expected_prefix) + assert "The answer is 42" in result["content"] + + +@pytest.mark.asyncio +async def test_generate_async_no_thinking_tags_when_no_reasoning(): + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: + mock_get_reasoning.return_value = None + + config = RailsConfig.from_content(config={"models": []}) + llm = FakeLLM(responses=["Regular response"]) + llm_rails = LLMRails(config=config, llm=llm) + + result = await llm_rails.generate_async( + messages=[{"role": "user", "content": "Hello"}] + ) + + assert not result["content"].startswith("") + assert result["content"] == "Regular response" diff --git a/tests/test_runtime_event_logging.py b/tests/test_runtime_event_logging.py index f39429648..2aadd4bac 100644 --- a/tests/test_runtime_event_logging.py +++ b/tests/test_runtime_event_logging.py @@ -19,6 +19,7 @@ import pytest from nemoguardrails import RailsConfig +from tests.conftest import REASONING_TRACE_MOCK_PATH from tests.utils import TestChat @@ -28,9 +29,7 @@ async def test_bot_thinking_event_logged_in_runtime(caplog): caplog.set_level(logging.INFO) - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content(config={"models": [], "passthrough": True}) @@ -89,9 +88,7 @@ async def test_all_events_logged_when_multiple_events_generated(caplog): caplog.set_level(logging.INFO) - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content(config={"models": [], "passthrough": True}) @@ -116,9 +113,7 @@ async def test_bot_thinking_event_logged_before_bot_message(caplog): caplog.set_level(logging.INFO) - with patch( - "nemoguardrails.actions.llm.generation.get_and_clear_reasoning_trace_contextvar" - ) as mock_get_reasoning: + with patch(REASONING_TRACE_MOCK_PATH) as mock_get_reasoning: mock_get_reasoning.return_value = test_reasoning_trace config = RailsConfig.from_content(config={"models": [], "passthrough": True})