diff --git a/livekit-agents/livekit/agents/tts/stream_adapter.py b/livekit-agents/livekit/agents/tts/stream_adapter.py index fbb25df5d..d000aac8f 100644 --- a/livekit-agents/livekit/agents/tts/stream_adapter.py +++ b/livekit-agents/livekit/agents/tts/stream_adapter.py @@ -1,9 +1,8 @@ from __future__ import annotations -import asyncio from typing import AsyncIterable -from .. import tokenize, utils +from .. import tokenize from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from .tts import ( TTS, @@ -65,30 +64,27 @@ def __init__( wrapped_tts: TTS, sentence_tokenizer: tokenize.SentenceTokenizer, ) -> None: - super().__init__(tts=tts, conn_options=conn_options) + super().__init__( + tts=tts, + conn_options=conn_options, + tokenizer=sentence_tokenizer, + ) self._wrapped_tts = wrapped_tts - self._sent_stream = sentence_tokenizer.stream() async def _metrics_monitor_task( self, event_aiter: AsyncIterable[SynthesizedAudio] ) -> None: pass # do nothing - async def _run(self) -> None: - async def _forward_input(): - """forward input to vad""" - async for data in self._input_ch: - if isinstance(data, self._FlushSentinel): - self._sent_stream.flush() - continue - self._sent_stream.push_text(data) - - self._sent_stream.end_input() - + async def _run( + self, input_stream: tokenize.WordStream | tokenize.SentenceStream + ) -> None: async def _synthesize(): - async for ev in self._sent_stream: + async for ev in input_stream: last_audio: SynthesizedAudio | None = None - async for audio in self._wrapped_tts.synthesize(ev.token): + async for audio in self._wrapped_tts.synthesize( + ev.token, segment_id=ev.segment_id + ): if last_audio is not None: self._event_ch.send_nowait(last_audio) @@ -98,11 +94,4 @@ async def _synthesize(): last_audio.is_final = True self._event_ch.send_nowait(last_audio) - tasks = [ - asyncio.create_task(_forward_input()), - asyncio.create_task(_synthesize()), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) + await _synthesize() diff --git a/livekit-agents/livekit/agents/tts/tts.py b/livekit-agents/livekit/agents/tts/tts.py index e641bf39d..fd44ff3e4 100644 --- a/livekit-agents/livekit/agents/tts/tts.py +++ b/livekit-agents/livekit/agents/tts/tts.py @@ -9,6 +9,7 @@ from livekit import rtc +from .. import tokenize, utils from .._exceptions import APIConnectionError, APIError from ..log import logger from ..metrics import TTSMetrics @@ -75,6 +76,7 @@ def synthesize( text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + segment_id: str | None = None, ) -> ChunkedStream: ... def stream( @@ -234,9 +236,16 @@ async def __aexit__( class SynthesizeStream(ABC): class _FlushSentinel: ... - def __init__(self, *, tts: TTS, conn_options: APIConnectOptions) -> None: + def __init__( + self, + *, + tts: TTS, + conn_options: APIConnectOptions, + tokenizer: tokenize.WordTokenizer | tokenize.SentenceTokenizer, + ) -> None: super().__init__() self._tts = tts + self._tokenizer = tokenizer self._conn_options = conn_options self._input_ch = aio.Chan[Union[str, SynthesizeStream._FlushSentinel]]() self._event_ch = aio.Chan[SynthesizedAudio]() @@ -251,12 +260,52 @@ def __init__(self, *, tts: TTS, conn_options: APIConnectOptions) -> None: self._mtc_text = "" @abstractmethod - async def _run(self) -> None: ... + async def _run( + self, input_stream: tokenize.WordStream | tokenize.SentenceStream + ) -> None: ... async def _main_task(self) -> None: + if isinstance(self._tokenizer, tokenize.SentenceTokenizer): + self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]() + elif isinstance(self._tokenizer, tokenize.WordTokenizer): + self._segments_ch = utils.aio.Chan[tokenize.WordStream]() + + @utils.log_exceptions(logger=logger) + async def _tokenize_input(): + """tokenize text from the input_ch to words""" + input_stream = None + async for input in self._input_ch: + if isinstance(input, str): + if input_stream is None: + # new segment (after flush for e.g) + input_stream = self._tokenizer.stream() + self._segments_ch.send_nowait(input_stream) + + input_stream.push_text(input) + elif isinstance(input, self._FlushSentinel): + if input_stream is not None: + input_stream.end_input() + + input_stream = None + + self._segments_ch.close() + + @utils.log_exceptions(logger=logger) + async def _run_segments(): + async for input_stream in self._segments_ch: + await self._run(input_stream) + for i in range(self._conn_options.max_retry + 1): try: - return await self._run() + tasks = [ + asyncio.create_task(_tokenize_input()), + asyncio.create_task(_run_segments()), + ] + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) + return except APIError as e: if self._conn_options.max_retry == 0: raise diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py index 155d2c091..9446c789b 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py @@ -203,9 +203,14 @@ def synthesize( text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + segment_id: str | None = None, ) -> "ChunkedStream": return ChunkedStream( - tts=self, input_text=text, conn_options=conn_options, opts=self._opts + tts=self, + input_text=text, + conn_options=conn_options, + opts=self._opts, + segment_id=segment_id, ) @@ -217,14 +222,18 @@ def __init__( input_text: str, conn_options: APIConnectOptions, opts: _TTSOptions, + segment_id: str | None = None, ) -> None: super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) self._opts = opts + self._segment_id = segment_id + if self._segment_id is None: + self._segment_id = utils.shortuuid() async def _run(self): stream_callback = speechsdk.audio.PushAudioOutputStream( _PushAudioOutputStreamCallback( - self._opts, asyncio.get_running_loop(), self._event_ch + self._opts, asyncio.get_running_loop(), self._event_ch, self._segment_id ) ) synthesizer = _create_speech_synthesizer( @@ -289,12 +298,14 @@ def __init__( opts: _TTSOptions, loop: asyncio.AbstractEventLoop, event_ch: utils.aio.ChanSender[tts.SynthesizedAudio], + segment_id: str, ): super().__init__() self._event_ch = event_ch self._opts = opts self._loop = loop self._request_id = utils.shortuuid() + self._segment_id = segment_id self._bstream = utils.audio.AudioByteStream( sample_rate=opts.sample_rate, num_channels=1 @@ -304,6 +315,7 @@ def write(self, audio_buffer: memoryview) -> int: for frame in self._bstream.write(audio_buffer.tobytes()): audio = tts.SynthesizedAudio( request_id=self._request_id, + segment_id=self._segment_id, frame=frame, ) with contextlib.suppress(RuntimeError): @@ -315,6 +327,7 @@ def close(self) -> None: for frame in self._bstream.flush(): audio = tts.SynthesizedAudio( request_id=self._request_id, + segment_id=self._segment_id, frame=frame, ) with contextlib.suppress(RuntimeError): diff --git a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py index dd76473c7..305dfdd0c 100644 --- a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py +++ b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py @@ -49,6 +49,8 @@ NUM_CHANNELS = 1 BUFFERED_WORDS_COUNT = 8 +WSS_URL = "wss://api.cartesia.ai/tts/websocket" +BYTES_URL = "https://api.cartesia.ai/tts/bytes" @dataclass @@ -61,6 +63,7 @@ class _TTSOptions: emotion: list[TTSVoiceEmotion | str] | None api_key: str language: str + word_tokenizer: tokenize.WordTokenizer class TTS(tts.TTS): @@ -76,6 +79,9 @@ def __init__( sample_rate: int = 24000, api_key: str | None = None, http_session: aiohttp.ClientSession | None = None, + word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer( + ignore_punctuation=False, + ), ) -> None: """ Create a new instance of Cartesia TTS. @@ -92,6 +98,7 @@ def __init__( sample_rate (int, optional): The audio sample rate in Hz. Defaults to 24000. api_key (str, optional): The Cartesia API key. If not provided, it will be read from the CARTESIA_API_KEY environment variable. http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created. + word_tokenizer (tokenize.WordTokenizer, optional): The word tokenizer to use. Defaults to a basic tokenizer. """ super().__init__( @@ -113,6 +120,7 @@ def __init__( speed=speed, emotion=emotion, api_key=api_key, + word_tokenizer=word_tokenizer, ) self._session = http_session @@ -207,7 +215,7 @@ async def _run(self) -> None: try: async with self._session.post( - "https://api.cartesia.ai/tts/bytes", + BYTES_URL, headers=headers, json=json, timeout=aiohttp.ClientTimeout( @@ -251,18 +259,24 @@ def __init__( opts: _TTSOptions, session: aiohttp.ClientSession, ): - super().__init__(tts=tts, conn_options=conn_options) + super().__init__( + tts=tts, + conn_options=conn_options, + tokenizer=opts.word_tokenizer, + ) self._opts, self._session = opts, session - self._sent_tokenizer_stream = tokenize.basic.SentenceTokenizer( - min_sentence_len=BUFFERED_WORDS_COUNT - ).stream() - async def _run(self) -> None: + async def _run( + self, + input_stream: tokenize.WordStream, + max_retry: int = 3, + ) -> None: request_id = utils.shortuuid() async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse): + nonlocal close_ws base_pkt = _to_cartesia_options(self._opts) - async for ev in self._sent_tokenizer_stream: + async for ev in input_stream: token_pkt = base_pkt.copy() token_pkt["context_id"] = request_id token_pkt["transcript"] = ev.token + " " @@ -273,17 +287,11 @@ async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse): end_pkt["context_id"] = request_id end_pkt["transcript"] = " " end_pkt["continue"] = False + close_ws = True await ws.send_str(json.dumps(end_pkt)) - async def _input_task(): - async for data in self._input_ch: - if isinstance(data, self._FlushSentinel): - self._sent_tokenizer_stream.flush() - continue - self._sent_tokenizer_stream.push_text(data) - self._sent_tokenizer_stream.end_input() - async def _recv_task(ws: aiohttp.ClientWebSocketResponse): + nonlocal close_ws audio_bstream = utils.audio.AudioByteStream( sample_rate=self._opts.sample_rate, num_channels=NUM_CHANNELS, @@ -312,7 +320,9 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, ): - raise Exception("Cartesia connection closed unexpectedly") + if not close_ws: + raise Exception("Cartesia connection closed unexpectedly") + return if msg.type != aiohttp.WSMsgType.TEXT: logger.warning("unexpected Cartesia message type %s", msg.type) @@ -340,17 +350,18 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: else: logger.error("unexpected Cartesia message %s", data) - url = f"wss://api.cartesia.ai/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}" - ws: aiohttp.ClientWebSocketResponse | None = None try: ws = await asyncio.wait_for( - self._session.ws_connect(url), self._conn_options.timeout + self._session.ws_connect( + _to_cartesia_url(api_key=self._opts.api_key, version=API_VERSION) + ), + self._conn_options.timeout, ) + close_ws = False tasks = [ - asyncio.create_task(_input_task()), asyncio.create_task(_sentence_stream_task(ws)), asyncio.create_task(_recv_task(ws)), ] @@ -364,6 +375,10 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: await ws.close() +def _to_cartesia_url(api_key: str, version: str) -> str: + return f"{WSS_URL}?api_key={api_key}&cartesia_version={version}" + + def _to_cartesia_options(opts: _TTSOptions) -> dict[str, Any]: voice: dict[str, Any] = {} if isinstance(opts.voice, str): diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py index 56d7405a7..9fae03f57 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py @@ -153,7 +153,11 @@ def __init__( conn_options: APIConnectOptions, session: aiohttp.ClientSession, ) -> None: - super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) + super().__init__( + tts=tts, + input_text=input_text, + conn_options=conn_options, + ) self._opts = opts self._session = session self._base_url = base_url @@ -227,7 +231,11 @@ def __init__( opts: _TTSOptions, session: aiohttp.ClientSession, ): - super().__init__(tts=tts, conn_options=conn_options) + super().__init__( + tts=tts, + conn_options=conn_options, + tokenizer=opts.word_tokenizer, + ) self._opts = opts self._session = session self._base_url = base_url @@ -248,40 +256,33 @@ def update_options( self._reconnect_event.set() - async def _run(self) -> None: - closing_ws = False + async def _run( + self, + input_stream: tokenize.WordStream | tokenize.SentenceStream, + max_retry: int = 3, + ): request_id = utils.shortuuid() - segment_id = utils.shortuuid() - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=NUM_CHANNELS, - ) - - @utils.log_exceptions(logger=logger) - async def _tokenize_input(): - # Converts incoming text into WordStreams and sends them into _segments_ch - word_stream = None - async for input in self._input_ch: - if isinstance(input, str): - if word_stream is None: - word_stream = self._opts.word_tokenizer.stream() - self._segments_ch.send_nowait(word_stream) - word_stream.push_text(input) - elif isinstance(input, self._FlushSentinel): - if word_stream: - word_stream.end_input() - word_stream = None - self._segments_ch.close() - - @utils.log_exceptions(logger=logger) - async def _run_segments(ws: aiohttp.ClientWebSocketResponse): - nonlocal closing_ws - async for word_stream in self._segments_ch: - async for word in word_stream: - speak_msg = {"type": "Speak", "text": f"{word.token} "} - await ws.send_str(json.dumps(speak_msg)) - - # Always flush after a segment + segment_id = None + + async def send_task( + ws: aiohttp.ClientWebSocketResponse, + flush_after_words: int = 30, + ): + nonlocal closing_ws, segment_id + word_count = 0 + async for word in input_stream: + segment_id = word.segment_id + speak_msg = {"type": "Speak", "text": f"{word.token} "} + await ws.send_str(json.dumps(speak_msg)) + word_count += 1 + + if word_count >= flush_after_words: + flush_msg = {"type": "Flush"} + await ws.send_str(json.dumps(flush_msg)) + word_count = 0 + + # flush remaining words + if word_count > 0: flush_msg = {"type": "Flush"} await ws.send_str(json.dumps(flush_msg)) @@ -292,6 +293,11 @@ async def _run_segments(ws: aiohttp.ClientWebSocketResponse): async def recv_task(ws: aiohttp.ClientWebSocketResponse): last_frame: rtc.AudioFrame | None = None + nonlocal closing_ws, segment_id + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=1, + ) def _send_last_frame(*, segment_id: str, is_final: bool) -> None: nonlocal last_frame @@ -365,8 +371,7 @@ async def _connection_timeout(): closing_ws = False tasks = [ - asyncio.create_task(_tokenize_input()), - asyncio.create_task(_run_segments(ws)), + asyncio.create_task(send_task(ws)), asyncio.create_task(recv_task(ws)), ] wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py index 0c5490707..8ae840ac2 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py @@ -318,50 +318,17 @@ def __init__( conn_options: APIConnectOptions, opts: _TTSOptions, ): - super().__init__(tts=tts, conn_options=conn_options) + super().__init__( + tts=tts, + conn_options=conn_options, + tokenizer=opts.word_tokenizer, + ) self._opts, self._session = opts, session self._mp3_decoder = utils.codecs.Mp3StreamDecoder() - async def _run(self) -> None: - self._segments_ch = utils.aio.Chan[tokenize.WordStream]() - - @utils.log_exceptions(logger=logger) - async def _tokenize_input(): - """tokenize text from the input_ch to words""" - word_stream = None - async for input in self._input_ch: - if isinstance(input, str): - if word_stream is None: - # new segment (after flush for e.g) - word_stream = self._opts.word_tokenizer.stream() - self._segments_ch.send_nowait(word_stream) - - word_stream.push_text(input) - elif isinstance(input, self._FlushSentinel): - if word_stream is not None: - word_stream.end_input() - - word_stream = None - - self._segments_ch.close() - - @utils.log_exceptions(logger=logger) - async def _run(): - async for word_stream in self._segments_ch: - await self._run_ws(word_stream) - - tasks = [ - asyncio.create_task(_tokenize_input()), - asyncio.create_task(_run()), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) - - async def _run_ws( + async def _run( self, - word_stream: tokenize.WordStream, + input_stream: tokenize.WordStream, max_retry: int = 3, ) -> None: ws_conn: aiohttp.ClientWebSocketResponse | None = None @@ -406,7 +373,7 @@ async def send_task(): nonlocal eos_sent xml_content = [] - async for data in word_stream: + async for data in input_stream: text = data.token # send the xml phoneme in one go diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/tts.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/tts.py index 3ab494a66..94b46f408 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/tts.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/tts.py @@ -161,6 +161,7 @@ def synthesize( text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + segment_id: str | None = None, ) -> "ChunkedStream": return ChunkedStream( tts=self, @@ -168,6 +169,7 @@ def synthesize( conn_options=conn_options, opts=self._opts, client=self._ensure_client(), + segment_id=segment_id, ) @@ -180,9 +182,13 @@ def __init__( conn_options: APIConnectOptions, opts: _TTSOptions, client: texttospeech.TextToSpeechAsyncClient, + segment_id: str | None = None, ) -> None: super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) self._opts, self._client = opts, client + self._segment_id = segment_id + if self._segment_id is None: + self._segment_id = utils.shortuuid() async def _run(self) -> None: request_id = utils.shortuuid() @@ -204,18 +210,27 @@ async def _run(self) -> None: for frame in decoder.decode_chunk(response.audio_content): for frame in bstream.write(frame.data.tobytes()): self._event_ch.send_nowait( - tts.SynthesizedAudio(request_id=request_id, frame=frame) + tts.SynthesizedAudio( + request_id=request_id, + segment_id=self._segment_id, + frame=frame, + ) ) for frame in bstream.flush(): self._event_ch.send_nowait( - tts.SynthesizedAudio(request_id=request_id, frame=frame) + tts.SynthesizedAudio( + request_id=request_id, + segment_id=self._segment_id, + frame=frame, + ) ) else: data = response.audio_content[44:] # skip WAV header self._event_ch.send_nowait( tts.SynthesizedAudio( request_id=request_id, + segment_id=self._segment_id, frame=rtc.AudioFrame( data=data, sample_rate=self._opts.audio_config.sample_rate_hertz, diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py index ce7741eb8..23675a026 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py @@ -143,6 +143,7 @@ def synthesize( text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + segment_id: str | None = None, ) -> "ChunkedStream": return ChunkedStream( tts=self, @@ -150,6 +151,7 @@ def synthesize( conn_options=conn_options, opts=self._opts, client=self._client, + segment_id=segment_id, ) @@ -162,10 +164,14 @@ def __init__( conn_options: APIConnectOptions, opts: _TTSOptions, client: openai.AsyncClient, + segment_id: str | None = None, ) -> None: super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) self._client = client self._opts = opts + self._segment_id = segment_id + if self._segment_id is None: + self._segment_id = utils.shortuuid() async def _run(self): oai_stream = self._client.audio.speech.with_streaming_response.create( @@ -178,6 +184,7 @@ async def _run(self): ) request_id = utils.shortuuid() + audio_bstream = utils.audio.AudioByteStream( sample_rate=OPENAI_TTS_SAMPLE_RATE, num_channels=OPENAI_TTS_CHANNELS, @@ -191,6 +198,7 @@ async def _run(self): tts.SynthesizedAudio( frame=frame, request_id=request_id, + segment_id=self._segment_id, ) ) @@ -199,6 +207,7 @@ async def _run(self): tts.SynthesizedAudio( frame=frame, request_id=request_id, + segment_id=self._segment_id, ) ) diff --git a/tests/test_tts.py b/tests/test_tts.py index 91f8035b5..cca1407eb 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -79,6 +79,11 @@ async def test_synthesize(tts_factory): lambda: elevenlabs.TTS(encoding="pcm_44100"), id="elevenlabs.pcm_44100" ), pytest.param(lambda: cartesia.TTS(), id="cartesia"), + pytest.param(lambda: deepgram.TTS(), id="deepgram"), +] + + +STREAM_ADAPTER_TTS: list[Callable[[], tts.TTS]] = [ pytest.param( lambda: agents.tts.StreamAdapter( tts=openai.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER @@ -97,11 +102,50 @@ async def test_synthesize(tts_factory): ), id="azure.stream", ), - pytest.param(lambda: deepgram.TTS(), id="deepgram"), - pytest.param(lambda: playai.TTS(), id="playai"), ] +@pytest.mark.usefixtures("job_process") +@pytest.mark.parametrize("tts_factory", STREAM_ADAPTER_TTS) +async def test_stream_adapter(tts_factory): + tts: agents.tts.TTS = tts_factory() + + synthesize_transcript = make_test_synthesize() + + # Split the transcript into two segments + text_segments = [ + synthesize_transcript[: len(synthesize_transcript) // 2], + synthesize_transcript[len(synthesize_transcript) // 2 :], + ] + + stream = tts.stream() + + segments = set() + for i in range(2): # Testing 2 segments + text = text_segments[i] + stream.push_text(text) + stream.flush() + if i == 1: + stream.end_input() + + frames = [] + is_final = False + async for audio in stream: + is_final = audio.is_final + segments.add(audio.segment_id) + frames.append(audio.frame) + + assert is_final, "final audio should be marked as final" + + # Combine the segments for expected text + expected_text = "".join(text_segments) + + await _assert_valid_synthesized_audio(frames, tts, expected_text, WER_THRESHOLD) + + assert len(segments) == 2, "should have 2 segments" + await stream.aclose() + + @pytest.mark.usefixtures("job_process") @pytest.mark.parametrize("tts_factory", STREAM_TTS) async def test_stream(tts_factory): @@ -109,27 +153,36 @@ async def test_stream(tts_factory): synthesize_transcript = make_test_synthesize() - pattern = [1, 2, 4] - text = synthesize_transcript - chunks = [] - pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1)) - - for chunk_size in pattern_iter: - if not text: - break - chunks.append(text[:chunk_size]) - text = text[chunk_size:] + # Split the transcript into two segments + text_segments = [ + synthesize_transcript[: len(synthesize_transcript) // 2], + synthesize_transcript[len(synthesize_transcript) // 2 :], + ] stream = tts.stream() segments = set() - # for i in range(2): # TODO(theomonnom): we should test 2 segments - for chunk in chunks: - stream.push_text(chunk) + for i in range(2): # Testing 2 segments + text = text_segments[i] + + # Generate chunks for the current segment + pattern = [1, 2, 4] + chunks = [] + text_remaining = text + pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1)) + + for chunk_size in pattern_iter: + if not text_remaining: + break + chunks.append(text_remaining[:chunk_size]) + text_remaining = text_remaining[chunk_size:] + + for chunk in chunks: + stream.push_text(chunk) - stream.flush() - # if i == 1: - stream.end_input() + stream.flush() + if i == 1: + stream.end_input() frames = [] is_final = False @@ -140,11 +193,12 @@ async def test_stream(tts_factory): assert is_final, "final audio should be marked as final" - await _assert_valid_synthesized_audio( - frames, tts, synthesize_transcript, WER_THRESHOLD - ) + # Combine the segments for expected text + expected_text = "".join(text_segments) + + await _assert_valid_synthesized_audio(frames, tts, expected_text, WER_THRESHOLD) - # assert len(segments) == 2 + assert len(segments) == 2, "should have 2 segments" await stream.aclose()