diff --git a/ui/desktop/src/components/ClientSpeechRecorder.tsx b/ui/desktop/src/components/ClientSpeechRecorder.tsx index b2e81cfc6..716a67251 100644 --- a/ui/desktop/src/components/ClientSpeechRecorder.tsx +++ b/ui/desktop/src/components/ClientSpeechRecorder.tsx @@ -3,6 +3,8 @@ import { Button } from './ui/button'; import { Mic, Square } from 'lucide-react'; import WaveSurfer from 'wavesurfer.js'; import RecordPlugin from 'wavesurfer.js/dist/plugins/record.esm.js'; +// Import the worker directly +import SpeechWorker from './speech/worker?worker'; // Constants for audio processing const SAMPLE_RATE = 16000; @@ -91,9 +93,7 @@ export const AudioWaveform = React.forwardRef< // Initialize Web Worker for speech recognition console.log('Initializing Web Worker'); - const worker = new Worker(new URL('../components/speech/worker.ts', import.meta.url), { - type: 'module' - }); + const worker = new SpeechWorker(); worker.onmessage = (event) => { const { type, message } = event.data; @@ -165,7 +165,7 @@ export const AudioWaveform = React.forwardRef< // Load audio worklet console.log('Loading audio worklet'); await audioContext.audioWorklet.addModule( - new URL('../components/speech/processor.js', import.meta.url) + new URL('./speech/processor.js', import.meta.url) ); // Create audio processor @@ -267,4 +267,4 @@ export function ClientSpeechRecorder({onTranscription, containerClassName}: { ); -} +} \ No newline at end of file diff --git a/ui/desktop/src/components/speech/worker.ts b/ui/desktop/src/components/speech/worker.ts index 0a2cb4210..cb551ab4e 100644 --- a/ui/desktop/src/components/speech/worker.ts +++ b/ui/desktop/src/components/speech/worker.ts @@ -1,3 +1,4 @@ +// worker-wrapper.ts import {AutoModel, Tensor, pipeline} from "@huggingface/transformers"; // Constants for audio processing @@ -7,77 +8,8 @@ const MAX_BUFFER_DURATION = 30; const SPEECH_PAD_SAMPLES = 1600; const MAX_NUM_PREV_BUFFERS = 4; -// Check for WebGPU support -async function supportsWebGPU() { - if (!navigator.gpu) return false; - try { - const adapter = await navigator.gpu.requestAdapter(); - if (!adapter) return false; - const device = await adapter.requestDevice(); - return !!device; - } catch { - return false; - } -} - -console.log('Worker: Starting initialization'); -const device = (await supportsWebGPU()) ? "webgpu" : "wasm"; -console.log('Worker: Using device:', device); -window.postMessage({type: "info", message: `Using device: "${device}"`}); -window.postMessage({ - type: "info", - message: "Loading models...", - duration: "until_next", -}); - -// Load VAD model -console.log('Worker: Loading VAD model'); -const silero_vad = await AutoModel.from_pretrained( - "onnx-community/silero-vad", - { - config: {model_type: "custom"}, - dtype: "fp32", - }, -).catch((error) => { - console.error('Worker: Failed to load VAD model:', error); - window.postMessage({error}); - throw error; -}); -console.log('Worker: VAD model loaded'); - -// Configure model based on device -const DEVICE_DTYPE_CONFIGS = { - webgpu: { - encoder_model: "fp32", - decoder_model_merged: "q4", - }, - wasm: { - encoder_model: "fp32", - decoder_model_merged: "q8", - }, -}; - -// Initialize transcriber -console.log('Worker: Loading transcriber model'); -const transcriber = await pipeline( - "automatic-speech-recognition", - "onnx-community/moonshine-base-ONNX", - { - device, - dtype: DEVICE_DTYPE_CONFIGS[device], - }, -).catch((error) => { - console.error('Worker: Failed to load transcriber model:', error); - window.postMessage({error}); - throw error; -}); -console.log('Worker: Transcriber model loaded'); - -// Warm up the model -console.log('Worker: Warming up models'); -await transcriber(new Float32Array(SAMPLE_RATE)); -console.log('Worker: Models warmed up'); -window.postMessage({type: "status", status: "ready", message: "Ready!"}); +// Initialize worker context +const ctx = window as unknown as Worker; // Chain promises for inference let inferenceChain = Promise.resolve(); @@ -87,10 +19,78 @@ const BUFFER = new Float32Array(MAX_BUFFER_DURATION * SAMPLE_RATE); let bufferPointer = 0; // VAD state -const sr = new Tensor("int64", [SAMPLE_RATE], []); -let state = new Tensor("float32", new Float32Array(2 * 1 * 128), [2, 1, 128]); -let isRecording = false; +let state: Tensor; +let silero_vad: any; +let transcriber: any; +let sr: Tensor; let isActive = false; +let prevBuffers: Float32Array[] = []; + +async function initializeModels() { + console.log('Worker: Starting initialization'); + // Since we can't detect WebGPU in worker, default to wasm + const device = "wasm"; + console.log('Worker: Using device:', device); + ctx.postMessage({type: "info", message: `Using device: "${device}"`}); + ctx.postMessage({ + type: "info", + message: "Loading models...", + duration: "until_next", + }); + + // Load VAD model + console.log('Worker: Loading VAD model'); + silero_vad = await AutoModel.from_pretrained( + "onnx-community/silero-vad", + { + config: {model_type: "custom"}, + dtype: "fp32", + }, + ).catch((error) => { + console.error('Worker: Failed to load VAD model:', error); + ctx.postMessage({error}); + throw error; + }); + console.log('Worker: VAD model loaded'); + + // Configure model based on device + const DEVICE_DTYPE_CONFIGS = { + webgpu: { + encoder_model: "fp32", + decoder_model_merged: "q4", + }, + wasm: { + encoder_model: "fp32", + decoder_model_merged: "q8", + }, + }; + + // Initialize transcriber + console.log('Worker: Loading transcriber model'); + transcriber = await pipeline( + "automatic-speech-recognition", + "onnx-community/moonshine-base-ONNX", + { + device, + dtype: DEVICE_DTYPE_CONFIGS[device], + }, + ).catch((error) => { + console.error('Worker: Failed to load transcriber model:', error); + ctx.postMessage({error}); + throw error; + }); + console.log('Worker: Transcriber model loaded'); + + // Initialize VAD state + sr = new Tensor("int64", [SAMPLE_RATE], []); + state = new Tensor("float32", new Float32Array(2 * 1 * 128), [2, 1, 128]); + + // Warm up the model + console.log('Worker: Warming up models'); + await transcriber(new Float32Array(SAMPLE_RATE)); + console.log('Worker: Models warmed up'); + ctx.postMessage({type: "status", status: "ready", message: "Ready!"}); +} /** * Voice Activity Detection @@ -118,11 +118,9 @@ const transcribe = async (buffer: Float32Array, data: any) => { transcriber(buffer), )); console.log('Transcribe: Result:', text); - window.postMessage({type: "output", buffer, message: text, ...data}); + ctx.postMessage({type: "output", buffer, message: text, ...data}); }; -let prevBuffers: Float32Array[] = []; - const reset = (offset = 0) => { console.log('Reset: Resetting buffer with offset:', offset); BUFFER.fill(0, offset); @@ -157,8 +155,14 @@ const dispatchForTranscription = (overflow?: Float32Array) => { reset(overflowLength); }; -// Handle incoming audio data -window.onmessage = async (event) => { +// Initialize models +initializeModels().catch(error => { + console.error('Worker: Failed to initialize:', error); + ctx.postMessage({type: "error", error}); +}); + +// Handle incoming messages in worker context +ctx.onmessage = async (event) => { const {buffer, command} = event.data; if (command === 'stop' && isActive) { diff --git a/ui/desktop/vite.config.mts b/ui/desktop/vite.config.mts index fb6c807a1..5bb48d1e6 100644 --- a/ui/desktop/vite.config.mts +++ b/ui/desktop/vite.config.mts @@ -6,6 +6,7 @@ import { resolve } from 'path'; export default defineConfig({ plugins: [react()], build: { + target: 'esnext', rollupOptions: { input: { main: resolve(__dirname, 'src/main.ts'), @@ -13,4 +14,17 @@ export default defineConfig({ }, }, }, + worker: { + format: 'es', + plugins: () => [react()], + rollupOptions: { + output: { + format: 'es', + chunkFileNames: '[name]-[hash].js', + } + } + }, + optimizeDeps: { + exclude: ['@huggingface/transformers'] + } }); \ No newline at end of file