Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ def __init__(
language_model_name: NotGivenOr[str] = NOT_GIVEN,
credentials: NotGivenOr[Credentials] = NOT_GIVEN,
):
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
# Enable diarization capability if show_speaker_label is enabled
diarization_enabled = is_given(show_speaker_label) and show_speaker_label
super().__init__(
capabilities=stt.STTCapabilities(
streaming=True,
interim_results=True,
diarization=diarization_enabled,
)
)

if not _AWS_SDK_AVAILABLE:
raise ImportError(
Expand Down Expand Up @@ -317,13 +325,20 @@ def _process_transcript_event(self, transcript_event: TranscriptEvent) -> None:

def _streaming_recognize_response_to_speech_data(self, resp: Result) -> stt.SpeechData:
confidence = 0.0
speaker_id = None

if resp.alternatives and (items := resp.alternatives[0].items):
confidence = items[0].confidence or 0.0
# Extract speaker_id from the first item if available
# AWS Transcribe returns speaker labels like "spk_0", "spk_1", etc.
if items[0].speaker:
speaker_id = items[0].speaker

return stt.SpeechData(
language=resp.language_code or self._opts.language,
start_time=resp.start_time if resp.start_time is not None else 0.0,
end_time=resp.end_time if resp.end_time is not None else 0.0,
text=resp.alternatives[0].transcript if resp.alternatives else "",
confidence=confidence,
speaker_id=speaker_id,
)
Loading