Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not retry on already streamed tokens #1239

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/dirty-jeans-burn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

Do not retry on already streamed tokens
51 changes: 50 additions & 1 deletion livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from livekit import rtc
from livekit.agents._exceptions import APIConnectionError, APIError
from typing_extensions import override

from .. import utils
from ..log import logger
Expand Down Expand Up @@ -66,6 +67,41 @@ class ToolChoice:


TEvent = TypeVar("TEvent")
T = TypeVar("T")


class TrackedChannel(aio.Chan[T]):
"""
A channel that tracks the number of times it has sent or received a value. We
use this to handle timeouts from the LLM - someone has received a value from
the LLM, we don't want to retry because the audio stream will have already
received the chat chunk.
"""

def __init__(self) -> None:
super().__init__()
self.num_sends = 0
self.num_receives = 0

@override
def send_nowait(self, value: T) -> None:
self.num_sends += 1
return super().send_nowait(value)

@override
async def send(self, value: T) -> None:
self.num_sends += 1
await super().send(value)

@override
async def recv(self) -> T:
self.num_receives += 1
return await super().recv()

@override
def recv_nowait(self) -> T:
self.num_receives += 1
return super().recv_nowait()


class LLM(
Expand Down Expand Up @@ -128,7 +164,7 @@ def __init__(
self._fnc_ctx = fnc_ctx
self._conn_options = conn_options

self._event_ch = aio.Chan[ChatChunk]()
self._event_ch = TrackedChannel[ChatChunk]()
self._event_aiter, monitor_aiter = aio.itertools.tee(self._event_ch, 2)
self._metrics_task = asyncio.create_task(
self._metrics_monitor_task(monitor_aiter), name="LLM._metrics_task"
Expand All @@ -148,6 +184,19 @@ async def _main_task(self) -> None:
try:
return await self._run()
except APIError as e:
if self._event_ch.num_receives > 0:
logger.warning(
"LLM already sent a value that was used, not retrying",
exc_info=e,
extra={
"llm": self._llm._label,
"attempt": i + 1,
"num_sends": self._event_ch.num_sends,
"num_receives": self._event_ch.num_receives,
},
)
raise

if self._conn_options.max_retry == 0:
raise
elif i == self._conn_options.max_retry:
Expand Down
Loading