diff --git a/Cargo.lock b/Cargo.lock index b1c76218a2..5f1e86a8f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1611,6 +1611,7 @@ dependencies = [ "axum-core 0.5.2", "bytes", "cookie", + "form_urlencoded", "futures-util", "http 1.3.1", "http-body 1.0.1", @@ -1619,6 +1620,8 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", + "serde_html_form", + "serde_path_to_error", "tower 0.5.2", "tower-layer", "tower-service", @@ -7813,7 +7816,9 @@ version = "0.1.0" dependencies = [ "codes-iso-639", "deepgram", + "schemars 0.8.22", "serde", + "specta", "thiserror 2.0.12", "whisper", ] @@ -12532,6 +12537,19 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "serde_html_form" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4" +dependencies = [ + "form_urlencoded", + "indexmap 2.10.0", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.140" @@ -14334,6 +14352,7 @@ version = "0.1.0" dependencies = [ "audio-utils", "axum 0.8.4", + "axum-extra", "chunker", "data", "dirs 6.0.0", diff --git a/Cargo.toml b/Cargo.toml index 613b1826a3..b3ab5ec1b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,7 +149,8 @@ inventory = "0.3.20" serial_test = "3" testcontainers-modules = "0.11.5" -axum = "0.8.1" +axum = "0.8.4" +axum-extra = "0.10.1" tower = "0.5.2" tower-http = "0.6.2" diff --git a/apps/app/server/src/native/listen/realtime.rs b/apps/app/server/src/native/listen/realtime.rs index 86eeed588e..ef37be76bd 100644 --- a/apps/app/server/src/native/listen/realtime.rs +++ b/apps/app/server/src/native/listen/realtime.rs @@ -29,7 +29,10 @@ async fn websocket(socket: WebSocket, state: STTState, params: ListenParams) { let (mut ws_sender, ws_receiver) = socket.split(); - let mut stt = state.realtime_stt.for_language(params.language).await; + let mut stt = state + .realtime_stt + .for_language(params.languages.first().unwrap().clone()) + .await; let input_stream = futures_util::stream::try_unfold(ws_receiver, |mut ws_receiver| async move { diff --git a/apps/app/server/src/native/listen/recorded.rs b/apps/app/server/src/native/listen/recorded.rs index 1344771e8d..6db2596267 100644 --- a/apps/app/server/src/native/listen/recorded.rs +++ b/apps/app/server/src/native/listen/recorded.rs @@ -13,7 +13,10 @@ pub async fn handler( Query(params): Query, State(state): State, ) -> impl IntoResponse { - let stt = state.recorded_stt.for_language(params.language).await; + let stt = state + .recorded_stt + .for_language(params.languages.first().unwrap().clone()) + .await; let input = RecordedSpeech::File("TODO".into()); let result = stt.transcribe(input).await.unwrap(); diff --git a/apps/desktop/src/components/settings/views/general.tsx b/apps/desktop/src/components/settings/views/general.tsx index 0867edff60..5d1bb7f93a 100644 --- a/apps/desktop/src/components/settings/views/general.tsx +++ b/apps/desktop/src/components/settings/views/general.tsx @@ -2,6 +2,7 @@ import { zodResolver } from "@hookform/resolvers/zod"; import { LANGUAGES_ISO_639_1 } from "@huggingface/languages"; import { Trans, useLingui } from "@lingui/react/macro"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { Plus, X } from "lucide-react"; import { useEffect } from "react"; import { useForm } from "react-hook-form"; import { z } from "zod"; @@ -9,6 +10,9 @@ import { z } from "zod"; import { showModelSelectToast } from "@/components/toast/model-select"; import { commands } from "@/types"; import { commands as dbCommands, type ConfigGeneral } from "@hypr/plugin-db"; +import { Badge } from "@hypr/ui/components/ui/badge"; +import { Button } from "@hypr/ui/components/ui/button"; +import { Command, CommandEmpty, CommandGroup, CommandInput, CommandItem } from "@hypr/ui/components/ui/command"; import { Form, FormControl, @@ -19,7 +23,8 @@ import { FormMessage, } from "@hypr/ui/components/ui/form"; import { Input } from "@hypr/ui/components/ui/input"; -import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@hypr/ui/components/ui/select"; +import { Popover, PopoverContent, PopoverTrigger } from "@hypr/ui/components/ui/popover"; +import { Select, SelectTrigger, SelectValue } from "@hypr/ui/components/ui/select"; import { Switch } from "@hypr/ui/components/ui/switch"; type ISO_639_1_CODE = keyof typeof LANGUAGES_ISO_639_1; @@ -71,6 +76,7 @@ const SUPPORTED_LANGUAGES: ISO_639_1_CODE[] = [ const schema = z.object({ autostart: z.boolean().optional(), displayLanguage: z.enum(SUPPORTED_LANGUAGES as [string, ...string[]]), + spokenLanguages: z.array(z.enum(SUPPORTED_LANGUAGES as [string, ...string[]])).min(1), telemetryConsent: z.boolean().optional(), jargons: z.string(), saveRecordings: z.boolean().optional(), @@ -95,6 +101,7 @@ export default function General() { defaultValues: { autostart: false, displayLanguage: "en", + spokenLanguages: ["en"], telemetryConsent: true, jargons: "", saveRecordings: true, @@ -106,6 +113,7 @@ export default function General() { form.reset({ autostart: config.data.general.autostart ?? false, displayLanguage: config.data.general.display_language ?? "en", + spokenLanguages: config.data.general.spoken_languages ?? ["en"], telemetryConsent: config.data.general.telemetry_consent ?? true, jargons: (config.data.general.jargons ?? []).join(", "), saveRecordings: config.data.general.save_recordings ?? true, @@ -123,6 +131,7 @@ export default function General() { const nextGeneral: ConfigGeneral = { autostart: v.autostart ?? false, display_language: v.displayLanguage, + spoken_languages: v.spokenLanguages, telemetry_consent: v.telemetryConsent ?? true, jargons: v.jargons.split(",").map((jargon) => jargon.trim()).filter(Boolean), save_recordings: v.saveRecordings ?? true, @@ -163,7 +172,7 @@ export default function General() { return (
- +
- Language + Display language - Choose your preferred language of use + Primary language for the interface
- {/* FormMessage is usually displayed below the control, might need separate handling if it must be in-row and an error occurs */} - {/* */} + + )} + /> + + ( + +
+ + Spoken languages + + + Select languages you speak for better transcription + +
+ +
+
+ {field.value.map((langCode) => ( + + {LANGUAGES_ISO_639_1[langCode as ISO_639_1_CODE]?.name || langCode} + + + ))} +
+ + + + + + + + No language found. + + {SUPPORTED_LANGUAGES.filter( + (lang) => !field.value.includes(lang), + ).map((lang) => ( + { + if (!field.value.includes(lang)) { + field.onChange([...field.value, lang]); + mutation.mutate(form.getValues()); + } + }} + > + {LANGUAGES_ISO_639_1[lang].name} + + ))} + + + + +
+
+
)} /> @@ -258,13 +341,13 @@ export default function General() { name="jargons" render={({ field }) => ( -
+
- Jargons + Custom Vocabulary - You can make Hyprnote takes these words into account when transcribing + Add specific terms or jargon for improved transcription accuracy
@@ -273,9 +356,9 @@ export default function General() { {...field} onBlur={() => mutation.mutate(form.getValues())} placeholder={t({ - id: "Type jargons (e.g., Blitz Meeting, PaC Squad)", + id: "Type terms separated by commas (e.g., Blitz Meeting, PaC Squad)", })} - className="focus-visible:ring-0 focus-visible:ring-offset-0" + className="focus-visible:ring-1 focus-visible:ring-ring" /> diff --git a/apps/desktop/src/locales/en/messages.po b/apps/desktop/src/locales/en/messages.po index 03b55fa809..20523a2348 100644 --- a/apps/desktop/src/locales/en/messages.po +++ b/apps/desktop/src/locales/en/messages.po @@ -14,9 +14,9 @@ msgstr "" "Plural-Forms: \n" #. js-lingui-explicit-id -#: src/components/settings/views/general.tsx:275 -msgid "Type jargons (e.g., Blitz Meeting, PaC Squad)" -msgstr "Type jargons (e.g., Blitz Meeting, PaC Squad)" +#: src/components/settings/views/general.tsx:358 +msgid "Type terms separated by commas (e.g., Blitz Meeting, PaC Squad)" +msgstr "Type terms separated by commas (e.g., Blitz Meeting, PaC Squad)" #. js-lingui-explicit-id #: ../../packages/utils/src/datetime.ts:22 @@ -208,6 +208,11 @@ msgstr "1 day later" msgid "{weeks} weeks later" msgstr "{weeks} weeks later" +#. js-lingui-explicit-id +#: src/components/settings/views/general.tsx:275 +#~ msgid "Type jargons (e.g., Blitz Meeting, PaC Squad)" +#~ msgstr "Type jargons (e.g., Blitz Meeting, PaC Squad)" + #. js-lingui-explicit-id #: ../../packages/utils/src/datetime.ts:168 #~ msgid "just now" @@ -315,6 +320,10 @@ msgstr "Add members" msgid "Add participant" msgstr "Add participant" +#: src/components/settings/views/general.tsx:349 +msgid "Add specific terms or jargon for improved transcription accuracy" +msgstr "Add specific terms or jargon for improved transcription accuracy" + #: src/components/settings/views/team.tsx:142 #: src/components/settings/views/team.tsx:229 msgid "Admin" @@ -436,13 +445,13 @@ msgstr "Chat with meeting notes" #~ msgid "Choose the language you want to use for the speech-to-text model and language model" #~ msgstr "Choose the language you want to use for the speech-to-text model and language model" -#: src/components/settings/views/general.tsx:177 +#: src/components/settings/views/general.tsx:186 msgid "Choose whether to save your recordings locally." msgstr "Choose whether to save your recordings locally." #: src/components/settings/views/general.tsx:230 -msgid "Choose your preferred language of use" -msgstr "Choose your preferred language of use" +#~ msgid "Choose your preferred language of use" +#~ msgstr "Choose your preferred language of use" #: src/components/settings/components/wer-modal.tsx:126 msgid "Close" @@ -539,6 +548,10 @@ msgstr "Current Plan" msgid "Custom Endpoint" msgstr "Custom Endpoint" +#: src/components/settings/views/general.tsx:346 +msgid "Custom Vocabulary" +msgstr "Custom Vocabulary" + #: src/components/settings/components/ai/llm-view.tsx:149 #~ msgid "Default (llama-3.2-3b-q4)" #~ msgstr "Default (llama-3.2-3b-q4)" @@ -565,6 +578,10 @@ msgstr "Description" #~ msgid "Did you get consent from everyone in the meeting?" #~ msgstr "Did you get consent from everyone in the meeting?" +#: src/components/settings/views/general.tsx:236 +msgid "Display language" +msgstr "Display language" + #. placeholder {0}: metadata?.size && `(${metadata.size})` #: src/components/settings/components/ai/stt-view.tsx:237 msgid "Download {0}" @@ -692,7 +709,7 @@ msgstr "Get Started" msgid "Grant Access" msgstr "Grant Access" -#: src/components/settings/views/general.tsx:203 +#: src/components/settings/views/general.tsx:212 msgid "Help us improve Hyprnote by sharing anonymous usage data" msgstr "Help us improve Hyprnote by sharing anonymous usage data" @@ -741,8 +758,8 @@ msgid "Invite members" msgstr "Invite members" #: src/components/settings/views/general.tsx:263 -msgid "Jargons" -msgstr "Jargons" +#~ msgid "Jargons" +#~ msgstr "Jargons" #: src/components/settings/views/profile.tsx:139 msgid "Job title" @@ -761,8 +778,8 @@ msgstr "Key decisions" #~ msgstr "Lab" #: src/components/settings/views/general.tsx:227 -msgid "Language" -msgstr "Language" +#~ msgid "Language" +#~ msgstr "Language" #: src/components/settings/views/billing.tsx:200 msgid "Learn more about our pricing plans" @@ -971,6 +988,10 @@ msgstr "Performance difference between languages" msgid "Play video" msgstr "Play video" +#: src/components/settings/views/general.tsx:239 +msgid "Primary language for the interface" +msgstr "Primary language for the interface" + #: src/components/settings/views/billing.tsx:33 msgid "Pro" msgstr "Pro" @@ -1024,7 +1045,7 @@ msgstr "Role" #~ msgid "Save and close" #~ msgstr "Save and close" -#: src/components/settings/views/general.tsx:174 +#: src/components/settings/views/general.tsx:183 msgid "Save recordings" msgstr "Save recordings" @@ -1069,6 +1090,10 @@ msgstr "Select a transcribing model" msgid "Select Calendars" msgstr "Select Calendars" +#: src/components/settings/views/general.tsx:268 +msgid "Select languages you speak for better transcription" +msgstr "Select languages you speak for better transcription" + #: src/components/settings/components/ai/llm-view.tsx:253 #~ msgid "Select or enter the model name required by your endpoint." #~ msgstr "Select or enter the model name required by your endpoint." @@ -1081,7 +1106,7 @@ msgstr "Send invite" msgid "Settings" msgstr "Settings" -#: src/components/settings/views/general.tsx:200 +#: src/components/settings/views/general.tsx:209 msgid "Share usage data" msgstr "Share usage data" @@ -1101,6 +1126,10 @@ msgstr "Single sign-on for all users" #~ msgid "Sound" #~ msgstr "Sound" +#: src/components/settings/views/general.tsx:265 +msgid "Spoken languages" +msgstr "Spoken languages" + #: src/components/settings/views/billing.tsx:76 msgid "Start Annual Plan" msgstr "Start Annual Plan" @@ -1311,8 +1340,8 @@ msgstr "Works offline" #~ msgstr "Yes, activate speaker" #: src/components/settings/views/general.tsx:266 -msgid "You can make Hyprnote takes these words into account when transcribing" -msgstr "You can make Hyprnote takes these words into account when transcribing" +#~ msgid "You can make Hyprnote takes these words into account when transcribing" +#~ msgstr "You can make Hyprnote takes these words into account when transcribing" #: src/components/settings/views/integrations.tsx:207 msgid "Your API key for Obsidian local-rest-api plugin." diff --git a/apps/desktop/src/locales/ko/messages.po b/apps/desktop/src/locales/ko/messages.po index b6cfb26f95..234668df1c 100644 --- a/apps/desktop/src/locales/ko/messages.po +++ b/apps/desktop/src/locales/ko/messages.po @@ -14,8 +14,8 @@ msgstr "" "Plural-Forms: \n" #. js-lingui-explicit-id -#: src/components/settings/views/general.tsx:275 -msgid "Type jargons (e.g., Blitz Meeting, PaC Squad)" +#: src/components/settings/views/general.tsx:358 +msgid "Type terms separated by commas (e.g., Blitz Meeting, PaC Squad)" msgstr "" #. js-lingui-explicit-id @@ -208,6 +208,11 @@ msgstr "" msgid "{weeks} weeks later" msgstr "" +#. js-lingui-explicit-id +#: src/components/settings/views/general.tsx:275 +#~ msgid "Type jargons (e.g., Blitz Meeting, PaC Squad)" +#~ msgstr "" + #. js-lingui-explicit-id #: ../../packages/utils/src/datetime.ts:168 #~ msgid "just now" @@ -315,6 +320,10 @@ msgstr "" msgid "Add participant" msgstr "" +#: src/components/settings/views/general.tsx:349 +msgid "Add specific terms or jargon for improved transcription accuracy" +msgstr "" + #: src/components/settings/views/team.tsx:142 #: src/components/settings/views/team.tsx:229 msgid "Admin" @@ -436,13 +445,13 @@ msgstr "" #~ msgid "Choose the language you want to use for the speech-to-text model and language model" #~ msgstr "" -#: src/components/settings/views/general.tsx:177 +#: src/components/settings/views/general.tsx:186 msgid "Choose whether to save your recordings locally." msgstr "" #: src/components/settings/views/general.tsx:230 -msgid "Choose your preferred language of use" -msgstr "" +#~ msgid "Choose your preferred language of use" +#~ msgstr "" #: src/components/settings/components/wer-modal.tsx:126 msgid "Close" @@ -539,6 +548,10 @@ msgstr "" msgid "Custom Endpoint" msgstr "" +#: src/components/settings/views/general.tsx:346 +msgid "Custom Vocabulary" +msgstr "" + #: src/components/settings/components/ai/llm-view.tsx:149 #~ msgid "Default (llama-3.2-3b-q4)" #~ msgstr "" @@ -565,6 +578,10 @@ msgstr "" #~ msgid "Did you get consent from everyone in the meeting?" #~ msgstr "" +#: src/components/settings/views/general.tsx:236 +msgid "Display language" +msgstr "" + #. placeholder {0}: metadata?.size && `(${metadata.size})` #: src/components/settings/components/ai/stt-view.tsx:237 msgid "Download {0}" @@ -692,7 +709,7 @@ msgstr "" msgid "Grant Access" msgstr "" -#: src/components/settings/views/general.tsx:203 +#: src/components/settings/views/general.tsx:212 msgid "Help us improve Hyprnote by sharing anonymous usage data" msgstr "" @@ -741,8 +758,8 @@ msgid "Invite members" msgstr "" #: src/components/settings/views/general.tsx:263 -msgid "Jargons" -msgstr "" +#~ msgid "Jargons" +#~ msgstr "" #: src/components/settings/views/profile.tsx:139 msgid "Job title" @@ -761,8 +778,8 @@ msgstr "" #~ msgstr "" #: src/components/settings/views/general.tsx:227 -msgid "Language" -msgstr "" +#~ msgid "Language" +#~ msgstr "" #: src/components/settings/views/billing.tsx:200 msgid "Learn more about our pricing plans" @@ -971,6 +988,10 @@ msgstr "" msgid "Play video" msgstr "" +#: src/components/settings/views/general.tsx:239 +msgid "Primary language for the interface" +msgstr "" + #: src/components/settings/views/billing.tsx:33 msgid "Pro" msgstr "" @@ -1024,7 +1045,7 @@ msgstr "" #~ msgid "Save and close" #~ msgstr "" -#: src/components/settings/views/general.tsx:174 +#: src/components/settings/views/general.tsx:183 msgid "Save recordings" msgstr "" @@ -1069,6 +1090,10 @@ msgstr "" msgid "Select Calendars" msgstr "" +#: src/components/settings/views/general.tsx:268 +msgid "Select languages you speak for better transcription" +msgstr "" + #: src/components/settings/components/ai/llm-view.tsx:253 #~ msgid "Select or enter the model name required by your endpoint." #~ msgstr "" @@ -1081,7 +1106,7 @@ msgstr "" msgid "Settings" msgstr "" -#: src/components/settings/views/general.tsx:200 +#: src/components/settings/views/general.tsx:209 msgid "Share usage data" msgstr "" @@ -1101,6 +1126,10 @@ msgstr "" #~ msgid "Sound" #~ msgstr "" +#: src/components/settings/views/general.tsx:265 +msgid "Spoken languages" +msgstr "" + #: src/components/settings/views/billing.tsx:76 msgid "Start Annual Plan" msgstr "" @@ -1311,8 +1340,8 @@ msgstr "" #~ msgstr "" #: src/components/settings/views/general.tsx:266 -msgid "You can make Hyprnote takes these words into account when transcribing" -msgstr "" +#~ msgid "You can make Hyprnote takes these words into account when transcribing" +#~ msgstr "" #: src/components/settings/views/integrations.tsx:207 msgid "Your API key for Obsidian local-rest-api plugin." diff --git a/crates/db-user/src/config_types.rs b/crates/db-user/src/config_types.rs index b8b85977eb..adf6ccbddc 100644 --- a/crates/db-user/src/config_types.rs +++ b/crates/db-user/src/config_types.rs @@ -1,6 +1,3 @@ -use serde::Deserialize; -use std::str::FromStr; - use crate::user_common_derives; user_common_derives! { @@ -39,8 +36,9 @@ user_common_derives! { pub autostart: bool, #[specta(type = String)] #[schemars(with = "String", regex(pattern = "^[a-zA-Z]{2}$"))] - #[serde(serialize_with = "serialize_language", deserialize_with = "deserialize_language")] pub display_language: hypr_language::Language, + #[specta(type = Vec)] + pub spoken_languages: Vec, pub jargons: Vec, pub telemetry_consent: bool, pub save_recordings: Option, @@ -53,6 +51,7 @@ impl Default for ConfigGeneral { Self { autostart: false, display_language: hypr_language::ISO639::En.into(), + spoken_languages: vec![hypr_language::ISO639::En.into()], jargons: vec![], telemetry_consent: true, save_recordings: Some(true), @@ -97,19 +96,3 @@ impl Default for ConfigAI { } } } - -fn serialize_language( - lang: &hypr_language::Language, - serializer: S, -) -> Result { - let code = lang.iso639().code(); - serializer.serialize_str(code) -} - -fn deserialize_language<'de, D: serde::Deserializer<'de>>( - deserializer: D, -) -> Result { - let str = String::deserialize(deserializer)?; - let iso639 = hypr_language::ISO639::from_str(&str).map_err(serde::de::Error::custom)?; - Ok(iso639.into()) -} diff --git a/crates/language/Cargo.toml b/crates/language/Cargo.toml index 87505e9d7c..890d7835ce 100644 --- a/crates/language/Cargo.toml +++ b/crates/language/Cargo.toml @@ -14,5 +14,7 @@ codes-iso-639 = { workspace = true } deepgram = { workspace = true, optional = true, features = ["listen"] } hypr-whisper = { workspace = true, optional = true } +schemars = { workspace = true } serde = { workspace = true } +specta = { workspace = true, features = ["derive"] } thiserror = { workspace = true } diff --git a/crates/language/src/lib.rs b/crates/language/src/lib.rs index b1ef8cc180..c4d4437f42 100644 --- a/crates/language/src/lib.rs +++ b/crates/language/src/lib.rs @@ -1,10 +1,14 @@ mod error; pub use error::*; +use std::str::FromStr; + pub use codes_iso_639::part_1::LanguageCode as ISO639; -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, specta::Type, schemars::JsonSchema)] pub struct Language { + #[specta(type = String)] + #[schemars(with = "String", regex(pattern = "^[a-zA-Z]{2}$"))] iso639: ISO639, } @@ -338,3 +342,23 @@ impl Language { } } } + +impl serde::Serialize for Language { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.iso639().code()) + } +} + +impl<'de> serde::Deserialize<'de> for Language { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let code = String::deserialize(deserializer)?; + let iso639 = ISO639::from_str(&code).map_err(serde::de::Error::custom)?; + Ok(iso639.into()) + } +} diff --git a/crates/whisper-local/src/model.rs b/crates/whisper-local/src/model.rs index 15528e7ab3..06779163af 100644 --- a/crates/whisper-local/src/model.rs +++ b/crates/whisper-local/src/model.rs @@ -17,7 +17,7 @@ lazy_static! { #[derive(Default)] pub struct WhisperBuilder { model_path: Option, - language: Option, + languages: Option>, static_prompt: Option, dynamic_prompt: Option, } @@ -28,8 +28,8 @@ impl WhisperBuilder { self } - pub fn language(mut self, language: Language) -> Self { - self.language = Some(language); + pub fn languages(mut self, languages: Vec) -> Self { + self.languages = Some(languages); self } @@ -63,7 +63,7 @@ impl WhisperBuilder { let token_beg = ctx.token_beg(); Whisper { - language: self.language, + languages: self.languages.unwrap_or_default(), static_prompt: self.static_prompt.unwrap_or_default(), dynamic_prompt: self.dynamic_prompt.unwrap_or_default(), state, @@ -84,7 +84,7 @@ impl WhisperBuilder { } pub struct Whisper { - language: Option, + languages: Vec, static_prompt: String, dynamic_prompt: String, state: WhisperState, @@ -98,6 +98,9 @@ impl Whisper { } pub fn transcribe(&mut self, audio: &[f32]) -> Result, super::Error> { + let token_beg = self.token_beg; + let language = self.get_language(audio)?; + let params = { let mut p = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); @@ -107,12 +110,14 @@ impl Whisper { tracing::info!(initial_prompt = ?initial_prompt, "transcribe"); - p.set_translate(false); - p.set_language(self.language.as_ref().map(|l| l.as_ref())); + p.set_translate(true); + p.set_detect_language(false); + p.set_language(language.as_deref()); + p.set_initial_prompt(&initial_prompt); unsafe { - Self::suppress_beg(&mut p, &self.token_beg); + Self::suppress_beg(&mut p, &token_beg); } p.set_no_timestamps(true); @@ -122,7 +127,6 @@ impl Whisper { p.set_temperature(0.0); p.set_temperature_inc(0.2); - p.set_detect_language(false); p.set_single_segment(true); p.set_suppress_blank(true); p.set_suppress_nst(true); @@ -173,6 +177,40 @@ impl Whisper { Ok(segments) } + fn get_language(&mut self, audio: &[f32]) -> Result, super::Error> { + if self.languages.len() == 0 { + return Ok(None); + } + + if self.languages.len() == 1 { + let lang = &self.languages[0]; + return Ok(Some(lang.to_string())); + } + + let lang_str = { + self.state.pcm_to_mel(audio, 1)?; + let (_lang_id, lang_probs) = self.state.lang_detect(0, 1)?; + + let mut best_lang = None; + let mut best_prob = f32::NEG_INFINITY; + + for lang in &self.languages { + let lang_id = lang.whisper_index(); + if lang_id < lang_probs.len() { + let prob = lang_probs[lang_id]; + if prob > best_prob { + best_prob = prob; + best_lang = Some(lang.as_ref().to_string()); + } + } + } + + best_lang + }; + + Ok(lang_str) + } + fn filter_segments(segments: Vec) -> Vec { segments .into_iter() diff --git a/crates/whisper/src/lib.rs b/crates/whisper/src/lib.rs index 7e144329ac..32d24992fa 100644 --- a/crates/whisper/src/lib.rs +++ b/crates/whisper/src/lib.rs @@ -1,5 +1,6 @@ // https://github.com/openai/whisper/blob/ba3f3cd/whisper/tokenizer.py#L10-L128 -#[derive(strum::EnumString, strum::Display, strum::AsRefStr)] +#[repr(u8)] +#[derive(Debug, Copy, Clone, strum::EnumString, strum::Display, strum::AsRefStr)] pub enum Language { #[strum(serialize = "en")] En, @@ -202,3 +203,9 @@ pub enum Language { #[strum(serialize = "yue")] Yue, } + +impl Language { + pub fn whisper_index(self) -> usize { + self as usize + } +} diff --git a/plugins/db/js/bindings.gen.ts b/plugins/db/js/bindings.gen.ts index beb9c0f78d..d5a28dac2a 100644 --- a/plugins/db/js/bindings.gen.ts +++ b/plugins/db/js/bindings.gen.ts @@ -151,7 +151,7 @@ export type ChatMessage = { id: string; group_id: string; created_at: string; ro export type ChatMessageRole = "User" | "Assistant" export type Config = { id: string; user_id: string; general: ConfigGeneral; notification: ConfigNotification; ai: ConfigAI } export type ConfigAI = { api_base: string | null; api_key: string | null; ai_specificity: number | null } -export type ConfigGeneral = { autostart: boolean; display_language: string; jargons: string[]; telemetry_consent: boolean; save_recordings: boolean | null; selected_template_id: string | null } +export type ConfigGeneral = { autostart: boolean; display_language: string; spoken_languages: string[]; jargons: string[]; telemetry_consent: boolean; save_recordings: boolean | null; selected_template_id: string | null } export type ConfigNotification = { before: boolean; auto: boolean; ignoredPlatforms: string[] | null } export type Event = { id: string; user_id: string; tracking_id: string; calendar_id: string | null; name: string; note: string; start_date: string; end_date: string; google_event_url: string | null } export type GetSessionFilter = { id: string } | { calendarEventId: string } | { tagId: string } diff --git a/plugins/listener-interface/src/lib.rs b/plugins/listener-interface/src/lib.rs index f9bc623ff0..94d0c4cecf 100644 --- a/plugins/listener-interface/src/lib.rs +++ b/plugins/listener-interface/src/lib.rs @@ -86,10 +86,7 @@ common_derives! { #[derive(Default)] pub struct ListenParams { pub audio_mode: AudioMode, - #[specta(type = String)] - #[schemars(with = "String")] - #[serde(serialize_with = "serialize_language", deserialize_with = "deserialize_language")] - pub language: hypr_language::Language, + pub languages: Vec, pub static_prompt: String, pub dynamic_prompt: String, } @@ -121,22 +118,3 @@ pub struct DiarizationChunk { pub speaker: i32, pub confidence: Option, } - -use serde::Deserialize; -use std::str::FromStr; - -fn serialize_language( - lang: &hypr_language::Language, - serializer: S, -) -> Result { - let code = lang.iso639().code(); - serializer.serialize_str(code) -} - -fn deserialize_language<'de, D: serde::Deserializer<'de>>( - deserializer: D, -) -> Result { - let str = String::deserialize(deserializer)?; - let iso639 = hypr_language::ISO639::from_str(&str).map_err(serde::de::Error::custom)?; - Ok(iso639.into()) -} diff --git a/plugins/listener/src/client.rs b/plugins/listener/src/client.rs index 7909bae451..2ea373f2e4 100644 --- a/plugins/listener/src/client.rs +++ b/plugins/listener/src/client.rs @@ -37,14 +37,19 @@ impl ListenClientBuilder { ..self.params.clone().unwrap_or_default() }; - let language = params.language.code(); - url.set_path("/api/desktop/listen/realtime"); - url.query_pairs_mut() - .append_pair("language", language) - .append_pair("static_prompt", ¶ms.static_prompt) - .append_pair("dynamic_prompt", ¶ms.dynamic_prompt) - .append_pair("audio_mode", params.audio_mode.as_ref()); + + { + let mut query_pairs = url.query_pairs_mut(); + + for lang in ¶ms.languages { + query_pairs.append_pair("languages", lang.iso639().code()); + } + query_pairs + .append_pair("static_prompt", ¶ms.static_prompt) + .append_pair("dynamic_prompt", ¶ms.dynamic_prompt) + .append_pair("audio_mode", params.audio_mode.as_ref()); + } let host = url.host_str().unwrap(); @@ -189,7 +194,7 @@ mod tests { .api_base("http://127.0.0.1:1234") .api_key("".to_string()) .params(hypr_listener_interface::ListenParams { - language: hypr_language::ISO639::En.into(), + languages: vec![hypr_language::ISO639::En.into()], ..Default::default() }) .build_single(); diff --git a/plugins/listener/src/fsm.rs b/plugins/listener/src/fsm.rs index efd14a93c4..8bc4311084 100644 --- a/plugins/listener/src/fsm.rs +++ b/plugins/listener/src/fsm.rs @@ -205,21 +205,21 @@ impl Session { let user_id = self.app.db_user_id().await?.unwrap(); self.session_id = Some(session_id.clone()); - let (record, language, jargons) = { + let (record, languages, jargons) = { let config = self.app.db_get_config(&user_id).await?; let record = config .as_ref() .is_none_or(|c| c.general.save_recordings.unwrap_or(true)); - let language = config.as_ref().map_or_else( - || hypr_language::ISO639::En.into(), - |c| c.general.display_language.clone(), + let languages = config.as_ref().map_or_else( + || vec![hypr_language::ISO639::En.into()], + |c| c.general.spoken_languages.clone(), ); let jargons = config.map_or_else(Vec::new, |c| c.general.jargons); - (record, language, jargons) + (record, languages, jargons) }; let session = self @@ -241,7 +241,7 @@ impl Session { self.speaker_muted_rx = Some(speaker_muted_rx_main.clone()); self.session_state_tx = Some(session_state_tx); - let listen_client = setup_listen_client(&self.app, language, jargons).await?; + let listen_client = setup_listen_client(&self.app, languages, jargons).await?; let mic_sample_stream = { let mut input = match &self.mic_device_name { @@ -541,7 +541,7 @@ impl Session { async fn setup_listen_client( app: &tauri::AppHandle, - language: hypr_language::Language, + languages: Vec, _jargons: Vec, ) -> Result { let api_base = { @@ -557,7 +557,7 @@ async fn setup_listen_client( .unwrap_or_default() }; - tracing::info!(api_base = ?api_base, api_key = ?api_key, language = ?language, "listen_client"); + tracing::info!(api_base = ?api_base, api_key = ?api_key, languages = ?languages, "listen_client"); // let static_prompt = format!( // "{} / {}:", @@ -572,7 +572,7 @@ async fn setup_listen_client( .api_base(api_base) .api_key(api_key) .params(hypr_listener_interface::ListenParams { - language, + languages, static_prompt, ..Default::default() }) diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index a2bd3ce0b3..b836656ccd 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -28,7 +28,7 @@ tauri-plugin = { workspace = true, features = ["build"] } [dev-dependencies] hypr-data = { workspace = true } -hypr-language = { workspace = true } +hypr-language = { workspace = true, features = ["whisper"] } tauri-plugin-listener = { workspace = true } tauri-plugin-store = { workspace = true } @@ -68,6 +68,7 @@ thiserror = { workspace = true } rodio = { workspace = true, features = ["symphonia", "symphonia-all"] } axum = { workspace = true, features = ["ws", "multipart"] } +axum-extra = { workspace = true, features = ["query"] } tower-http = { workspace = true, features = ["cors", "trace"] } futures-util = { workspace = true } diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index ce1200d9c3..688e2e53e3 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -220,6 +220,7 @@ impl> LocalSttPluginExt for T { let mut model = hypr_whisper_local::Whisper::builder() .model_path(model_path.as_ref().to_str().unwrap()) + .languages(vec![]) .static_prompt("") .dynamic_prompt("") .build(); diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index d109c66850..493969c21a 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -133,7 +133,7 @@ mod test { .api_base(api_base) .api_key("NONE") .params(hypr_listener_interface::ListenParams { - language: hypr_language::ISO639::En.into(), + languages: vec![hypr_language::ISO639::En.into()], ..Default::default() }) .build_single(); diff --git a/plugins/local-stt/src/server.rs b/plugins/local-stt/src/server.rs index 2e6126d82d..a2a715055b 100644 --- a/plugins/local-stt/src/server.rs +++ b/plugins/local-stt/src/server.rs @@ -6,13 +6,14 @@ use std::{ use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, - Query, State as AxumState, + State as AxumState, }, http::StatusCode, response::IntoResponse, routing::get, Router, }; +use axum_extra::extract::Query; use futures_util::{SinkExt, StreamExt}; use tower_http::cors::{self, CorsLayer}; @@ -124,8 +125,15 @@ async fn websocket_with_model( let model_cache_dir = state.model_cache_dir.clone(); let model_path = model_cache_dir.join(model_type.file_name()); + let languages: Vec = params + .languages + .into_iter() + .filter_map(|lang| lang.try_into().ok()) + .collect(); + let model = hypr_whisper_local::Whisper::builder() .model_path(model_path.to_str().unwrap()) + .languages(languages) .static_prompt(¶ms.static_prompt) .dynamic_prompt(¶ms.dynamic_prompt) .build();