Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
cache: t.Optional[CacheInterface] = None,
bypass_temperature: bool = False,
bypass_n: bool = False,
):
super().__init__(cache=cache)
self.langchain_llm = langchain_llm
Expand All @@ -158,6 +159,8 @@ def __init__(
self.is_finished_parser = is_finished_parser
# Certain LLMs (e.g., OpenAI o1 series) do not support temperature
self.bypass_temperature = bypass_temperature
# Certain reasoning LLMs (e.g., OpenAI o1 series) do not support n parameter for
self.bypass_n = bypass_n

def is_finished(self, response: LLMResult) -> bool:
"""
Expand Down Expand Up @@ -225,7 +228,7 @@ def generate_text(
old_temperature = self.langchain_llm.temperature # type: ignore
self.langchain_llm.temperature = temperature # type: ignore

if is_multiple_completion_supported(self.langchain_llm):
if is_multiple_completion_supported(self.langchain_llm) and not self.bypass_n:
result = self.langchain_llm.generate_prompt(
prompts=[prompt],
n=n,
Expand Down Expand Up @@ -278,7 +281,7 @@ async def agenerate_text(
self.langchain_llm.temperature = temperature # type: ignore

# handle n
if hasattr(self.langchain_llm, "n"):
if hasattr(self.langchain_llm, "n") and not self.bypass_n:
self.langchain_llm.n = n # type: ignore
result = await self.langchain_llm.agenerate_prompt(
prompts=[prompt],
Expand Down
205 changes: 201 additions & 4 deletions tests/unit/llms/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import typing as t
from unittest.mock import MagicMock, patch

import pytest
from langchain_core.outputs import Generation, LLMResult
from langchain_core.prompt_values import PromptValue

from ragas.llms.base import BaseRagasLLM

if t.TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper


class FakeTestLLM(BaseRagasLLM):
Expand Down Expand Up @@ -38,3 +38,200 @@ async def agenerate_text(

def is_finished(self, response: LLMResult) -> bool:
return True


class MockLangchainLLM:
"""Mock Langchain LLM for testing bypass_n functionality."""

def __init__(self):
self.n = None # This makes hasattr(self.langchain_llm, "n") return True
self.temperature = None
self.model_name = "mock-model"

def generate_prompt(self, prompts, n=None, stop=None, callbacks=None):
# Track if n was passed to the method
self._n_passed = n
# Simulate the behavior where if n is passed, we return n generations per prompt
# If n is not passed, we return one generation per prompt
num_prompts = len(prompts)
if n is not None:
# If n is specified, return n generations for each prompt
generations = [
[Generation(text="test response")] * n for _ in range(num_prompts)
]
else:
# If n is not specified, return one generation per prompt
generations = [
[Generation(text="test response")] for _ in range(num_prompts)
]
return LLMResult(generations=generations)

async def agenerate_prompt(self, prompts, n=None, stop=None, callbacks=None):
# Track if n was passed to the method
self._n_passed = n
# If n is not passed as parameter but self.n is set, use self.n
if n is None and hasattr(self, "n") and self.n is not None:
n = self.n
# Simulate the behavior where if n is passed, we return n generations per prompt
# If n is not passed, we return one generation per prompt
num_prompts = len(prompts)
if n is not None:
# If n is specified, return n generations for each prompt
generations = [
[Generation(text="test response")] * n for _ in range(num_prompts)
]
else:
# If n is not specified, return one generation per prompt
generations = [
[Generation(text="test response")] for _ in range(num_prompts)
]
return LLMResult(generations=generations)


def create_mock_prompt():
"""Create a mock prompt for testing."""
prompt = MagicMock(spec=PromptValue)
prompt.to_string.return_value = "test prompt"
return prompt


class TestLangchainLLMWrapperBypassN:
"""Test bypass_n functionality in LangchainLLMWrapper."""

def test_bypass_n_true_sync_does_not_pass_n(self):
"""Test that when bypass_n=True, n is not passed to underlying LLM in sync method."""
mock_llm = MockLangchainLLM()
# Mock is_multiple_completion_supported to return True for this test
with patch(
"ragas.llms.base.is_multiple_completion_supported", return_value=True
):
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True)
prompt = create_mock_prompt()

# Call generate_text with n=3
result = wrapper.generate_text(prompt, n=3)

# Verify that n was not passed to the underlying LLM
assert mock_llm._n_passed is None
# When bypass_n=True, the wrapper should duplicate prompts instead of passing n
# The result should still have 3 generations (created by duplicating prompts)
assert len(result.generations[0]) == 3

def test_bypass_n_false_sync_passes_n(self):
"""Test that when bypass_n=False (default), n is passed to underlying LLM in sync method."""
mock_llm = MockLangchainLLM()
# Mock is_multiple_completion_supported to return True for this test
with patch(
"ragas.llms.base.is_multiple_completion_supported", return_value=True
):
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False)
prompt = create_mock_prompt()

# Call generate_text with n=3
result = wrapper.generate_text(prompt, n=3)

# Verify that n was passed to the underlying LLM
assert mock_llm._n_passed == 3
# Result should have 3 generations
assert len(result.generations[0]) == 3

@pytest.mark.asyncio
async def test_bypass_n_true_async_does_not_pass_n(self):
"""Test that when bypass_n=True, n is not passed to underlying LLM in async method."""
mock_llm = MockLangchainLLM()
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True)
prompt = create_mock_prompt()

# Call agenerate_text with n=3
result = await wrapper.agenerate_text(prompt, n=3)

# Verify that n was not passed to the underlying LLM
assert mock_llm._n_passed is None
# When bypass_n=True, the wrapper should duplicate prompts instead of passing n
# The result should still have 3 generations (created by duplicating prompts)
assert len(result.generations[0]) == 3

@pytest.mark.asyncio
async def test_bypass_n_false_async_passes_n(self):
"""Test that when bypass_n=False (default), n is passed to underlying LLM in async method."""
mock_llm = MockLangchainLLM()
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False)
prompt = create_mock_prompt()

# Call agenerate_text with n=3
result = await wrapper.agenerate_text(prompt, n=3)

# Verify that n was passed to the underlying LLM (via n attribute)
assert mock_llm.n == 3
# Result should have 3 generations
assert len(result.generations[0]) == 3

def test_default_bypass_n_behavior(self):
"""Test that default behavior (bypass_n=False) remains unchanged."""
mock_llm = MockLangchainLLM()
# Mock is_multiple_completion_supported to return True for this test
with patch(
"ragas.llms.base.is_multiple_completion_supported", return_value=True
):
# Create wrapper without explicitly setting bypass_n (should default to False)
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm)
prompt = create_mock_prompt()

# Call generate_text with n=2
result = wrapper.generate_text(prompt, n=2)

# Verify that n was passed to the underlying LLM (default behavior)
assert mock_llm._n_passed == 2
assert len(result.generations[0]) == 2

@pytest.mark.asyncio
async def test_default_bypass_n_behavior_async(self):
"""Test that default behavior (bypass_n=False) remains unchanged in async method."""
mock_llm = MockLangchainLLM()
# Create wrapper without explicitly setting bypass_n (should default to False)
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm)
prompt = create_mock_prompt()

# Call agenerate_text with n=2
result = await wrapper.agenerate_text(prompt, n=2)

# Verify that n was passed to the underlying LLM (default behavior)
assert mock_llm.n == 2
assert len(result.generations[0]) == 2

def test_bypass_n_true_with_multiple_completion_supported(self):
"""Test bypass_n=True with LLM that supports multiple completions."""
# Create a mock LLM that would normally support multiple completions
mock_llm = MockLangchainLLM()
# Mock the is_multiple_completion_supported to return True for this test
with patch(
"ragas.llms.base.is_multiple_completion_supported", return_value=True
):
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True)
prompt = create_mock_prompt()

# Call generate_text with n=3
result = wrapper.generate_text(prompt, n=3)

# Verify that n was not passed to the underlying LLM due to bypass_n=True
assert mock_llm._n_passed is None
# Result should still have 3 generations (created by duplicating prompts)
assert len(result.generations[0]) == 3

@pytest.mark.asyncio
async def test_bypass_n_true_with_multiple_completion_supported_async(self):
"""Test bypass_n=True with LLM that supports multiple completions in async method."""
mock_llm = MockLangchainLLM()
with patch(
"ragas.llms.base.is_multiple_completion_supported", return_value=True
):
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True)
prompt = create_mock_prompt()

# Call agenerate_text with n=3
result = await wrapper.agenerate_text(prompt, n=3)

# Verify that n was not passed to the underlying LLM due to bypass_n=True
assert mock_llm._n_passed is None
# Result should still have 3 generations
assert len(result.generations[0]) == 3
Loading