diff --git a/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py b/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py
index 99eee601cb..0f1b0ac101 100644
--- a/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py
+++ b/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+from collections.abc import Sequence
from textwrap import dedent
from typing import Any
@@ -7,6 +8,7 @@
from pydantic_core import to_json
from pydantic_ai import Agent, models
+from pydantic_ai.messages import MultiModalContentTypes, UserContent
from pydantic_ai.settings import ModelSettings
__all__ = (
@@ -62,16 +64,7 @@ async def judge_output(
If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
but this can be changed using the `set_default_judge_model` function.
"""
- user_prompt = dedent(
- f"""
-
-
- {rubric}
-
- """
- )
+ user_prompt = _build_prompt(output=output, rubric=rubric)
return (
await _judge_output_agent.run(user_prompt, model=model or _default_model, model_settings=model_settings)
).output
@@ -112,19 +105,8 @@ async def judge_input_output(
If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
but this can be changed using the `set_default_judge_model` function.
"""
- user_prompt = dedent(
- f"""
-
- {_stringify(inputs)}
-
-
-
- {rubric}
-
- """
- )
+ user_prompt = _build_prompt(inputs=inputs, output=output, rubric=rubric)
+
return (
await _judge_input_output_agent.run(user_prompt, model=model or _default_model, model_settings=model_settings)
).output
@@ -168,22 +150,7 @@ async def judge_input_output_expected(
If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
but this can be changed using the `set_default_judge_model` function.
"""
- user_prompt = dedent(
- f"""
-
- {_stringify(inputs)}
-
-
- {_stringify(expected_output)}
-
-
-
- {rubric}
-
- """
- )
+ user_prompt = _build_prompt(inputs=inputs, output=output, rubric=rubric, expected_output=expected_output)
return (
await _judge_input_output_expected_agent.run(
@@ -227,19 +194,7 @@ async def judge_output_expected(
If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
but this can be changed using the `set_default_judge_model` function.
"""
- user_prompt = dedent(
- f"""
-
- {_stringify(expected_output)}
-
-
-
- {rubric}
-
- """
- )
+ user_prompt = _build_prompt(output=output, rubric=rubric, expected_output=expected_output)
return (
await _judge_output_expected_agent.run(
user_prompt, model=model or _default_model, model_settings=model_settings
@@ -265,3 +220,41 @@ def _stringify(value: Any) -> str:
return to_json(value).decode()
except Exception:
return repr(value)
+
+
+def _build_prompt(
+ output: Any,
+ rubric: str,
+ inputs: Any | None = None,
+ expected_output: Any | None = None,
+) -> str | Sequence[str | UserContent]:
+ """Build a prompt that includes input, output, and rubric."""
+ sections: list[str | UserContent] = []
+
+ if inputs is not None:
+ if isinstance(inputs, str):
+ sections.append(f'\n{inputs}\n')
+ else:
+ sections.append('\n')
+ if isinstance(inputs, Sequence):
+ for item in inputs: # type: ignore
+ if isinstance(item, (str, MultiModalContentTypes)):
+ sections.append(item)
+ else:
+ sections.append(_stringify(item))
+ elif isinstance(inputs, MultiModalContentTypes):
+ sections.append(inputs)
+ else:
+ sections.append(_stringify(inputs))
+ sections.append('')
+
+ sections.append(f'')
+ sections.append(f'\n{rubric}\n')
+
+ if expected_output is not None:
+ sections.append(f'\n{_stringify(expected_output)}\n')
+
+ if inputs is None or isinstance(inputs, str):
+ return '\n\n'.join(sections) # type: ignore[arg-type]
+ else:
+ return sections
diff --git a/tests/evals/test_llm_as_a_judge.py b/tests/evals/test_llm_as_a_judge.py
index 4e18c5b13d..404c1f81a8 100644
--- a/tests/evals/test_llm_as_a_judge.py
+++ b/tests/evals/test_llm_as_a_judge.py
@@ -1,9 +1,10 @@
from __future__ import annotations as _annotations
import pytest
+from inline_snapshot import snapshot
from pytest_mock import MockerFixture
-from ..conftest import try_import
+from ..conftest import BinaryContent, try_import
with try_import() as imports_successful:
from pydantic_ai.settings import ModelSettings
@@ -141,6 +142,54 @@ async def test_judge_input_output_mock(mocker: MockerFixture):
assert '\nOutput contains input\n' in call_args[0]
+async def test_judge_input_output_binary_content_list_mock(mocker: MockerFixture, image_content: BinaryContent):
+ """Test judge_input_output function with mocked agent."""
+ # Mock the agent run method
+ mock_result = mocker.MagicMock()
+ mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0)
+ mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result)
+
+ result = await judge_input_output([image_content, image_content], 'Hello world', 'Output contains input')
+ assert isinstance(result, GradingOutput)
+ assert result.reason == 'Test passed'
+ assert result.pass_ is True
+ assert result.score == 1.0
+
+ # Verify the agent was called with correct prompt
+ mock_run.assert_called_once()
+ raw_prompt = mock_run.call_args[0][0]
+
+ # 1) It must be a list
+ assert isinstance(raw_prompt, list), 'Expected prompt to be a list when passing binary'
+
+ # 2) The BinaryContent you passed in should be one of the elements
+ assert image_content in raw_prompt, 'Expected the exact BinaryContent instance to be in the prompt list'
+
+
+async def test_judge_input_output_binary_content_mock(mocker: MockerFixture, image_content: BinaryContent):
+ """Test judge_input_output function with mocked agent."""
+ # Mock the agent run method
+ mock_result = mocker.MagicMock()
+ mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0)
+ mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result)
+
+ result = await judge_input_output(image_content, 'Hello world', 'Output contains input')
+ assert isinstance(result, GradingOutput)
+ assert result.reason == 'Test passed'
+ assert result.pass_ is True
+ assert result.score == 1.0
+
+ # Verify the agent was called with correct prompt
+ mock_run.assert_called_once()
+ raw_prompt = mock_run.call_args[0][0]
+
+ # 1) It must be a list
+ assert isinstance(raw_prompt, list), 'Expected prompt to be a list when passing binary'
+
+ # 2) The BinaryContent you passed in should be one of the elements
+ assert image_content in raw_prompt, 'Expected the exact BinaryContent instance to be in the prompt list'
+
+
@pytest.mark.anyio
async def test_judge_input_output_with_model_settings_mock(mocker: MockerFixture):
"""Test judge_input_output function with model_settings and mocked agent."""
@@ -172,7 +221,7 @@ async def test_judge_input_output_with_model_settings_mock(mocker: MockerFixture
@pytest.mark.anyio
-async def test_judge_input_output_expected_mock(mocker: MockerFixture):
+async def test_judge_input_output_expected_mock(mocker: MockerFixture, image_content: BinaryContent):
"""Test judge_input_output_expected function with mocked agent."""
# Mock the agent run method
mock_result = mocker.MagicMock()
@@ -187,16 +236,29 @@ async def test_judge_input_output_expected_mock(mocker: MockerFixture):
assert result.score == 1.0
# Verify the agent was called with correct prompt
- mock_run.assert_called_once()
call_args = mock_run.call_args[0]
assert '\nHello\n' in call_args[0]
assert '\nHello\n' in call_args[0]
assert '' in call_args[0]
assert '\nOutput contains input\n' in call_args[0]
+ result = await judge_input_output_expected(image_content, 'Hello world', 'Hello', 'Output contains input')
+ assert isinstance(result, GradingOutput)
+ assert result.reason == 'Test passed'
+ assert result.pass_ is True
+ assert result.score == 1.0
+
+ call_args = mock_run.call_args[0]
+ assert image_content in call_args[0]
+ assert '\nHello\n' in call_args[0]
+ assert '' in call_args[0]
+ assert '\nOutput contains input\n' in call_args[0]
+
@pytest.mark.anyio
-async def test_judge_input_output_expected_with_model_settings_mock(mocker: MockerFixture):
+async def test_judge_input_output_expected_with_model_settings_mock(
+ mocker: MockerFixture, image_content: BinaryContent
+):
"""Test judge_input_output_expected function with model_settings and mocked agent."""
mock_result = mocker.MagicMock()
mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0)
@@ -216,7 +278,6 @@ async def test_judge_input_output_expected_with_model_settings_mock(mocker: Mock
assert result.pass_ is True
assert result.score == 1.0
- mock_run.assert_called_once()
call_args, call_kwargs = mock_run.call_args
assert '\nHello settings\n' in call_args[0]
assert '\nHello\n' in call_args[0]
@@ -226,6 +287,108 @@ async def test_judge_input_output_expected_with_model_settings_mock(mocker: Mock
# Check if 'model' kwarg is passed, its value will be the default model or None
assert 'model' in call_kwargs
+ result = await judge_input_output_expected(
+ image_content,
+ 'Hello world with settings',
+ 'Hello',
+ 'Output contains input with settings',
+ model_settings=test_model_settings,
+ )
+
+ assert isinstance(result, GradingOutput)
+ assert result.reason == 'Test passed with settings'
+ assert result.pass_ is True
+ assert result.score == 1.0
+
+ call_args, call_kwargs = mock_run.call_args
+ assert image_content in call_args[0]
+ assert '\nHello\n' in call_args[0]
+ assert '' in call_args[0]
+ assert '\nOutput contains input with settings\n' in call_args[0]
+ assert call_kwargs['model_settings'] == test_model_settings
+ # Check if 'model' kwarg is passed, its value will be the default model or None
+ assert 'model' in call_kwargs
+
+ result = await judge_input_output_expected(
+ 123,
+ 'Hello world with settings',
+ 'Hello',
+ 'Output contains input with settings',
+ model_settings=test_model_settings,
+ )
+
+ assert isinstance(result, GradingOutput)
+ assert result.reason == 'Test passed with settings'
+ assert result.pass_ is True
+ assert result.score == 1.0
+
+ call_args, call_kwargs = mock_run.call_args
+
+ assert call_args == snapshot(
+ (
+ [
+ '\n',
+ '123',
+ '',
+ """\
+\
+""",
+ """\
+
+Output contains input with settings
+\
+""",
+ """\
+
+Hello
+\
+""",
+ ],
+ )
+ )
+
+ result = await judge_input_output_expected(
+ [123],
+ 'Hello world with settings',
+ 'Hello',
+ 'Output contains input with settings',
+ model_settings=test_model_settings,
+ )
+
+ assert isinstance(result, GradingOutput)
+ assert result.reason == 'Test passed with settings'
+ assert result.pass_ is True
+ assert result.score == 1.0
+
+ call_args, call_kwargs = mock_run.call_args
+
+ assert call_args == snapshot(
+ (
+ [
+ '\n',
+ '123',
+ '',
+ """\
+\
+""",
+ """\
+
+Output contains input with settings
+\
+""",
+ """\
+
+Hello
+\
+""",
+ ],
+ )
+ )
+
@pytest.mark.anyio
async def test_judge_output_expected_mock(mocker: MockerFixture):
@@ -243,7 +406,6 @@ async def test_judge_output_expected_mock(mocker: MockerFixture):
assert result.score == 1.0
# Verify the agent was called with correct prompt
- mock_run.assert_called_once()
call_args = mock_run.call_args[0]
assert '' not in call_args[0]
assert '\nHello\n' in call_args[0]
@@ -252,7 +414,7 @@ async def test_judge_output_expected_mock(mocker: MockerFixture):
@pytest.mark.anyio
-async def test_judge_output_expected_with_model_settings_mock(mocker: MockerFixture):
+async def test_judge_output_expected_with_model_settings_mock(mocker: MockerFixture, image_content: BinaryContent):
"""Test judge_output_expected function with model_settings and mocked agent."""
mock_result = mocker.MagicMock()
mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0)
@@ -280,3 +442,23 @@ async def test_judge_output_expected_with_model_settings_mock(mocker: MockerFixt
assert call_kwargs['model_settings'] == test_model_settings
# Check if 'model' kwarg is passed, its value will be the default model or None
assert 'model' in call_kwargs
+
+ result = await judge_output_expected(
+ image_content,
+ 'Hello',
+ 'Output contains input with settings',
+ model_settings=test_model_settings,
+ )
+ assert isinstance(result, GradingOutput)
+ assert result.reason == 'Test passed with settings'
+ assert result.pass_ is True
+ assert result.score == 1.0
+
+ call_args, call_kwargs = mock_run.call_args
+ assert '' not in call_args[0]
+ assert '\nHello\n' in call_args[0]
+ assert '