Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ def _process_create_args(
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")

if create_args.get("model", "unknown").startswith("gemini-"):
# Gemini models accept only one system message(else, it will read only the last one)
# So, merge system messages into one
if not self.model_info.get("multiple_system_messages", False):
# Some models accept only one system message(or, it will read only the last one)
# So, merge system messages into one (if multiple and continuous)
system_message_content = ""
_messages: List[LLMMessage] = []
_first_system_message_idx = -1
Expand All @@ -540,7 +540,9 @@ def _process_create_args(
elif _last_system_message_idx + 1 != idx:
# That case, system message is not continuous
# Merge system messages only contiues system messages
raise ValueError("Multiple and Not continuous system messages are not supported")
raise ValueError(
"Multiple and Not continuous system messages are not supported if model_info['multiple_system_messages'] is False"
)
system_message_content += message.content + "\n"
_last_system_message_idx = idx
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2096,7 +2096,7 @@ async def test_muliple_system_message(model: str, openai_client: OpenAIChatCompl


@pytest.mark.asyncio
async def test_system_message_merge_for_gemini_models() -> None:
async def test_system_message_merge_with_continuous_system_messages_models() -> None:
"""Tests that system messages are merged correctly for Gemini models."""
# Create a mock client
mock_client = MagicMock()
Expand All @@ -2109,6 +2109,7 @@ async def test_system_message_merge_for_gemini_models() -> None:
"json_output": False,
"family": "unknown",
"structured_output": False,
"multiple_system_messages": False,
},
)

Expand Down Expand Up @@ -2157,6 +2158,7 @@ async def test_system_message_merge_with_non_continuous_messages() -> None:
"json_output": False,
"family": "unknown",
"structured_output": False,
"multiple_system_messages": False,
},
)

Expand All @@ -2180,7 +2182,7 @@ async def test_system_message_merge_with_non_continuous_messages() -> None:


@pytest.mark.asyncio
async def test_system_message_not_merged_for_non_gemini_models() -> None:
async def test_system_message_not_merged_for_multiple_system_messages_true() -> None:
"""Tests that system messages aren't modified for non-Gemini models."""
# Create a mock client
mock_client = MagicMock()
Expand All @@ -2193,6 +2195,7 @@ async def test_system_message_not_merged_for_non_gemini_models() -> None:
"json_output": False,
"family": "unknown",
"structured_output": False,
"multiple_system_messages": True,
},
)

Expand Down
Loading