Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"mistralai/Devstral-Small-2505",
"Qwen/Qwen3-14B-FP8",
"Qwen/Qwen3-14B"
"Qwen/Qwen3-14B",
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import aiohttp
from pydantic import BaseModel

from livekit import rtc
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
Expand All @@ -37,7 +38,7 @@
utils,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
from livekit.agents.utils import AudioBuffer, rtc
from livekit.agents.utils import AudioBuffer
from livekit.agents.utils.misc import is_given

from .log import logger
Expand Down Expand Up @@ -199,10 +200,10 @@ async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.SpeechEvent:
language = self._opts.language if isinstance(language, type(NOT_GIVEN)) else language
language = language if language is not None else self._opts.language
wav_bytes = rtc.combine_audio_frames(buffer).to_wav_bytes()

audio_b64 = base64.b64encode(wav_bytes).decode("utf-8")
Expand Down Expand Up @@ -233,13 +234,15 @@ async def _recognize_impl(
)

response_json = await res.json()
timestamps = response_json.get("timestamps", [])
transcription = response_json.get("transcription", [])

detected_language = response_json["info"]["language"]

start_time = response_json["timestamps"][0][0]
end_time = response_json["timestamps"][-1][1]
start_time = timestamps[0][0] if timestamps else 0.0
end_time = timestamps[-1][1] if timestamps else 0.0
request_id = response_json.get("request_id", "")
text = "".join(response_json["transcription"])
text = "".join(transcription)

alternatives = [
stt.SpeechData(
Expand Down Expand Up @@ -275,10 +278,9 @@ def stream(
) -> "SpeechStream":
"""Create a streaming transcription session."""
opts_language = language if is_given(language) else self._opts.language
opts_model = model if is_given(model) else self._model

# Create options for the stream
stream_opts = SimplismartSTTOptions(language=opts_language, model=opts_model)
stream_opts = SimplismartSTTOptions(language=opts_language)

# Create a fresh session for this stream to avoid conflicts
stream_session = aiohttp.ClientSession()
Expand Down Expand Up @@ -316,7 +318,7 @@ def __init__(
self._api_key = api_key
self._session = http_session
self._reconnect_event = asyncio.Event()
self._request_id = id(self)
self._request_id = str(id(self))
self.ws_url = stt._base_url

async def _run(self) -> None:
Expand Down Expand Up @@ -355,7 +357,12 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
return

# this will trigger a reconnection, see the _run loop
raise APIStatusError(message="simplismart connection closed unexpectedly")
raise APIStatusError(
message="simplismart connection closed unexpectedly",
status_code=-1,
request_id=self._request_id,
body=None,
)

if msg.type != aiohttp.WSMsgType.BINARY:
logger.warning("unexpected simplismart message type %s", msg.type)
Expand Down Expand Up @@ -414,6 +421,11 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
raise APIConnectionError("failed to connect to simplismart") from e
return ws

async def aclose(self) -> None:
await super().aclose()
if self._session and not self._session.closed:
await self._session.close()

async def _send_initial_config(self, ws: aiohttp.ClientWebSocketResponse) -> None:
"""Send initial configuration message with language for Simplismart models."""
try:
Expand All @@ -438,7 +450,7 @@ def _handle_transcript_data(self, data: str) -> None:

try:
# Create usage event with proper metrics extraction
metrics = {}
metrics: dict[str, float] = {}
request_data = {
"original_id": request_id,
"processing_latency": metrics.get("processing_latency", 0.0),
Expand All @@ -454,7 +466,7 @@ def _handle_transcript_data(self, data: str) -> None:

# Create speech data
speech_data = stt.SpeechData(
language=self._opts.language,
language=self._opts.language or "en",
text=transcript_text,
)

Expand Down