diff --git a/tycode-core/src/voice/stt/aws_transcribe.rs b/tycode-core/src/voice/stt/aws_transcribe.rs index 6422e36..076298f 100644 --- a/tycode-core/src/voice/stt/aws_transcribe.rs +++ b/tycode-core/src/voice/stt/aws_transcribe.rs @@ -14,7 +14,7 @@ use aws_sdk_transcribestreaming::{ use tokio::sync::mpsc; use super::provider::{AudioSink, SpeechToText, TranscriptionStream}; -use super::types::{Speaker, TranscriptionChunk}; +use super::types::{Speaker, TranscriptionChunk, TranscriptionError}; use crate::voice::audio::AudioProfile; /// Configuration for AWS Transcribe streaming @@ -86,7 +86,7 @@ impl SpeechToText for AwsTranscribe { } async fn start(&self) -> Result<(AudioSink, TranscriptionStream)> { - let (result_tx, result_rx) = mpsc::channel::(100); + let (result_tx, result_rx) = mpsc::channel::>(100); let (audio_tx, mut audio_rx) = mpsc::channel::>(100); let language_code = Self::parse_language_code(&self.config.language_code); @@ -116,20 +116,36 @@ impl SpeechToText for AwsTranscribe { let output = match response { Ok(output) => output, Err(e) => { - tracing::error!("Failed to start AWS Transcribe stream: {e:?}"); + let error = TranscriptionError::StartupFailed { + message: format!("{e:?}"), + }; + let _ = result_tx.send(Err(error)).await; return; } }; let mut transcript_stream = output.transcript_result_stream; - while let Ok(Some(event)) = transcript_stream.recv().await { - let TranscriptResultStream::TranscriptEvent(transcript_event) = event else { - continue; - }; - - for chunk in extract_chunks(transcript_event) { - if result_tx.send(chunk).await.is_err() { + loop { + match transcript_stream.recv().await { + Ok(Some(event)) => { + let TranscriptResultStream::TranscriptEvent(transcript_event) = event + else { + continue; + }; + + for chunk in extract_chunks(transcript_event) { + if result_tx.send(Ok(chunk)).await.is_err() { + return; + } + } + } + Ok(None) => break, + Err(e) => { + let error = TranscriptionError::StreamError { + message: format!("{e:?}"), + }; + let _ = result_tx.send(Err(error)).await; return; } } diff --git a/tycode-core/src/voice/stt/provider.rs b/tycode-core/src/voice/stt/provider.rs index aae83fb..f025336 100644 --- a/tycode-core/src/voice/stt/provider.rs +++ b/tycode-core/src/voice/stt/provider.rs @@ -2,7 +2,7 @@ use anyhow::{Context, Result}; use async_trait::async_trait; use tokio::sync::mpsc; -use super::types::TranscriptionChunk; +use super::types::{TranscriptionChunk, TranscriptionError}; use crate::voice::audio::AudioProfile; /// Trait for speech-to-text providers @@ -41,17 +41,17 @@ impl AudioSink { /// Handle for receiving transcription results pub struct TranscriptionStream { - receiver: mpsc::Receiver, + receiver: mpsc::Receiver>, } impl TranscriptionStream { - pub fn new(receiver: mpsc::Receiver) -> Self { + pub fn new(receiver: mpsc::Receiver>) -> Self { Self { receiver } } - /// Receive the next transcription chunk + /// Receive the next transcription result /// Returns None when the stream ends - pub async fn recv(&mut self) -> Option { + pub async fn recv(&mut self) -> Option> { self.receiver.recv().await } } diff --git a/tycode-core/src/voice/stt/types.rs b/tycode-core/src/voice/stt/types.rs index 5b7620e..f448ed0 100644 --- a/tycode-core/src/voice/stt/types.rs +++ b/tycode-core/src/voice/stt/types.rs @@ -1,4 +1,25 @@ use serde::{Deserialize, Serialize}; +use std::fmt; + +/// Errors that can occur during transcription +#[derive(Debug, Clone)] +pub enum TranscriptionError { + /// AWS Transcribe failed to start streaming + StartupFailed { message: String }, + /// Stream error during transcription + StreamError { message: String }, +} + +impl fmt::Display for TranscriptionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::StartupFailed { message } => write!(f, "Transcription startup failed: {message}"), + Self::StreamError { message } => write!(f, "Transcription stream error: {message}"), + } + } +} + +impl std::error::Error for TranscriptionError {} /// A chunk of transcribed text #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/tycode-core/tests/voice.rs b/tycode-core/tests/voice.rs index 6ea83f2..e88a346 100644 --- a/tycode-core/tests/voice.rs +++ b/tycode-core/tests/voice.rs @@ -133,7 +133,7 @@ async fn test_aws_transcribe_from_file() { while tokio::time::Instant::now() < deadline { match tokio::time::timeout(tokio::time::Duration::from_secs(5), transcriptions.recv()).await { - Ok(Some(chunk)) => { + Ok(Some(Ok(chunk))) => { println!( "Received: {} (partial: {}, speaker: {:?})", chunk.text, chunk.is_partial, chunk.speaker @@ -142,6 +142,10 @@ async fn test_aws_transcribe_from_file() { results.push(chunk.text); } } + Ok(Some(Err(e))) => { + println!("Transcription error: {}", e); + break; + } Ok(None) => break, Err(_) => break, } @@ -235,7 +239,7 @@ async fn test_live_microphone() { } transcription = transcriptions.recv() => { match transcription { - Some(chunk) => { + Some(Ok(chunk)) => { transcriptions_received += 1; if chunk.is_partial { print!("\r[partial] {}", chunk.text); @@ -245,6 +249,10 @@ async fn test_live_microphone() { println!("\n[final] {}", chunk.text); } } + Some(Err(e)) => { + println!("[error] Transcription error: {}", e); + break; + } None => { println!("[debug] Transcription stream ended"); break;