diff --git a/Cargo.lock b/Cargo.lock index 3ef057677f..2e3c4537e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -966,6 +966,7 @@ dependencies = [ "bytes", "cidre", "cpal", + "dasp", "data", "ebur128", "futures-channel", @@ -7837,9 +7838,9 @@ dependencies = [ [[package]] name = "libsql" -version = "0.9.13" +version = "0.9.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b187535f3bad97145ec605403f7e2a4ee5e716b4912263314f79a76e39718d4" +checksum = "1d445da25d61b9413dae38d481799cfe5054502fb849c18f7a4a1ffeff39ef19" dependencies = [ "anyhow", "async-stream", @@ -7877,9 +7878,9 @@ dependencies = [ [[package]] name = "libsql-ffi" -version = "0.9.13" +version = "0.9.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f33c522e2fa888cf0dda209fd9007992d333f96078da6e610c0393df07b44918" +checksum = "c5a6c4c1d1ff03ed18f10ce9d1ee6b3820b0ef77e5656ccbfdb02388c025937d" dependencies = [ "bindgen 0.66.1", "cc", @@ -7889,9 +7890,9 @@ dependencies = [ [[package]] name = "libsql-hrana" -version = "0.9.13" +version = "0.9.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6487c3017fb9847f65ca76d98bc54c2cdec0e658bc95ee06768b984b7ed9d8b" +checksum = "8a9d38212a209cbecb16dabd681afaa3cd4498c5ccef46d838ad7162aed9a86c" dependencies = [ "base64 0.21.7", "bytes", @@ -7901,9 +7902,9 @@ dependencies = [ [[package]] name = "libsql-rusqlite" -version = "0.9.13" +version = "0.9.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "812cf096e36358f2dd3743196e8fcfe7c47abf674db59e303df659b41cdb78d6" +checksum = "310b02070aa2098e6706ccade6b6bbb3c696b207bf23cd20e9635c4a2dbabde0" dependencies = [ "bitflags 2.9.1", "fallible-iterator 0.2.0", @@ -7933,9 +7934,9 @@ dependencies = [ [[package]] name = "libsql-sys" -version = "0.9.13" +version = "0.9.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b25ab72a80d81356e0061396c1315c42e2beb5fabb1288d90293b17e45aac559" +checksum = "bec1127725da9cfb9fc473b281ee09c1a8a3e103785e663b2e8391f11f0c119f" dependencies = [ "bytes", "libsql-ffi", @@ -7947,9 +7948,9 @@ dependencies = [ [[package]] name = "libsql_replication" -version = "0.9.13" +version = "0.9.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "393efb39d69fa576fb144217a51a0b65c7a13cbde57eb2b58137326dd39fa877" +checksum = "ff72eb531eb84fa38d6683f5af31668ea59fd52ac1b461515e82bcf4b21ba98f" dependencies = [ "aes 0.8.4", "async-stream", @@ -16113,8 +16114,7 @@ name = "vad" version = "0.1.0" dependencies = [ "data", - "ndarray", - "ort", + "onnx", "serde", "thiserror 2.0.12", ] diff --git a/Cargo.toml b/Cargo.toml index 5df53927a2..4f975af306 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -200,7 +200,7 @@ kalosm-sound = { git = "https://github.com/floneum/floneum", rev = "52967ae", de kalosm-streams = { git = "https://github.com/floneum/floneum", rev = "52967ae" } deepgram = { version = "0.6.8", default-features = false } -libsql = "0.9.8" +libsql = "0.9.17" block2 = "0.6" objc2 = "0.6" diff --git a/apps/desktop/src-tauri/src/lib.rs b/apps/desktop/src-tauri/src/lib.rs index 17b58b2472..520d7aaa70 100644 --- a/apps/desktop/src-tauri/src/lib.rs +++ b/apps/desktop/src-tauri/src/lib.rs @@ -17,8 +17,10 @@ pub async fn main() { tauri::async_runtime::set(tokio::runtime::Handle::current()); { - let env_filter = - EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); + let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| { + EnvFilter::new("info") + .add_directive("ort::logging=error".parse().unwrap()) + }); tracing_subscriber::Registry::default() .with(fmt::layer()) diff --git a/apps/desktop/src/components/editor-area/note-header/listen-button.tsx b/apps/desktop/src/components/editor-area/note-header/listen-button.tsx index 4c69f5fede..b1fbe9d0eb 100644 --- a/apps/desktop/src/components/editor-area/note-header/listen-button.tsx +++ b/apps/desktop/src/components/editor-area/note-header/listen-button.tsx @@ -1,6 +1,16 @@ import { Trans } from "@lingui/react/macro"; import { useMutation, useQuery } from "@tanstack/react-query"; -import { MicIcon, MicOffIcon, PauseIcon, PlayIcon, StopCircleIcon, Volume2Icon, VolumeOffIcon } from "lucide-react"; +import { + Check, + ChevronDown, + MicIcon, + MicOffIcon, + Pause, + PlayIcon, + Square, + Volume2Icon, + VolumeOffIcon, +} from "lucide-react"; import { useEffect, useState } from "react"; import SoundIndicator from "@/components/sound-indicator"; @@ -319,16 +329,14 @@ function RecordingControls({ return ( <> -
- + toggleMicMuted.mutate()} - type="mic" + onToggleMuted={() => toggleMicMuted.mutate()} /> - toggleSpeakerMuted.mutate()} - type="speaker" />
@@ -361,7 +369,7 @@ function RecordingControls({ onClick={onPause} className="w-full" > - + Pause @@ -377,35 +385,139 @@ function RecordingControls({ ); } -function AudioControlButton({ - type, +function MicrophoneSelector({ + isMuted, + onToggleMuted, + disabled, +}: { + isMuted?: boolean; + onToggleMuted: () => void; + disabled?: boolean; +}) { + const [isOpen, setIsOpen] = useState(false); + const [selectedDevice, setSelectedDevice] = useState(""); + + const { data: devices = [], isLoading } = useQuery({ + queryKey: ["microphone-devices"], + queryFn: () => listenerCommands.listMicrophoneDevices(), + refetchOnWindowFocus: false, + }); + + useEffect(() => { + if (!selectedDevice && devices.length > 0) { + setSelectedDevice(devices[0]); + } + }, [devices, selectedDevice]); + + const Icon = isMuted ? MicOffIcon : MicIcon; + + return ( +
+ +
+ + + + + +
+ + +
+
+ Microphone +
+ + {isLoading + ? ( +
+
+

Loading devices...

+
+ ) + : devices.length === 0 + ? ( +
+

No microphones found

+
+ ) + : ( +
+ {devices.map((device) => { + const isSelected = device === selectedDevice; + return ( + + ); + })} +
+ )} +
+
+
+
+ ); +} + +function SpeakerButton({ isMuted, onClick, disabled, }: { - type: "mic" | "speaker"; isMuted?: boolean; onClick: () => void; disabled?: boolean; }) { - const Icon = type === "mic" - ? isMuted - ? MicOffIcon - : MicIcon - : isMuted - ? VolumeOffIcon - : Volume2Icon; + const Icon = isMuted ? VolumeOffIcon : Volume2Icon; return ( - +
+ +
); } diff --git a/apps/desktop/src/locales/en/messages.po b/apps/desktop/src/locales/en/messages.po index cdcb672395..b64ee657a3 100644 --- a/apps/desktop/src/locales/en/messages.po +++ b/apps/desktop/src/locales/en/messages.po @@ -256,8 +256,8 @@ msgstr "(Beta) Upcoming meeting notifications" #. placeholder {0}: disabled ? "Wait..." : isHovered ? "Resume" : "Ended" #: src/components/settings/views/templates.tsx:194 #: src/components/settings/components/wer-modal.tsx:116 -#: src/components/editor-area/note-header/listen-button.tsx:179 -#: src/components/editor-area/note-header/listen-button.tsx:218 +#: src/components/editor-area/note-header/listen-button.tsx:189 +#: src/components/editor-area/note-header/listen-button.tsx:228 msgid "{0}" msgstr "{0}" @@ -870,7 +870,7 @@ msgstr "No speech-to-text models available or failed to load." #~ msgid "No Template" #~ msgstr "No Template" -#: src/components/editor-area/note-header/listen-button.tsx:342 +#: src/components/editor-area/note-header/listen-button.tsx:350 msgid "No Template (Default)" msgstr "No Template (Default)" @@ -939,7 +939,7 @@ msgstr "Optional for participant suggestions" msgid "Owner" msgstr "Owner" -#: src/components/editor-area/note-header/listen-button.tsx:365 +#: src/components/editor-area/note-header/listen-button.tsx:373 msgid "Pause" msgstr "Pause" @@ -951,7 +951,7 @@ msgstr "people" msgid "Performance difference between languages" msgstr "Performance difference between languages" -#: src/components/editor-area/note-header/listen-button.tsx:198 +#: src/components/editor-area/note-header/listen-button.tsx:208 msgid "Play video" msgstr "Play video" @@ -995,7 +995,7 @@ msgstr "Required to transcribe other people's voice during meetings" msgid "Required to transcribe your voice during meetings" msgstr "Required to transcribe your voice during meetings" -#: src/components/editor-area/note-header/listen-button.tsx:107 +#: src/components/editor-area/note-header/listen-button.tsx:117 msgid "Resume" msgstr "Resume" @@ -1093,11 +1093,11 @@ msgstr "Start Annual Plan" msgid "Start Monthly Plan" msgstr "Start Monthly Plan" -#: src/components/editor-area/note-header/listen-button.tsx:154 +#: src/components/editor-area/note-header/listen-button.tsx:164 msgid "Start recording" msgstr "Start recording" -#: src/components/editor-area/note-header/listen-button.tsx:373 +#: src/components/editor-area/note-header/listen-button.tsx:381 msgid "Stop" msgstr "Stop" diff --git a/apps/desktop/src/locales/ko/messages.po b/apps/desktop/src/locales/ko/messages.po index 9b7286dbb1..8af8746781 100644 --- a/apps/desktop/src/locales/ko/messages.po +++ b/apps/desktop/src/locales/ko/messages.po @@ -256,8 +256,8 @@ msgstr "" #. placeholder {0}: disabled ? "Wait..." : isHovered ? "Resume" : "Ended" #: src/components/settings/views/templates.tsx:194 #: src/components/settings/components/wer-modal.tsx:116 -#: src/components/editor-area/note-header/listen-button.tsx:179 -#: src/components/editor-area/note-header/listen-button.tsx:218 +#: src/components/editor-area/note-header/listen-button.tsx:189 +#: src/components/editor-area/note-header/listen-button.tsx:228 msgid "{0}" msgstr "" @@ -870,7 +870,7 @@ msgstr "" #~ msgid "No Template" #~ msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:342 +#: src/components/editor-area/note-header/listen-button.tsx:350 msgid "No Template (Default)" msgstr "" @@ -939,7 +939,7 @@ msgstr "" msgid "Owner" msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:365 +#: src/components/editor-area/note-header/listen-button.tsx:373 msgid "Pause" msgstr "" @@ -951,7 +951,7 @@ msgstr "" msgid "Performance difference between languages" msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:198 +#: src/components/editor-area/note-header/listen-button.tsx:208 msgid "Play video" msgstr "" @@ -995,7 +995,7 @@ msgstr "" msgid "Required to transcribe your voice during meetings" msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:107 +#: src/components/editor-area/note-header/listen-button.tsx:117 msgid "Resume" msgstr "" @@ -1093,11 +1093,11 @@ msgstr "" msgid "Start Monthly Plan" msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:154 +#: src/components/editor-area/note-header/listen-button.tsx:164 msgid "Start recording" msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:373 +#: src/components/editor-area/note-header/listen-button.tsx:381 msgid "Stop" msgstr "" diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index bfe6a8e05f..abde9554ef 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -17,6 +17,7 @@ futures-util = { workspace = true } tokio = { workspace = true, features = ["rt", "macros"] } cpal = { workspace = true } +dasp = { workspace = true } rodio = { workspace = true } ebur128 = "0.1.10" diff --git a/crates/audio/src/errors.rs b/crates/audio/src/errors.rs index 43779ae96b..1683603c45 100644 --- a/crates/audio/src/errors.rs +++ b/crates/audio/src/errors.rs @@ -1,2 +1,7 @@ #[derive(thiserror::Error, Debug)] -pub enum Error {} +pub enum Error { + #[error("No input device found")] + NoInputDevice, + #[error(transparent)] + DefaultStreamConfigError(#[from] cpal::DefaultStreamConfigError), +} diff --git a/crates/audio/src/lib.rs b/crates/audio/src/lib.rs index 5f7aeec1f3..7c2b1397f8 100644 --- a/crates/audio/src/lib.rs +++ b/crates/audio/src/lib.rs @@ -73,7 +73,7 @@ impl AudioInput { pub fn from_mic() -> Self { Self { source: AudioSource::RealtimeMic, - mic: Some(MicInput::default()), + mic: Some(MicInput::new().unwrap()), speaker: None, data: None, } diff --git a/crates/audio/src/mic.rs b/crates/audio/src/mic.rs index f32164b1ac..c836ff1cf5 100644 --- a/crates/audio/src/mic.rs +++ b/crates/audio/src/mic.rs @@ -1,23 +1,393 @@ -pub use kalosm_sound::{MicInput, MicStream}; +use cpal::{ + traits::{DeviceTrait, HostTrait, StreamTrait}, + Device, SizedSample, +}; +use dasp::sample::ToSample; +use futures_channel::mpsc; +use futures_util::{Stream, StreamExt}; +use std::pin::Pin; +use std::sync::Arc; +use tokio::sync::{Mutex, RwLock}; + +use crate::AsyncSource; + +/// Information about an audio input device +#[derive(Debug, Clone)] +pub struct AudioDeviceInfo { + pub name: String, + pub index: usize, +} + +/// A microphone input with runtime device selection. +pub struct MicInput { + host: cpal::Host, + current_device: Arc>, + current_config: Arc>, + stream_manager: Arc>, +} + +struct StreamManager { + switch_tx: Option>, +} + +enum DeviceSwitchCommand { + SwitchDevice(Device, cpal::SupportedStreamConfig), +} + +#[derive(Debug, thiserror::Error)] +pub enum MicInputError { + #[error("No input device available")] + NoInputDevice, + #[error("Failed to get device config: {0}")] + ConfigError(#[from] cpal::DefaultStreamConfigError), + #[error("Device error: {0}")] + DeviceError(String), + #[error("Stream error: {0}")] + StreamError(#[from] cpal::BuildStreamError), + #[error("Play stream error: {0}")] + PlayStreamError(#[from] cpal::PlayStreamError), +} + +impl MicInput { + /// Create a new MicInput with the default input device + pub fn new() -> Result { + let host = cpal::default_host(); + let device = host + .default_input_device() + .ok_or(MicInputError::NoInputDevice)?; + let config = device.default_input_config()?; + + Ok(Self { + host, + current_device: Arc::new(RwLock::new(device)), + current_config: Arc::new(RwLock::new(config)), + stream_manager: Arc::new(Mutex::new(StreamManager { switch_tx: None })), + }) + } + + /// Create a MicInput with a specific device by index + pub fn with_device(device_index: usize) -> Result { + let host = cpal::default_host(); + let device = host + .input_devices() + .map_err(|e| MicInputError::DeviceError(e.to_string()))? + .nth(device_index) + .ok_or(MicInputError::DeviceError( + "Device index out of range".to_string(), + ))?; + let config = device.default_input_config()?; + + Ok(Self { + host, + current_device: Arc::new(RwLock::new(device)), + current_config: Arc::new(RwLock::new(config)), + stream_manager: Arc::new(Mutex::new(StreamManager { switch_tx: None })), + }) + } + + /// Get a list of available input devices + pub fn list_input_devices(&self) -> Vec { + match self.host.input_devices() { + Ok(devices) => devices + .enumerate() + .filter_map(|(index, device)| { + device + .name() + .ok() + .map(|name| AudioDeviceInfo { name, index }) + }) + .collect(), + Err(_) => Vec::new(), + } + } + + /// Get the currently selected device name + pub async fn current_device_name(&self) -> Result { + let device_guard = self.current_device.read().await; + device_guard + .name() + .map_err(|e| MicInputError::DeviceError(e.to_string())) + } + + /// Switch to a different input device by index + pub async fn switch_device(&self, device_index: usize) -> Result<(), MicInputError> { + let devices: Vec<_> = self + .host + .input_devices() + .map_err(|e| MicInputError::DeviceError(e.to_string()))? + .collect(); + + let device = devices + .into_iter() + .nth(device_index) + .ok_or(MicInputError::DeviceError( + "Device index out of range".to_string(), + ))?; + + let config = device.default_input_config()?; + + // Update the current device and config + { + let mut device_guard = self.current_device.write().await; + *device_guard = device.clone(); + } + { + let mut config_guard = self.current_config.write().await; + *config_guard = config.clone(); + } + + // Send switch command if there's an active stream + let manager = self.stream_manager.lock().await; + if let Some(tx) = &manager.switch_tx { + tx.send(DeviceSwitchCommand::SwitchDevice(device, config)) + .map_err(|_| { + MicInputError::DeviceError("Failed to send switch command".to_string()) + })?; + } + + Ok(()) + } + + /// Creates a new stream of audio data from the microphone (synchronous). + pub fn stream(&self) -> MicStream { + // Use bounded channel to prevent unbounded memory growth + let (tx, rx) = mpsc::channel::>(64); + let (switch_tx, switch_rx) = std::sync::mpsc::channel::(); + let (shutdown_tx, shutdown_rx) = std::sync::mpsc::channel::<()>(); + + // Clone current device and config synchronously using try_read + let (device, config) = { + let device_guard = self + .current_device + .try_read() + .expect("Failed to read device"); + let config_guard = self + .current_config + .try_read() + .expect("Failed to read config"); + (device_guard.clone(), config_guard.clone()) + }; + let config_clone = config.clone(); + + // Store the switch channel sender asynchronously + let stream_manager = self.stream_manager.clone(); + tokio::spawn(async move { + let mut manager = stream_manager.lock().await; + manager.switch_tx = Some(switch_tx); + }); + + // Spawn the CPAL handler thread + std::thread::spawn(move || { + cpal_stream_thread(device, config, tx, switch_rx, shutdown_rx); + }); + + let receiver = rx.map(futures_util::stream::iter).flatten(); + MicStream { + config: config_clone, + receiver: Box::pin(receiver), + shutdown_tx: Some(shutdown_tx), + } + } +} + +fn cpal_stream_thread( + initial_device: Device, + initial_config: cpal::SupportedStreamConfig, + audio_tx: mpsc::Sender>, + switch_rx: std::sync::mpsc::Receiver, + shutdown_rx: std::sync::mpsc::Receiver<()>, +) { + let mut current_stream: Option> = None; + let mut current_device = initial_device; + let mut current_config = initial_config; + + loop { + // Start stream if we don't have one + if current_stream.is_none() { + match start_stream(¤t_device, ¤t_config, audio_tx.clone()) { + Ok(stream) => { + current_stream = Some(stream); + tracing::info!("Audio stream started: {:?}", current_device.name()); + } + Err(e) => { + tracing::error!("Failed to start audio stream: {}", e); + std::thread::sleep(std::time::Duration::from_secs(1)); + continue; + } + } + } + + // Check for commands with timeout + match switch_rx.recv_timeout(std::time::Duration::from_millis(10)) { + Ok(DeviceSwitchCommand::SwitchDevice(new_device, new_config)) => { + tracing::info!("Switching audio device to: {:?}", new_device.name()); + + // Stop current stream + current_stream = None; + + // Small delay to ensure clean switch + std::thread::sleep(std::time::Duration::from_millis(50)); + + // Update device and config + current_device = new_device; + current_config = new_config; + } + Err(std::sync::mpsc::RecvTimeoutError::Timeout) => { + // Check if we should shutdown + match shutdown_rx.try_recv() { + Ok(_) => break, + Err(std::sync::mpsc::TryRecvError::Empty) => continue, + Err(std::sync::mpsc::TryRecvError::Disconnected) => break, + } + } + Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => { + // Channel closed + break; + } + } + } + + // Cleanup + drop(current_stream); + tracing::info!("Audio stream thread shutting down"); +} + +fn start_stream( + device: &Device, + config: &cpal::SupportedStreamConfig, + tx: mpsc::Sender>, +) -> Result, MicInputError> { + fn build_stream + SizedSample>( + device: &cpal::Device, + config: &cpal::SupportedStreamConfig, + mut tx: mpsc::Sender>, + ) -> Result { + let channels = config.channels() as usize; + device.build_input_stream::( + &config.config(), + move |data: &[S], _: &_| { + let samples: Vec = data + .iter() + .step_by(channels) + .map(|&x| x.to_sample()) + .collect(); + + // Try to send, but don't block or panic if receiver is gone + match tx.try_send(samples) { + Ok(_) => {} + Err(e) => { + if e.is_full() { + tracing::warn!("Audio buffer full, dropping samples"); + } + // If disconnected, the stream will be cleaned up + } + } + }, + |err| { + tracing::error!("Audio stream error: {}", err); + }, + None, + ) + } + + let stream: Box = match config.sample_format() { + cpal::SampleFormat::I8 => Box::new(build_stream::(device, config, tx)?), + cpal::SampleFormat::I16 => Box::new(build_stream::(device, config, tx)?), + cpal::SampleFormat::I32 => Box::new(build_stream::(device, config, tx)?), + cpal::SampleFormat::F32 => Box::new(build_stream::(device, config, tx)?), + sample_format => { + return Err(MicInputError::DeviceError(format!( + "Unsupported sample format '{}'", + sample_format + ))); + } + }; + + stream.play()?; + Ok(stream) +} + +/// A stream of audio data from the microphone. +pub struct MicStream { + config: cpal::SupportedStreamConfig, + receiver: Pin + Send + Sync>>, + shutdown_tx: Option>, +} + +impl Drop for MicStream { + fn drop(&mut self) { + // Signal shutdown to the background thread + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + } +} + +impl Stream for MicStream { + type Item = f32; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.receiver.as_mut().poll_next_unpin(cx) + } +} + +impl AsyncSource for MicStream { + fn as_stream(&mut self) -> impl Stream + '_ { + self + } + + fn sample_rate(&self) -> u32 { + self.config.sample_rate().0 + } +} + +// Add rodio compatibility methods if needed +impl MicStream { + /// Read all samples currently in the buffer (for compatibility) + pub fn read_all(&mut self) -> rodio::buffer::SamplesBuffer { + let mut samples = Vec::new(); + let mut cx = std::task::Context::from_waker(futures_util::task::noop_waker_ref()); + + // Drain all available samples + while let std::task::Poll::Ready(Some(sample)) = self.receiver.poll_next_unpin(&mut cx) { + samples.push(sample); + } + + rodio::buffer::SamplesBuffer::new( + self.config.channels(), + self.config.sample_rate().0, + samples, + ) + } +} #[cfg(test)] mod tests { use super::*; - use futures_util::StreamExt; - #[tokio::test] - async fn test_mic() { - let mic = MicInput::default(); - let mut stream = mic.stream(); + #[test] + fn assert_mic_stream_send_sync() { + fn assert_sync() {} + assert_sync::(); + fn assert_send() {} + assert_send::(); + } - let mut buffer = Vec::new(); - while let Some(sample) = stream.next().await { - buffer.push(sample); - if buffer.len() > 6000 { - break; + #[test] + fn test_mic_input_creation() { + // This test might fail on systems without audio devices + match MicInput::new() { + Ok(mic) => { + let devices = mic.list_input_devices(); + assert!(!devices.is_empty(), "Should have at least one input device"); } + Err(MicInputError::NoInputDevice) => { + // Expected on systems without mics + } + Err(e) => panic!("Unexpected error: {}", e), } - - assert!(buffer.iter().any(|x| *x != 0.0)); } } diff --git a/crates/onnx/src/lib.rs b/crates/onnx/src/lib.rs index e45235a384..858d7333fe 100644 --- a/crates/onnx/src/lib.rs +++ b/crates/onnx/src/lib.rs @@ -1,4 +1,5 @@ use ort::{ + logging::LogLevel, session::{builder::GraphOptimizationLevel, Session}, Result, }; @@ -8,6 +9,7 @@ pub use ort; pub fn load_model(bytes: &[u8]) -> Result { let session = Session::builder()? + .with_log_level(LogLevel::Error)? .with_intra_threads(1)? .with_inter_threads(1)? .with_optimization_level(GraphOptimizationLevel::Level3)? diff --git a/crates/vad/Cargo.toml b/crates/vad/Cargo.toml index a114c4a8ed..cbf3d65c08 100644 --- a/crates/vad/Cargo.toml +++ b/crates/vad/Cargo.toml @@ -5,14 +5,13 @@ edition = "2021" [features] default = [] -load-dynamic = ["ort/load-dynamic"] +load-dynamic = ["hypr-onnx/load-dynamic"] [dependencies] +hypr-onnx = { workspace = true } + serde = { workspace = true } thiserror = { workspace = true } -ndarray = "0.16" -ort = { version = "=2.0.0-rc.10", features = ["ndarray"] } - [dev-dependencies] hypr-data = { workspace = true } diff --git a/crates/vad/src/error.rs b/crates/vad/src/error.rs index b02763d7ef..99dddd9e1c 100644 --- a/crates/vad/src/error.rs +++ b/crates/vad/src/error.rs @@ -3,9 +3,9 @@ use serde::{ser::Serializer, Serialize}; #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] - OrtError(#[from] ort::Error), + OrtError(#[from] hypr_onnx::ort::Error), #[error(transparent)] - ShapeError(#[from] ndarray::ShapeError), + ShapeError(#[from] hypr_onnx::ndarray::ShapeError), #[error("Invalid or missing output from model")] InvalidOutput, } diff --git a/crates/vad/src/lib.rs b/crates/vad/src/lib.rs index 82b422b1f7..bee391dc3f 100644 --- a/crates/vad/src/lib.rs +++ b/crates/vad/src/lib.rs @@ -1,10 +1,12 @@ mod error; pub use error::*; -use ndarray::{Array1, Array2, Array3, ArrayBase, Ix1, Ix3, OwnedRepr}; -use ort::{ - session::{builder::GraphOptimizationLevel, Session}, - value::TensorRef, +use hypr_onnx::{ + ndarray::{Array1, Array2, Array3, ArrayBase, Ix1, Ix3, OwnedRepr}, + ort::{ + session::{builder::GraphOptimizationLevel, Session}, + value::TensorRef, + }, }; const MODEL_BYTES: &[u8] = @@ -48,7 +50,7 @@ impl Vad { let samples = audio_chunk.len(); let audio_tensor = Array2::from_shape_vec((1, samples), audio_chunk.to_vec())?; - let mut result = self.session.run(ort::inputs![ + let mut result = self.session.run(hypr_onnx::ort::inputs![ TensorRef::from_array_view(audio_tensor.view())?, TensorRef::from_array_view(self.sample_rate_tensor.view())?, TensorRef::from_array_view(self.h_tensor.view())?,