Skip to content

Commit cc96fcc

Browse files
longcwakshaym1shra
authored andcommitted
gemini realtime: support NON_BLOCKING tool behavior (livekit#3482)
1 parent 3707029 commit cc96fcc

File tree

2 files changed

+72
-14
lines changed

2 files changed

+72
-14
lines changed

livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from dataclasses import dataclass, field
1111
from typing import Literal
1212

13-
from google import genai
14-
from google.genai import types
13+
from google.genai import Client as GenAIClient, types
1514
from google.genai.live import AsyncSession
1615
from livekit import rtc
1716
from livekit.agents import APIConnectionError, llm, utils
@@ -76,6 +75,8 @@ class _RealtimeOptions:
7675
context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN
7776
api_version: NotGivenOr[str] = NOT_GIVEN
7877
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN
78+
tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN
79+
tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN
7980

8081

8182
@dataclass
@@ -136,6 +137,8 @@ def __init__(
136137
proactivity: NotGivenOr[bool] = NOT_GIVEN,
137138
realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN,
138139
context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN,
140+
tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN,
141+
tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN,
139142
api_version: NotGivenOr[str] = NOT_GIVEN,
140143
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
141144
http_options: NotGivenOr[types.HttpOptions] = NOT_GIVEN,
@@ -174,6 +177,8 @@ def __init__(
174177
proactivity (bool, optional): Whether to enable proactive audio. Defaults to False.
175178
realtime_input_config (RealtimeInputConfig, optional): The configuration for realtime input. Defaults to None.
176179
context_window_compression (ContextWindowCompressionConfig, optional): The configuration for context window compression. Defaults to None.
180+
tool_behavior (Behavior, optional): The behavior for tool call. Default behavior is BLOCK in Gemini Realtime API.
181+
tool_response_scheduling (FunctionResponseScheduling, optional): The scheduling for tool response. Default scheduling is WHEN_IDLE.
177182
conn_options (APIConnectOptions, optional): The configuration for the API connection. Defaults to DEFAULT_API_CONNECT_OPTIONS.
178183
_gemini_tools (list[LLMTool], optional): Gemini-specific tools to use for the session. This parameter is experimental and may change.
179184
@@ -265,6 +270,7 @@ def __init__(
265270
context_window_compression=context_window_compression,
266271
api_version=api_version,
267272
gemini_tools=_gemini_tools,
273+
tool_behavior=tool_behavior,
268274
conn_options=conn_options,
269275
http_options=http_options,
270276
)
@@ -281,6 +287,8 @@ def update_options(
281287
*,
282288
voice: NotGivenOr[str] = NOT_GIVEN,
283289
temperature: NotGivenOr[float] = NOT_GIVEN,
290+
tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN,
291+
tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN,
284292
) -> None:
285293
"""
286294
Update the options for the RealtimeModel.
@@ -296,10 +304,18 @@ def update_options(
296304
if is_given(temperature):
297305
self._opts.temperature = temperature
298306

307+
if is_given(tool_behavior):
308+
self._opts.tool_behavior = tool_behavior
309+
310+
if is_given(tool_response_scheduling):
311+
self._opts.tool_response_scheduling = tool_response_scheduling
312+
299313
for sess in self._sessions:
300314
sess.update_options(
301315
voice=self._opts.voice,
302316
temperature=self._opts.temperature,
317+
tool_behavior=self._opts.tool_behavior,
318+
tool_response_scheduling=self._opts.tool_response_scheduling,
303319
)
304320

305321
async def aclose(self) -> None:
@@ -337,7 +353,7 @@ def __init__(self, realtime_model: RealtimeModel) -> None:
337353
if api_version:
338354
http_options.api_version = api_version
339355

340-
self._client = genai.Client(
356+
self._client = GenAIClient(
341357
api_key=self._opts.api_key,
342358
vertexai=self._opts.vertexai,
343359
project=self._opts.project,
@@ -381,6 +397,8 @@ def update_options(
381397
voice: NotGivenOr[str] = NOT_GIVEN,
382398
temperature: NotGivenOr[float] = NOT_GIVEN,
383399
tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
400+
tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN,
401+
tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN,
384402
) -> None:
385403
should_restart = False
386404
if is_given(voice) and self._opts.voice != voice:
@@ -391,6 +409,20 @@ def update_options(
391409
self._opts.temperature = temperature if is_given(temperature) else NOT_GIVEN
392410
should_restart = True
393411

412+
if is_given(tool_behavior) and self._opts.tool_behavior != tool_behavior:
413+
self._opts.tool_behavior = tool_behavior
414+
should_restart = True
415+
416+
if (
417+
is_given(tool_response_scheduling)
418+
and self._opts.tool_response_scheduling != tool_response_scheduling
419+
):
420+
self._opts.tool_response_scheduling = tool_response_scheduling
421+
# no need to restart
422+
423+
if is_given(tool_choice):
424+
logger.warning("tool_choice is not supported by the Google Realtime API.")
425+
394426
if should_restart:
395427
self._mark_restart_needed()
396428

@@ -422,7 +454,11 @@ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
422454
).to_provider_format(format="google", inject_dummy_user_message=False)
423455
# we are not generating, and do not need to inject
424456
turns = [types.Content.model_validate(turn) for turn in turns_dict]
425-
tool_results = get_tool_results_for_realtime(append_ctx, vertexai=self._opts.vertexai)
457+
tool_results = get_tool_results_for_realtime(
458+
append_ctx,
459+
vertexai=self._opts.vertexai,
460+
tool_response_scheduling=self._opts.tool_response_scheduling,
461+
)
426462
if turns:
427463
self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=False))
428464
if tool_results:
@@ -434,7 +470,7 @@ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
434470

435471
async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool]) -> None:
436472
new_declarations: list[types.FunctionDeclaration] = to_fnc_ctx(
437-
tools, use_parameters_json_schema=False
473+
tools, use_parameters_json_schema=False, tool_behavior=self._opts.tool_behavior
438474
)
439475
current_tool_names = {f.name for f in self._gemini_declarations}
440476
new_tool_names = {f.name for f in new_declarations}

livekit-plugins/livekit-plugins-google/livekit/plugins/google/utils.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
is_function_tool,
1717
is_raw_function_tool,
1818
)
19+
from livekit.agents.types import NOT_GIVEN, NotGivenOr
20+
from livekit.agents.utils import is_given
1921

2022
from .log import logger
2123
from .tools import _LLMTool
@@ -24,7 +26,10 @@
2426

2527

2628
def to_fnc_ctx(
27-
fncs: list[FunctionTool | RawFunctionTool], *, use_parameters_json_schema: bool = True
29+
fncs: list[FunctionTool | RawFunctionTool],
30+
*,
31+
use_parameters_json_schema: bool = True,
32+
tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN,
2833
) -> list[types.FunctionDeclaration]:
2934
tools: list[types.FunctionDeclaration] = []
3035
for fnc in fncs:
@@ -43,10 +48,14 @@ def to_fnc_ctx(
4348
info.raw_schema.get("parameters", {})
4449
)
4550
)
51+
52+
if is_given(tool_behavior):
53+
fnc_kwargs["behavior"] = tool_behavior
54+
4655
tools.append(types.FunctionDeclaration(**fnc_kwargs))
4756

4857
elif is_function_tool(fnc):
49-
tools.append(_build_gemini_fnc(fnc))
58+
tools.append(_build_gemini_fnc(fnc, tool_behavior=tool_behavior))
5059

5160
return tools
5261

@@ -88,14 +97,20 @@ def create_tools_config(
8897

8998

9099
def get_tool_results_for_realtime(
91-
chat_ctx: llm.ChatContext, *, vertexai: bool = False
100+
chat_ctx: llm.ChatContext,
101+
*,
102+
vertexai: bool = False,
103+
tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN,
92104
) -> types.LiveClientToolResponse | None:
93105
function_responses: list[types.FunctionResponse] = []
94106
for msg in chat_ctx.items:
95107
if msg.type == "function_call_output":
96108
res = types.FunctionResponse(
97109
name=msg.name,
98110
response={"output": msg.output},
111+
scheduling=tool_response_scheduling
112+
if is_given(tool_response_scheduling)
113+
else types.FunctionResponseScheduling.WHEN_IDLE,
99114
)
100115
if not vertexai:
101116
# vertexai does not support id in FunctionResponse
@@ -109,14 +124,21 @@ def get_tool_results_for_realtime(
109124
)
110125

111126

112-
def _build_gemini_fnc(function_tool: FunctionTool) -> types.FunctionDeclaration:
127+
def _build_gemini_fnc(
128+
function_tool: FunctionTool, *, tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN
129+
) -> types.FunctionDeclaration:
113130
fnc = llm.utils.build_legacy_openai_schema(function_tool, internally_tagged=True)
114131
json_schema = _GeminiJsonSchema(fnc["parameters"]).simplify()
115-
return types.FunctionDeclaration(
116-
name=fnc["name"],
117-
description=fnc["description"],
118-
parameters=types.Schema.model_validate(json_schema) if json_schema else None,
119-
)
132+
133+
kwargs = {
134+
"name": fnc["name"],
135+
"description": fnc["description"],
136+
"parameters": types.Schema.model_validate(json_schema) if json_schema else None,
137+
}
138+
if is_given(tool_behavior):
139+
kwargs["behavior"] = tool_behavior
140+
141+
return types.FunctionDeclaration(**kwargs)
120142

121143

122144
def to_response_format(response_format: type | dict) -> types.SchemaUnion:

0 commit comments

Comments
 (0)