Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions tycode-core/src/voice/stt/aws_transcribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use aws_sdk_transcribestreaming::{
use tokio::sync::mpsc;

use super::provider::{AudioSink, SpeechToText, TranscriptionStream};
use super::types::{Speaker, TranscriptionChunk};
use super::types::{Speaker, TranscriptionChunk, TranscriptionError};
use crate::voice::audio::AudioProfile;

/// Configuration for AWS Transcribe streaming
Expand Down Expand Up @@ -86,7 +86,7 @@ impl SpeechToText for AwsTranscribe {
}

async fn start(&self) -> Result<(AudioSink, TranscriptionStream)> {
let (result_tx, result_rx) = mpsc::channel::<TranscriptionChunk>(100);
let (result_tx, result_rx) = mpsc::channel::<Result<TranscriptionChunk, TranscriptionError>>(100);
let (audio_tx, mut audio_rx) = mpsc::channel::<Vec<u8>>(100);

let language_code = Self::parse_language_code(&self.config.language_code);
Expand Down Expand Up @@ -116,20 +116,36 @@ impl SpeechToText for AwsTranscribe {
let output = match response {
Ok(output) => output,
Err(e) => {
tracing::error!("Failed to start AWS Transcribe stream: {e:?}");
let error = TranscriptionError::StartupFailed {
message: format!("{e:?}"),
};
let _ = result_tx.send(Err(error)).await;
return;
}
};

let mut transcript_stream = output.transcript_result_stream;

while let Ok(Some(event)) = transcript_stream.recv().await {
let TranscriptResultStream::TranscriptEvent(transcript_event) = event else {
continue;
};

for chunk in extract_chunks(transcript_event) {
if result_tx.send(chunk).await.is_err() {
loop {
match transcript_stream.recv().await {
Ok(Some(event)) => {
let TranscriptResultStream::TranscriptEvent(transcript_event) = event
else {
continue;
};

for chunk in extract_chunks(transcript_event) {
if result_tx.send(Ok(chunk)).await.is_err() {
return;
}
}
}
Ok(None) => break,
Err(e) => {
let error = TranscriptionError::StreamError {
message: format!("{e:?}"),
};
let _ = result_tx.send(Err(error)).await;
return;
}
}
Expand Down
10 changes: 5 additions & 5 deletions tycode-core/src/voice/stt/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::{Context, Result};
use async_trait::async_trait;
use tokio::sync::mpsc;

use super::types::TranscriptionChunk;
use super::types::{TranscriptionChunk, TranscriptionError};
use crate::voice::audio::AudioProfile;

/// Trait for speech-to-text providers
Expand Down Expand Up @@ -41,17 +41,17 @@ impl AudioSink {

/// Handle for receiving transcription results
pub struct TranscriptionStream {
receiver: mpsc::Receiver<TranscriptionChunk>,
receiver: mpsc::Receiver<Result<TranscriptionChunk, TranscriptionError>>,
}

impl TranscriptionStream {
pub fn new(receiver: mpsc::Receiver<TranscriptionChunk>) -> Self {
pub fn new(receiver: mpsc::Receiver<Result<TranscriptionChunk, TranscriptionError>>) -> Self {
Self { receiver }
}

/// Receive the next transcription chunk
/// Receive the next transcription result
/// Returns None when the stream ends
pub async fn recv(&mut self) -> Option<TranscriptionChunk> {
pub async fn recv(&mut self) -> Option<Result<TranscriptionChunk, TranscriptionError>> {
self.receiver.recv().await
}
}
21 changes: 21 additions & 0 deletions tycode-core/src/voice/stt/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
use serde::{Deserialize, Serialize};
use std::fmt;

/// Errors that can occur during transcription
#[derive(Debug, Clone)]
pub enum TranscriptionError {
/// AWS Transcribe failed to start streaming
StartupFailed { message: String },
/// Stream error during transcription
StreamError { message: String },
}

impl fmt::Display for TranscriptionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::StartupFailed { message } => write!(f, "Transcription startup failed: {message}"),
Self::StreamError { message } => write!(f, "Transcription stream error: {message}"),
}
}
}

impl std::error::Error for TranscriptionError {}

/// A chunk of transcribed text
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
12 changes: 10 additions & 2 deletions tycode-core/tests/voice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async fn test_aws_transcribe_from_file() {
while tokio::time::Instant::now() < deadline {
match tokio::time::timeout(tokio::time::Duration::from_secs(5), transcriptions.recv()).await
{
Ok(Some(chunk)) => {
Ok(Some(Ok(chunk))) => {
println!(
"Received: {} (partial: {}, speaker: {:?})",
chunk.text, chunk.is_partial, chunk.speaker
Expand All @@ -142,6 +142,10 @@ async fn test_aws_transcribe_from_file() {
results.push(chunk.text);
}
}
Ok(Some(Err(e))) => {
println!("Transcription error: {}", e);
break;
}
Ok(None) => break,
Err(_) => break,
}
Expand Down Expand Up @@ -235,7 +239,7 @@ async fn test_live_microphone() {
}
transcription = transcriptions.recv() => {
match transcription {
Some(chunk) => {
Some(Ok(chunk)) => {
transcriptions_received += 1;
if chunk.is_partial {
print!("\r[partial] {}", chunk.text);
Expand All @@ -245,6 +249,10 @@ async fn test_live_microphone() {
println!("\n[final] {}", chunk.text);
}
}
Some(Err(e)) => {
println!("[error] Transcription error: {}", e);
break;
}
None => {
println!("[debug] Transcription stream ended");
break;
Expand Down