Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ class FallbackModel(TypedDict, total=False):
Extra fields are passed through to the provider.

Example:
>>> FallbackModel(name="deepgram/nova-3", extra_kwargs={"keywords": ["livekit"]})
>>> FallbackModel(model="deepgram/nova-3", extra_kwargs={"keywords": ["livekit"]})
"""

name: Required[str]
model: Required[str]
"""Model name (e.g. "deepgram/nova-3", "assemblyai/universal-streaming", "cartesia/ink-whisper")."""

extra_kwargs: dict[str, Any]
Expand All @@ -103,7 +103,7 @@ def _normalize_fallback(
def _make_fallback(model: FallbackModelType) -> FallbackModel:
if isinstance(model, str):
name, _ = _parse_model_string(model)
return FallbackModel(name=name)
return FallbackModel(model=name)
return model

if isinstance(fallback, list):
Expand Down Expand Up @@ -521,7 +521,8 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:

if self._opts.fallback:
models = [
{"name": m.get("name"), "extra": m.get("extra_kwargs")} for m in self._opts.fallback
{"model": m.get("model"), "extra": m.get("extra_kwargs")}
for m in self._opts.fallback
]
params["fallback"] = {"models": models}

Expand Down
14 changes: 9 additions & 5 deletions livekit-agents/livekit/agents/inference/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class FallbackModel(TypedDict):
Extra fields are passed through to the provider.

Example:
>>> FallbackModel(name="cartesia/sonic", voice="")
>>> FallbackModel(model="cartesia/sonic", voice="")
"""

name: str
model: str
"""Model name (e.g. "cartesia/sonic", "elevenlabs/eleven_flash_v2", "rime/arcana")."""

voice: str
Expand All @@ -87,8 +87,8 @@ def _normalize_fallback(
) -> list[FallbackModel]:
def _make_fallback(model: FallbackModelType) -> FallbackModel:
if isinstance(model, str):
name, voice = _parse_model_string(model)
return FallbackModel(name=name, voice=voice if voice else "")
model_name, voice = _parse_model_string(model)
return FallbackModel(model=model_name, voice=voice if voice else "")
return model

if isinstance(fallback, list):
Expand Down Expand Up @@ -382,7 +382,11 @@ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
params["language"] = self._opts.language
if self._opts.fallback:
models = [
{"name": m.get("name"), "voice": m.get("voice"), "extra": m.get("extra_kwargs", {})}
{
"model": m.get("model"),
"voice": m.get("voice"),
"extra": m.get("extra_kwargs", {}),
}
for m in self._opts.fallback
]
params["fallback"] = {"models": models}
Expand Down
68 changes: 34 additions & 34 deletions tests/test_inference_stt_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,79 +83,79 @@ class TestNormalizeFallback:
def test_single_string_model(self):
"""Single string model becomes a list with one FallbackModel."""
result = _normalize_fallback("deepgram/nova-3")
assert result == [{"name": "deepgram/nova-3"}]
assert result == [{"model": "deepgram/nova-3"}]

def test_single_fallback_model_dict(self):
"""Single FallbackModel dict becomes a list with that dict."""
fallback = FallbackModel(name="deepgram/nova-3")
fallback = FallbackModel(model="deepgram/nova-3")
result = _normalize_fallback(fallback)
assert result == [{"name": "deepgram/nova-3"}]
assert result == [{"model": "deepgram/nova-3"}]

def test_list_of_string_models(self):
"""List of string models becomes list of FallbackModels."""
result = _normalize_fallback(["deepgram/nova-3", "cartesia/ink-whisper"])
assert result == [
{"name": "deepgram/nova-3"},
{"name": "cartesia/ink-whisper"},
{"model": "deepgram/nova-3"},
{"model": "cartesia/ink-whisper"},
]

def test_list_of_fallback_model_dicts(self):
"""List of FallbackModel dicts is preserved."""
fallbacks = [
FallbackModel(name="deepgram/nova-3"),
FallbackModel(name="assemblyai"),
FallbackModel(model="deepgram/nova-3"),
FallbackModel(model="assemblyai"),
]
result = _normalize_fallback(fallbacks)
assert result == [
{"name": "deepgram/nova-3"},
{"name": "assemblyai"},
{"model": "deepgram/nova-3"},
{"model": "assemblyai"},
]

def test_mixed_list_strings_and_dicts(self):
"""Mixed list of strings and FallbackModel dicts."""
fallbacks = [
"deepgram/nova-3",
FallbackModel(name="cartesia/ink-whisper"),
FallbackModel(model="cartesia/ink-whisper"),
"assemblyai",
]
result = _normalize_fallback(fallbacks)
assert result == [
{"name": "deepgram/nova-3"},
{"name": "cartesia/ink-whisper"},
{"name": "assemblyai"},
{"model": "deepgram/nova-3"},
{"model": "cartesia/ink-whisper"},
{"model": "assemblyai"},
]

def test_string_with_language_suffix_discards_language(self):
"""Language suffix in string model is discarded."""
result = _normalize_fallback("deepgram/nova-3:en")
assert result == [{"name": "deepgram/nova-3"}]
assert result == [{"model": "deepgram/nova-3"}]

def test_fallback_model_with_extra_kwargs(self):
"""FallbackModel with extra_kwargs is preserved."""
fallback = FallbackModel(
name="deepgram/nova-3",
model="deepgram/nova-3",
extra_kwargs={"keywords": [("livekit", 1.5)], "punctuate": True},
)
result = _normalize_fallback(fallback)
assert result == [
{
"name": "deepgram/nova-3",
"model": "deepgram/nova-3",
"extra_kwargs": {"keywords": [("livekit", 1.5)], "punctuate": True},
}
]

def test_list_with_extra_kwargs_preserved(self):
"""List with FallbackModels containing extra_kwargs."""
fallbacks = [
FallbackModel(name="deepgram/nova-3", extra_kwargs={"punctuate": True}),
FallbackModel(model="deepgram/nova-3", extra_kwargs={"punctuate": True}),
"cartesia/ink-whisper",
FallbackModel(name="assemblyai", extra_kwargs={"format_turns": True}),
FallbackModel(model="assemblyai", extra_kwargs={"format_turns": True}),
]
result = _normalize_fallback(fallbacks)
assert result == [
{"name": "deepgram/nova-3", "extra_kwargs": {"punctuate": True}},
{"name": "cartesia/ink-whisper"},
{"name": "assemblyai", "extra_kwargs": {"format_turns": True}},
{"model": "deepgram/nova-3", "extra_kwargs": {"punctuate": True}},
{"model": "cartesia/ink-whisper"},
{"model": "assemblyai", "extra_kwargs": {"format_turns": True}},
]

def test_empty_list(self):
Expand All @@ -166,7 +166,7 @@ def test_empty_list(self):
def test_multiple_colons_in_model_string(self):
"""Multiple colons in model string - splits on last, discards language."""
result = _normalize_fallback("some:model:part:fr")
assert result == [{"name": "some:model:part"}]
assert result == [{"model": "some:model:part"}]


class TestSTTConstructorFallbackAndConnectOptions:
Expand All @@ -180,32 +180,32 @@ def test_fallback_not_given(self):
def test_fallback_single_string(self):
"""Single string fallback is normalized to list of FallbackModel."""
stt = _make_stt(fallback="cartesia/ink-whisper")
assert stt._opts.fallback == [{"name": "cartesia/ink-whisper"}]
assert stt._opts.fallback == [{"model": "cartesia/ink-whisper"}]

def test_fallback_list_of_strings(self):
"""List of string fallbacks is normalized."""
stt = _make_stt(fallback=["deepgram/nova-3", "assemblyai"])
assert stt._opts.fallback == [
{"name": "deepgram/nova-3"},
{"name": "assemblyai"},
{"model": "deepgram/nova-3"},
{"model": "assemblyai"},
]

def test_fallback_single_fallback_model(self):
"""Single FallbackModel is normalized to list."""
stt = _make_stt(fallback=FallbackModel(name="deepgram/nova-3"))
assert stt._opts.fallback == [{"name": "deepgram/nova-3"}]
stt = _make_stt(fallback=FallbackModel(model="deepgram/nova-3"))
assert stt._opts.fallback == [{"model": "deepgram/nova-3"}]

def test_fallback_with_extra_kwargs(self):
"""FallbackModel with extra_kwargs is preserved in _opts."""
stt = _make_stt(
fallback=FallbackModel(
name="deepgram/nova-3",
model="deepgram/nova-3",
extra_kwargs={"punctuate": True, "keywords": [("livekit", 1.5)]},
)
)
assert stt._opts.fallback == [
{
"name": "deepgram/nova-3",
"model": "deepgram/nova-3",
"extra_kwargs": {"punctuate": True, "keywords": [("livekit", 1.5)]},
}
]
Expand All @@ -215,20 +215,20 @@ def test_fallback_mixed_list(self):
stt = _make_stt(
fallback=[
"deepgram/nova-3",
FallbackModel(name="cartesia", extra_kwargs={"min_volume": 0.5}),
FallbackModel(model="cartesia", extra_kwargs={"min_volume": 0.5}),
"assemblyai",
]
)
assert stt._opts.fallback == [
{"name": "deepgram/nova-3"},
{"name": "cartesia", "extra_kwargs": {"min_volume": 0.5}},
{"name": "assemblyai"},
{"model": "deepgram/nova-3"},
{"model": "cartesia", "extra_kwargs": {"min_volume": 0.5}},
{"model": "assemblyai"},
]

def test_fallback_string_with_language_discarded(self):
"""Language suffix in fallback string is discarded."""
stt = _make_stt(fallback="deepgram/nova-3:en")
assert stt._opts.fallback == [{"name": "deepgram/nova-3"}]
assert stt._opts.fallback == [{"model": "deepgram/nova-3"}]

def test_connect_options_not_given_uses_default(self):
"""When connect_options is not provided, uses DEFAULT_API_CONNECT_OPTIONS."""
Expand Down
50 changes: 25 additions & 25 deletions tests/test_inference_tts_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,72 +72,72 @@ class TestNormalizeFallback:
def test_single_string_model(self):
"""Single string model becomes a list with one FallbackModel."""
result = _normalize_fallback("cartesia/sonic")
assert result == [{"name": "cartesia/sonic", "voice": ""}]
assert result == [{"model": "cartesia/sonic", "voice": ""}]

def test_single_string_model_with_voice(self):
"""Single string model with voice suffix extracts voice."""
result = _normalize_fallback("cartesia/sonic:my-voice")
assert result == [{"name": "cartesia/sonic", "voice": "my-voice"}]
assert result == [{"model": "cartesia/sonic", "voice": "my-voice"}]

def test_single_fallback_model_dict(self):
"""Single FallbackModel dict becomes a list with that dict."""
fallback = FallbackModel(name="cartesia/sonic", voice="narrator")
fallback = FallbackModel(model="cartesia/sonic", voice="narrator")
result = _normalize_fallback(fallback)
assert result == [{"name": "cartesia/sonic", "voice": "narrator"}]
assert result == [{"model": "cartesia/sonic", "voice": "narrator"}]

def test_list_of_string_models(self):
"""List of string models becomes list of FallbackModels."""
result = _normalize_fallback(["cartesia/sonic", "elevenlabs/eleven_flash_v2"])
assert result == [
{"name": "cartesia/sonic", "voice": ""},
{"name": "elevenlabs/eleven_flash_v2", "voice": ""},
{"model": "cartesia/sonic", "voice": ""},
{"model": "elevenlabs/eleven_flash_v2", "voice": ""},
]

def test_list_of_string_models_with_voices(self):
"""List of string models with voice suffixes."""
result = _normalize_fallback(["cartesia/sonic:voice1", "elevenlabs:voice2"])
assert result == [
{"name": "cartesia/sonic", "voice": "voice1"},
{"name": "elevenlabs", "voice": "voice2"},
{"model": "cartesia/sonic", "voice": "voice1"},
{"model": "elevenlabs", "voice": "voice2"},
]

def test_list_of_fallback_model_dicts(self):
"""List of FallbackModel dicts is preserved."""
fallbacks = [
FallbackModel(name="cartesia/sonic", voice="narrator"),
FallbackModel(name="elevenlabs", voice=""),
FallbackModel(model="cartesia/sonic", voice="narrator"),
FallbackModel(model="elevenlabs", voice=""),
]
result = _normalize_fallback(fallbacks)
assert result == [
{"name": "cartesia/sonic", "voice": "narrator"},
{"name": "elevenlabs", "voice": ""},
{"model": "cartesia/sonic", "voice": "narrator"},
{"model": "elevenlabs", "voice": ""},
]

def test_mixed_list_strings_and_dicts(self):
"""Mixed list of strings and FallbackModel dicts."""
fallbacks = [
"cartesia/sonic:voice1",
FallbackModel(name="elevenlabs/eleven_flash_v2", voice="custom"),
FallbackModel(model="elevenlabs/eleven_flash_v2", voice="custom"),
"rime/mist",
]
result = _normalize_fallback(fallbacks)
assert result == [
{"name": "cartesia/sonic", "voice": "voice1"},
{"name": "elevenlabs/eleven_flash_v2", "voice": "custom"},
{"name": "rime/mist", "voice": ""},
{"model": "cartesia/sonic", "voice": "voice1"},
{"model": "elevenlabs/eleven_flash_v2", "voice": "custom"},
{"model": "rime/mist", "voice": ""},
]

def test_fallback_model_with_extra_kwargs(self):
"""FallbackModel with extra_kwargs is preserved."""
fallback = FallbackModel(
name="cartesia/sonic",
model="cartesia/sonic",
voice="narrator",
extra_kwargs={"duration": 30.0, "speed": "fast"},
)
result = _normalize_fallback(fallback)
assert result == [
{
"name": "cartesia/sonic",
"model": "cartesia/sonic",
"voice": "narrator",
"extra_kwargs": {"duration": 30.0, "speed": "fast"},
}
Expand All @@ -146,15 +146,15 @@ def test_fallback_model_with_extra_kwargs(self):
def test_list_with_extra_kwargs_preserved(self):
"""List with FallbackModels containing extra_kwargs."""
fallbacks = [
FallbackModel(name="cartesia/sonic", voice="v1", extra_kwargs={"speed": "slow"}),
FallbackModel(model="cartesia/sonic", voice="v1", extra_kwargs={"speed": "slow"}),
"elevenlabs:voice2",
FallbackModel(name="rime/mist", voice="", extra_kwargs={"custom": True}),
FallbackModel(model="rime/mist", voice="", extra_kwargs={"custom": True}),
]
result = _normalize_fallback(fallbacks)
assert result == [
{"name": "cartesia/sonic", "voice": "v1", "extra_kwargs": {"speed": "slow"}},
{"name": "elevenlabs", "voice": "voice2"},
{"name": "rime/mist", "voice": "", "extra_kwargs": {"custom": True}},
{"model": "cartesia/sonic", "voice": "v1", "extra_kwargs": {"speed": "slow"}},
{"model": "elevenlabs", "voice": "voice2"},
{"model": "rime/mist", "voice": "", "extra_kwargs": {"custom": True}},
]

def test_empty_list(self):
Expand All @@ -164,6 +164,6 @@ def test_empty_list(self):

def test_fallback_model_with_none_voice(self):
"""FallbackModel with explicit None voice."""
fallback = FallbackModel(name="cartesia/sonic", voice="")
fallback = FallbackModel(model="cartesia/sonic", voice="")
result = _normalize_fallback(fallback)
assert result == [{"name": "cartesia/sonic", "voice": ""}]
assert result == [{"model": "cartesia/sonic", "voice": ""}]
Loading