diff --git a/livekit-agents/livekit/agents/tts/stream_adapter.py b/livekit-agents/livekit/agents/tts/stream_adapter.py index fbb25df5d..6ee070263 100644 --- a/livekit-agents/livekit/agents/tts/stream_adapter.py +++ b/livekit-agents/livekit/agents/tts/stream_adapter.py @@ -67,7 +67,8 @@ def __init__( ) -> None: super().__init__(tts=tts, conn_options=conn_options) self._wrapped_tts = wrapped_tts - self._sent_stream = sentence_tokenizer.stream() + self._sent_stream = sentence_tokenizer + self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]() async def _metrics_monitor_task( self, event_aiter: AsyncIterable[SynthesizedAudio] @@ -85,10 +86,35 @@ async def _forward_input(): self._sent_stream.end_input() - async def _synthesize(): - async for ev in self._sent_stream: + async def _tokenize_input(): + """tokenize text""" + input_stream = None + async for input in self._input_ch: + if isinstance(input, str): + if input_stream is None: + # new segment (after flush for e.g) + input_stream = self._sent_stream.stream() + self._segments_ch.send_nowait(input_stream) + + input_stream.push_text(input) + elif isinstance(input, self._FlushSentinel): + if input_stream is not None: + input_stream.end_input() + + input_stream = None + self._segments_ch.close() + + async def _run_segments(): + async for input_stream in self._segments_ch: + await _synthesize(input_stream) + + async def _synthesize(input_stream): + async for ev in input_stream: + print("ev: ", ev) last_audio: SynthesizedAudio | None = None - async for audio in self._wrapped_tts.synthesize(ev.token): + async for audio in self._wrapped_tts.synthesize( + ev.token, segment_id=ev.segment_id + ): if last_audio is not None: self._event_ch.send_nowait(last_audio) @@ -99,8 +125,8 @@ async def _synthesize(): self._event_ch.send_nowait(last_audio) tasks = [ - asyncio.create_task(_forward_input()), - asyncio.create_task(_synthesize()), + asyncio.create_task(_tokenize_input()), + asyncio.create_task(_run_segments()), ] try: await asyncio.gather(*tasks) diff --git a/livekit-agents/livekit/agents/tts/tts.py b/livekit-agents/livekit/agents/tts/tts.py index e641bf39d..8c7cd752d 100644 --- a/livekit-agents/livekit/agents/tts/tts.py +++ b/livekit-agents/livekit/agents/tts/tts.py @@ -75,6 +75,7 @@ def synthesize( text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + segment_id: str = "", ) -> ChunkedStream: ... def stream( diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py index 155d2c091..81b48f9a0 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py @@ -203,9 +203,14 @@ def synthesize( text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + segment_id: str = "", ) -> "ChunkedStream": return ChunkedStream( - tts=self, input_text=text, conn_options=conn_options, opts=self._opts + tts=self, + input_text=text, + conn_options=conn_options, + opts=self._opts, + segment_id=segment_id, ) @@ -217,14 +222,16 @@ def __init__( input_text: str, conn_options: APIConnectOptions, opts: _TTSOptions, + segment_id: str = "", ) -> None: super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) self._opts = opts + self._segment_id = segment_id async def _run(self): stream_callback = speechsdk.audio.PushAudioOutputStream( _PushAudioOutputStreamCallback( - self._opts, asyncio.get_running_loop(), self._event_ch + self._opts, asyncio.get_running_loop(), self._event_ch, self._segment_id ) ) synthesizer = _create_speech_synthesizer( @@ -289,12 +296,14 @@ def __init__( opts: _TTSOptions, loop: asyncio.AbstractEventLoop, event_ch: utils.aio.ChanSender[tts.SynthesizedAudio], + segment_id: str, ): super().__init__() self._event_ch = event_ch self._opts = opts self._loop = loop self._request_id = utils.shortuuid() + self._segment_id = segment_id self._bstream = utils.audio.AudioByteStream( sample_rate=opts.sample_rate, num_channels=1 @@ -305,6 +314,7 @@ def write(self, audio_buffer: memoryview) -> int: audio = tts.SynthesizedAudio( request_id=self._request_id, frame=frame, + segment_id=self._segment_id, ) with contextlib.suppress(RuntimeError): self._loop.call_soon_threadsafe(self._event_ch.send_nowait, audio) @@ -316,6 +326,7 @@ def close(self) -> None: audio = tts.SynthesizedAudio( request_id=self._request_id, frame=frame, + segment_id=self._segment_id, ) with contextlib.suppress(RuntimeError): self._loop.call_soon_threadsafe(self._event_ch.send_nowait, audio) diff --git a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py index dd76473c7..4e5a5219f 100644 --- a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py +++ b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py @@ -253,35 +253,81 @@ def __init__( ): super().__init__(tts=tts, conn_options=conn_options) self._opts, self._session = opts, session - self._sent_tokenizer_stream = tokenize.basic.SentenceTokenizer( + self._sent_tokenizer = tokenize.basic.SentenceTokenizer( min_sentence_len=BUFFERED_WORDS_COUNT - ).stream() + ) async def _run(self) -> None: + self._closing_input = False + self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]() + + @utils.log_exceptions(logger=logger) + async def _tokenize_input(): + """tokenize text from the input_ch to words""" + input_stream = None + async for input in self._input_ch: + if isinstance(input, str): + if input_stream is None: + # new segment (after flush for e.g) + input_stream = self._sent_tokenizer.stream() + self._segments_ch.send_nowait(input_stream) + + input_stream.push_text(input) + elif isinstance(input, self._FlushSentinel): + if input_stream is not None: + input_stream.end_input() + + input_stream = None + self._segments_ch.close() + + @utils.log_exceptions(logger=logger) + async def _run_segments(ws: aiohttp.ClientWebSocketResponse): + async for input_stream in self._segments_ch: + await self._run_ws(input_stream, ws) + self._closing_input = True + + url = f"wss://api.cartesia.ai/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}" + + ws: aiohttp.ClientWebSocketResponse | None = None + + try: + ws = await asyncio.wait_for( + self._session.ws_connect(url), self._conn_options.timeout + ) + tasks = [ + asyncio.create_task(_tokenize_input()), + asyncio.create_task(_run_segments(ws)), + ] + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) + finally: + if ws is not None: + await ws.close() + + async def _run_ws( + self, + input_stream: tokenize.SentenceStream, + ws: aiohttp.ClientWebSocketResponse, + ) -> None: request_id = utils.shortuuid() async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse): base_pkt = _to_cartesia_options(self._opts) - async for ev in self._sent_tokenizer_stream: + async for ev in input_stream: token_pkt = base_pkt.copy() token_pkt["context_id"] = request_id token_pkt["transcript"] = ev.token + " " token_pkt["continue"] = True await ws.send_str(json.dumps(token_pkt)) - end_pkt = base_pkt.copy() - end_pkt["context_id"] = request_id - end_pkt["transcript"] = " " - end_pkt["continue"] = False - await ws.send_str(json.dumps(end_pkt)) - - async def _input_task(): - async for data in self._input_ch: - if isinstance(data, self._FlushSentinel): - self._sent_tokenizer_stream.flush() - continue - self._sent_tokenizer_stream.push_text(data) - self._sent_tokenizer_stream.end_input() + if self._closing_input: + end_pkt = base_pkt.copy() + end_pkt["context_id"] = request_id + end_pkt["transcript"] = " " + end_pkt["continue"] = False + await ws.send_str(json.dumps(end_pkt)) async def _recv_task(ws: aiohttp.ClientWebSocketResponse): audio_bstream = utils.audio.AudioByteStream( @@ -312,7 +358,7 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, ): - raise Exception("Cartesia connection closed unexpectedly") + raise APIStatusError("Cartesia connection closed unexpectedly") if msg.type != aiohttp.WSMsgType.TEXT: logger.warning("unexpected Cartesia message type %s", msg.type) @@ -334,34 +380,19 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: _send_last_frame(segment_id=segment_id, is_final=True) if segment_id == request_id: - # we're not going to receive more frames, close the connection - await ws.close() break else: logger.error("unexpected Cartesia message %s", data) - url = f"wss://api.cartesia.ai/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}" - - ws: aiohttp.ClientWebSocketResponse | None = None + tasks = [ + asyncio.create_task(_sentence_stream_task(ws)), + asyncio.create_task(_recv_task(ws)), + ] try: - ws = await asyncio.wait_for( - self._session.ws_connect(url), self._conn_options.timeout - ) - - tasks = [ - asyncio.create_task(_input_task()), - asyncio.create_task(_sentence_stream_task(ws)), - asyncio.create_task(_recv_task(ws)), - ] - - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) + await asyncio.gather(*tasks) finally: - if ws is not None: - await ws.close() + await utils.aio.gracefully_cancel(*tasks) def _to_cartesia_options(opts: _TTSOptions) -> dict[str, Any]: diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py index 56d7405a7..5634c149f 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py @@ -249,49 +249,115 @@ def update_options( self._reconnect_event.set() async def _run(self) -> None: - closing_ws = False - request_id = utils.shortuuid() - segment_id = utils.shortuuid() - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=NUM_CHANNELS, - ) + self._closing_input = False @utils.log_exceptions(logger=logger) async def _tokenize_input(): # Converts incoming text into WordStreams and sends them into _segments_ch - word_stream = None + input_stream = None async for input in self._input_ch: if isinstance(input, str): - if word_stream is None: - word_stream = self._opts.word_tokenizer.stream() - self._segments_ch.send_nowait(word_stream) - word_stream.push_text(input) + if input_stream is None: + input_stream = self._opts.word_tokenizer.stream() + self._segments_ch.send_nowait(input_stream) + input_stream.push_text(input) elif isinstance(input, self._FlushSentinel): - if word_stream: - word_stream.end_input() - word_stream = None + if input_stream: + input_stream.end_input() + input_stream = None self._segments_ch.close() @utils.log_exceptions(logger=logger) async def _run_segments(ws: aiohttp.ClientWebSocketResponse): - nonlocal closing_ws - async for word_stream in self._segments_ch: - async for word in word_stream: - speak_msg = {"type": "Speak", "text": f"{word.token} "} - await ws.send_str(json.dumps(speak_msg)) + async for input_stream in self._segments_ch: + segment_id = utils.shortuuid() + await self._run_ws(input_stream, segment_id, ws) + self._closing_input = True + + async def _connection_timeout(): + # Deepgram has a 60-minute timeout period for websocket connections + await asyncio.sleep(3300) + logger.warning( + "Deepgram TTS maximum connection time reached. Reconnecting..." + ) + self._reconnect_event.set() + + ws: aiohttp.ClientWebSocketResponse | None = None + while True: + try: + config = { + "encoding": self._opts.encoding, + "model": self._opts.model, + "sample_rate": self._opts.sample_rate, + } + ws = await asyncio.wait_for( + self._session.ws_connect( + _to_deepgram_url(config, self._base_url, websocket=True), + headers={"Authorization": f"Token {self._api_key}"}, + ), + self._conn_options.timeout, + ) + + tasks = [ + asyncio.create_task(_tokenize_input()), + asyncio.create_task(_run_segments(ws)), + ] + wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + connection_timeout_task = asyncio.create_task(_connection_timeout()) + try: + done, pending = await asyncio.wait( + [ + asyncio.gather(*tasks), + wait_reconnect_task, + connection_timeout_task, + ], + return_when=asyncio.FIRST_COMPLETED, + ) # type: ignore + + # propagate exceptions from completed tasks + for task in done: + if task != wait_reconnect_task: + task.result() + + if wait_reconnect_task not in done: + break + self._reconnect_event.clear() + finally: + await utils.aio.gracefully_cancel( + *tasks, wait_reconnect_task, connection_timeout_task + ) + finally: + if ws is not None and not ws.closed: + await ws.close() - # Always flush after a segment - flush_msg = {"type": "Flush"} - await ws.send_str(json.dumps(flush_msg)) + async def _run_ws( + self, + input_stream: tokenize.WordStream, + segment_id: str, + ws: aiohttp.ClientWebSocketResponse, + ) -> None: + request_id = utils.shortuuid() + + async def send_task( + ws: aiohttp.ClientWebSocketResponse, + flush_after_words: int = 30, + ): + async for word in input_stream: + speak_msg = {"type": "Speak", "text": f"{word.token} "} + await ws.send_str(json.dumps(speak_msg)) - # after all segments, close - close_msg = {"type": "Close"} - closing_ws = True - await ws.send_str(json.dumps(close_msg)) + flush_msg = {"type": "Flush"} + await ws.send_str(json.dumps(flush_msg)) + if self._closing_input: + close_msg = {"type": "Close"} + await ws.send_str(json.dumps(close_msg)) async def recv_task(ws: aiohttp.ClientWebSocketResponse): last_frame: rtc.AudioFrame | None = None + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=NUM_CHANNELS, + ) def _send_last_frame(*, segment_id: str, is_final: bool) -> None: nonlocal last_frame @@ -313,8 +379,8 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, ): - if not closing_ws: - raise Exception( + if not self._closing_input: + raise APIStatusError( "Deepgram websocket connection closed unexpectedly" ) return @@ -332,6 +398,7 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: _send_last_frame(segment_id=segment_id, is_final=False) last_frame = frame _send_last_frame(segment_id=segment_id, is_final=True) + break elif mtype == "Warning": logger.warning("Deepgram warning: %s", resp.get("warn_msg")) elif mtype == "Metadata": @@ -339,67 +406,15 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: else: logger.debug("Unknown message type: %s", resp) - async def _connection_timeout(): - # Deepgram has a 60-minute timeout period for websocket connections - await asyncio.sleep(3300) - logger.warning( - "Deepgram TTS maximum connection time reached. Reconnecting..." - ) - self._reconnect_event.set() - - ws: aiohttp.ClientWebSocketResponse | None = None - while True: - try: - config = { - "encoding": self._opts.encoding, - "model": self._opts.model, - "sample_rate": self._opts.sample_rate, - } - ws = await asyncio.wait_for( - self._session.ws_connect( - _to_deepgram_url(config, self._base_url, websocket=True), - headers={"Authorization": f"Token {self._api_key}"}, - ), - self._conn_options.timeout, - ) - closing_ws = False - - tasks = [ - asyncio.create_task(_tokenize_input()), - asyncio.create_task(_run_segments(ws)), - asyncio.create_task(recv_task(ws)), - ] - wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) - connection_timeout_task = asyncio.create_task(_connection_timeout()) - - try: - done, _ = await asyncio.wait( - [ - asyncio.gather(*tasks), - wait_reconnect_task, - connection_timeout_task, - ], - return_when=asyncio.FIRST_COMPLETED, - ) # type: ignore - if wait_reconnect_task not in done: - break - self._reconnect_event.clear() - finally: - await utils.aio.gracefully_cancel( - *tasks, wait_reconnect_task, connection_timeout_task - ) + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + ] - except asyncio.TimeoutError as e: - raise APITimeoutError() from e - except aiohttp.ClientResponseError as e: - raise APIStatusError( - message=e.message, status_code=e.status, request_id=None, body=None - ) from e - except Exception as e: - raise APIConnectionError() from e - finally: - if ws is not None and not ws.closed: - await ws.close() + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) def _to_deepgram_url( diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py index 0c5490707..be94885a7 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py @@ -469,7 +469,7 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: aiohttp.WSMsgType.CLOSING, ): if not eos_sent: - raise Exception( + raise APIStatusError( "11labs connection closed unexpectedly, not all tokens have been consumed" ) return diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/tts.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/tts.py index 3ab494a66..57b44d916 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/tts.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/tts.py @@ -161,6 +161,7 @@ def synthesize( text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + segment_id: str = "", ) -> "ChunkedStream": return ChunkedStream( tts=self, @@ -168,6 +169,7 @@ def synthesize( conn_options=conn_options, opts=self._opts, client=self._ensure_client(), + segment_id=segment_id, ) @@ -180,9 +182,11 @@ def __init__( conn_options: APIConnectOptions, opts: _TTSOptions, client: texttospeech.TextToSpeechAsyncClient, + segment_id: str = "", ) -> None: super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) self._opts, self._client = opts, client + self._segment_id = segment_id async def _run(self) -> None: request_id = utils.shortuuid() @@ -204,18 +208,27 @@ async def _run(self) -> None: for frame in decoder.decode_chunk(response.audio_content): for frame in bstream.write(frame.data.tobytes()): self._event_ch.send_nowait( - tts.SynthesizedAudio(request_id=request_id, frame=frame) + tts.SynthesizedAudio( + request_id=request_id, + segment_id=self._segment_id, + frame=frame, + ) ) for frame in bstream.flush(): self._event_ch.send_nowait( - tts.SynthesizedAudio(request_id=request_id, frame=frame) + tts.SynthesizedAudio( + request_id=request_id, + segment_id=self._segment_id, + frame=frame, + ) ) else: data = response.audio_content[44:] # skip WAV header self._event_ch.send_nowait( tts.SynthesizedAudio( request_id=request_id, + segment_id=self._segment_id, frame=rtc.AudioFrame( data=data, sample_rate=self._opts.audio_config.sample_rate_hertz, diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py index ce7741eb8..878b1135a 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py @@ -143,6 +143,7 @@ def synthesize( text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + segment_id: str = "", ) -> "ChunkedStream": return ChunkedStream( tts=self, @@ -150,6 +151,7 @@ def synthesize( conn_options=conn_options, opts=self._opts, client=self._client, + segment_id=segment_id, ) @@ -162,10 +164,12 @@ def __init__( conn_options: APIConnectOptions, opts: _TTSOptions, client: openai.AsyncClient, + segment_id: str = "", ) -> None: super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) self._client = client self._opts = opts + self._segment_id = segment_id async def _run(self): oai_stream = self._client.audio.speech.with_streaming_response.create( @@ -191,6 +195,7 @@ async def _run(self): tts.SynthesizedAudio( frame=frame, request_id=request_id, + segment_id=self._segment_id, ) ) @@ -199,6 +204,7 @@ async def _run(self): tts.SynthesizedAudio( frame=frame, request_id=request_id, + segment_id=self._segment_id, ) ) diff --git a/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/tts.py b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/tts.py index 464f3f418..ae33bfafc 100644 --- a/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/tts.py +++ b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/tts.py @@ -207,23 +207,23 @@ def __init__( self._config = self._opts.tts_options self._segments_ch = utils.aio.Chan[tokenize.WordStream]() self._mp3_decoder = utils.codecs.Mp3StreamDecoder() + self._segment_id = utils.shortuuid() async def _run(self) -> None: request_id = utils.shortuuid() - segment_id = utils.shortuuid() bstream = utils.audio.AudioByteStream( sample_rate=self._config.sample_rate, num_channels=NUM_CHANNELS, ) last_frame: rtc.AudioFrame | None = None - def _send_last_frame(*, segment_id: str, is_final: bool) -> None: + def _send_last_frame(*, is_final: bool) -> None: nonlocal last_frame if last_frame is not None: self._event_ch.send_nowait( tts.SynthesizedAudio( request_id=request_id, - segment_id=segment_id, + segment_id=self._segment_id, frame=last_frame, is_final=is_final, ) @@ -240,13 +240,13 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: ): for frame in self._mp3_decoder.decode_chunk(chunk): for frame in bstream.write(frame.data.tobytes()): - _send_last_frame(segment_id=segment_id, is_final=False) + _send_last_frame(is_final=False) last_frame = frame for frame in bstream.flush(): - _send_last_frame(segment_id=segment_id, is_final=False) + _send_last_frame(is_final=False) last_frame = frame - _send_last_frame(segment_id=segment_id, is_final=True) + _send_last_frame(is_final=True) except Exception as e: raise APIConnectionError() from e finally: @@ -273,6 +273,7 @@ async def _tokenize_input(self): async def _create_text_stream(self): async def text_stream(): async for word_stream in self._segments_ch: + self._segment_id = utils.shortuuid() async for word in word_stream: yield word.token diff --git a/tests/test_tts.py b/tests/test_tts.py index 91f8035b5..c59ca7b6c 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -109,27 +109,36 @@ async def test_stream(tts_factory): synthesize_transcript = make_test_synthesize() - pattern = [1, 2, 4] - text = synthesize_transcript - chunks = [] - pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1)) - - for chunk_size in pattern_iter: - if not text: - break - chunks.append(text[:chunk_size]) - text = text[chunk_size:] + # Split the transcript into two segments + text_segments = [ + synthesize_transcript[: len(synthesize_transcript) // 2], + synthesize_transcript[len(synthesize_transcript) // 2 :], + ] stream = tts.stream() segments = set() - # for i in range(2): # TODO(theomonnom): we should test 2 segments - for chunk in chunks: - stream.push_text(chunk) + for i in range(2): # Testing 2 segments + text = text_segments[i] + + # Generate chunks for the current segment + pattern = [1, 2, 4] + chunks = [] + text_remaining = text + pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1)) + + for chunk_size in pattern_iter: + if not text_remaining: + break + chunks.append(text_remaining[:chunk_size]) + text_remaining = text_remaining[chunk_size:] + + for chunk in chunks: + stream.push_text(chunk) - stream.flush() - # if i == 1: - stream.end_input() + stream.flush() + if i == 1: + stream.end_input() frames = [] is_final = False @@ -140,11 +149,12 @@ async def test_stream(tts_factory): assert is_final, "final audio should be marked as final" - await _assert_valid_synthesized_audio( - frames, tts, synthesize_transcript, WER_THRESHOLD - ) + # Combine the segments for expected text + expected_text = "".join(text_segments) + + await _assert_valid_synthesized_audio(frames, tts, expected_text, WER_THRESHOLD) - # assert len(segments) == 2 + assert len(segments) == 2, "should have 2 segments" await stream.aclose()