diff --git a/apps/desktop/src/components/settings/ai/shared/index.tsx b/apps/desktop/src/components/settings/ai/shared/index.tsx index 30b5f61aa2..b9b9c156bc 100644 --- a/apps/desktop/src/components/settings/ai/shared/index.tsx +++ b/apps/desktop/src/components/settings/ai/shared/index.tsx @@ -118,6 +118,7 @@ export function NonHyprProviderCard({ type: providerType, base_url: config.baseUrl ?? "", api_key: "", + custom_headers: "", } satisfies AIProvider), listeners: { onChange: ({ formApi }) => { @@ -186,18 +187,26 @@ export function NonHyprProviderCard({ )} )} - {!showBaseUrl && config.baseUrl && ( -
- - Advanced - -
+
+ + Advanced + +
+ {!showBaseUrl && config.baseUrl && ( {(field) => } -
-
- )} + )} + + {(field) => ( + field.handleChange(v)} + /> + )} + +
+
@@ -259,6 +268,96 @@ function useProvider(id: string) { return [data, setProvider] as const; } +function parseHeaders(value: string): Array<{ key: string; value: string }> { + if (!value.trim()) return []; + try { + const parsed = JSON.parse(value) as Record; + return Object.entries(parsed).map(([k, v]) => ({ key: k, value: v })); + } catch { + return []; + } +} + +function serializeHeaders( + headers: Array<{ key: string; value: string }>, +): string { + const filtered = headers.filter((h) => h.key.trim()); + if (filtered.length === 0) return ""; + const obj: Record = {}; + for (const h of filtered) { + obj[h.key] = h.value; + } + return JSON.stringify(obj); +} + +function CustomHeadersField({ + value, + onChange, +}: { + value: string; + onChange: (v: string) => void; +}) { + const headers = parseHeaders(value); + + const update = (newHeaders: Array<{ key: string; value: string }>) => { + onChange(serializeHeaders(newHeaders)); + }; + + const addHeader = () => { + update([...headers, { key: "", value: "" }]); + }; + + const removeHeader = (index: number) => { + update(headers.filter((_, i) => i !== index)); + }; + + const updateHeader = (index: number, field: "key" | "value", val: string) => { + const updated = headers.map((h, i) => + i === index ? { ...h, [field]: val } : h, + ); + update(updated); + }; + + return ( +
+ + {headers.map((header, index) => ( +
+ + updateHeader(index, "key", e.target.value)} + /> + + + updateHeader(index, "value", e.target.value)} + /> + + +
+ ))} + +
+ ); +} + function FormField({ field, label, diff --git a/apps/desktop/src/hooks/useLLMConnection.ts b/apps/desktop/src/hooks/useLLMConnection.ts index 58c2d527a2..e082b5380b 100644 --- a/apps/desktop/src/hooks/useLLMConnection.ts +++ b/apps/desktop/src/hooks/useLLMConnection.ts @@ -30,6 +30,7 @@ type LLMConnectionInfo = { modelId: string; baseUrl: string; apiKey: string; + customHeaders: Record; }; export type LLMConnectionStatus = @@ -142,6 +143,7 @@ const resolveLLMConnection = (params: { providerDefinition.baseUrl?.trim() || ""; const apiKey = providerConfig?.api_key?.trim() || ""; + const customHeaders = parseCustomHeaders(providerConfig?.custom_headers); const context: ProviderEligibilityContext = { isAuthenticated: !!session, @@ -188,13 +190,14 @@ const resolveLLMConnection = (params: { modelId, baseUrl: baseUrl ?? new URL("/llm", env.VITE_AI_URL).toString(), apiKey: session.access_token, + customHeaders, }, status: { status: "success", providerId, isHosted: true }, }; } return { - conn: { providerId, modelId, baseUrl, apiKey }, + conn: { providerId, modelId, baseUrl, apiKey, customHeaders }, status: { status: "success", providerId, isHosted: false }, }; }; @@ -226,13 +229,26 @@ const wrapWithThinkingMiddleware = ( }); }; +function parseCustomHeaders(raw: string | undefined): Record { + if (!raw?.trim()) return {}; + try { + return JSON.parse(raw) as Record; + } catch { + return {}; + } +} + const createLanguageModel = (conn: LLMConnectionInfo): LanguageModelV3 => { + const h = + Object.keys(conn.customHeaders).length > 0 ? conn.customHeaders : undefined; + switch (conn.providerId) { case "hyprnote": { const provider = createOpenRouter({ fetch: tracedFetch, baseURL: conn.baseUrl, apiKey: conn.apiKey, + headers: h, }); return wrapWithThinkingMiddleware(provider.chat(conn.modelId)); } @@ -244,6 +260,7 @@ const createLanguageModel = (conn: LLMConnectionInfo): LanguageModelV3 => { headers: { "anthropic-version": "2023-06-01", "anthropic-dangerous-direct-browser-access": "true", + ...conn.customHeaders, }, }); return wrapWithThinkingMiddleware(provider(conn.modelId)); @@ -254,6 +271,7 @@ const createLanguageModel = (conn: LLMConnectionInfo): LanguageModelV3 => { fetch: tauriFetch, baseURL: conn.baseUrl, apiKey: conn.apiKey, + headers: h, }); return wrapWithThinkingMiddleware(provider(conn.modelId)); } @@ -262,6 +280,7 @@ const createLanguageModel = (conn: LLMConnectionInfo): LanguageModelV3 => { const provider = createOpenRouter({ fetch: tauriFetch, apiKey: conn.apiKey, + headers: h, }); return wrapWithThinkingMiddleware(provider.chat(conn.modelId)); } @@ -271,6 +290,7 @@ const createLanguageModel = (conn: LLMConnectionInfo): LanguageModelV3 => { fetch: tauriFetch, baseURL: conn.baseUrl, apiKey: conn.apiKey, + headers: h, }); return wrapWithThinkingMiddleware(provider(conn.modelId)); } @@ -280,6 +300,9 @@ const createLanguageModel = (conn: LLMConnectionInfo): LanguageModelV3 => { const ollamaFetch: typeof fetch = async (input, init) => { const headers = new Headers(init?.headers); headers.set("Origin", ollamaOrigin); + for (const [k, v] of Object.entries(conn.customHeaders)) { + headers.set(k, v); + } return tauriFetch(input as RequestInfo | URL, { ...init, headers, @@ -298,6 +321,7 @@ const createLanguageModel = (conn: LLMConnectionInfo): LanguageModelV3 => { fetch: tauriFetch, name: conn.providerId, baseURL: conn.baseUrl, + headers: h, }; if (conn.apiKey) { config.apiKey = conn.apiKey; diff --git a/apps/desktop/src/hooks/useRunBatch.ts b/apps/desktop/src/hooks/useRunBatch.ts index 5e6dd0277d..125607cd99 100644 --- a/apps/desktop/src/hooks/useRunBatch.ts +++ b/apps/desktop/src/hooks/useRunBatch.ts @@ -164,6 +164,16 @@ export const useRunBatch = (sessionId: string) => { ]); }); + const customHeadersObj = (() => { + try { + return conn.customHeaders + ? JSON.parse(conn.customHeaders) + : undefined; + } catch { + return undefined; + } + })(); + const params: BatchParams = { session_id: sessionId, provider, @@ -173,6 +183,9 @@ export const useRunBatch = (sessionId: string) => { api_key: options?.apiKey ?? conn.apiKey, keywords: options?.keywords ?? keywords ?? [], languages: options?.languages ?? languages ?? [], + ...(customHeadersObj + ? { custom_headers: customHeadersObj as Record } + : {}), }; await runBatch(params, { handlePersist: persist, sessionId }); diff --git a/apps/desktop/src/hooks/useSTTConnection.ts b/apps/desktop/src/hooks/useSTTConnection.ts index d493df8f4e..fb1a5f0d15 100644 --- a/apps/desktop/src/hooks/useSTTConnection.ts +++ b/apps/desktop/src/hooks/useSTTConnection.ts @@ -73,6 +73,7 @@ export const useSTTConnection = () => { model: current_stt_model, baseUrl: server.url, apiKey: "", + customHeaders: undefined as string | undefined, }, }; } @@ -86,6 +87,7 @@ export const useSTTConnection = () => { const baseUrl = providerConfig?.base_url?.trim(); const apiKey = providerConfig?.api_key?.trim(); + const customHeadersRaw = providerConfig?.custom_headers?.trim(); const connection = useMemo(() => { if (!current_stt_provider || !current_stt_model) { @@ -106,6 +108,7 @@ export const useSTTConnection = () => { model: current_stt_model, baseUrl: baseUrl ?? new URL("/stt", env.VITE_AI_URL).toString(), apiKey: auth.session.access_token, + customHeaders: customHeadersRaw, }; } @@ -118,6 +121,7 @@ export const useSTTConnection = () => { model: current_stt_model, baseUrl, apiKey, + customHeaders: customHeadersRaw, }; }, [ current_stt_provider, diff --git a/apps/desktop/src/hooks/useStartListening.ts b/apps/desktop/src/hooks/useStartListening.ts index c3a347eaa3..d1dd531c71 100644 --- a/apps/desktop/src/hooks/useStartListening.ts +++ b/apps/desktop/src/hooks/useStartListening.ts @@ -119,6 +119,14 @@ export function useStartListening(sessionId: string) { }); }; + const customHeadersObj = (() => { + try { + return conn.customHeaders ? JSON.parse(conn.customHeaders) : undefined; + } catch { + return undefined; + } + })(); + start( { session_id: sessionId, @@ -129,6 +137,9 @@ export function useStartListening(sessionId: string) { base_url: conn.baseUrl, api_key: conn.apiKey, keywords, + ...(customHeadersObj + ? { custom_headers: customHeadersObj as Record } + : {}), }, { handlePersist, diff --git a/apps/desktop/src/store/tinybase/persister/settings/transform.ts b/apps/desktop/src/store/tinybase/persister/settings/transform.ts index 9ee109ecdc..8ffcc0e1c4 100644 --- a/apps/desktop/src/store/tinybase/persister/settings/transform.ts +++ b/apps/desktop/src/store/tinybase/persister/settings/transform.ts @@ -3,8 +3,17 @@ import type { Content } from "tinybase/with-schemas"; import type { Schemas, Store } from "../../store/settings"; import { SETTINGS_MAPPING } from "../../store/settings"; -type ProviderData = { base_url: string; api_key: string }; -type ProviderRow = { type: "llm" | "stt"; base_url: string; api_key: string }; +type ProviderData = { + base_url: string; + api_key: string; + custom_headers?: string; +}; +type ProviderRow = { + type: "llm" | "stt"; + base_url: string; + api_key: string; + custom_headers: string; +}; const JSON_ARRAY_FIELDS = new Set([ "spoken_languages", @@ -106,7 +115,10 @@ function settingsToProviderRows( type: providerType, base_url: data.base_url ?? "", api_key: data.api_key ?? "", - }; + ...(data.custom_headers + ? { custom_headers: data.custom_headers } + : {}), + } as ProviderRow; } } } @@ -143,9 +155,13 @@ function providerRowsToSettings(rows: Record): { }; for (const [rowId, row] of Object.entries(rows)) { - const { type, base_url, api_key } = row; + const { type, base_url, api_key, custom_headers } = row; if (type === "llm" || type === "stt") { - result[type][rowId] = { base_url, api_key }; + result[type][rowId] = { + base_url, + api_key, + ...(custom_headers ? { custom_headers } : {}), + }; } } diff --git a/apps/desktop/src/store/tinybase/store/settings.ts b/apps/desktop/src/store/tinybase/store/settings.ts index 14e8f3f834..05bc3fbd70 100644 --- a/apps/desktop/src/store/tinybase/store/settings.ts +++ b/apps/desktop/src/store/tinybase/store/settings.ts @@ -80,6 +80,7 @@ export const SETTINGS_MAPPING = { type: { type: "string" }, base_url: { type: "string" }, api_key: { type: "string" }, + custom_headers: { type: "string" }, }, }, }, @@ -167,6 +168,7 @@ export const StoreComponent = () => { select("type"); select("base_url"); select("api_key"); + select("custom_headers"); where((getCell) => getCell("type") === "llm"); }, ) @@ -177,6 +179,7 @@ export const StoreComponent = () => { select("type"); select("base_url"); select("api_key"); + select("custom_headers"); where((getCell) => getCell("type") === "stt"); }, ), diff --git a/packages/store/src/zod.ts b/packages/store/src/zod.ts index 75097c0c4c..4895ee319e 100644 --- a/packages/store/src/zod.ts +++ b/packages/store/src/zod.ts @@ -225,6 +225,7 @@ export const aiProviderSchema = z type: z.enum(["stt", "llm"]), base_url: z.url().min(1), api_key: z.string(), + custom_headers: z.string().optional(), }) .refine( (data) => !data.base_url.startsWith("https:") || data.api_key.length > 0, diff --git a/plugins/listener/js/bindings.gen.ts b/plugins/listener/js/bindings.gen.ts index a37b6d3364..3448fcbee6 100644 --- a/plugins/listener/js/bindings.gen.ts +++ b/plugins/listener/js/bindings.gen.ts @@ -112,7 +112,7 @@ sessionProgressEvent: "plugin:listener:session-progress-event" export type SessionDataEvent = { type: "audio_amplitude"; session_id: string; mic: number; speaker: number } | { type: "mic_muted"; session_id: string; value: boolean } | { type: "stream_response"; session_id: string; response: StreamResponse } export type SessionErrorEvent = { type: "audio_error"; session_id: string; error: string; device: string | null; is_fatal: boolean } | { type: "connection_error"; session_id: string; error: string } export type SessionLifecycleEvent = { type: "inactive"; session_id: string; error: string | null } | { type: "active"; session_id: string } | { type: "finalizing"; session_id: string } -export type SessionParams = { session_id: string; languages: string[]; onboarding: boolean; record_enabled: boolean; model: string; base_url: string; api_key: string; keywords: string[] } +export type SessionParams = { session_id: string; languages: string[]; onboarding: boolean; record_enabled: boolean; model: string; base_url: string; api_key: string; keywords: string[]; custom_headers?: Record } export type SessionProgressEvent = { type: "audio_initializing"; session_id: string } | { type: "audio_ready"; session_id: string; device: string | null } | { type: "connecting"; session_id: string } | { type: "connected"; session_id: string; adapter: string } export type State = "active" | "inactive" | "finalizing" export type StreamAlternatives = { transcript: string; words: StreamWord[]; confidence: number; languages?: string[] } diff --git a/plugins/listener/src/actors/listener/adapters.rs b/plugins/listener/src/actors/listener/adapters.rs index e548f0f27a..c3d0b517f1 100644 --- a/plugins/listener/src/actors/listener/adapters.rs +++ b/plugins/listener/src/actors/listener/adapters.rs @@ -135,14 +135,18 @@ async fn spawn_rx_task_single_with_adapter( let (tx, rx) = tokio::sync::mpsc::channel::>(32); - let client = owhisper_client::ListenClient::builder() + let mut builder = owhisper_client::ListenClient::builder() .adapter::() .api_base(args.base_url.clone()) .api_key(args.api_key.clone()) .params(build_listen_params(&args)) - .extra_header(DEVICE_FINGERPRINT_HEADER, hypr_host::fingerprint()) - .build_single() - .await; + .extra_header(DEVICE_FINGERPRINT_HEADER, hypr_host::fingerprint()); + + for (name, value) in &args.custom_headers { + builder = builder.extra_header(name, value); + } + + let client = builder.build_single().await; let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); @@ -207,14 +211,18 @@ async fn spawn_rx_task_dual_with_adapter( let (tx, rx) = tokio::sync::mpsc::channel::>(32); - let client = owhisper_client::ListenClient::builder() + let mut builder = owhisper_client::ListenClient::builder() .adapter::() .api_base(args.base_url.clone()) .api_key(args.api_key.clone()) .params(build_listen_params(&args)) - .extra_header(DEVICE_FINGERPRINT_HEADER, hypr_host::fingerprint()) - .build_dual() - .await; + .extra_header(DEVICE_FINGERPRINT_HEADER, hypr_host::fingerprint()); + + for (name, value) in &args.custom_headers { + builder = builder.extra_header(name, value); + } + + let client = builder.build_dual().await; let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); diff --git a/plugins/listener/src/actors/listener/mod.rs b/plugins/listener/src/actors/listener/mod.rs index e47a390e78..80cf46a23a 100644 --- a/plugins/listener/src/actors/listener/mod.rs +++ b/plugins/listener/src/actors/listener/mod.rs @@ -39,6 +39,7 @@ pub struct ListenerArgs { pub base_url: String, pub api_key: String, pub keywords: Vec, + pub custom_headers: std::collections::HashMap, pub mode: crate::actors::ChannelMode, pub session_started_at: Instant, pub session_started_at_unix: SystemTime, diff --git a/plugins/listener/src/actors/session.rs b/plugins/listener/src/actors/session.rs index dac82518d6..c40403daa1 100644 --- a/plugins/listener/src/actors/session.rs +++ b/plugins/listener/src/actors/session.rs @@ -23,6 +23,8 @@ pub struct SessionParams { pub base_url: String, pub api_key: String, pub keywords: Vec, + #[serde(default)] + pub custom_headers: std::collections::HashMap, } #[derive(Clone)] @@ -110,6 +112,7 @@ pub async fn spawn_session_supervisor( base_url: ctx.params.base_url.clone(), api_key: ctx.params.api_key.clone(), keywords: ctx.params.keywords.clone(), + custom_headers: ctx.params.custom_headers.clone(), mode, session_started_at: ctx.started_at_instant, session_started_at_unix: ctx.started_at_system, diff --git a/plugins/listener2/js/bindings.gen.ts b/plugins/listener2/js/bindings.gen.ts index a7d3d873b1..e922fe776c 100644 --- a/plugins/listener2/js/bindings.gen.ts +++ b/plugins/listener2/js/bindings.gen.ts @@ -74,7 +74,7 @@ batchEvent: "plugin:listener2:batch-event" export type BatchAlternatives = { transcript: string; confidence: number; words?: BatchWord[] } export type BatchChannel = { alternatives: BatchAlternatives[] } export type BatchEvent = { type: "batchStarted"; session_id: string } | { type: "batchResponse"; session_id: string; response: BatchResponse } | { type: "batchProgress"; session_id: string; response: StreamResponse; percentage: number } | { type: "batchFailed"; session_id: string; error: string } -export type BatchParams = { session_id: string; provider: BatchProvider; file_path: string; model?: string | null; base_url: string; api_key: string; languages?: string[]; keywords?: string[] } +export type BatchParams = { session_id: string; provider: BatchProvider; file_path: string; model?: string | null; base_url: string; api_key: string; languages?: string[]; keywords?: string[]; custom_headers?: Record } export type BatchProvider = "deepgram" | "soniox" | "assemblyai" | "am" export type BatchResponse = { metadata: JsonValue; results: BatchResults } export type BatchResults = { channels: BatchChannel[] } diff --git a/plugins/listener2/src/batch.rs b/plugins/listener2/src/batch.rs index b02837a481..1936a133eb 100644 --- a/plugins/listener2/src/batch.rs +++ b/plugins/listener2/src/batch.rs @@ -78,6 +78,7 @@ pub struct BatchArgs { pub listen_params: owhisper_interface::ListenParams, pub start_notifier: BatchStartNotifier, pub session_id: String, + pub custom_headers: std::collections::HashMap, } pub struct BatchState { @@ -437,14 +438,18 @@ async fn spawn_batch_task_with_adapter( sample_rate: metadata.sample_rate, ..args.listen_params.clone() }; - let client = owhisper_client::ListenClient::builder() + let mut builder = owhisper_client::ListenClient::builder() .adapter::() .api_base(args.base_url) .api_key(args.api_key) .params(listen_params) - .extra_header(DEVICE_FINGERPRINT_HEADER, hypr_host::fingerprint()) - .build_with_channels(channel_count) - .await; + .extra_header(DEVICE_FINGERPRINT_HEADER, hypr_host::fingerprint()); + + for (name, value) in &args.custom_headers { + builder = builder.extra_header(name, value); + } + + let client = builder.build_with_channels(channel_count).await; let chunk_count = chunked_audio.chunks.len(); let chunk_interval = stream_config.chunk_interval(); diff --git a/plugins/listener2/src/ext.rs b/plugins/listener2/src/ext.rs index 7d5cd05bc0..8537dcc68c 100644 --- a/plugins/listener2/src/ext.rs +++ b/plugins/listener2/src/ext.rs @@ -34,6 +34,8 @@ pub struct BatchParams { pub languages: Vec, #[serde(default)] pub keywords: Vec, + #[serde(default)] + pub custom_headers: std::collections::HashMap, } pub struct Listener2<'a, R: tauri::Runtime, M: tauri::Manager> { @@ -225,6 +227,7 @@ async fn run_batch_am( listen_params: listen_params.clone(), start_notifier: start_notifier.clone(), session_id: params.session_id.clone(), + custom_headers: params.custom_headers.clone(), }; match spawn_batch_actor(args).await {