Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
605 changes: 605 additions & 0 deletions apps/desktop/src/components/welcome-modal/custom-endpoint-view.tsx

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ interface ModelDownloadProgress {

interface DownloadProgressViewProps {
selectedSttModel: SupportedModel;
llmSelection: "hyprllm" | "byom" | null;
onContinue: () => void;
}

Expand Down Expand Up @@ -86,6 +87,7 @@ const ModelProgressCard = ({

export const DownloadProgressView = ({
selectedSttModel,
llmSelection,
onContinue,
}: DownloadProgressViewProps) => {
const [sttDownload, setSttDownload] = useState<ModelDownloadProgress>({
Expand All @@ -107,7 +109,11 @@ export const DownloadProgressView = ({
useEffect(() => {
localSttCommands.downloadModel(selectedSttModel, sttDownload.channel);

localLlmCommands.downloadModel("HyprLLM", llmDownload.channel);
if (llmSelection === "hyprllm") {
localLlmCommands.downloadModel("HyprLLM", llmDownload.channel);
} else {
setLlmDownload(prev => ({ ...prev, completed: true }));
}

sttDownload.channel.onmessage = (progress) => {
if (progress < 0) {
Expand All @@ -122,19 +128,21 @@ export const DownloadProgressView = ({
}));
};

llmDownload.channel.onmessage = (progress) => {
if (progress < 0) {
setLlmDownload(prev => ({ ...prev, error: true }));
return;
}
if (llmSelection === "hyprllm") {
llmDownload.channel.onmessage = (progress) => {
if (progress < 0) {
setLlmDownload(prev => ({ ...prev, error: true }));
return;
}

setLlmDownload(prev => ({
...prev,
progress: Math.max(prev.progress, progress),
completed: progress >= 100,
}));
};
}, [selectedSttModel, sttDownload.channel, llmDownload.channel]);
setLlmDownload(prev => ({
...prev,
progress: Math.max(prev.progress, progress),
completed: progress >= 100,
}));
};
}
}, [selectedSttModel, sttDownload.channel, llmDownload.channel, llmSelection]);

const bothCompleted = sttDownload.completed && llmDownload.completed;
const hasErrors = sttDownload.error || llmDownload.error;
Expand Down Expand Up @@ -174,7 +182,7 @@ export const DownloadProgressView = ({
};

const handleLlmCompletion = async () => {
if (llmDownload.completed) {
if (llmDownload.completed && llmSelection === "hyprllm") {
try {
await localLlmCommands.setCurrentModel("HyprLLM");
await localLlmCommands.startServer();
Expand All @@ -186,7 +194,7 @@ export const DownloadProgressView = ({

handleSttCompletion();
handleLlmCompletion();
}, [sttDownload.completed, llmDownload.completed, selectedSttModel]);
}, [sttDownload.completed, llmDownload.completed, selectedSttModel, llmSelection]);

const sttMetadata = sttModelMetadata[selectedSttModel];

Expand Down Expand Up @@ -233,12 +241,14 @@ export const DownloadProgressView = ({
size={sttMetadata?.size || "250MB"}
/>

<ModelProgressCard
title="Language Model"
icon={BrainIcon}
download={llmDownload}
size="1.1GB"
/>
{llmSelection === "hyprllm" && (
<ModelProgressCard
title="Language Model"
icon={BrainIcon}
download={llmDownload}
size="1.1GB"
/>
)}
</div>

<PushableButton
Expand Down
188 changes: 184 additions & 4 deletions apps/desktop/src/components/welcome-modal/index.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useNavigate } from "@tanstack/react-router";
import { message } from "@tauri-apps/plugin-dialog";
import { ArrowLeft } from "lucide-react"; // Add this import
import { useEffect, useState } from "react";

import { showLlmModelDownloadToast, showSttModelDownloadToast } from "@/components/toast/shared";
Expand All @@ -10,15 +11,23 @@ import { commands as localSttCommands, SupportedModel } from "@hypr/plugin-local
import { commands as sfxCommands } from "@hypr/plugin-sfx";
import { Modal, ModalBody } from "@hypr/ui/components/ui/modal";
import { Particles } from "@hypr/ui/components/ui/particles";
import { ConfigureEndpointConfig } from "../settings/components/ai/shared";

import { zodResolver } from "@hookform/resolvers/zod";
import { commands as connectorCommands } from "@hypr/plugin-connector";
import { commands as dbCommands } from "@hypr/plugin-db";
import { commands as localLlmCommands } from "@hypr/plugin-local-llm";
import { useForm } from "react-hook-form";
import { z } from "zod";
import { AudioPermissionsView } from "./audio-permissions-view";
// import { CalendarPermissionsView } from "./calendar-permissions-view";
import { useHypr } from "@/contexts";
import { commands as analyticsCommands } from "@hypr/plugin-analytics";
import { Trans } from "@lingui/react/macro";
import { CustomEndpointView } from "./custom-endpoint-view";
import { DownloadProgressView } from "./download-progress-view";
import { LanguageSelectionView } from "./language-selection-view";
import { LLMSelectionView } from "./llm-selection-view";
import { ModelSelectionView } from "./model-selection-view";
import { WelcomeView } from "./welcome-view";

Expand All @@ -27,6 +36,28 @@ interface WelcomeModalProps {
onClose: () => void;
}

// Form schemas
const openaiSchema = z.object({
api_key: z.string().min(1, "API key is required").startsWith("sk-", "OpenAI API key must start with 'sk-'"),
model: z.string().min(1, "Model selection is required"),
});

const geminiSchema = z.object({
api_key: z.string().min(1, "API key is required").startsWith("AIza", "Gemini API key must start with 'AIza'"),
model: z.string().min(1, "Model selection is required"),
});

const openrouterSchema = z.object({
api_key: z.string().min(1, "API key is required").startsWith("sk-", "OpenRouter API key must start with 'sk-'"),
model: z.string().min(1, "Model selection is required"),
});

const customSchema = z.object({
api_base: z.string().url("Must be a valid URL"),
api_key: z.string().optional(),
model: z.string().min(1, "Model is required"),
});

export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {
const navigate = useNavigate();
const queryClient = useQueryClient();
Expand All @@ -37,15 +68,98 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {
| "model-selection"
| "download-progress"
| "audio-permissions"
| "llm-selection"
| "custom-endpoint"
| "language-selection"
>("welcome");
const [selectedSttModel, setSelectedSttModel] = useState<SupportedModel>("QuantizedSmall");
const [wentThroughDownloads, setWentThroughDownloads] = useState(false);
const [llmSelection, setLlmSelection] = useState<"hyprllm" | "byom" | null>(null);
const [cameFromLlmSelection, setCameFromLlmSelection] = useState(false);

const selectSTTModel = useMutation({
mutationFn: (model: SupportedModel) => localSttCommands.setCurrentModel(model),
});

const openaiForm = useForm<{ api_key: string; model: string }>({
resolver: zodResolver(openaiSchema),
mode: "onChange",
defaultValues: {
api_key: "",
model: "",
},
});

const geminiForm = useForm<{ api_key: string; model: string }>({
resolver: zodResolver(geminiSchema),
mode: "onChange",
defaultValues: {
api_key: "",
model: "",
},
});

const openrouterForm = useForm<{ api_key: string; model: string }>({
resolver: zodResolver(openrouterSchema),
mode: "onChange",
defaultValues: {
api_key: "",
model: "",
},
});

const customForm = useForm<{ api_base: string; api_key?: string; model: string }>({
resolver: zodResolver(customSchema),
mode: "onChange",
defaultValues: {
api_base: "",
api_key: "",
model: "",
},
});

const configureCustomEndpoint = async (config: ConfigureEndpointConfig) => {
const finalApiBase = config.provider === "openai"
? "https://api.openai.com/v1"
: config.provider === "gemini"
? "https://generativelanguage.googleapis.com/v1beta/openai"
: config.provider === "openrouter"
? "https://openrouter.ai/api/v1"
: config.api_base;

try {
await connectorCommands.setCustomLlmEnabled(true);

await connectorCommands.setProviderSource(config.provider);

await connectorCommands.setCustomLlmModel(config.model);

await connectorCommands.setCustomLlmConnection({
api_base: finalApiBase,
api_key: config.api_key || null,
});

if (config.provider === "openai" && config.api_key) {
await connectorCommands.setOpenaiApiKey(config.api_key);
await connectorCommands.setOpenaiModel(config.model);
} else if (config.provider === "gemini" && config.api_key) {
await connectorCommands.setGeminiApiKey(config.api_key);
await connectorCommands.setGeminiModel(config.model);
} else if (config.provider === "openrouter" && config.api_key) {
await connectorCommands.setOpenrouterApiKey(config.api_key);
await connectorCommands.setOpenrouterModel(config.model);
} else if (config.provider === "others") {
await connectorCommands.setOthersApiBase(config.api_base);
if (config.api_key) {
await connectorCommands.setOthersApiKey(config.api_key);
}
await connectorCommands.setOthersModel(config.model);
}
} catch (error) {
console.error("Failed to configure custom endpoint:", error);
}
};

useEffect(() => {
let cleanup: (() => void) | undefined;
let unlisten: (() => void) | undefined;
Expand Down Expand Up @@ -117,6 +231,15 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {
}
}, [currentStep, userId]);

useEffect(() => {
if (currentStep === "llm-selection" && userId) {
analyticsCommands.event({
event: "onboarding_reached_llm_selection",
distinct_id: userId,
});
}
}, [currentStep, userId]);

useEffect(() => {
if (currentStep === "language-selection" && userId) {
analyticsCommands.event({
Expand All @@ -143,6 +266,22 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {
};

const handleAudioPermissionsContinue = () => {
setCurrentStep("llm-selection");
};

const handleLLMSelectionContinue = (selection: "hyprllm" | "byom") => {
setLlmSelection(selection);
if (selection === "hyprllm") {
setCameFromLlmSelection(true);
setCurrentStep("model-selection");
} else {
setCameFromLlmSelection(false);
setCurrentStep("custom-endpoint");
}
};

const handleCustomEndpointContinue = () => {
setCameFromLlmSelection(false);
setCurrentStep("model-selection");
};

Expand All @@ -167,19 +306,22 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {
useEffect(() => {
if (!isOpen && wentThroughDownloads) {
localSttCommands.startServer();

localLlmCommands.startServer();

const checkAndShowToasts = async () => {
try {
const sttModelExists = await localSttCommands.isModelDownloaded(selectedSttModel as SupportedModel);
const llmModelExists = await localLlmCommands.isModelDownloaded("HyprLLM");

if (!sttModelExists) {
showSttModelDownloadToast(selectedSttModel, undefined, queryClient);
}

if (!llmModelExists) {
showLlmModelDownloadToast("HyprLLM", undefined, queryClient);
if (llmSelection === "hyprllm") {
const llmModelExists = await localLlmCommands.isModelDownloaded("HyprLLM");
if (!llmModelExists) {
showLlmModelDownloadToast("HyprLLM", undefined, queryClient);
}
}
} catch (error) {
console.error("Error checking model download status:", error);
Expand All @@ -188,7 +330,7 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {

checkAndShowToasts();
}
}, [isOpen, wentThroughDownloads, selectedSttModel, queryClient]);
}, [isOpen, wentThroughDownloads, selectedSttModel, llmSelection, queryClient]);

return (
<Modal
Expand All @@ -199,6 +341,28 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {
preventClose
>
<ModalBody className="relative p-0 flex flex-col items-center justify-center overflow-hidden">
{/* Back button for custom-endpoint */}
{currentStep === "custom-endpoint" && (
<button
onClick={() => setCurrentStep("llm-selection")}
className="absolute top-6 left-6 z-20 flex items-center gap-2 text-sm text-neutral-600 hover:text-neutral-800 transition-colors bg-white/80 backdrop-blur-sm rounded-lg px-3 py-2 hover:bg-white/90"
>
<ArrowLeft className="w-4 h-4" />
<Trans>Back</Trans>
</button>
)}

{/* Back button for model-selection (only when coming from llm-selection) */}
{currentStep === "model-selection" && cameFromLlmSelection && (
<button
onClick={() => setCurrentStep("llm-selection")}
className="absolute top-6 left-6 z-20 flex items-center gap-2 text-sm text-neutral-600 hover:text-neutral-800 transition-colors bg-white/80 backdrop-blur-sm rounded-lg px-3 py-2 hover:bg-white/90"
>
<ArrowLeft className="w-4 h-4" />
<Trans>Back</Trans>
</button>
)}

<div className="z-10">
{currentStep === "welcome" && (
<WelcomeView
Expand All @@ -214,6 +378,7 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {
{currentStep === "download-progress" && (
<DownloadProgressView
selectedSttModel={selectedSttModel}
llmSelection={llmSelection}
onContinue={handleDownloadProgressContinue}
/>
)}
Expand All @@ -222,6 +387,21 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) {
onContinue={handleAudioPermissionsContinue}
/>
)}
{currentStep === "llm-selection" && (
<LLMSelectionView
onContinue={handleLLMSelectionContinue}
/>
)}
{currentStep === "custom-endpoint" && (
<CustomEndpointView
onContinue={handleCustomEndpointContinue}
configureCustomEndpoint={configureCustomEndpoint}
openaiForm={openaiForm}
geminiForm={geminiForm}
openrouterForm={openrouterForm}
customForm={customForm}
/>
)}
{currentStep === "language-selection" && (
<LanguageSelectionView
onContinue={handleLanguageSelectionContinue}
Expand Down
Loading
Loading