Skip to content

Commit

Permalink
Merge branch 'multisegments-tts' of https://github.com/livekit/agents
Browse files Browse the repository at this point in the history
…into multisegments-tts
  • Loading branch information
jayeshp19 committed Dec 25, 2024
2 parents baae79b + 8ab274e commit 6aa7514
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 152 deletions.
39 changes: 14 additions & 25 deletions livekit-agents/livekit/agents/tts/stream_adapter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import asyncio
from typing import AsyncIterable

from .. import tokenize, utils
from .. import tokenize
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from .tts import (
TTS,
Expand Down Expand Up @@ -65,30 +64,27 @@ def __init__(
wrapped_tts: TTS,
sentence_tokenizer: tokenize.SentenceTokenizer,
) -> None:
super().__init__(tts=tts, conn_options=conn_options)
super().__init__(
tts=tts,
conn_options=conn_options,
tokenizer=sentence_tokenizer,
)
self._wrapped_tts = wrapped_tts
self._sent_stream = sentence_tokenizer.stream()

async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SynthesizedAudio]
) -> None:
pass # do nothing

async def _run(self) -> None:
async def _forward_input():
"""forward input to vad"""
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
self._sent_stream.flush()
continue
self._sent_stream.push_text(data)

self._sent_stream.end_input()

async def _run(
self, input_stream: tokenize.WordStream | tokenize.SentenceStream
) -> None:
async def _synthesize():
async for ev in self._sent_stream:
async for ev in input_stream:
last_audio: SynthesizedAudio | None = None
async for audio in self._wrapped_tts.synthesize(ev.token):
async for audio in self._wrapped_tts.synthesize(
ev.token, segment_id=ev.segment_id
):
if last_audio is not None:
self._event_ch.send_nowait(last_audio)

Expand All @@ -98,11 +94,4 @@ async def _synthesize():
last_audio.is_final = True
self._event_ch.send_nowait(last_audio)

tasks = [
asyncio.create_task(_forward_input()),
asyncio.create_task(_synthesize()),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
await _synthesize()
55 changes: 52 additions & 3 deletions livekit-agents/livekit/agents/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from livekit import rtc

from .. import tokenize, utils
from .._exceptions import APIConnectionError, APIError
from ..log import logger
from ..metrics import TTSMetrics
Expand Down Expand Up @@ -75,6 +76,7 @@ def synthesize(
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
segment_id: str | None = None,
) -> ChunkedStream: ...

def stream(
Expand Down Expand Up @@ -234,9 +236,16 @@ async def __aexit__(
class SynthesizeStream(ABC):
class _FlushSentinel: ...

def __init__(self, *, tts: TTS, conn_options: APIConnectOptions) -> None:
def __init__(
self,
*,
tts: TTS,
conn_options: APIConnectOptions,
tokenizer: tokenize.WordTokenizer | tokenize.SentenceTokenizer,
) -> None:
super().__init__()
self._tts = tts
self._tokenizer = tokenizer
self._conn_options = conn_options
self._input_ch = aio.Chan[Union[str, SynthesizeStream._FlushSentinel]]()
self._event_ch = aio.Chan[SynthesizedAudio]()
Expand All @@ -251,12 +260,52 @@ def __init__(self, *, tts: TTS, conn_options: APIConnectOptions) -> None:
self._mtc_text = ""

@abstractmethod
async def _run(self) -> None: ...
async def _run(
self, input_stream: tokenize.WordStream | tokenize.SentenceStream
) -> None: ...

async def _main_task(self) -> None:
if isinstance(self._tokenizer, tokenize.SentenceTokenizer):
self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]()
elif isinstance(self._tokenizer, tokenize.WordTokenizer):
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()

@utils.log_exceptions(logger=logger)
async def _tokenize_input():
"""tokenize text from the input_ch to words"""
input_stream = None
async for input in self._input_ch:
if isinstance(input, str):
if input_stream is None:
# new segment (after flush for e.g)
input_stream = self._tokenizer.stream()
self._segments_ch.send_nowait(input_stream)

input_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if input_stream is not None:
input_stream.end_input()

input_stream = None

self._segments_ch.close()

@utils.log_exceptions(logger=logger)
async def _run_segments():
async for input_stream in self._segments_ch:
await self._run(input_stream)

for i in range(self._conn_options.max_retry + 1):
try:
return await self._run()
tasks = [
asyncio.create_task(_tokenize_input()),
asyncio.create_task(_run_segments()),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
return
except APIError as e:
if self._conn_options.max_retry == 0:
raise
Expand Down
17 changes: 15 additions & 2 deletions livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,14 @@ def synthesize(
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
segment_id: str | None = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self, input_text=text, conn_options=conn_options, opts=self._opts
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
segment_id=segment_id,
)


Expand All @@ -217,14 +222,18 @@ def __init__(
input_text: str,
conn_options: APIConnectOptions,
opts: _TTSOptions,
segment_id: str | None = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts = opts
self._segment_id = segment_id
if self._segment_id is None:
self._segment_id = utils.shortuuid()

async def _run(self):
stream_callback = speechsdk.audio.PushAudioOutputStream(
_PushAudioOutputStreamCallback(
self._opts, asyncio.get_running_loop(), self._event_ch
self._opts, asyncio.get_running_loop(), self._event_ch, self._segment_id
)
)
synthesizer = _create_speech_synthesizer(
Expand Down Expand Up @@ -289,12 +298,14 @@ def __init__(
opts: _TTSOptions,
loop: asyncio.AbstractEventLoop,
event_ch: utils.aio.ChanSender[tts.SynthesizedAudio],
segment_id: str,
):
super().__init__()
self._event_ch = event_ch
self._opts = opts
self._loop = loop
self._request_id = utils.shortuuid()
self._segment_id = segment_id

self._bstream = utils.audio.AudioByteStream(
sample_rate=opts.sample_rate, num_channels=1
Expand All @@ -304,6 +315,7 @@ def write(self, audio_buffer: memoryview) -> int:
for frame in self._bstream.write(audio_buffer.tobytes()):
audio = tts.SynthesizedAudio(
request_id=self._request_id,
segment_id=self._segment_id,
frame=frame,
)
with contextlib.suppress(RuntimeError):
Expand All @@ -315,6 +327,7 @@ def close(self) -> None:
for frame in self._bstream.flush():
audio = tts.SynthesizedAudio(
request_id=self._request_id,
segment_id=self._segment_id,
frame=frame,
)
with contextlib.suppress(RuntimeError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@

NUM_CHANNELS = 1
BUFFERED_WORDS_COUNT = 8
WSS_URL = "wss://api.cartesia.ai/tts/websocket"
BYTES_URL = "https://api.cartesia.ai/tts/bytes"


@dataclass
Expand All @@ -61,6 +63,7 @@ class _TTSOptions:
emotion: list[TTSVoiceEmotion | str] | None
api_key: str
language: str
word_tokenizer: tokenize.WordTokenizer


class TTS(tts.TTS):
Expand All @@ -76,6 +79,9 @@ def __init__(
sample_rate: int = 24000,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(
ignore_punctuation=False,
),
) -> None:
"""
Create a new instance of Cartesia TTS.
Expand All @@ -92,6 +98,7 @@ def __init__(
sample_rate (int, optional): The audio sample rate in Hz. Defaults to 24000.
api_key (str, optional): The Cartesia API key. If not provided, it will be read from the CARTESIA_API_KEY environment variable.
http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
word_tokenizer (tokenize.WordTokenizer, optional): The word tokenizer to use. Defaults to a basic tokenizer.
"""

super().__init__(
Expand All @@ -113,6 +120,7 @@ def __init__(
speed=speed,
emotion=emotion,
api_key=api_key,
word_tokenizer=word_tokenizer,
)
self._session = http_session

Expand Down Expand Up @@ -207,7 +215,7 @@ async def _run(self) -> None:

try:
async with self._session.post(
"https://api.cartesia.ai/tts/bytes",
BYTES_URL,
headers=headers,
json=json,
timeout=aiohttp.ClientTimeout(
Expand Down Expand Up @@ -251,18 +259,24 @@ def __init__(
opts: _TTSOptions,
session: aiohttp.ClientSession,
):
super().__init__(tts=tts, conn_options=conn_options)
super().__init__(
tts=tts,
conn_options=conn_options,
tokenizer=opts.word_tokenizer,
)
self._opts, self._session = opts, session
self._sent_tokenizer_stream = tokenize.basic.SentenceTokenizer(
min_sentence_len=BUFFERED_WORDS_COUNT
).stream()

async def _run(self) -> None:
async def _run(
self,
input_stream: tokenize.WordStream,
max_retry: int = 3,
) -> None:
request_id = utils.shortuuid()

async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal close_ws
base_pkt = _to_cartesia_options(self._opts)
async for ev in self._sent_tokenizer_stream:
async for ev in input_stream:
token_pkt = base_pkt.copy()
token_pkt["context_id"] = request_id
token_pkt["transcript"] = ev.token + " "
Expand All @@ -273,17 +287,11 @@ async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse):
end_pkt["context_id"] = request_id
end_pkt["transcript"] = " "
end_pkt["continue"] = False
close_ws = True
await ws.send_str(json.dumps(end_pkt))

async def _input_task():
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
self._sent_tokenizer_stream.flush()
continue
self._sent_tokenizer_stream.push_text(data)
self._sent_tokenizer_stream.end_input()

async def _recv_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal close_ws
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate,
num_channels=NUM_CHANNELS,
Expand Down Expand Up @@ -312,7 +320,9 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
raise Exception("Cartesia connection closed unexpectedly")
if not close_ws:
raise Exception("Cartesia connection closed unexpectedly")
return

if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("unexpected Cartesia message type %s", msg.type)
Expand Down Expand Up @@ -340,17 +350,18 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
else:
logger.error("unexpected Cartesia message %s", data)

url = f"wss://api.cartesia.ai/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}"

ws: aiohttp.ClientWebSocketResponse | None = None

try:
ws = await asyncio.wait_for(
self._session.ws_connect(url), self._conn_options.timeout
self._session.ws_connect(
_to_cartesia_url(api_key=self._opts.api_key, version=API_VERSION)
),
self._conn_options.timeout,
)
close_ws = False

tasks = [
asyncio.create_task(_input_task()),
asyncio.create_task(_sentence_stream_task(ws)),
asyncio.create_task(_recv_task(ws)),
]
Expand All @@ -364,6 +375,10 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
await ws.close()


def _to_cartesia_url(api_key: str, version: str) -> str:
return f"{WSS_URL}?api_key={api_key}&cartesia_version={version}"


def _to_cartesia_options(opts: _TTSOptions) -> dict[str, Any]:
voice: dict[str, Any] = {}
if isinstance(opts.voice, str):
Expand Down
Loading

0 comments on commit 6aa7514

Please sign in to comment.