Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 88 additions & 37 deletions livekit-agents/livekit/agents/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import datetime
import os
import threading
import time
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, AsyncIterator
Expand Down Expand Up @@ -278,17 +279,22 @@ async def _main_task(self) -> None:
return
except APIError as e:
retry_interval = self._conn_options._interval_for_retry(i)
if self._conn_options.max_retry == 0 or self._conn_options.max_retry == i:
should_retry = (
e.retryable
and not output_emitter.has_pushed_audio()
and self._conn_options.max_retry > 0
and i < self._conn_options.max_retry
)
if not should_retry:
self._emit_error(e, recoverable=False)
raise
else:
self._emit_error(e, recoverable=True)
logger.warning(
f"failed to synthesize speech, retrying in {retry_interval}s",
exc_info=e,
extra={"tts": self._tts._label, "attempt": i + 1, "streamed": False},
)

self._emit_error(e, recoverable=True)
logger.warning(
f"failed to synthesize speech, retrying in {retry_interval}s",
exc_info=e,
extra={"tts": self._tts._label, "attempt": i + 1, "streamed": False},
)
await asyncio.sleep(retry_interval)
# Reset the flag when retrying
self._current_attempt_has_error = False
Expand Down Expand Up @@ -353,6 +359,9 @@ def __init__(self, *, tts: TTS, conn_options: APIConnectOptions) -> None:
self._tts = tts
self._conn_options = conn_options
self._input_ch = aio.Chan[Union[str, SynthesizeStream._FlushSentinel]]()
self._input_lock = threading.Lock()
self._replay_events: list[str | SynthesizeStream._FlushSentinel] = []
self._input_ended = False
self._event_ch = aio.Chan[SynthesizedAudio]()
self._tee = aio.itertools.tee(self._event_ch, 2)
self._event_aiter, self._monitor_aiter = self._tee
Expand Down Expand Up @@ -388,6 +397,10 @@ async def _main_task(self) -> None:
)

for i in range(self._conn_options.max_retry + 1):
if i > 0:
# Retry runs `_run` again on the same stream instance. Most streaming TTS
# implementations consume `self._input_ch`, so we need to replay buffered input.
self._reset_for_retry()
output_emitter = AudioEmitter(label=self._tts.label, dst_ch=self._event_ch)
try:
with tracer.start_as_current_span("tts_request_run") as attempt_span:
Expand Down Expand Up @@ -416,17 +429,22 @@ async def _main_task(self) -> None:
return
except APIError as e:
retry_interval = self._conn_options._interval_for_retry(i)
if self._conn_options.max_retry == 0 or self._conn_options.max_retry == i:
should_retry = (
e.retryable
and not output_emitter.has_pushed_audio()
and self._conn_options.max_retry > 0
and i < self._conn_options.max_retry
)
if not should_retry:
self._emit_error(e, recoverable=False)
raise
else:
self._emit_error(e, recoverable=True)
logger.warning(
f"failed to synthesize speech, retrying in {retry_interval}s",
exc_info=e,
extra={"tts": self._tts._label, "attempt": i + 1, "streamed": True},
)

self._emit_error(e, recoverable=True)
logger.warning(
f"failed to synthesize speech, retrying in {retry_interval}s",
exc_info=e,
extra={"tts": self._tts._label, "attempt": i + 1, "streamed": True},
)
await asyncio.sleep(retry_interval)
# Reset the flag when retrying
self._current_attempt_has_error = False
Expand Down Expand Up @@ -509,51 +527,81 @@ def _emit_metrics() -> None:

def push_text(self, token: str) -> None:
"""Push some text to be synthesized"""
if not token or self._input_ch.closed:
if not token:
return
with self._input_lock:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

push_text is a sync method, any reason a lock is needed? I don't think we are going to push text from different threads, right?

if self._input_ch.closed:
return

self._pushed_text += token
if not self._mtc_text:
if self._num_segments >= 1:
logger.warning(
"SynthesizeStream: handling multiple segments in a single instance is "
"deprecated. Please create a new SynthesizeStream instance for each segment. "
"Most TTS plugins now use pooled WebSocket connections via ConnectionPool."
)
return

if self._metrics_task is None:
self._metrics_task = asyncio.create_task(
self._metrics_monitor_task(self._monitor_aiter), name="TTS._metrics_task"
)
self._num_segments += 1

if not self._mtc_text:
if self._num_segments >= 1:
logger.warning(
"SynthesizeStream: handling multiple segments in a single instance is "
"deprecated. Please create a new SynthesizeStream instance for each segment. "
"Most TTS plugins now use pooled WebSocket connections via ConnectionPool."
)
return
self._pushed_text += token
self._replay_events.append(token)

self._num_segments += 1
if self._metrics_task is None:
self._metrics_task = asyncio.create_task(
self._metrics_monitor_task(self._monitor_aiter), name="TTS._metrics_task"
)

self._mtc_text += token
self._input_ch.send_nowait(token)
self._mtc_text += token
self._input_ch.send_nowait(token)

def flush(self) -> None:
"""Mark the end of the current segment"""
with self._input_lock:
self._flush_locked()

def _flush_locked(self) -> None:
if self._input_ch.closed:
return

if self._mtc_text:
self._mtc_pending_texts.append(self._mtc_text)
self._mtc_text = ""

self._input_ch.send_nowait(self._FlushSentinel())
sentinel = self._FlushSentinel()
self._replay_events.append(sentinel)
self._input_ch.send_nowait(sentinel)

def end_input(self) -> None:
"""Mark the end of input, no more text will be pushed"""
self.flush()
self._input_ch.close()
with self._input_lock:
self._flush_locked()
self._input_ended = True
self._input_ch.close()

def _reset_for_retry(self) -> None:
# Reset per-attempt timing used for metrics; without this, retries can produce incorrect
# durations/TTFB because `_mark_started` only sets the first time.
self._started_time = 0
with self._input_lock:
old_ch = self._input_ch
ch = aio.Chan[Union[str, SynthesizeStream._FlushSentinel]]()
for ev in self._replay_events:
ch.send_nowait(ev)

if self._input_ended:
ch.close()

self._input_ch = ch
if not old_ch.closed:
old_ch.close()

async def aclose(self) -> None:
"""Close ths stream immediately"""
await aio.cancel_and_wait(self._task)
self._event_ch.close()
self._input_ch.close()
with self._input_lock:
self._input_ch.close()

if self._metrics_task is not None:
await self._metrics_task
Expand Down Expand Up @@ -626,6 +674,9 @@ def pushed_duration(self, idx: int = -1) -> float:
else 0.0
)

def has_pushed_audio(self) -> bool:
return any(d > 0.0 for d in self._audio_durations)

@property
def num_segments(self) -> int:
return self._num_segments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,10 @@ def __init__(self, *, tts: TTS, conn_options: APIConnectOptions):
super().__init__(tts=tts, conn_options=conn_options)
self._tts: TTS = tts
self._opts = replace(tts._opts)
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()

async def _run(self, output_emitter: tts.AudioEmitter) -> None:
request_id = utils.shortuuid()
segments_ch = utils.aio.Chan[tokenize.WordStream]()
output_emitter.initialize(
request_id=request_id,
sample_rate=self._opts.sample_rate,
Expand All @@ -270,17 +270,17 @@ async def _tokenize_input() -> None:
if isinstance(input, str):
if word_stream is None:
word_stream = self._opts.word_tokenizer.stream()
self._segments_ch.send_nowait(word_stream)
segments_ch.send_nowait(word_stream)
word_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if word_stream:
word_stream.end_input()
word_stream = None

self._segments_ch.close()
segments_ch.close()

async def _run_segments() -> None:
async for word_stream in self._segments_ch:
async for word_stream in segments_ch:
await self._run_ws(word_stream, output_emitter)

tasks = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ async def current_connection(self) -> _Connection:
if (
self._current_connection
and self._current_connection.is_current
and not self._current_connection._closed
and self._current_connection.is_open
):
return self._current_connection

Expand Down Expand Up @@ -363,17 +363,20 @@ def __init__(self, *, tts: TTS, conn_options: APIConnectOptions):
self._tts: TTS = tts
self._opts = replace(tts._opts)
self._context_id = utils.shortuuid()
self._sent_tokenizer_stream = self._opts.word_tokenizer.stream()
self._text_buffer = ""
self._start_times_ms: list[int] = []
self._durations_ms: list[int] = []
self._connection: _Connection | None = None

async def aclose(self) -> None:
await self._sent_tokenizer_stream.aclose()
await super().aclose()

async def _run(self, output_emitter: tts.AudioEmitter) -> None:
self._context_id = utils.shortuuid()
self._text_buffer = ""
self._start_times_ms.clear()
self._durations_ms.clear()

output_emitter.initialize(
request_id=self._context_id,
sample_rate=self._opts.sample_rate,
Expand All @@ -383,6 +386,8 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
)
output_emitter.start_segment(segment_id=self._context_id)

sent_tokenizer_stream = self._opts.word_tokenizer.stream()

connection: _Connection
try:
connection = await asyncio.wait_for(
Expand All @@ -399,10 +404,10 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
async def _input_task() -> None:
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
self._sent_tokenizer_stream.flush()
sent_tokenizer_stream.flush()
continue
self._sent_tokenizer_stream.push_text(data)
self._sent_tokenizer_stream.end_input()
sent_tokenizer_stream.push_text(data)
sent_tokenizer_stream.end_input()

async def _sentence_stream_task() -> None:
flush_on_chunk = (
Expand All @@ -411,7 +416,7 @@ async def _sentence_stream_task() -> None:
and self._opts.auto_mode
)
xml_content: list[str] = []
async for data in self._sent_tokenizer_stream:
async for data in sent_tokenizer_stream:
text = data.token
# send xml tags fully formed
xml_start_tokens = ["<phoneme", "<break"]
Expand Down Expand Up @@ -455,12 +460,13 @@ async def _sentence_stream_task() -> None:
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except Exception as e:
if isinstance(e, APIStatusError):
if isinstance(e, (APIStatusError, APIError, APIConnectionError, APITimeoutError)):
raise e
raise APIStatusError("Could not synthesize") from e
finally:
output_emitter.end_segment()
await utils.aio.gracefully_cancel(input_t, stream_t)
await sent_tokenizer_stream.aclose()


@dataclass
Expand Down Expand Up @@ -504,6 +510,8 @@ class _StreamData:
stream: SynthesizeStream
waiter: asyncio.Future[None]
timeout_timer: asyncio.TimerHandle | None = None
received_audio: bool = False
sent_text: bool = False


class _Connection:
Expand Down Expand Up @@ -531,6 +539,10 @@ def voice_id(self) -> str:
def is_current(self) -> bool:
return self._is_current

@property
def is_open(self) -> bool:
return self._ws is not None and not self._ws.closed and not self._closed

def mark_non_current(self) -> None:
"""Mark this connection as no longer current - it will shut down when drained"""
self._is_current = False
Expand Down Expand Up @@ -560,6 +572,10 @@ def send_content(self, content: _SynthesizeContent) -> None:
"""Send synthesis content to the connection"""
if self._closed or not self._ws or self._ws.closed:
raise APIConnectionError("WebSocket connection is closed")
if content.text.strip():
ctx = self._context_data.get(content.context_id)
if ctx:
ctx.sent_text = True
self._input_queue.send_nowait(content)

def close_context(self, context_id: str) -> None:
Expand Down Expand Up @@ -709,10 +725,25 @@ async def _recv_loop(self) -> None:
if data.get("audio"):
b64data = base64.b64decode(data["audio"])
emitter.push(b64data)
ctx.received_audio = True
if ctx.timeout_timer:
ctx.timeout_timer.cancel()

if data.get("isFinal"):
if not ctx.received_audio and ctx.sent_text and not ctx.waiter.done():
# ElevenLabs sometimes returns `isFinal` with an empty `audio` payload.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a valid case that elevenlabs returns final without audio, like when the pushed text is empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, empty/whitespace input returns isFinal with audio: null. I added a sent_text guard so we only error when real text was sent

# Empty input can return isFinal without audio, so only treat
# it as a retryable failure when we actually sent text.
ctx.waiter.set_exception(
APIError("11labs stream ended without audio", retryable=True)
)
self.mark_non_current()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why mark_non_current is needed when one of the generation failed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just ensures the next attempt gets a fresh connection

self._cleanup_context(context_id)
if not self._is_current and not self._active_contexts:
logger.debug("no active contexts, shutting down connection")
break
continue

if stream is not None:
timed_words, _ = _to_timed_words(
stream._text_buffer,
Expand Down
Loading