Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 4 additions & 24 deletions crates/owhisper-client/src/adapter/assemblyai/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use owhisper_interface::batch::{
use serde::{Deserialize, Serialize};

use super::AssemblyAIAdapter;
use crate::adapter::http::ensure_success;
use crate::adapter::{BatchFuture, BatchSttAdapter, ClientWithMiddleware};
use crate::error::Error;
use crate::polling::{PollingConfig, PollingResult, poll_until};
Expand Down Expand Up @@ -134,14 +135,7 @@ impl AssemblyAIAdapter {
.send()
.await?;

let upload_status = upload_response.status();
if !upload_status.is_success() {
return Err(Error::UnexpectedStatus {
status: upload_status,
body: upload_response.text().await.unwrap_or_default(),
});
}

let upload_response = ensure_success(upload_response).await?;
let upload_result: UploadResponse = upload_response.json().await?;

let language_code = params
Expand Down Expand Up @@ -172,14 +166,7 @@ impl AssemblyAIAdapter {
.send()
.await?;

let create_status = create_response.status();
if !create_status.is_success() {
return Err(Error::UnexpectedStatus {
status: create_status,
body: create_response.text().await.unwrap_or_default(),
});
}

let create_response = ensure_success(create_response).await?;
let create_result: TranscriptResponse = create_response.json().await?;
let transcript_id = create_result.id;

Expand All @@ -197,14 +184,7 @@ impl AssemblyAIAdapter {
.send()
.await?;

let poll_status = poll_response.status();
if !poll_status.is_success() {
return Err(Error::UnexpectedStatus {
status: poll_status,
body: poll_response.text().await.unwrap_or_default(),
});
}

let poll_response = ensure_success(poll_response).await?;
let result: TranscriptResponse = poll_response.json().await?;

match result.status.as_str() {
Expand Down
99 changes: 99 additions & 0 deletions crates/owhisper-client/src/adapter/audio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use std::path::PathBuf;

use hypr_audio_utils::{Source, f32_to_i16_bytes, resample_audio, source_from_path};

use crate::error::Error;

const TARGET_SAMPLE_RATE: u32 = 16000;

pub async fn decode_audio_to_linear16(path: PathBuf) -> Result<(bytes::Bytes, u32), Error> {
tokio::task::spawn_blocking(move || -> Result<(bytes::Bytes, u32), Error> {
let decoder =
source_from_path(&path).map_err(|err| Error::AudioProcessing(err.to_string()))?;

let channels = decoder.channels().max(1);

let samples = resample_audio(decoder, TARGET_SAMPLE_RATE)
.map_err(|err| Error::AudioProcessing(err.to_string()))?;

let samples = mix_to_mono(samples, channels);

if samples.is_empty() {
return Err(Error::AudioProcessing(
"audio file contains no samples".to_string(),
));
}

let bytes = f32_to_i16_bytes(samples.into_iter());

Ok((bytes, TARGET_SAMPLE_RATE))
})
.await?
}

pub async fn decode_audio_to_bytes(path: PathBuf) -> Result<bytes::Bytes, Error> {
let (bytes, _sample_rate) = decode_audio_to_linear16(path).await?;
Ok(bytes)
}

fn mix_to_mono(samples: Vec<f32>, channels: u16) -> Vec<f32> {
if channels == 1 {
return samples;
}

let channels_usize = channels as usize;
let mut mono = Vec::with_capacity(samples.len() / channels_usize);
for frame in samples.chunks(channels_usize) {
if frame.is_empty() {
continue;
}
let sum: f32 = frame.iter().copied().sum();
mono.push(sum / frame.len() as f32);
}
mono
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_decode_audio_to_linear16() {
let path = PathBuf::from(hypr_data::english_1::AUDIO_PATH);
let result = decode_audio_to_linear16(path).await;
assert!(result.is_ok());
let (bytes, sample_rate) = result.unwrap();
assert!(!bytes.is_empty());
assert_eq!(sample_rate, 16000);
}

#[tokio::test]
async fn test_decode_audio_to_bytes() {
let path = PathBuf::from(hypr_data::english_1::AUDIO_PATH);
let result = decode_audio_to_bytes(path).await;
assert!(result.is_ok());
let bytes = result.unwrap();
assert!(!bytes.is_empty());
}

#[test]
fn test_mix_to_mono_single_channel() {
let samples = vec![1.0, 2.0, 3.0, 4.0];
let result = mix_to_mono(samples.clone(), 1);
assert_eq!(result, samples);
}

#[test]
fn test_mix_to_mono_stereo() {
let samples = vec![1.0, 3.0, 2.0, 4.0];
let result = mix_to_mono(samples, 2);
assert_eq!(result, vec![2.0, 3.0]);
}

#[test]
fn test_mix_to_mono_empty() {
let samples: Vec<f32> = vec![];
let result = mix_to_mono(samples, 2);
assert!(result.is_empty());
}
}
73 changes: 73 additions & 0 deletions crates/owhisper-client/src/adapter/http.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use reqwest::Response;
use serde::de::DeserializeOwned;

use crate::error::Error;

pub async fn ensure_success(response: Response) -> Result<Response, Error> {
let status = response.status();
if status.is_success() {
Ok(response)
} else {
let body = response.text().await.unwrap_or_default();
Err(Error::UnexpectedStatus { status, body })
}
}

pub async fn parse_json_response<T: DeserializeOwned>(
response: Response,
provider: &str,
) -> Result<T, Error> {
let response = ensure_success(response).await?;
let text = response.text().await?;

match serde_json::from_str(&text) {
Ok(v) => Ok(v),
Err(e) => {
tracing::warn!(
error = ?e,
%provider,
body = %text,
"stt_json_parse_failed"
);
Err(Error::AudioProcessing(format!(
"JSON parse error for {}: {}",
provider, e
)))
}
}
}

pub fn parse_provider_json<T: DeserializeOwned>(raw: &str, provider: &str) -> Option<T> {
match serde_json::from_str(raw) {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!(
error = ?e,
%provider,
raw,
"stt_json_parse_failed"
);
None
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_parse_provider_json_success() {
let json = r#"{"key": "value"}"#;
let result: Option<serde_json::Value> = parse_provider_json(json, "test");
assert!(result.is_some());
assert_eq!(result.unwrap()["key"], "value");
}

#[test]
fn test_parse_provider_json_failure() {
let json = "invalid json";
let result: Option<serde_json::Value> = parse_provider_json(json, "test");
assert!(result.is_none());
}
}
3 changes: 3 additions & 0 deletions crates/owhisper-client/src/adapter/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
mod argmax;
mod assemblyai;
#[cfg(feature = "argmax")]
pub mod audio;
mod deepgram;
mod deepgram_compat;
mod fireworks;
mod gladia;
pub mod http;
mod openai;
mod owhisper;
pub mod parsing;
Expand Down
128 changes: 127 additions & 1 deletion crates/owhisper-client/src/adapter/parsing.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use owhisper_interface::stream::Word;
use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word};

pub fn ms_to_secs(ms: u64) -> f64 {
ms as f64 / 1000.0
Expand Down Expand Up @@ -34,6 +34,132 @@ pub fn calculate_time_span<T: HasTimeSpan>(words: &[T]) -> (f64, f64) {
}
}

pub fn build_transcript_response(
transcript: String,
words: Vec<Word>,
is_final: bool,
speech_final: bool,
from_finalize: bool,
languages: Vec<String>,
channel_index: Vec<i32>,
) -> StreamResponse {
let (start, duration) = calculate_time_span(&words);

let channel = Channel {
alternatives: vec![Alternatives {
transcript,
words,
confidence: 1.0,
languages,
}],
};

StreamResponse::TranscriptResponse {
is_final,
speech_final,
from_finalize,
start,
duration,
channel,
metadata: Metadata::default(),
channel_index,
}
}

pub struct TranscriptResponseBuilder {
transcript: String,
words: Vec<Word>,
is_final: bool,
speech_final: bool,
from_finalize: bool,
languages: Vec<String>,
channel_index: Vec<i32>,
start: Option<f64>,
duration: Option<f64>,
}

impl TranscriptResponseBuilder {
pub fn new(transcript: impl Into<String>) -> Self {
Self {
transcript: transcript.into(),
words: Vec::new(),
is_final: false,
speech_final: false,
from_finalize: false,
languages: Vec::new(),
channel_index: vec![0],
start: None,
duration: None,
}
}

pub fn words(mut self, words: Vec<Word>) -> Self {
self.words = words;
self
}

pub fn is_final(mut self, is_final: bool) -> Self {
self.is_final = is_final;
self
}

pub fn speech_final(mut self, speech_final: bool) -> Self {
self.speech_final = speech_final;
self
}

pub fn from_finalize(mut self, from_finalize: bool) -> Self {
self.from_finalize = from_finalize;
self
}

pub fn languages(mut self, languages: Vec<String>) -> Self {
self.languages = languages;
self
}

pub fn channel_index(mut self, channel_index: Vec<i32>) -> Self {
self.channel_index = channel_index;
self
}

pub fn start(mut self, start: f64) -> Self {
self.start = Some(start);
self
}

pub fn duration(mut self, duration: f64) -> Self {
self.duration = Some(duration);
self
}

pub fn build(self) -> StreamResponse {
let (computed_start, computed_duration) = calculate_time_span(&self.words);
let start = self.start.unwrap_or(computed_start);
let duration = self.duration.unwrap_or(computed_duration);

let channel = Channel {
alternatives: vec![Alternatives {
transcript: self.transcript,
words: self.words,
confidence: 1.0,
languages: self.languages,
}],
};

StreamResponse::TranscriptResponse {
is_final: self.is_final,
speech_final: self.speech_final,
from_finalize: self.from_finalize,
start,
duration,
channel,
metadata: Metadata::default(),
channel_index: self.channel_index,
}
}
}

pub struct WordBuilder {
word: String,
start: f64,
Expand Down
Loading
Loading