diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 8e39cb582..74d9b8cd7 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -149,6 +149,7 @@ def __init__( is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None, cache: t.Optional[CacheInterface] = None, bypass_temperature: bool = False, + bypass_n: bool = False, ): super().__init__(cache=cache) self.langchain_llm = langchain_llm @@ -158,6 +159,8 @@ def __init__( self.is_finished_parser = is_finished_parser # Certain LLMs (e.g., OpenAI o1 series) do not support temperature self.bypass_temperature = bypass_temperature + # Certain reasoning LLMs (e.g., OpenAI o1 series) do not support n parameter for + self.bypass_n = bypass_n def is_finished(self, response: LLMResult) -> bool: """ @@ -225,7 +228,7 @@ def generate_text( old_temperature = self.langchain_llm.temperature # type: ignore self.langchain_llm.temperature = temperature # type: ignore - if is_multiple_completion_supported(self.langchain_llm): + if is_multiple_completion_supported(self.langchain_llm) and not self.bypass_n: result = self.langchain_llm.generate_prompt( prompts=[prompt], n=n, @@ -278,7 +281,7 @@ async def agenerate_text( self.langchain_llm.temperature = temperature # type: ignore # handle n - if hasattr(self.langchain_llm, "n"): + if hasattr(self.langchain_llm, "n") and not self.bypass_n: self.langchain_llm.n = n # type: ignore result = await self.langchain_llm.agenerate_prompt( prompts=[prompt], diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index 059bfe875..0002b0714 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -1,13 +1,13 @@ from __future__ import annotations import typing as t +from unittest.mock import MagicMock, patch +import pytest from langchain_core.outputs import Generation, LLMResult +from langchain_core.prompt_values import PromptValue -from ragas.llms.base import BaseRagasLLM - -if t.TYPE_CHECKING: - from langchain_core.prompt_values import PromptValue +from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper class FakeTestLLM(BaseRagasLLM): @@ -38,3 +38,200 @@ async def agenerate_text( def is_finished(self, response: LLMResult) -> bool: return True + + +class MockLangchainLLM: + """Mock Langchain LLM for testing bypass_n functionality.""" + + def __init__(self): + self.n = None # This makes hasattr(self.langchain_llm, "n") return True + self.temperature = None + self.model_name = "mock-model" + + def generate_prompt(self, prompts, n=None, stop=None, callbacks=None): + # Track if n was passed to the method + self._n_passed = n + # Simulate the behavior where if n is passed, we return n generations per prompt + # If n is not passed, we return one generation per prompt + num_prompts = len(prompts) + if n is not None: + # If n is specified, return n generations for each prompt + generations = [ + [Generation(text="test response")] * n for _ in range(num_prompts) + ] + else: + # If n is not specified, return one generation per prompt + generations = [ + [Generation(text="test response")] for _ in range(num_prompts) + ] + return LLMResult(generations=generations) + + async def agenerate_prompt(self, prompts, n=None, stop=None, callbacks=None): + # Track if n was passed to the method + self._n_passed = n + # If n is not passed as parameter but self.n is set, use self.n + if n is None and hasattr(self, "n") and self.n is not None: + n = self.n + # Simulate the behavior where if n is passed, we return n generations per prompt + # If n is not passed, we return one generation per prompt + num_prompts = len(prompts) + if n is not None: + # If n is specified, return n generations for each prompt + generations = [ + [Generation(text="test response")] * n for _ in range(num_prompts) + ] + else: + # If n is not specified, return one generation per prompt + generations = [ + [Generation(text="test response")] for _ in range(num_prompts) + ] + return LLMResult(generations=generations) + + +def create_mock_prompt(): + """Create a mock prompt for testing.""" + prompt = MagicMock(spec=PromptValue) + prompt.to_string.return_value = "test prompt" + return prompt + + +class TestLangchainLLMWrapperBypassN: + """Test bypass_n functionality in LangchainLLMWrapper.""" + + def test_bypass_n_true_sync_does_not_pass_n(self): + """Test that when bypass_n=True, n is not passed to underlying LLM in sync method.""" + mock_llm = MockLangchainLLM() + # Mock is_multiple_completion_supported to return True for this test + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) + prompt = create_mock_prompt() + + # Call generate_text with n=3 + result = wrapper.generate_text(prompt, n=3) + + # Verify that n was not passed to the underlying LLM + assert mock_llm._n_passed is None + # When bypass_n=True, the wrapper should duplicate prompts instead of passing n + # The result should still have 3 generations (created by duplicating prompts) + assert len(result.generations[0]) == 3 + + def test_bypass_n_false_sync_passes_n(self): + """Test that when bypass_n=False (default), n is passed to underlying LLM in sync method.""" + mock_llm = MockLangchainLLM() + # Mock is_multiple_completion_supported to return True for this test + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False) + prompt = create_mock_prompt() + + # Call generate_text with n=3 + result = wrapper.generate_text(prompt, n=3) + + # Verify that n was passed to the underlying LLM + assert mock_llm._n_passed == 3 + # Result should have 3 generations + assert len(result.generations[0]) == 3 + + @pytest.mark.asyncio + async def test_bypass_n_true_async_does_not_pass_n(self): + """Test that when bypass_n=True, n is not passed to underlying LLM in async method.""" + mock_llm = MockLangchainLLM() + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) + prompt = create_mock_prompt() + + # Call agenerate_text with n=3 + result = await wrapper.agenerate_text(prompt, n=3) + + # Verify that n was not passed to the underlying LLM + assert mock_llm._n_passed is None + # When bypass_n=True, the wrapper should duplicate prompts instead of passing n + # The result should still have 3 generations (created by duplicating prompts) + assert len(result.generations[0]) == 3 + + @pytest.mark.asyncio + async def test_bypass_n_false_async_passes_n(self): + """Test that when bypass_n=False (default), n is passed to underlying LLM in async method.""" + mock_llm = MockLangchainLLM() + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False) + prompt = create_mock_prompt() + + # Call agenerate_text with n=3 + result = await wrapper.agenerate_text(prompt, n=3) + + # Verify that n was passed to the underlying LLM (via n attribute) + assert mock_llm.n == 3 + # Result should have 3 generations + assert len(result.generations[0]) == 3 + + def test_default_bypass_n_behavior(self): + """Test that default behavior (bypass_n=False) remains unchanged.""" + mock_llm = MockLangchainLLM() + # Mock is_multiple_completion_supported to return True for this test + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): + # Create wrapper without explicitly setting bypass_n (should default to False) + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm) + prompt = create_mock_prompt() + + # Call generate_text with n=2 + result = wrapper.generate_text(prompt, n=2) + + # Verify that n was passed to the underlying LLM (default behavior) + assert mock_llm._n_passed == 2 + assert len(result.generations[0]) == 2 + + @pytest.mark.asyncio + async def test_default_bypass_n_behavior_async(self): + """Test that default behavior (bypass_n=False) remains unchanged in async method.""" + mock_llm = MockLangchainLLM() + # Create wrapper without explicitly setting bypass_n (should default to False) + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm) + prompt = create_mock_prompt() + + # Call agenerate_text with n=2 + result = await wrapper.agenerate_text(prompt, n=2) + + # Verify that n was passed to the underlying LLM (default behavior) + assert mock_llm.n == 2 + assert len(result.generations[0]) == 2 + + def test_bypass_n_true_with_multiple_completion_supported(self): + """Test bypass_n=True with LLM that supports multiple completions.""" + # Create a mock LLM that would normally support multiple completions + mock_llm = MockLangchainLLM() + # Mock the is_multiple_completion_supported to return True for this test + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) + prompt = create_mock_prompt() + + # Call generate_text with n=3 + result = wrapper.generate_text(prompt, n=3) + + # Verify that n was not passed to the underlying LLM due to bypass_n=True + assert mock_llm._n_passed is None + # Result should still have 3 generations (created by duplicating prompts) + assert len(result.generations[0]) == 3 + + @pytest.mark.asyncio + async def test_bypass_n_true_with_multiple_completion_supported_async(self): + """Test bypass_n=True with LLM that supports multiple completions in async method.""" + mock_llm = MockLangchainLLM() + with patch( + "ragas.llms.base.is_multiple_completion_supported", return_value=True + ): + wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True) + prompt = create_mock_prompt() + + # Call agenerate_text with n=3 + result = await wrapper.agenerate_text(prompt, n=3) + + # Verify that n was not passed to the underlying LLM due to bypass_n=True + assert mock_llm._n_passed is None + # Result should still have 3 generations + assert len(result.generations[0]) == 3