Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jayeshp19 committed Dec 25, 2024
1 parent baae79b commit 47818f1
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 164 deletions.
38 changes: 32 additions & 6 deletions livekit-agents/livekit/agents/tts/stream_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def __init__(
) -> None:
super().__init__(tts=tts, conn_options=conn_options)
self._wrapped_tts = wrapped_tts
self._sent_stream = sentence_tokenizer.stream()
self._sent_stream = sentence_tokenizer
self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]()

async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SynthesizedAudio]
Expand All @@ -85,10 +86,35 @@ async def _forward_input():

self._sent_stream.end_input()

async def _synthesize():
async for ev in self._sent_stream:
async def _tokenize_input():
"""tokenize text"""
input_stream = None
async for input in self._input_ch:
if isinstance(input, str):
if input_stream is None:
# new segment (after flush for e.g)
input_stream = self._sent_stream.stream()
self._segments_ch.send_nowait(input_stream)

input_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if input_stream is not None:
input_stream.end_input()

input_stream = None
self._segments_ch.close()

async def _run_segments():
async for input_stream in self._segments_ch:
await _synthesize(input_stream)

async def _synthesize(input_stream):
async for ev in input_stream:
print("ev: ", ev)
last_audio: SynthesizedAudio | None = None
async for audio in self._wrapped_tts.synthesize(ev.token):
async for audio in self._wrapped_tts.synthesize(
ev.token, segment_id=ev.segment_id
):
if last_audio is not None:
self._event_ch.send_nowait(last_audio)

Expand All @@ -99,8 +125,8 @@ async def _synthesize():
self._event_ch.send_nowait(last_audio)

tasks = [
asyncio.create_task(_forward_input()),
asyncio.create_task(_synthesize()),
asyncio.create_task(_tokenize_input()),
asyncio.create_task(_run_segments()),
]
try:
await asyncio.gather(*tasks)
Expand Down
1 change: 1 addition & 0 deletions livekit-agents/livekit/agents/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def synthesize(
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
segment_id: str = "",
) -> ChunkedStream: ...

def stream(
Expand Down
15 changes: 13 additions & 2 deletions livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,14 @@ def synthesize(
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
segment_id: str = "",
) -> "ChunkedStream":
return ChunkedStream(
tts=self, input_text=text, conn_options=conn_options, opts=self._opts
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
segment_id=segment_id,
)


Expand All @@ -217,14 +222,16 @@ def __init__(
input_text: str,
conn_options: APIConnectOptions,
opts: _TTSOptions,
segment_id: str = "",
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts = opts
self._segment_id = segment_id

async def _run(self):
stream_callback = speechsdk.audio.PushAudioOutputStream(
_PushAudioOutputStreamCallback(
self._opts, asyncio.get_running_loop(), self._event_ch
self._opts, asyncio.get_running_loop(), self._event_ch, self._segment_id
)
)
synthesizer = _create_speech_synthesizer(
Expand Down Expand Up @@ -289,12 +296,14 @@ def __init__(
opts: _TTSOptions,
loop: asyncio.AbstractEventLoop,
event_ch: utils.aio.ChanSender[tts.SynthesizedAudio],
segment_id: str,
):
super().__init__()
self._event_ch = event_ch
self._opts = opts
self._loop = loop
self._request_id = utils.shortuuid()
self._segment_id = segment_id

self._bstream = utils.audio.AudioByteStream(
sample_rate=opts.sample_rate, num_channels=1
Expand All @@ -305,6 +314,7 @@ def write(self, audio_buffer: memoryview) -> int:
audio = tts.SynthesizedAudio(
request_id=self._request_id,
frame=frame,
segment_id=self._segment_id,
)
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(self._event_ch.send_nowait, audio)
Expand All @@ -316,6 +326,7 @@ def close(self) -> None:
audio = tts.SynthesizedAudio(
request_id=self._request_id,
frame=frame,
segment_id=self._segment_id,
)
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(self._event_ch.send_nowait, audio)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,35 +253,81 @@ def __init__(
):
super().__init__(tts=tts, conn_options=conn_options)
self._opts, self._session = opts, session
self._sent_tokenizer_stream = tokenize.basic.SentenceTokenizer(
self._sent_tokenizer = tokenize.basic.SentenceTokenizer(
min_sentence_len=BUFFERED_WORDS_COUNT
).stream()
)

async def _run(self) -> None:
self._closing_input = False
self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]()

@utils.log_exceptions(logger=logger)
async def _tokenize_input():
"""tokenize text from the input_ch to words"""
input_stream = None
async for input in self._input_ch:
if isinstance(input, str):
if input_stream is None:
# new segment (after flush for e.g)
input_stream = self._sent_tokenizer.stream()
self._segments_ch.send_nowait(input_stream)

input_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if input_stream is not None:
input_stream.end_input()

input_stream = None
self._segments_ch.close()

@utils.log_exceptions(logger=logger)
async def _run_segments(ws: aiohttp.ClientWebSocketResponse):
async for input_stream in self._segments_ch:
await self._run_ws(input_stream, ws)
self._closing_input = True

url = f"wss://api.cartesia.ai/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}"

ws: aiohttp.ClientWebSocketResponse | None = None

try:
ws = await asyncio.wait_for(
self._session.ws_connect(url), self._conn_options.timeout
)
tasks = [
asyncio.create_task(_tokenize_input()),
asyncio.create_task(_run_segments(ws)),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
finally:
if ws is not None:
await ws.close()

async def _run_ws(
self,
input_stream: tokenize.SentenceStream,
ws: aiohttp.ClientWebSocketResponse,
) -> None:
request_id = utils.shortuuid()

async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse):
base_pkt = _to_cartesia_options(self._opts)
async for ev in self._sent_tokenizer_stream:
async for ev in input_stream:
token_pkt = base_pkt.copy()
token_pkt["context_id"] = request_id
token_pkt["transcript"] = ev.token + " "
token_pkt["continue"] = True
await ws.send_str(json.dumps(token_pkt))

end_pkt = base_pkt.copy()
end_pkt["context_id"] = request_id
end_pkt["transcript"] = " "
end_pkt["continue"] = False
await ws.send_str(json.dumps(end_pkt))

async def _input_task():
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
self._sent_tokenizer_stream.flush()
continue
self._sent_tokenizer_stream.push_text(data)
self._sent_tokenizer_stream.end_input()
if self._closing_input:
end_pkt = base_pkt.copy()
end_pkt["context_id"] = request_id
end_pkt["transcript"] = " "
end_pkt["continue"] = False
await ws.send_str(json.dumps(end_pkt))

async def _recv_task(ws: aiohttp.ClientWebSocketResponse):
audio_bstream = utils.audio.AudioByteStream(
Expand Down Expand Up @@ -312,7 +358,7 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
raise Exception("Cartesia connection closed unexpectedly")
raise APIStatusError("Cartesia connection closed unexpectedly")

if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("unexpected Cartesia message type %s", msg.type)
Expand All @@ -334,34 +380,19 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
_send_last_frame(segment_id=segment_id, is_final=True)

if segment_id == request_id:
# we're not going to receive more frames, close the connection
await ws.close()
break
else:
logger.error("unexpected Cartesia message %s", data)

url = f"wss://api.cartesia.ai/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}"

ws: aiohttp.ClientWebSocketResponse | None = None
tasks = [
asyncio.create_task(_sentence_stream_task(ws)),
asyncio.create_task(_recv_task(ws)),
]

try:
ws = await asyncio.wait_for(
self._session.ws_connect(url), self._conn_options.timeout
)

tasks = [
asyncio.create_task(_input_task()),
asyncio.create_task(_sentence_stream_task(ws)),
asyncio.create_task(_recv_task(ws)),
]

try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
await asyncio.gather(*tasks)
finally:
if ws is not None:
await ws.close()
await utils.aio.gracefully_cancel(*tasks)


def _to_cartesia_options(opts: _TTSOptions) -> dict[str, Any]:
Expand Down
Loading

0 comments on commit 47818f1

Please sign in to comment.