Skip to content

Commit 22eb7e5

Browse files
GWealecopybara-github
authored andcommitted
feat: Add support for parsing inline JSON tool calls in LiteLLM responses
Close #1968 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 831911719
1 parent 2efc184 commit 22eb7e5

File tree

4 files changed

+511
-53
lines changed

4 files changed

+511
-53
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from . import agent
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import datetime
18+
import json
19+
import re
20+
from typing import Any
21+
from zoneinfo import ZoneInfo
22+
from zoneinfo import ZoneInfoNotFoundError
23+
24+
from google.adk.agents.llm_agent import Agent
25+
from google.adk.models.lite_llm import LiteLlm
26+
from google.adk.models.lite_llm import LiteLLMClient
27+
28+
29+
class InlineJsonToolClient(LiteLLMClient):
30+
"""LiteLLM client that emits inline JSON tool calls for testing."""
31+
32+
async def acompletion(self, model, messages, tools, **kwargs):
33+
del tools, kwargs # Only needed for API parity.
34+
35+
tool_message = _find_last_role(messages, role="tool")
36+
if tool_message:
37+
tool_summary = _coerce_to_text(tool_message.get("content"))
38+
return {
39+
"id": "mock-inline-tool-final-response",
40+
"model": model,
41+
"choices": [{
42+
"message": {
43+
"role": "assistant",
44+
"content": (
45+
f"The instrumentation tool responded with: {tool_summary}"
46+
),
47+
},
48+
"finish_reason": "stop",
49+
}],
50+
"usage": {
51+
"prompt_tokens": 60,
52+
"completion_tokens": 12,
53+
"total_tokens": 72,
54+
},
55+
}
56+
57+
timezone = _extract_timezone(messages) or "Asia/Taipei"
58+
inline_call = json.dumps(
59+
{
60+
"name": "get_current_time",
61+
"arguments": {"timezone_str": timezone},
62+
},
63+
separators=(",", ":"),
64+
)
65+
66+
return {
67+
"id": "mock-inline-tool-call",
68+
"model": model,
69+
"choices": [{
70+
"message": {
71+
"role": "assistant",
72+
"content": (
73+
f"{inline_call}\nLet me double-check the clock for you."
74+
),
75+
},
76+
"finish_reason": "tool_calls",
77+
}],
78+
"usage": {
79+
"prompt_tokens": 45,
80+
"completion_tokens": 15,
81+
"total_tokens": 60,
82+
},
83+
}
84+
85+
86+
def _find_last_role(
87+
messages: list[dict[str, Any]], role: str
88+
) -> dict[str, Any]:
89+
"""Returns the last message with the given role."""
90+
for message in reversed(messages):
91+
if message.get("role") == role:
92+
return message
93+
return {}
94+
95+
96+
def _coerce_to_text(content: Any) -> str:
97+
"""Best-effort conversion from OpenAI message content to text."""
98+
if isinstance(content, str):
99+
return content
100+
if isinstance(content, dict):
101+
return _coerce_to_text(content.get("text"))
102+
if isinstance(content, list):
103+
texts = []
104+
for part in content:
105+
if isinstance(part, dict):
106+
texts.append(part.get("text") or "")
107+
elif isinstance(part, str):
108+
texts.append(part)
109+
return " ".join(text for text in texts if text)
110+
return ""
111+
112+
113+
_TIMEZONE_PATTERN = re.compile(r"([A-Za-z]+/[A-Za-z_]+)")
114+
115+
116+
def _extract_timezone(messages: list[dict[str, Any]]) -> str | None:
117+
"""Extracts an IANA timezone string from the last user message."""
118+
user_message = _find_last_role(messages, role="user")
119+
text = _coerce_to_text(user_message.get("content"))
120+
if not text:
121+
return None
122+
match = _TIMEZONE_PATTERN.search(text)
123+
if match:
124+
return match.group(1)
125+
lowered = text.lower()
126+
if "taipei" in lowered:
127+
return "Asia/Taipei"
128+
if "new york" in lowered:
129+
return "America/New_York"
130+
if "london" in lowered:
131+
return "Europe/London"
132+
if "tokyo" in lowered:
133+
return "Asia/Tokyo"
134+
return None
135+
136+
137+
def get_current_time(timezone_str: str) -> dict[str, str]:
138+
"""Returns mock current time for the provided timezone."""
139+
try:
140+
tz = ZoneInfo(timezone_str)
141+
except ZoneInfoNotFoundError as exc:
142+
return {
143+
"status": "error",
144+
"report": f"Unable to parse timezone '{timezone_str}': {exc}",
145+
}
146+
now = datetime.datetime.now(tz)
147+
return {
148+
"status": "success",
149+
"report": (
150+
f"The current time in {timezone_str} is"
151+
f" {now.strftime('%Y-%m-%d %H:%M:%S %Z')}."
152+
),
153+
}
154+
155+
156+
_mock_model = LiteLlm(
157+
model="mock/inline-json-tool-calls",
158+
llm_client=InlineJsonToolClient(),
159+
)
160+
161+
root_agent = Agent(
162+
name="litellm_inline_tool_tester",
163+
model=_mock_model,
164+
description=(
165+
"Demonstrates LiteLLM inline JSON tool-call parsing without an external"
166+
" VLLM deployment."
167+
),
168+
instruction=(
169+
"You are a deterministic clock assistant. Always call the"
170+
" get_current_time tool before answering user questions. After the tool"
171+
" responds, summarize what it returned."
172+
),
173+
tools=[get_current_time],
174+
)

src/google/adk/models/lite_llm.py

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from typing import Tuple
3232
from typing import TypedDict
3333
from typing import Union
34+
import uuid
3435
import warnings
3536

3637
from google.genai import types
@@ -64,6 +65,7 @@
6465
_NEW_LINE = "\n"
6566
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
6667
_LITELLM_STRUCTURED_TYPES = {"json_object", "json_schema"}
68+
_JSON_DECODER = json.JSONDecoder()
6769

6870
# Mapping of LiteLLM finish_reason strings to FinishReason enum values
6971
# Note: tool_calls/function_call map to STOP because:
@@ -431,6 +433,118 @@ def _get_content(
431433
return content_objects
432434

433435

436+
def _build_tool_call_from_json_dict(
437+
candidate: Any, *, index: int
438+
) -> Optional[ChatCompletionMessageToolCall]:
439+
"""Creates a tool call object from JSON content embedded in text."""
440+
441+
if not isinstance(candidate, dict):
442+
return None
443+
444+
name = candidate.get("name")
445+
args = candidate.get("arguments")
446+
if not isinstance(name, str) or args is None:
447+
return None
448+
449+
if isinstance(args, str):
450+
arguments_payload = args
451+
else:
452+
try:
453+
arguments_payload = json.dumps(args, ensure_ascii=False)
454+
except (TypeError, ValueError):
455+
arguments_payload = _safe_json_serialize(args)
456+
457+
call_id = candidate.get("id") or f"adk_tool_call_{uuid.uuid4().hex}"
458+
call_index = candidate.get("index")
459+
if isinstance(call_index, int):
460+
index = call_index
461+
462+
function = Function(
463+
name=name,
464+
arguments=arguments_payload,
465+
)
466+
# Some LiteLLM types carry an `index` field only in streaming contexts,
467+
# so guard the assignment to stay compatible with older versions.
468+
if hasattr(function, "index"):
469+
function.index = index # type: ignore[attr-defined]
470+
471+
tool_call = ChatCompletionMessageToolCall(
472+
type="function",
473+
id=str(call_id),
474+
function=function,
475+
)
476+
# Same reasoning as above: not every ChatCompletionMessageToolCall exposes it.
477+
if hasattr(tool_call, "index"):
478+
tool_call.index = index # type: ignore[attr-defined]
479+
480+
return tool_call
481+
482+
483+
def _parse_tool_calls_from_text(
484+
text_block: str,
485+
) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]:
486+
"""Extracts inline JSON tool calls from LiteLLM text responses."""
487+
488+
tool_calls = []
489+
if not text_block:
490+
return tool_calls, None
491+
492+
remainder_segments = []
493+
cursor = 0
494+
text_length = len(text_block)
495+
496+
while cursor < text_length:
497+
brace_index = text_block.find("{", cursor)
498+
if brace_index == -1:
499+
remainder_segments.append(text_block[cursor:])
500+
break
501+
502+
remainder_segments.append(text_block[cursor:brace_index])
503+
try:
504+
candidate, end = _JSON_DECODER.raw_decode(text_block, brace_index)
505+
except json.JSONDecodeError:
506+
remainder_segments.append(text_block[brace_index])
507+
cursor = brace_index + 1
508+
continue
509+
510+
tool_call = _build_tool_call_from_json_dict(
511+
candidate, index=len(tool_calls)
512+
)
513+
if tool_call:
514+
tool_calls.append(tool_call)
515+
else:
516+
remainder_segments.append(text_block[brace_index:end])
517+
cursor = end
518+
519+
remainder = "".join(segment for segment in remainder_segments if segment)
520+
remainder = remainder.strip()
521+
522+
return tool_calls, remainder or None
523+
524+
525+
def _split_message_content_and_tool_calls(
526+
message: Message,
527+
) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]:
528+
"""Returns message content and tool calls, parsing inline JSON when needed."""
529+
530+
existing_tool_calls = message.get("tool_calls") or []
531+
normalized_tool_calls = (
532+
list(existing_tool_calls) if existing_tool_calls else []
533+
)
534+
content = message.get("content")
535+
536+
# LiteLLM responses either provide structured tool_calls or inline JSON, not
537+
# both. When tool_calls are present we trust them and skip the fallback parser.
538+
if normalized_tool_calls or not isinstance(content, str):
539+
return content, normalized_tool_calls
540+
541+
fallback_tool_calls, remainder = _parse_tool_calls_from_text(content)
542+
if fallback_tool_calls:
543+
return remainder, fallback_tool_calls
544+
545+
return content, []
546+
547+
434548
def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
435549
"""Converts a types.Content role to a litellm role.
436550
@@ -584,15 +698,24 @@ def _model_response_to_chunk(
584698
if message is None and response["choices"][0].get("delta", None):
585699
message = response["choices"][0]["delta"]
586700

587-
if message.get("content", None):
588-
yield TextChunk(text=message.get("content")), finish_reason
701+
message_content: Optional[OpenAIMessageContent] = None
702+
tool_calls: list[ChatCompletionMessageToolCall] = []
703+
if message is not None:
704+
(
705+
message_content,
706+
tool_calls,
707+
) = _split_message_content_and_tool_calls(message)
589708

590-
if message.get("tool_calls", None):
591-
for tool_call in message.get("tool_calls"):
709+
if message_content:
710+
yield TextChunk(text=message_content), finish_reason
711+
712+
if tool_calls:
713+
for idx, tool_call in enumerate(tool_calls):
592714
# aggregate tool_call
593715
if tool_call.type == "function":
594716
func_name = tool_call.function.name
595717
func_args = tool_call.function.arguments
718+
func_index = getattr(tool_call, "index", idx)
596719

597720
# Ignore empty chunks that don't carry any information.
598721
if not func_name and not func_args:
@@ -602,12 +725,10 @@ def _model_response_to_chunk(
602725
id=tool_call.id,
603726
name=func_name,
604727
args=func_args,
605-
index=tool_call.index,
728+
index=func_index,
606729
), finish_reason
607730

608-
if finish_reason and not (
609-
message.get("content", None) or message.get("tool_calls", None)
610-
):
731+
if finish_reason and not (message_content or tool_calls):
611732
yield None, finish_reason
612733

613734
if not message:
@@ -687,11 +808,12 @@ def _message_to_generate_content_response(
687808
"""
688809

689810
parts = []
690-
if message.get("content", None):
691-
parts.append(types.Part.from_text(text=message.get("content")))
811+
message_content, tool_calls = _split_message_content_and_tool_calls(message)
812+
if isinstance(message_content, str) and message_content:
813+
parts.append(types.Part.from_text(text=message_content))
692814

693-
if message.get("tool_calls", None):
694-
for tool_call in message.get("tool_calls"):
815+
if tool_calls:
816+
for tool_call in tool_calls:
695817
if tool_call.type == "function":
696818
part = types.Part.from_function_call(
697819
name=tool_call.function.name,

0 commit comments

Comments
 (0)