Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Literal, TypedDict, Union, overload

import aiohttp
from typing_extensions import Required

from livekit import rtc

Expand Down Expand Up @@ -69,6 +70,48 @@ class AssemblyaiOptions(TypedDict, total=False):
STTLanguages = Literal["multi", "en", "de", "es", "fr", "ja", "pt", "zh", "hi"]


class FallbackModel(TypedDict, total=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some docs for this?

"""A fallback model with optional extra configuration.

Extra fields are passed through to the provider.

Example:
>>> FallbackModel(name="deepgram/nova-3", extra_kwargs={"keywords": ["livekit"]})
"""

name: Required[str]
"""Model name (e.g. "deepgram/nova-3", "assemblyai/universal-streaming", "cartesia/ink-whisper")."""

extra_kwargs: dict[str, Any]
"""Extra configuration for the model."""


FallbackModelType = Union[FallbackModel, str]


def _parse_model_string(model: str) -> tuple[str, NotGivenOr[str]]:
language: NotGivenOr[str] = NOT_GIVEN
if (idx := model.rfind(":")) != -1:
language = model[idx + 1 :]
model = model[:idx]
return model, language


def _normalize_fallback(
fallback: list[FallbackModelType] | FallbackModelType,
) -> list[FallbackModel]:
def _make_fallback(model: FallbackModelType) -> FallbackModel:
if isinstance(model, str):
name, _ = _parse_model_string(model)
return FallbackModel(name=name)
return model

if isinstance(fallback, list):
return [_make_fallback(m) for m in fallback]

return [_make_fallback(fallback)]


STTModels = Union[
DeepgramModels,
CartesiaModels,
Expand All @@ -77,6 +120,7 @@ class AssemblyaiOptions(TypedDict, total=False):
]
STTEncoding = Literal["pcm_s16le"]


DEFAULT_ENCODING: STTEncoding = "pcm_s16le"
DEFAULT_SAMPLE_RATE: int = 16000
DEFAULT_BASE_URL = "https://agent-gateway.livekit.cloud/v1"
Expand All @@ -92,6 +136,8 @@ class STTOptions:
api_key: str
api_secret: str
extra_kwargs: dict[str, Any]
fallback: NotGivenOr[list[FallbackModel]]
conn_options: NotGivenOr[APIConnectOptions]


class STT(stt.STT):
Expand All @@ -108,6 +154,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[CartesiaOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None: ...

@overload
Expand All @@ -123,6 +171,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[DeepgramOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None: ...

@overload
Expand All @@ -138,6 +188,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[AssemblyaiOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None: ...

@overload
Expand All @@ -153,6 +205,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None: ...

def __init__(
Expand All @@ -169,6 +223,8 @@ def __init__(
extra_kwargs: NotGivenOr[
dict[str, Any] | CartesiaOptions | DeepgramOptions | AssemblyaiOptions
] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None:
"""Livekit Cloud Inference STT

Expand All @@ -182,6 +238,9 @@ def __init__(
api_secret (str, optional): LIVEKIT_API_SECRET, if not provided, read from environment variable.
http_session (aiohttp.ClientSession, optional): HTTP session to use.
extra_kwargs (dict, optional): Extra kwargs to pass to the STT model.
fallback (FallbackModelType, optional): Fallback models - either a list of model names,
a list of FallbackModel instances.
conn_options (APIConnectOptions, optional): Connection options for request attempts.
"""
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True),
Expand Down Expand Up @@ -212,6 +271,9 @@ def __init__(
raise ValueError(
"api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable"
)
fallback_models: NotGivenOr[list[FallbackModel]] = NOT_GIVEN
if is_given(fallback):
fallback_models = _normalize_fallback(fallback) # type: ignore[arg-type]

self._opts = STTOptions(
model=model,
Expand All @@ -222,6 +284,8 @@ def __init__(
api_key=lk_api_key,
api_secret=lk_api_secret,
extra_kwargs=dict(extra_kwargs) if is_given(extra_kwargs) else {},
fallback=fallback_models,
conn_options=conn_options if is_given(conn_options) else DEFAULT_API_CONNECT_OPTIONS,
)

self._session = http_session
Expand All @@ -237,12 +301,8 @@ def from_model_string(cls, model: str) -> STT:
Returns:
STT: STT instance
"""

language: NotGivenOr[str] = NOT_GIVEN
if (idx := model.rfind(":")) != -1:
language = model[idx + 1 :]
model = model[:idx]
return cls(model, language=language)
model_name, language = _parse_model_string(model)
return cls(model=model_name, language=language)

@property
def model(self) -> str:
Expand Down Expand Up @@ -459,6 +519,18 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
if self._opts.language:
params["settings"]["language"] = self._opts.language

if self._opts.fallback:
models = [
{"name": m.get("name"), "extra": m.get("extra_kwargs")} for m in self._opts.fallback
]
params["fallback"] = {"models": models}

if self._opts.conn_options:
params["connection"] = {
"timeout": self._opts.conn_options.timeout,
"retries": self._opts.conn_options.max_retry,
}

base_url = self._opts.base_url
if base_url.startswith(("http://", "https://")):
base_url = base_url.replace("http", "ws", 1)
Expand Down
100 changes: 92 additions & 8 deletions livekit-agents/livekit/agents/inference/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Literal, TypedDict, Union, overload

import aiohttp
from typing_extensions import NotRequired

from .. import tokenize, tts, utils
from .._exceptions import APIConnectionError, APIError, APIStatusError, APITimeoutError
Expand Down Expand Up @@ -42,6 +43,59 @@
"inworld/inworld-tts-1",
]

TTSModels = Union[CartesiaModels, ElevenlabsModels, RimeModels, InworldModels]


def _parse_model_string(model: str) -> tuple[str, str | None]:
"""Parse a model string into a model and voice
Args:
model (str): Model string to parse
Returns:
tuple[str, str | None]: Model and voice (voice is None if not specified)
"""
voice: str | None = None
if (idx := model.rfind(":")) != -1:
voice = model[idx + 1 :]
model = model[:idx]
return model, voice


class FallbackModel(TypedDict):
"""A fallback model with optional extra configuration.

Extra fields are passed through to the provider.

Example:
>>> FallbackModel(name="cartesia/sonic", voice="")
"""

name: str
"""Model name (e.g. "cartesia/sonic", "elevenlabs/eleven_flash_v2", "rime/arcana")."""

voice: str
"""Voice to use for the model."""

extra_kwargs: NotRequired[dict[str, Any]]
"""Extra configuration for the model."""


FallbackModelType = Union[FallbackModel, str]


def _normalize_fallback(
fallback: list[FallbackModelType] | FallbackModelType,
) -> list[FallbackModel]:
def _make_fallback(model: FallbackModelType) -> FallbackModel:
if isinstance(model, str):
name, voice = _parse_model_string(model)
return FallbackModel(name=name, voice=voice if voice else "")
return model

if isinstance(fallback, list):
return [_make_fallback(m) for m in fallback]

return [_make_fallback(fallback)]


class CartesiaOptions(TypedDict, total=False):
duration: float # max duration of audio in seconds
Expand All @@ -61,8 +115,6 @@ class InworldOptions(TypedDict, total=False):
pass


TTSModels = Union[CartesiaModels, ElevenlabsModels, RimeModels, InworldModels]

TTSEncoding = Literal["pcm_s16le"]

DEFAULT_ENCODING: TTSEncoding = "pcm_s16le"
Expand All @@ -81,6 +133,8 @@ class _TTSOptions:
api_key: str
api_secret: str
extra_kwargs: dict[str, Any]
fallback: NotGivenOr[list[FallbackModel]]
conn_options: NotGivenOr[APIConnectOptions]


class TTS(tts.TTS):
Expand All @@ -98,6 +152,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[CartesiaOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None:
pass

Expand All @@ -115,6 +171,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[ElevenlabsOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None:
pass

Expand All @@ -132,6 +190,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[RimeOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None:
pass

Expand All @@ -149,6 +209,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[InworldOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None:
pass

Expand All @@ -166,6 +228,8 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None:
pass

Expand All @@ -184,6 +248,8 @@ def __init__(
extra_kwargs: NotGivenOr[
dict[str, Any] | CartesiaOptions | ElevenlabsOptions | RimeOptions | InworldOptions
] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None:
"""Livekit Cloud Inference TTS

Expand All @@ -198,6 +264,9 @@ def __init__(
api_secret (str, optional): LIVEKIT_API_SECRET, if not provided, read from environment variable.
http_session (aiohttp.ClientSession, optional): HTTP session to use.
extra_kwargs (dict, optional): Extra kwargs to pass to the TTS model.
fallback (FallbackModelType, optional): Fallback models - either a list of model names,
a list of FallbackModel instances.
conn_options (APIConnectOptions, optional): Connection options for request attempts.
"""
sample_rate = sample_rate if is_given(sample_rate) else DEFAULT_SAMPLE_RATE
super().__init__(
Expand Down Expand Up @@ -232,6 +301,10 @@ def __init__(
"api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable"
)

fallback_models: NotGivenOr[list[FallbackModel]] = NOT_GIVEN
if is_given(fallback):
fallback_models = _normalize_fallback(fallback) # type: ignore[arg-type]

self._opts = _TTSOptions(
model=model,
voice=voice,
Expand All @@ -242,6 +315,8 @@ def __init__(
api_key=lk_api_key,
api_secret=lk_api_secret,
extra_kwargs=dict(extra_kwargs) if is_given(extra_kwargs) else {},
fallback=fallback_models,
conn_options=conn_options if is_given(conn_options) else DEFAULT_API_CONNECT_OPTIONS,
)
self._session = http_session
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
Expand All @@ -262,11 +337,8 @@ def from_model_string(cls, model: str) -> TTS:
Returns:
TTS: TTS instance
"""
voice: NotGivenOr[str] = NOT_GIVEN
if (idx := model.rfind(":")) != -1:
voice = model[idx + 1 :]
model = model[:idx]
return cls(model, voice=voice)
model, voice = _parse_model_string(model)
return cls(model=model, voice=voice if voice else NOT_GIVEN)

@property
def model(self) -> str:
Expand Down Expand Up @@ -295,7 +367,7 @@ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
raise APIStatusError("LiveKit TTS quota exceeded", status_code=e.status) from e
raise APIConnectionError("failed to connect to LiveKit TTS") from e

params = {
params: dict[str, Any] = {
"type": "session.create",
"sample_rate": str(self._opts.sample_rate),
"encoding": self._opts.encoding,
Expand All @@ -308,6 +380,18 @@ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
params["model"] = self._opts.model
if self._opts.language:
params["language"] = self._opts.language
if self._opts.fallback:
models = [
{"name": m.get("name"), "voice": m.get("voice"), "extra": m.get("extra_kwargs", {})}
for m in self._opts.fallback
]
params["fallback"] = {"models": models}

if self._opts.conn_options:
params["connection"] = {
"timeout": self._opts.conn_options.timeout,
"retries": self._opts.conn_options.max_retry,
}

try:
await ws.send_str(json.dumps(params))
Expand Down
Loading