Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/desktop/src/components/settings/ai/stt/configure.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ function HyprProviderCard({
<HyprProviderLocalRow
model="am-parakeet-v2"
displayName="Parakeet v2"
description="English only. Works best for English."
description="Optimized for English. Best accuracy for English conversations."
/>
<HyprProviderLocalRow
model="am-parakeet-v3"
displayName="Parakeet v3"
description="English and European languages."
description="Better for European languages. Supports multilingual conversations."
/>
<HyprProviderLocalRow
model="am-whisper-large-v3"
Expand Down
99 changes: 70 additions & 29 deletions apps/desktop/src/components/settings/ai/stt/select.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { useForm } from "@tanstack/react-form";
import { useQueries, useQuery } from "@tanstack/react-query";
import { arch } from "@tauri-apps/plugin-os";
import { useCallback } from "react";

import type { AIProviderStorage } from "@hypr/store";
import { Input } from "@hypr/ui/components/ui/input";
Expand All @@ -15,6 +16,7 @@ import { cn } from "@hypr/utils";

import { useBillingAccess } from "../../../../billing";
import { useConfigValues } from "../../../../config/use-config";
import { useValidateSttModel } from "../../../../hooks/useValidateSttModel";
import * as settings from "../../../../store/tinybase/settings";
import {
getProviderSelectionBlockers,
Expand Down Expand Up @@ -50,10 +52,22 @@ export function SelectProviderAndModel() {
settings.STORE_ID,
);

const getValidatedModel = () => {
if (!current_stt_provider || !current_stt_model) return "";

const providerModels =
configuredProviders[current_stt_provider as ProviderId]?.models ?? [];
const isModelValid = providerModels.some(
(model) => model.id === current_stt_model && model.isDownloaded,
);

return isModelValid ? current_stt_model : "";
};

const form = useForm({
defaultValues: {
provider: current_stt_provider || "",
model: current_stt_model || "",
model: getValidatedModel(),
},
listeners: {
onChange: ({ formApi }) => {
Expand All @@ -73,6 +87,17 @@ export function SelectProviderAndModel() {
},
});

const handleClearModel = useCallback(() => {
handleSelectModel("");
form.setFieldValue("model", "");
}, [handleSelectModel, form]);

useValidateSttModel(
current_stt_provider,
current_stt_model,
handleClearModel,
);

return (
<div className="flex flex-col gap-3">
<h3 className="text-md font-semibold">Model being used</h3>
Expand Down Expand Up @@ -170,34 +195,46 @@ export function SelectProviderAndModel() {
}

const allModels = configuredProviders?.[providerId]?.models ?? [];
const models = allModels.filter((model) => {
if (model.id === "cloud") {
return true;
}
if (model.id.startsWith("Quantized")) {
return model.isDownloaded;
}
return true;
});

const modelsToShow =
providerId === "hyprnote"
? allModels
: allModels.filter((model) => {
if (model.id === "cloud") {
return true;
}
return model.isDownloaded;
});

return (
<div className="flex-[3] min-w-0">
<Select
value={field.state.value}
onValueChange={(value) => field.handleChange(value)}
disabled={models.length === 0}
disabled={modelsToShow.length === 0}
>
<SelectTrigger className="bg-white shadow-none focus:ring-0">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{models.map((model) => (
<SelectContent className="w-full">
{modelsToShow.map((model) => (
<SelectItem
key={model.id}
value={model.id}
disabled={!model.isDownloaded}
className="group"
>
{displayModelId(model.id)}
<div className="flex items-center justify-between w-full">
<span>{displayModelId(model.id)}</span>
{!model.isDownloaded &&
providerId === "hyprnote" && (
<span className="text-xs text-neutral-500 ml-auto opacity-0 group-hover:opacity-100 transition-opacity">
{model.id === "cloud"
? "Start trial"
: "Download model"}
</span>
)}
</div>
</SelectItem>
))}
</SelectContent>
Expand Down Expand Up @@ -276,17 +313,7 @@ function useConfiguredMapping(): Record<
}

if (provider.id === "hyprnote") {
const models = [
{ id: "cloud", isDownloaded: billing.isPro },
{
id: "QuantizedTinyEn",
isDownloaded: tinyEn.data ?? false,
},
{
id: "QuantizedSmallEn",
isDownloaded: smallEn.data ?? false,
},
];
const models = [{ id: "cloud", isDownloaded: billing.isPro }];

if (isAppleSilicon) {
models.push(
Expand All @@ -298,13 +325,27 @@ function useConfiguredMapping(): Record<
id: "am-parakeet-v3",
isDownloaded: p3.data ?? false,
},
{
id: "am-whisper-large-v3",
isDownloaded: whisperLargeV3.data ?? false,
},
);
}

if (isAppleSilicon) {
models.push({
id: "am-whisper-large-v3",
isDownloaded: whisperLargeV3.data ?? false,
});
}

models.push(
{
id: "QuantizedTinyEn",
isDownloaded: tinyEn.data ?? false,
},
{
id: "QuantizedSmallEn",
isDownloaded: smallEn.data ?? false,
},
);

return [
provider.id,
{
Expand Down
56 changes: 56 additions & 0 deletions apps/desktop/src/hooks/useValidateSttModel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { useQuery } from "@tanstack/react-query";
import { useEffect } from "react";

import {
commands as localSttCommands,
type SupportedSttModel,
} from "@hypr/plugin-local-stt";

const SUPPORTED_LOCAL_MODELS: SupportedSttModel[] = [
"am-parakeet-v2",
"am-parakeet-v3",
"am-whisper-large-v3",
"QuantizedTinyEn",
"QuantizedSmallEn",
];

export function useValidateSttModel(
provider: string | undefined,
model: string | undefined,
onClearModel: () => void,
) {
const isLocalModel = provider === "hyprnote" && model && model !== "cloud";

const { data: isDownloaded } = useQuery({
queryKey: ["stt-model-downloaded", model, isLocalModel],
queryFn: async () => {
if (!isLocalModel || !model) return true;

if (SUPPORTED_LOCAL_MODELS.includes(model as SupportedSttModel)) {
try {
const result = await localSttCommands.isModelDownloaded(
model as SupportedSttModel,
);
return result.status === "ok" && result.data;
} catch (error) {
console.error("Error checking model download status:", error);
return false;
}
}

return true;
},
enabled: !!isLocalModel,
refetchInterval: 2000,
staleTime: 500,
});

useEffect(() => {
if (isLocalModel && isDownloaded === false) {
console.log(`Clearing invalid STT model selection: ${model}`);
onClearModel();
}
}, [isLocalModel, isDownloaded, model, onClearModel]);

return { isModelValid: !isLocalModel || isDownloaded !== false };
}
Loading