Skip to content

Commit 894d8c6

Browse files
GWealecopybara-github
authored andcommitted
fix: Refactor LiteLLM response schema formatting for different models
The `_to_litellm_response_format` function now adapts the output format based on the provided model. Gemini models continue to use the "response_schema" key, while OpenAI-compatible models (including Azure OpenAI and Anthropic) now use the "json_schema" key as per LiteLLM's documentation for JSON mode. The schema name is also included in the "json_schema" format. Close #3713 Close #3890 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 843326850
1 parent 99f893a commit 894d8c6

File tree

2 files changed

+247
-24
lines changed

2 files changed

+247
-24
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -982,8 +982,20 @@ def _message_to_generate_content_response(
982982

983983
def _to_litellm_response_format(
984984
response_schema: types.SchemaUnion,
985-
) -> Optional[Dict[str, Any]]:
986-
"""Converts ADK response schema objects into LiteLLM-compatible payloads."""
985+
model: str,
986+
) -> dict[str, Any] | None:
987+
"""Converts ADK response schema objects into LiteLLM-compatible payloads.
988+
989+
Args:
990+
response_schema: The response schema to convert.
991+
model: The model string to determine the appropriate format. Gemini models
992+
use 'response_schema' key, while OpenAI-compatible models use
993+
'json_schema' key.
994+
995+
Returns:
996+
A dictionary with the appropriate response format for LiteLLM.
997+
"""
998+
schema_name = "response"
987999

9881000
if isinstance(response_schema, dict):
9891001
schema_type = response_schema.get("type")
@@ -993,33 +1005,63 @@ def _to_litellm_response_format(
9931005
):
9941006
return response_schema
9951007
schema_dict = dict(response_schema)
1008+
if "title" in schema_dict:
1009+
schema_name = str(schema_dict["title"])
9961010
elif isinstance(response_schema, type) and issubclass(
9971011
response_schema, BaseModel
9981012
):
9991013
schema_dict = response_schema.model_json_schema()
1014+
schema_name = response_schema.__name__
10001015
elif isinstance(response_schema, BaseModel):
10011016
if isinstance(response_schema, types.Schema):
10021017
# GenAI Schema instances already represent JSON schema definitions.
10031018
schema_dict = response_schema.model_dump(exclude_none=True, mode="json")
1019+
if "title" in schema_dict:
1020+
schema_name = str(schema_dict["title"])
10041021
else:
10051022
schema_dict = response_schema.__class__.model_json_schema()
1023+
schema_name = response_schema.__class__.__name__
10061024
elif hasattr(response_schema, "model_dump"):
10071025
schema_dict = response_schema.model_dump(exclude_none=True, mode="json")
1026+
schema_name = response_schema.__class__.__name__
10081027
else:
10091028
logger.warning(
10101029
"Unsupported response_schema type %s for LiteLLM structured outputs.",
10111030
type(response_schema),
10121031
)
10131032
return None
10141033

1034+
# Gemini models use a special response format with 'response_schema' key
1035+
if _is_litellm_gemini_model(model):
1036+
return {
1037+
"type": "json_object",
1038+
"response_schema": schema_dict,
1039+
}
1040+
1041+
# OpenAI-compatible format (default) per LiteLLM docs:
1042+
# https://docs.litellm.ai/docs/completion/json_mode
1043+
if (
1044+
isinstance(schema_dict, dict)
1045+
and schema_dict.get("type") == "object"
1046+
and "additionalProperties" not in schema_dict
1047+
):
1048+
# OpenAI structured outputs require explicit additionalProperties: false.
1049+
schema_dict = dict(schema_dict)
1050+
schema_dict["additionalProperties"] = False
1051+
10151052
return {
1016-
"type": "json_object",
1017-
"response_schema": schema_dict,
1053+
"type": "json_schema",
1054+
"json_schema": {
1055+
"name": schema_name,
1056+
"strict": True,
1057+
"schema": schema_dict,
1058+
},
10181059
}
10191060

10201061

10211062
async def _get_completion_inputs(
10221063
llm_request: LlmRequest,
1064+
model: str,
10231065
) -> Tuple[
10241066
List[Message],
10251067
Optional[List[Dict]],
@@ -1030,13 +1072,14 @@ async def _get_completion_inputs(
10301072
10311073
Args:
10321074
llm_request: The LlmRequest to convert.
1075+
model: The model string to use for determining provider-specific behavior.
10331076
10341077
Returns:
10351078
The litellm inputs (message list, tool dictionary, response format and
10361079
generation params).
10371080
"""
10381081
# Determine provider for file handling
1039-
provider = _get_provider_from_model(llm_request.model or "")
1082+
provider = _get_provider_from_model(model)
10401083

10411084
# 1. Construct messages
10421085
messages: List[Message] = []
@@ -1071,14 +1114,15 @@ async def _get_completion_inputs(
10711114
]
10721115

10731116
# 3. Handle response format
1074-
response_format: Optional[Dict[str, Any]] = None
1117+
response_format: dict[str, Any] | None = None
10751118
if llm_request.config and llm_request.config.response_schema:
10761119
response_format = _to_litellm_response_format(
1077-
llm_request.config.response_schema
1120+
llm_request.config.response_schema,
1121+
model=model,
10781122
)
10791123

10801124
# 4. Extract generation parameters
1081-
generation_params: Optional[Dict] = None
1125+
generation_params: dict | None = None
10821126
if llm_request.config:
10831127
config_dict = llm_request.config.model_dump(exclude_none=True)
10841128
# Generate LiteLlm parameters here,
@@ -1190,9 +1234,7 @@ def _is_litellm_gemini_model(model_string: str) -> bool:
11901234
Returns:
11911235
True if it's a Gemini model accessed via LiteLLM, False otherwise
11921236
"""
1193-
# Matches "gemini/gemini-*" (Google AI Studio) or "vertex_ai/gemini-*" (Vertex AI).
1194-
pattern = r"^(gemini|vertex_ai)/gemini-"
1195-
return bool(re.match(pattern, model_string))
1237+
return model_string.startswith(("gemini/gemini-", "vertex_ai/gemini-"))
11961238

11971239

11981240
def _extract_gemini_model_from_litellm(litellm_model: str) -> str:
@@ -1308,16 +1350,17 @@ async def generate_content_async(
13081350
_append_fallback_user_content_if_missing(llm_request)
13091351
logger.debug(_build_request_log(llm_request))
13101352

1353+
model = llm_request.model or self.model
13111354
messages, tools, response_format, generation_params = (
1312-
await _get_completion_inputs(llm_request)
1355+
await _get_completion_inputs(llm_request, model)
13131356
)
13141357

13151358
if "functions" in self._additional_args:
13161359
# LiteLLM does not support both tools and functions together.
13171360
tools = None
13181361

13191362
completion_args = {
1320-
"model": llm_request.model or self.model,
1363+
"model": model,
13211364
"messages": messages,
13221365
"tools": tools,
13231366
"response_format": response_format,

0 commit comments

Comments
 (0)