diff --git a/apps/desktop2/src/chat/tools.ts b/apps/desktop2/src/chat/tools.ts index 131791d9a6..40eb6802b7 100644 --- a/apps/desktop2/src/chat/tools.ts +++ b/apps/desktop2/src/chat/tools.ts @@ -1,24 +1,42 @@ -import { tool } from "ai"; import { z } from "zod"; -export const searchSessionsTool = tool({ - description: "Search for sessions", - inputSchema: z.object({ - query: z.string().describe("The query to search for"), - }), - execute: async () => { - return { results: [] }; - }, -}); +import { searchFiltersSchema } from "../contexts/search/engine/types"; +import type { SearchFilters, SearchHit } from "../contexts/search/engine/types"; -export const tools = { - search_sessions: searchSessionsTool, -}; +export interface ToolDependencies { + search: (query: string, filters?: SearchFilters | null) => Promise; +} + +export const toolFactories = { + search_sessions: (deps: ToolDependencies) => ({ + description: ` + Search for sessions (meeting notes) using query and filters. + Returns relevant sessions with their content. + `.trim(), + parameters: z.object({ + query: z.string().describe("The search query to find relevant sessions"), + filters: searchFiltersSchema.optional().describe("Optional filters for the search query"), + }), + execute: async (params: { query: string; filters?: SearchFilters }) => { + const hits = await deps.search(params.query, params.filters || null); + + const results = hits.slice(0, 5).map((hit) => ({ + id: hit.document.id, + title: hit.document.title, + content: hit.document.content.slice(0, 500), + score: hit.score, + created_at: hit.document.created_at, + })); + + return { results }; + }, + }), +} as const; export type Tools = { - [K in keyof typeof tools]: { - input: Parameters>[0]; - output: Awaited>>; + [K in keyof typeof toolFactories]: { + input: Parameters["execute"]>[0]; + output: Awaited["execute"]>>; }; }; diff --git a/apps/desktop2/src/chat/transport.ts b/apps/desktop2/src/chat/transport.ts index 318e27bc4c..48488ae79e 100644 --- a/apps/desktop2/src/chat/transport.ts +++ b/apps/desktop2/src/chat/transport.ts @@ -2,15 +2,19 @@ import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import type { ChatRequestOptions, ChatTransport, UIMessageChunk } from "ai"; import { convertToModelMessages, smoothStream, stepCountIs, streamText } from "ai"; -import { searchSessionsTool } from "./tools"; +import { ToolRegistry } from "../contexts/tool"; import type { HyprUIMessage } from "./types"; +const modelName = "google/gemini-2.5-flash-lite"; const provider = createOpenAICompatible({ name: "openrouter", baseURL: "https://openrouter.ai/api/v1", + apiKey: "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", }); export class CustomChatTransport implements ChatTransport { + constructor(private registry: ToolRegistry) {} + async sendMessages( options: & { @@ -21,8 +25,8 @@ export class CustomChatTransport implements ChatTransport { & { trigger: "submit-message" | "regenerate-message"; messageId: string | undefined } & ChatRequestOptions, ): Promise> { - const model = provider.chatModel("openai/gpt-5-mini"); - const tools = this.getTools(); + const model = provider.chatModel(modelName); + const tools = this.registry.getForTransport(); const result = streamText({ model, @@ -50,10 +54,4 @@ export class CustomChatTransport implements ChatTransport { async reconnectToStream(): Promise | null> { return null; } - - private getTools(): Parameters[0]["tools"] { - return { - search_sessions: searchSessionsTool, - }; - } } diff --git a/apps/desktop2/src/components/chat/body.tsx b/apps/desktop2/src/components/chat/body.tsx index 9f0f3c2273..c90047417b 100644 --- a/apps/desktop2/src/components/chat/body.tsx +++ b/apps/desktop2/src/components/chat/body.tsx @@ -1,11 +1,13 @@ import type { ChatStatus } from "ai"; -import { Loader2, MessageCircle, RotateCcw, X } from "lucide-react"; +import { MessageCircle } from "lucide-react"; import { useEffect, useRef } from "react"; import { cn } from "@hypr/ui/lib/utils"; import type { HyprUIMessage } from "../../chat/types"; import { useShell } from "../../contexts/shell"; -import { ChatBodyMessage } from "./message"; +import { ErrorMessage } from "./message/error"; +import { LoadingMessage } from "./message/loading"; +import { NormalMessage } from "./message/normal"; import { hasRenderableContent } from "./shared"; export function ChatBody({ @@ -95,7 +97,7 @@ function ChatBodyNonEmpty({ return (
{messages.map((message, index) => ( - ); } - -function LoadingMessage({ onCancelAndRetry }: { onCancelAndRetry?: () => void }) { - return ( -
-
-
- - Thinking... -
- {onCancelAndRetry && ( - - )} -
-
- ); -} - -function ErrorMessage({ error, onRetry }: { error: Error; onRetry?: () => void }) { - return ( -
-
-

{error.message}

- {onRetry && ( - - )} -
-
- ); -} diff --git a/apps/desktop2/src/components/chat/message/error.tsx b/apps/desktop2/src/components/chat/message/error.tsx new file mode 100644 index 0000000000..9296f1556a --- /dev/null +++ b/apps/desktop2/src/components/chat/message/error.tsx @@ -0,0 +1,21 @@ +import { RotateCcw } from "lucide-react"; + +import { ActionButton, MessageBubble, MessageContainer } from "./shared"; + +export function ErrorMessage({ error, onRetry }: { error: Error; onRetry?: () => void }) { + return ( + + +

{error.message}

+ {onRetry && ( + + )} +
+
+ ); +} diff --git a/apps/desktop2/src/components/chat/message/loading.tsx b/apps/desktop2/src/components/chat/message/loading.tsx new file mode 100644 index 0000000000..9832c864d1 --- /dev/null +++ b/apps/desktop2/src/components/chat/message/loading.tsx @@ -0,0 +1,24 @@ +import { Loader2, X } from "lucide-react"; + +import { ActionButton, MessageBubble, MessageContainer } from "./shared"; + +export function LoadingMessage({ onCancelAndRetry }: { onCancelAndRetry?: () => void }) { + return ( + + +
+ + Thinking... +
+ {onCancelAndRetry && ( + + )} +
+
+ ); +} diff --git a/apps/desktop2/src/components/chat/message/index.tsx b/apps/desktop2/src/components/chat/message/normal.tsx similarity index 69% rename from apps/desktop2/src/components/chat/message/index.tsx rename to apps/desktop2/src/components/chat/message/normal.tsx index 77d8be7a9c..f72fb72ef8 100644 --- a/apps/desktop2/src/components/chat/message/index.tsx +++ b/apps/desktop2/src/components/chat/message/normal.tsx @@ -2,15 +2,14 @@ import { formatDistanceToNow } from "date-fns"; import { BrainIcon, RotateCcw } from "lucide-react"; import { Streamdown } from "streamdown"; -import { cn } from "@hypr/ui/lib/utils"; import type { ToolPartType } from "../../../chat/tools"; import type { HyprUIMessage } from "../../../chat/types"; import { hasRenderableContent } from "../shared"; -import { Disclosure } from "./shared"; +import { ActionButton, Disclosure, MessageBubble, MessageContainer } from "./shared"; import { Tool } from "./tool"; import type { Part } from "./types"; -export function ChatBodyMessage({ message, handleReload }: { message: HyprUIMessage; handleReload?: () => void }) { +export function NormalMessage({ message, handleReload }: { message: HyprUIMessage; handleReload?: () => void }) { const isUser = message.role === "user"; const shouldShowTimestamp = message.metadata?.createdAt @@ -22,45 +21,29 @@ export function ChatBodyMessage({ message, handleReload }: { message: HyprUIMess } return ( -
+
-
{message.parts.map((part, i) => )} {!isUser && handleReload && ( - + variant="default" + icon={RotateCcw} + label="Reload message" + /> )} -
+ {shouldShowTimestamp && message.metadata?.createdAt && (
{formatDistanceToNow(message.metadata.createdAt, { addSuffix: true })}
)}
-
+ ); } diff --git a/apps/desktop2/src/components/chat/message/shared.tsx b/apps/desktop2/src/components/chat/message/shared.tsx index f21a0e0367..5d9fed28a4 100644 --- a/apps/desktop2/src/components/chat/message/shared.tsx +++ b/apps/desktop2/src/components/chat/message/shared.tsx @@ -3,6 +3,85 @@ import { type ReactNode } from "react"; import { cn } from "@hypr/ui/lib/utils"; +export function MessageContainer({ + align = "start", + children, +}: { + align?: "start" | "end"; + children: ReactNode; +}) { + return ( +
+ {children} +
+ ); +} + +export function MessageBubble({ + variant = "assistant", + withActionButton, + children, +}: { + variant?: "user" | "assistant" | "error" | "loading"; + withActionButton?: boolean; + children: ReactNode; +}) { + return ( +
+ {children} +
+ ); +} + +export function ActionButton({ + onClick, + variant = "default", + icon: Icon, + label, +}: { + onClick: () => void; + variant?: "default" | "error"; + icon: React.ComponentType<{ className?: string }>; + label: string; +}) { + return ( + + ); +} + export function Disclosure( { icon, diff --git a/apps/desktop2/src/components/chat/message/tool/search.tsx b/apps/desktop2/src/components/chat/message/tool/search.tsx index 7f377205d5..cf940f54a6 100644 --- a/apps/desktop2/src/components/chat/message/tool/search.tsx +++ b/apps/desktop2/src/components/chat/message/tool/search.tsx @@ -1,5 +1,16 @@ import { SearchIcon } from "lucide-react"; +import { useCallback } from "react"; +import { Card, CardContent } from "@hypr/ui/components/ui/card"; +import { + Carousel, + CarouselContent, + CarouselItem, + CarouselNext, + CarouselPrevious, +} from "@hypr/ui/components/ui/carousel"; +import * as persisted from "../../../../store/tinybase/persisted"; +import { useTabs } from "../../../../store/zustand/tabs"; import { Disclosure } from "../shared"; import { ToolRenderer } from "../types"; @@ -36,14 +47,39 @@ const getTitle = (part: Part) => { return "Search"; }; -const RenderContent = ({ part }: { part: Part }) => { +function RenderContent({ part }: { part: Part }) { if (part.state === "output-available" && part.output && "results" in part.output) { const { results } = part.output; + if (!results || results.length === 0) { + return ( +
+ No results found +
+ ); + } + return ( -
-            {JSON.stringify(results, null, 2)}
-      
+
+ + + {results.map((result: any, index: number) => ( + + + + + + + + ))} + + + + +
); } @@ -52,4 +88,33 @@ const RenderContent = ({ part }: { part: Part }) => { } return null; -}; +} + +function RenderSession({ sessionId }: { sessionId: string }) { + const session = persisted.UI.useRow("sessions", sessionId, persisted.STORE_ID); + const { openNew } = useTabs(); + + const handleClick = useCallback(() => { + openNew({ + type: "sessions", + id: sessionId, + active: true, + state: { editor: "raw" }, + }); + }, [openNew, sessionId]); + + if (!session) { + return
Session unavailable
; + } + + return ( +
+ + {session.title || "Untitled"} + + + {session.enhanced_md ?? session.raw_md} + +
+ ); +} diff --git a/apps/desktop2/src/components/chat/session.tsx b/apps/desktop2/src/components/chat/session.tsx index 7ba345bd91..3498561b27 100644 --- a/apps/desktop2/src/components/chat/session.tsx +++ b/apps/desktop2/src/components/chat/session.tsx @@ -4,6 +4,7 @@ import { type ReactNode, useEffect, useMemo, useRef } from "react"; import { CustomChatTransport } from "../../chat/transport"; import type { HyprUIMessage } from "../../chat/types"; +import { useToolRegistry } from "../../contexts/tool"; import * as internal from "../../store/tinybase/internal"; import * as persisted from "../../store/tinybase/persisted"; import { id } from "../../utils"; @@ -26,7 +27,8 @@ export function ChatSession({ chatGroupId, children, }: ChatSessionProps) { - const transport = useMemo(() => new CustomChatTransport(), []); + const registry = useToolRegistry(); + const transport = useMemo(() => new CustomChatTransport(registry), [registry]); const store = persisted.UI.useStore(persisted.STORE_ID); const { user_id } = internal.UI.useValues(internal.STORE_ID); diff --git a/apps/desktop2/src/components/main/body/search.tsx b/apps/desktop2/src/components/main/body/search.tsx index 669d8770f0..68664f36f3 100644 --- a/apps/desktop2/src/components/main/body/search.tsx +++ b/apps/desktop2/src/components/main/body/search.tsx @@ -3,7 +3,7 @@ import { useRef } from "react"; import { useHotkeys } from "react-hotkeys-hook"; import { cn } from "@hypr/ui/lib/utils"; -import { useSearch } from "../../../contexts/search"; +import { useSearch } from "../../../contexts/search/ui"; export function Search() { const { query, setQuery, isSearching, isIndexing, onFocus, onBlur } = useSearch(); diff --git a/apps/desktop2/src/components/main/sidebar/index.tsx b/apps/desktop2/src/components/main/sidebar/index.tsx index e0afbc4b0c..0a26c8e4e7 100644 --- a/apps/desktop2/src/components/main/sidebar/index.tsx +++ b/apps/desktop2/src/components/main/sidebar/index.tsx @@ -1,7 +1,7 @@ import { clsx } from "clsx"; import { PanelLeftCloseIcon } from "lucide-react"; -import { useSearch } from "../../../contexts/search"; +import { useSearch } from "../../../contexts/search/ui"; import { useShell } from "../../../contexts/shell"; import { ProfileSection } from "./profile"; import { SearchResults } from "./search"; diff --git a/apps/desktop2/src/components/main/sidebar/search/group.tsx b/apps/desktop2/src/components/main/sidebar/search/group.tsx index 6b0c896889..5e48503162 100644 --- a/apps/desktop2/src/components/main/sidebar/search/group.tsx +++ b/apps/desktop2/src/components/main/sidebar/search/group.tsx @@ -2,7 +2,7 @@ import { ChevronDownIcon } from "lucide-react"; import { useState } from "react"; import { cn } from "@hypr/ui/lib/utils"; -import { type SearchGroup } from "../../../../contexts/search"; +import { type SearchGroup } from "../../../../contexts/search/ui"; import { SearchResultItem } from "./item"; const ITEMS_PER_PAGE = 3; diff --git a/apps/desktop2/src/components/main/sidebar/search/index.tsx b/apps/desktop2/src/components/main/sidebar/search/index.tsx index aeb168736a..5a8d5f2396 100644 --- a/apps/desktop2/src/components/main/sidebar/search/index.tsx +++ b/apps/desktop2/src/components/main/sidebar/search/index.tsx @@ -1,7 +1,7 @@ import { SearchXIcon } from "lucide-react"; import { cn } from "@hypr/ui/lib/utils"; -import { type GroupedSearchResults, useSearch } from "../../../../contexts/search"; +import { type GroupedSearchResults, useSearch } from "../../../../contexts/search/ui"; import { SearchResultGroup } from "./group"; export function SearchResults() { diff --git a/apps/desktop2/src/components/main/sidebar/search/item.tsx b/apps/desktop2/src/components/main/sidebar/search/item.tsx index 10b0193044..1fad37b594 100644 --- a/apps/desktop2/src/components/main/sidebar/search/item.tsx +++ b/apps/desktop2/src/components/main/sidebar/search/item.tsx @@ -2,7 +2,7 @@ import DOMPurify from "dompurify"; import { useCallback, useMemo } from "react"; import { cn } from "@hypr/ui/lib/utils"; -import { type SearchResult } from "../../../../contexts/search"; +import { type SearchResult } from "../../../../contexts/search/ui"; import * as persisted from "../../../../store/tinybase/persisted"; import { Tab, useTabs } from "../../../../store/zustand/tabs"; import { getInitials } from "../../body/contacts/shared"; @@ -33,18 +33,11 @@ export function SearchResultItem({ result }: { result: SearchResult }) { } function HumanSearchResultItem({ result, onClick }: { result: SearchResult; onClick: () => void }) { - const organization = persisted.UI.useRow("organizations", result.org_id, persisted.STORE_ID); - const sanitizedTitle = useMemo( () => DOMPurify.sanitize(result.titleHighlighted, { ALLOWED_TAGS: ["mark"], ALLOWED_ATTR: [] }), [result.titleHighlighted], ); - const jobTitle = result.metadata?.job_title as string | undefined; - const orgName = organization?.name; - - const subtitle = [jobTitle, orgName].filter(Boolean).join(", "); - return (
); diff --git a/apps/desktop2/src/contexts/search/engine/content.ts b/apps/desktop2/src/contexts/search/engine/content.ts new file mode 100644 index 0000000000..3a732715e5 --- /dev/null +++ b/apps/desktop2/src/contexts/search/engine/content.ts @@ -0,0 +1,13 @@ +import { flattenTranscript, mergeContent } from "./utils"; + +export function createSessionSearchableContent(row: Record): string { + return mergeContent([ + row.raw_md, + row.enhanced_md, + flattenTranscript(row.transcript), + ]); +} + +export function createHumanSearchableContent(row: Record): string { + return mergeContent([row.email, row.job_title, row.linkedin_username]); +} diff --git a/apps/desktop2/src/contexts/search/engine/filters.ts b/apps/desktop2/src/contexts/search/engine/filters.ts new file mode 100644 index 0000000000..b3a3d0a4e6 --- /dev/null +++ b/apps/desktop2/src/contexts/search/engine/filters.ts @@ -0,0 +1,29 @@ +import type { SearchFilters } from "./types"; + +export function buildOramaFilters(filters: SearchFilters | null): Record | undefined { + if (!filters || !filters.created_at) { + return undefined; + } + + const createdAtConditions: Record = {}; + + if (filters.created_at.gte !== undefined) { + createdAtConditions.gte = filters.created_at.gte; + } + if (filters.created_at.lte !== undefined) { + createdAtConditions.lte = filters.created_at.lte; + } + if (filters.created_at.gt !== undefined) { + createdAtConditions.gt = filters.created_at.gt; + } + if (filters.created_at.lt !== undefined) { + createdAtConditions.lt = filters.created_at.lt; + } + if (filters.created_at.eq !== undefined) { + createdAtConditions.eq = filters.created_at.eq; + } + + return Object.keys(createdAtConditions).length > 0 + ? { created_at: createdAtConditions } + : undefined; +} diff --git a/apps/desktop2/src/contexts/search/engine/index.tsx b/apps/desktop2/src/contexts/search/engine/index.tsx new file mode 100644 index 0000000000..130f8817a3 --- /dev/null +++ b/apps/desktop2/src/contexts/search/engine/index.tsx @@ -0,0 +1,131 @@ +import { createContext, useCallback, useContext, useEffect, useRef, useState } from "react"; + +import { create, search as oramaSearch } from "@orama/orama"; +import { pluginQPS } from "@orama/plugin-qps"; + +import { type Store as PersistedStore } from "../../../store/tinybase/persisted"; +import { buildOramaFilters } from "./filters"; +import { indexHumans, indexOrganizations, indexSessions } from "./indexing"; +import { createHumanListener, createOrganizationListener, createSessionListener } from "./listeners"; +import type { Index, SearchFilters, SearchHit } from "./types"; +import { SEARCH_SCHEMA } from "./types"; +import { normalizeQuery } from "./utils"; + +export type { SearchDocument, SearchEntityType, SearchFilters, SearchHit } from "./types"; + +const SearchEngineContext = createContext< + { + search: (query: string, filters?: SearchFilters | null) => Promise; + isIndexing: boolean; + } | null +>(null); + +export function SearchEngineProvider({ children, store }: { children: React.ReactNode; store?: PersistedStore }) { + const [isIndexing, setIsIndexing] = useState(true); + const oramaInstance = useRef(null); + const listenerIds = useRef([]); + + useEffect(() => { + if (!store) { + return; + } + + const initializeIndex = async () => { + setIsIndexing(true); + + try { + const db = create({ + schema: SEARCH_SCHEMA, + plugins: [pluginQPS()], + }); + + indexSessions(db, store); + indexHumans(db, store); + indexOrganizations(db, store); + + oramaInstance.current = db; + + const listener1 = store.addRowListener( + "sessions", + null, + createSessionListener(oramaInstance.current), + ); + const listener2 = store.addRowListener( + "humans", + null, + createHumanListener(oramaInstance.current), + ); + const listener3 = store.addRowListener( + "organizations", + null, + createOrganizationListener(oramaInstance.current), + ); + + listenerIds.current = [listener1, listener2, listener3]; + } catch (error) { + console.error("Failed to create search index:", error); + } finally { + setIsIndexing(false); + } + }; + + void initializeIndex(); + + return () => { + listenerIds.current.forEach((id) => { + store.delListener(id); + }); + listenerIds.current = []; + }; + }, [store]); + + const search = useCallback( + async (query: string, filters: SearchFilters | null = null): Promise => { + const normalizedQuery = normalizeQuery(query); + + if (normalizedQuery.length < 1) { + return []; + } + + if (!oramaInstance.current) { + return []; + } + + try { + const whereClause = buildOramaFilters(filters); + + const searchResults = await oramaSearch(oramaInstance.current, { + term: normalizedQuery, + boost: { + title: 3, + content: 1, + }, + limit: 100, + tolerance: 1, + ...(whereClause && { where: whereClause }), + }); + + return searchResults.hits as SearchHit[]; + } catch (error) { + console.error("Search failed:", error); + return []; + } + }, + [], + ); + + const value = { + search, + isIndexing, + }; + + return {children}; +} + +export function useSearchEngine() { + const context = useContext(SearchEngineContext); + if (!context) { + throw new Error("useSearchEngine must be used within SearchEngineProvider"); + } + return context; +} diff --git a/apps/desktop2/src/contexts/search/engine/indexing.ts b/apps/desktop2/src/contexts/search/engine/indexing.ts new file mode 100644 index 0000000000..d9e732600f --- /dev/null +++ b/apps/desktop2/src/contexts/search/engine/indexing.ts @@ -0,0 +1,74 @@ +import { insert } from "@orama/orama"; + +import { type Store as PersistedStore } from "../../../store/tinybase/persisted"; +import { createHumanSearchableContent, createSessionSearchableContent } from "./content"; +import type { Index } from "./types"; +import { collectCells, toNumber, toTrimmedString } from "./utils"; + +export function indexSessions(db: Index, store: PersistedStore): void { + const fields = [ + "user_id", + "created_at", + "folder_id", + "event_id", + "title", + "raw_md", + "enhanced_md", + "transcript", + ]; + + store.forEachRow("sessions", (rowId: string, _forEachCell) => { + const row = collectCells(store, "sessions", rowId, fields); + const title = toTrimmedString(row.title) || "Untitled"; + + void insert(db, { + id: rowId, + type: "session", + title, + content: createSessionSearchableContent(row), + created_at: toNumber(row.created_at), + }); + }); +} + +export function indexHumans(db: Index, store: PersistedStore): void { + const fields = [ + "name", + "email", + "org_id", + "job_title", + "linkedin_username", + "is_user", + "created_at", + ]; + + store.forEachRow("humans", (rowId: string, _forEachCell) => { + const row = collectCells(store, "humans", rowId, fields); + const title = toTrimmedString(row.name) || "Unknown"; + + void insert(db, { + id: rowId, + type: "human", + title, + content: createHumanSearchableContent(row), + created_at: toNumber(row.created_at), + }); + }); +} + +export function indexOrganizations(db: Index, store: PersistedStore): void { + const fields = ["name", "created_at"]; + + store.forEachRow("organizations", (rowId: string, _forEachCell) => { + const row = collectCells(store, "organizations", rowId, fields); + const title = toTrimmedString(row.name) || "Unknown Organization"; + + void insert(db, { + id: rowId, + type: "organization", + title, + content: "", + created_at: toNumber(row.created_at), + }); + }); +} diff --git a/apps/desktop2/src/contexts/search/engine/listeners.ts b/apps/desktop2/src/contexts/search/engine/listeners.ts new file mode 100644 index 0000000000..dc805470ea --- /dev/null +++ b/apps/desktop2/src/contexts/search/engine/listeners.ts @@ -0,0 +1,102 @@ +import { remove, type TypedDocument, update } from "@orama/orama"; +import { RowListener } from "tinybase/with-schemas"; + +import { Schemas } from "../../../store/tinybase/persisted"; +import { type Store as PersistedStore } from "../../../store/tinybase/persisted"; +import { createHumanSearchableContent, createSessionSearchableContent } from "./content"; +import type { Index } from "./types"; +import { collectCells, toNumber, toTrimmedString } from "./utils"; + +export function createSessionListener(index: Index): RowListener { + return (store, _, rowId) => { + try { + const rowExists = store.getRow("sessions", rowId); + + if (!rowExists) { + void remove(index, rowId); + } else { + const fields = [ + "user_id", + "created_at", + "title", + "raw_md", + "enhanced_md", + "transcript", + ]; + const row = collectCells(store, "sessions", rowId, fields); + const title = toTrimmedString(row.title) || "Untitled"; + + const data: TypedDocument = { + id: rowId, + type: "session", + title, + content: createSessionSearchableContent(row), + created_at: toNumber(row.created_at), + }; + + update(index, rowId, data); + } + } catch (error) { + console.error("Failed to update session in search index:", error); + } + }; +} + +export function createHumanListener(index: Index): RowListener { + return (store, _, rowId) => { + try { + const rowExists = store.getRow("humans", rowId); + + if (!rowExists) { + void remove(index, rowId); + } else { + const fields = [ + "name", + "email", + "created_at", + ]; + const row = collectCells(store, "humans", rowId, fields); + const title = toTrimmedString(row.name) || "Unknown"; + + const data: TypedDocument = { + id: rowId, + type: "human", + title, + content: createHumanSearchableContent(row), + created_at: toNumber(row.created_at), + }; + update(index, rowId, data); + } + } catch (error) { + console.error("Failed to update human in search index:", error); + } + }; +} + +export function createOrganizationListener(index: Index): RowListener { + return (store, _, rowId) => { + try { + const rowExists = store.getRow("organizations", rowId); + + if (!rowExists) { + remove(index, rowId); + } else { + const fields = ["name", "created_at"]; + const row = collectCells(store, "organizations", rowId, fields); + const title = toTrimmedString(row.name) || "Unknown Organization"; + + const data: TypedDocument = { + id: rowId, + type: "organization", + title, + content: "", + created_at: toNumber(row.created_at), + }; + + update(index, rowId, data); + } + } catch (error) { + console.error("Failed to update organization in search index:", error); + } + }; +} diff --git a/apps/desktop2/src/contexts/search/engine/types.ts b/apps/desktop2/src/contexts/search/engine/types.ts new file mode 100644 index 0000000000..ae021d0ba3 --- /dev/null +++ b/apps/desktop2/src/contexts/search/engine/types.ts @@ -0,0 +1,54 @@ +import { Orama } from "@orama/orama"; +import { z } from "zod"; + +const searchEntityTypeSchema = z.enum(["session", "human", "organization"]); +export type SearchEntityType = z.infer; + +export const searchDocumentSchema = z.object({ + id: z.string(), + type: searchEntityTypeSchema, + title: z.string(), + content: z.string(), + created_at: z.number(), +}); + +export type SearchDocument = z.infer; + +export const SEARCH_SCHEMA = { + id: "string", + type: "enum", + title: "string", + content: "string", + created_at: "number", +} as const satisfies InferOramaSchema; + +export type Index = Orama; + +const numberFilterSchema = z.object({ + gte: z.number().optional(), + lte: z.number().optional(), + gt: z.number().optional(), + lt: z.number().optional(), + eq: z.number().optional(), +}).optional(); + +export const searchFiltersSchema = z.object({ + created_at: numberFilterSchema, +}); + +export type SearchFilters = z.infer; + +export type SearchHit = { + score: number; + document: SearchDocument; +}; + +type InferOramaField = T extends z.ZodString ? "string" + : T extends z.ZodNumber ? "number" + : T extends z.ZodBoolean ? "boolean" + : T extends z.ZodEnum ? "enum" + : never; + +type InferOramaSchema> = { + [K in keyof T["shape"]]: InferOramaField; +}; diff --git a/apps/desktop2/src/contexts/search/engine/utils.ts b/apps/desktop2/src/contexts/search/engine/utils.ts new file mode 100644 index 0000000000..86b20d2a6a --- /dev/null +++ b/apps/desktop2/src/contexts/search/engine/utils.ts @@ -0,0 +1,113 @@ +const SPACE_REGEX = /\s+/g; + +export function safeParseJSON(value: unknown): unknown { + if (typeof value !== "string") { + return value; + } + + try { + return JSON.parse(value); + } catch { + return value; + } +} + +export function normalizeQuery(query: string): string { + return query.trim().replace(SPACE_REGEX, " "); +} + +export function toTrimmedString(value: unknown): string { + if (typeof value === "string") { + return value.trim(); + } + + return ""; +} + +export function toNumber(value: unknown): number { + if (typeof value === "number") { + return value; + } + if (typeof value === "string") { + const parsed = Number(value); + return isNaN(parsed) ? 0 : parsed; + } + return 0; +} + +export function toString(value: unknown): string { + if (typeof value === "string" && value.length > 0) { + return value; + } + return ""; +} + +export function toBoolean(value: unknown): boolean { + if (typeof value === "boolean") { + return value; + } + return false; +} + +export function mergeContent(parts: unknown[]): string { + return parts + .map(toTrimmedString) + .filter(Boolean) + .join(" "); +} + +export function flattenTranscript(transcript: unknown): string { + if (transcript == null) { + return ""; + } + + const parsed = safeParseJSON(transcript); + + if (typeof parsed === "string") { + return parsed; + } + + if (Array.isArray(parsed)) { + return mergeContent( + parsed.map((segment) => { + if (!segment) { + return ""; + } + + if (typeof segment === "string") { + return segment; + } + + if (typeof segment === "object") { + const record = segment as Record; + const preferred = record.text ?? record.content; + if (typeof preferred === "string") { + return preferred; + } + + return flattenTranscript(Object.values(record)); + } + + return ""; + }), + ); + } + + if (typeof parsed === "object" && parsed !== null) { + return mergeContent(Object.values(parsed).map((value) => flattenTranscript(value))); + } + + return ""; +} + +export function collectCells( + persistedStore: any, + table: string, + rowId: string, + fields: string[], +): Record { + return fields.reduce>((acc, field) => { + acc[field] = persistedStore.getCell(table, rowId, field); + return acc; + }, {}); +} diff --git a/apps/desktop2/src/contexts/search/index.tsx b/apps/desktop2/src/contexts/search/index.tsx deleted file mode 100644 index 042ffa73b0..0000000000 --- a/apps/desktop2/src/contexts/search/index.tsx +++ /dev/null @@ -1,587 +0,0 @@ -import { useRouteContext } from "@tanstack/react-router"; -import { createContext, useCallback, useContext, useEffect, useMemo, useRef, useState } from "react"; - -import { Highlight } from "@orama/highlight"; -import { create, insert, Orama, search as oramaSearch } from "@orama/orama"; -import { pluginQPS } from "@orama/plugin-qps"; - -export type SearchEntityType = "session" | "human" | "organization"; - -export interface SearchFilters { - created_at?: { - gte?: number; - lte?: number; - gt?: number; - lt?: number; - eq?: number; - }; -} - -export interface SearchResult { - id: string; - type: SearchEntityType; - title: string; - titleHighlighted: string; - content: string; - contentHighlighted: string; - created_at: number; - folder_id: string; - event_id: string; - org_id: string; - is_user: boolean; - metadata: Record; - score: number; -} - -export interface SearchGroup { - key: string; - type: SearchEntityType; - title: string; - results: SearchResult[]; - totalCount: number; - topScore: number; -} - -export interface GroupedSearchResults { - groups: SearchGroup[]; - totalResults: number; - maxScore: number; -} - -interface SearchContextValue { - query: string; - setQuery: (query: string) => void; - filters: SearchFilters | null; - setFilters: (filters: SearchFilters | null) => void; - results: GroupedSearchResults | null; - isSearching: boolean; - isFocused: boolean; - isIndexing: boolean; - onFocus: () => void; - onBlur: () => void; -} - -interface SearchDocument { - id: string; - type: SearchEntityType; - title: string; - content: string; - created_at: number; - folder_id: string; - event_id: string; - org_id: string; - is_user: boolean; - metadata: string; -} - -interface SearchHit { - score: number; - document: SearchDocument; -} - -type SerializableObject = Record; - -const SCORE_PERCENTILE_THRESHOLD = 0.1; -const SPACE_REGEX = /\s+/g; - -const GROUP_TITLES: Record = { - session: "Sessions", - human: "People", - organization: "Organizations", -}; - -function safeParseJSON(value: unknown): unknown { - if (typeof value !== "string") { - return value; - } - - try { - return JSON.parse(value); - } catch { - return value; - } -} - -function normalizeQuery(query: string): string { - return query.trim().replace(SPACE_REGEX, " "); -} - -function toTrimmedString(value: unknown): string { - if (typeof value === "string") { - return value.trim(); - } - - return ""; -} - -function mergeContent(parts: unknown[]): string { - return parts - .map(toTrimmedString) - .filter(Boolean) - .join(" "); -} - -function parseMetadata(metadata: unknown): SerializableObject { - if (typeof metadata !== "string" || metadata.length === 0) { - return {}; - } - - const parsed = safeParseJSON(metadata); - if (typeof parsed === "object" && parsed !== null) { - return parsed as SerializableObject; - } - - return {}; -} - -function flattenTranscript(transcript: unknown): string { - if (transcript == null) { - return ""; - } - - const parsed = safeParseJSON(transcript); - - if (typeof parsed === "string") { - return parsed; - } - - if (Array.isArray(parsed)) { - return mergeContent( - parsed.map((segment) => { - if (!segment) { - return ""; - } - - if (typeof segment === "string") { - return segment; - } - - if (typeof segment === "object") { - const record = segment as Record; - const preferred = record.text ?? record.content; - if (typeof preferred === "string") { - return preferred; - } - - return flattenTranscript(Object.values(record)); - } - - return ""; - }), - ); - } - - if (typeof parsed === "object" && parsed !== null) { - return mergeContent(Object.values(parsed).map((value) => flattenTranscript(value))); - } - - return ""; -} - -function collectCells( - persistedStore: any, - table: string, - rowId: string, - fields: string[], -): Record { - return fields.reduce>((acc, field) => { - acc[field] = persistedStore.getCell(table, rowId, field); - return acc; - }, {}); -} - -function createSessionSearchableContent(row: Record): string { - return mergeContent([ - row.raw_md, - row.enhanced_md, - flattenTranscript(row.transcript), - ]); -} - -function createHumanSearchableContent(row: Record): string { - return mergeContent([row.email, row.job_title, row.linkedin_username]); -} - -function toNumber(value: unknown): number { - if (typeof value === "number") { - return value; - } - if (typeof value === "string") { - const parsed = Number(value); - return isNaN(parsed) ? 0 : parsed; - } - return 0; -} - -function toString(value: unknown): string { - if (typeof value === "string" && value.length > 0) { - return value; - } - return ""; -} - -function toBoolean(value: unknown): boolean { - if (typeof value === "boolean") { - return value; - } - return false; -} - -function indexSessions(db: Orama, persistedStore: any): void { - const fields = [ - "user_id", - "created_at", - "folder_id", - "event_id", - "title", - "raw_md", - "enhanced_md", - "transcript", - ]; - - persistedStore.forEachRow("sessions", (rowId: string) => { - const row = collectCells(persistedStore, "sessions", rowId, fields); - const title = toTrimmedString(row.title) || "Untitled"; - - void insert(db, { - id: rowId, - type: "session", - title, - content: createSessionSearchableContent(row), - created_at: toNumber(row.created_at), - folder_id: toString(row.folder_id), - event_id: toString(row.event_id), - org_id: "", - is_user: false, - metadata: JSON.stringify({}), - }); - }); -} - -function indexHumans(db: Orama, persistedStore: any): void { - const fields = [ - "name", - "email", - "org_id", - "job_title", - "linkedin_username", - "is_user", - "created_at", - ]; - - persistedStore.forEachRow("humans", (rowId: string) => { - const row = collectCells(persistedStore, "humans", rowId, fields); - const title = toTrimmedString(row.name) || "Unknown"; - - void insert(db, { - id: rowId, - type: "human", - title, - content: createHumanSearchableContent(row), - created_at: toNumber(row.created_at), - folder_id: "", - event_id: "", - org_id: toString(row.org_id), - is_user: toBoolean(row.is_user), - metadata: JSON.stringify({ - email: row.email, - job_title: row.job_title, - }), - }); - }); -} - -function indexOrganizations(db: Orama, persistedStore: any): void { - const fields = ["name", "created_at"]; - - persistedStore.forEachRow("organizations", (rowId: string) => { - const row = collectCells(persistedStore, "organizations", rowId, fields); - const title = toTrimmedString(row.name) || "Unknown Organization"; - - void insert(db, { - id: rowId, - type: "organization", - title, - content: "", - created_at: toNumber(row.created_at), - folder_id: "", - event_id: "", - org_id: "", - is_user: false, - metadata: JSON.stringify({}), - }); - }); -} - -function buildOramaFilters(filters: SearchFilters | null): Record | undefined { - if (!filters || !filters.created_at) { - return undefined; - } - - const createdAtConditions: Record = {}; - - if (filters.created_at.gte !== undefined) { - createdAtConditions.gte = filters.created_at.gte; - } - if (filters.created_at.lte !== undefined) { - createdAtConditions.lte = filters.created_at.lte; - } - if (filters.created_at.gt !== undefined) { - createdAtConditions.gt = filters.created_at.gt; - } - if (filters.created_at.lt !== undefined) { - createdAtConditions.lt = filters.created_at.lt; - } - if (filters.created_at.eq !== undefined) { - createdAtConditions.eq = filters.created_at.eq; - } - - return Object.keys(createdAtConditions).length > 0 - ? { created_at: createdAtConditions } - : undefined; -} - -function calculateDynamicThreshold(scores: number[]): number { - if (scores.length === 0) { - return 0; - } - - const sortedScores = [...scores].sort((a, b) => b - a); - const thresholdIndex = Math.floor(sortedScores.length * SCORE_PERCENTILE_THRESHOLD); - - return sortedScores[Math.min(thresholdIndex, sortedScores.length - 1)] || 0; -} - -function createSearchResult(hit: SearchHit, query: string): SearchResult { - const highlighter = new Highlight(); - const titleHighlighted = highlighter.highlight(hit.document.title, query); - const contentHighlighted = highlighter.highlight(hit.document.content, query); - - return { - id: hit.document.id, - type: hit.document.type, - title: hit.document.title, - titleHighlighted: titleHighlighted.HTML, - content: hit.document.content, - contentHighlighted: contentHighlighted.HTML, - created_at: hit.document.created_at, - folder_id: hit.document.folder_id, - event_id: hit.document.event_id, - org_id: hit.document.org_id, - is_user: hit.document.is_user, - metadata: parseMetadata(hit.document.metadata), - score: hit.score, - }; -} - -function sortResultsByScore(a: SearchResult, b: SearchResult): number { - return b.score - a.score; -} - -function toGroup( - type: SearchEntityType, - results: SearchResult[], -): SearchGroup { - const topScore = results[0]?.score || 0; - - return { - key: type, - type, - title: GROUP_TITLES[type], - results, - totalCount: results.length, - topScore, - }; -} - -function groupSearchResults( - hits: SearchHit[], - query: string, -): GroupedSearchResults { - if (hits.length === 0) { - return { - groups: [], - totalResults: 0, - maxScore: 0, - }; - } - - const scores = hits.map((hit) => hit.score); - const maxScore = Math.max(...scores); - const threshold = calculateDynamicThreshold(scores); - - const grouped = hits.reduce>((acc, hit) => { - if (hit.score < threshold) { - return acc; - } - - const key = hit.document.type; - const list = acc.get(key) ?? []; - list.push(createSearchResult(hit, query)); - acc.set(key, list); - return acc; - }, new Map()); - - const groups = Array.from(grouped.entries()) - .map(([type, results]) => toGroup(type, results.sort(sortResultsByScore))) - .sort((a, b) => b.topScore - a.topScore); - - const totalResults = groups.reduce((count, group) => count + group.totalCount, 0); - - return { - groups, - totalResults, - maxScore, - }; -} - -const SearchContext = createContext(null); - -export function SearchProvider({ children }: { children: React.ReactNode }) { - const { persistedStore } = useRouteContext({ from: "__root__" }); - - const [query, setQuery] = useState(""); - const [filters, setFilters] = useState(null); - const [isSearching, setIsSearching] = useState(false); - const [isFocused, setIsFocused] = useState(false); - const [isIndexing, setIsIndexing] = useState(false); - const [searchHits, setSearchHits] = useState([]); - const [searchQuery, setSearchQuery] = useState(""); - - const oramaInstance = useRef | null>(null); - - const resetSearchState = useCallback(() => { - setSearchHits([]); - setSearchQuery(""); - }, []); - - const createIndex = useCallback(async () => { - if (!persistedStore || isIndexing) { - return; - } - - setIsIndexing(true); - - try { - const db = create({ - schema: { - id: "string", - type: "enum", - title: "string", - content: "string", - created_at: "number", - folder_id: "string", - event_id: "string", - org_id: "string", - is_user: "boolean", - metadata: "string", - } as const, - plugins: [pluginQPS()], - }); - - indexSessions(db, persistedStore); - indexHumans(db, persistedStore); - indexOrganizations(db, persistedStore); - - oramaInstance.current = db; - } catch (error) { - console.error("Failed to create search index:", error); - } finally { - setIsIndexing(false); - } - }, [persistedStore, isIndexing]); - - const performSearch = useCallback( - async (searchQueryInput: string, searchFilters: SearchFilters | null) => { - const normalizedQuery = normalizeQuery(searchQueryInput); - - if (!oramaInstance.current || normalizedQuery.length < 1) { - resetSearchState(); - setIsSearching(false); - return; - } - - setIsSearching(true); - - try { - const whereClause = buildOramaFilters(searchFilters); - - const searchResults = await oramaSearch(oramaInstance.current, { - term: normalizedQuery, - boost: { - title: 3, - content: 1, - }, - limit: 100, - tolerance: 1, - ...(whereClause && { where: whereClause }), - }); - - const hits = searchResults.hits as unknown as SearchHit[]; - setSearchHits(hits); - setSearchQuery(normalizedQuery); - } catch (error) { - console.error("Search failed:", error); - resetSearchState(); - } finally { - setIsSearching(false); - } - }, - [resetSearchState], - ); - - useEffect(() => { - const normalizedQuery = normalizeQuery(query); - - if (normalizedQuery.length < 1) { - resetSearchState(); - setIsSearching(false); - } else { - void performSearch(normalizedQuery, filters); - } - }, [query, filters, performSearch, resetSearchState]); - - const onFocus = useCallback(() => { - setIsFocused(true); - if (!oramaInstance.current) { - void createIndex(); - } - }, [createIndex]); - - const onBlur = useCallback(() => { - setIsFocused(false); - }, []); - - const results = useMemo(() => { - if (searchHits.length === 0 || !searchQuery) { - return null; - } - return groupSearchResults(searchHits, searchQuery); - }, [searchHits, searchQuery]); - - const value = useMemo( - () => ({ - query, - setQuery, - filters, - setFilters, - results, - isSearching, - isFocused, - isIndexing, - onFocus, - onBlur, - }), - [query, filters, results, isSearching, isFocused, isIndexing, onFocus, onBlur], - ); - - return {children}; -} - -export function useSearch() { - const context = useContext(SearchContext); - if (!context) { - throw new Error("useSearch must be used within SearchProvider"); - } - return context; -} diff --git a/apps/desktop2/src/contexts/search/ui.tsx b/apps/desktop2/src/contexts/search/ui.tsx new file mode 100644 index 0000000000..372c6ce185 --- /dev/null +++ b/apps/desktop2/src/contexts/search/ui.tsx @@ -0,0 +1,230 @@ +import { Highlight } from "@orama/highlight"; +import { createContext, useCallback, useContext, useEffect, useMemo, useState } from "react"; + +import type { SearchDocument, SearchEntityType, SearchFilters, SearchHit } from "./engine"; +import { useSearchEngine } from "./engine"; + +export type { SearchEntityType, SearchFilters } from "./engine"; + +export type SearchResult = SearchDocument & { + titleHighlighted: string; + contentHighlighted: string; + score: number; +}; + +export interface SearchGroup { + key: string; + type: SearchEntityType; + title: string; + results: SearchResult[]; + totalCount: number; + topScore: number; +} + +export interface GroupedSearchResults { + groups: SearchGroup[]; + totalResults: number; + maxScore: number; +} + +interface SearchUIContextValue { + query: string; + setQuery: (query: string) => void; + filters: SearchFilters | null; + setFilters: (filters: SearchFilters | null) => void; + results: GroupedSearchResults | null; + isSearching: boolean; + isFocused: boolean; + isIndexing: boolean; + onFocus: () => void; + onBlur: () => void; +} + +const SCORE_PERCENTILE_THRESHOLD = 0.1; + +const GROUP_TITLES: Record = { + session: "Sessions", + human: "People", + organization: "Organizations", +}; + +function calculateDynamicThreshold(scores: number[]): number { + if (scores.length === 0) { + return 0; + } + + const sortedScores = [...scores].sort((a, b) => b - a); + const thresholdIndex = Math.floor(sortedScores.length * SCORE_PERCENTILE_THRESHOLD); + + return sortedScores[Math.min(thresholdIndex, sortedScores.length - 1)] || 0; +} + +function createSearchResult(hit: SearchHit, query: string): SearchResult { + const highlighter = new Highlight(); + const titleHighlighted = highlighter.highlight(hit.document.title, query); + const contentHighlighted = highlighter.highlight(hit.document.content, query); + + return { + id: hit.document.id, + type: hit.document.type, + title: hit.document.title, + titleHighlighted: titleHighlighted.HTML, + content: hit.document.content, + contentHighlighted: contentHighlighted.HTML, + created_at: hit.document.created_at, + score: hit.score, + }; +} + +function sortResultsByScore(a: SearchResult, b: SearchResult): number { + return b.score - a.score; +} + +function toGroup( + type: SearchEntityType, + results: SearchResult[], +): SearchGroup { + const topScore = results[0]?.score || 0; + + return { + key: type, + type, + title: GROUP_TITLES[type], + results, + totalCount: results.length, + topScore, + }; +} + +function groupSearchResults( + hits: SearchHit[], + query: string, +): GroupedSearchResults { + if (hits.length === 0) { + return { + groups: [], + totalResults: 0, + maxScore: 0, + }; + } + + const scores = hits.map((hit) => hit.score); + const maxScore = Math.max(...scores); + const threshold = calculateDynamicThreshold(scores); + + const grouped = hits.reduce>((acc, hit) => { + if (hit.score < threshold) { + return acc; + } + + const key = hit.document.type; + const list = acc.get(key) ?? []; + list.push(createSearchResult(hit, query)); + acc.set(key, list); + return acc; + }, new Map()); + + const groups = Array.from(grouped.entries()) + .map(([type, results]) => toGroup(type, results.sort(sortResultsByScore))) + .sort((a, b) => b.topScore - a.topScore); + + const totalResults = groups.reduce((count, group) => count + group.totalCount, 0); + + return { + groups, + totalResults, + maxScore, + }; +} + +const SearchUIContext = createContext(null); + +export function SearchUIProvider({ children }: { children: React.ReactNode }) { + const { search, isIndexing } = useSearchEngine(); + + const [query, setQuery] = useState(""); + const [filters, setFilters] = useState(null); + const [isSearching, setIsSearching] = useState(false); + const [isFocused, setIsFocused] = useState(false); + const [searchHits, setSearchHits] = useState([]); + const [searchQuery, setSearchQuery] = useState(""); + + const resetSearchState = useCallback(() => { + setSearchHits([]); + setSearchQuery(""); + }, []); + + const performSearch = useCallback( + async (searchQueryInput: string, searchFilters: SearchFilters | null) => { + if (searchQueryInput.trim().length < 1) { + resetSearchState(); + setIsSearching(false); + return; + } + + setIsSearching(true); + + try { + const hits = await search(searchQueryInput, searchFilters); + setSearchHits(hits); + setSearchQuery(searchQueryInput.trim()); + } catch (error) { + console.error("Search failed:", error); + resetSearchState(); + } finally { + setIsSearching(false); + } + }, + [search, resetSearchState], + ); + + useEffect(() => { + if (query.trim().length < 1) { + resetSearchState(); + setIsSearching(false); + } else { + void performSearch(query, filters); + } + }, [query, filters, performSearch, resetSearchState]); + + const onFocus = useCallback(() => { + setIsFocused(true); + }, []); + + const onBlur = useCallback(() => { + setIsFocused(false); + }, []); + + const results = useMemo(() => { + if (searchHits.length === 0 || !searchQuery) { + return null; + } + return groupSearchResults(searchHits, searchQuery); + }, [searchHits, searchQuery]); + + const value = useMemo( + () => ({ + query, + setQuery, + filters, + setFilters, + results, + isSearching, + isFocused, + isIndexing, + onFocus, + onBlur, + }), + [query, filters, results, isSearching, isFocused, isIndexing, onFocus, onBlur], + ); + + return {children}; +} + +export function useSearch() { + const context = useContext(SearchUIContext); + if (!context) { + throw new Error("useSearch must be used within SearchUIProvider"); + } + return context; +} diff --git a/apps/desktop2/src/contexts/tool.tsx b/apps/desktop2/src/contexts/tool.tsx new file mode 100644 index 0000000000..876920b9ca --- /dev/null +++ b/apps/desktop2/src/contexts/tool.tsx @@ -0,0 +1,75 @@ +import { createContext, useCallback, useContext, useMemo, useRef } from "react"; + +export interface ToolRegistry { + getForTransport(): Record; + invoke(key: string, input: any): Promise; + getTool(key: string): any; + getAllTools(): Array<{ key: string; tool: any }>; + register(key: string, tool: any): void; +} + +interface ToolRegistryContextValue { + registry: ToolRegistry; +} + +const ToolRegistryContext = createContext(null); + +export function ToolRegistryProvider({ children }: { children: React.ReactNode }) { + const toolsRef = useRef>(new Map()); + + const register = useCallback((key: string, tool: any) => { + toolsRef.current.set(key, tool); + }, []); + + const getForTransport = useCallback((): Record => { + const tools: Record = {}; + toolsRef.current.forEach((tool, key) => { + tools[key] = tool; + }); + return tools; + }, []); + + const invoke = useCallback(async (key: string, input: any): Promise => { + const tool = toolsRef.current.get(key); + if (!tool) { + throw new Error(`Tool "${key}" not found in registry`); + } + + if (!tool.execute) { + throw new Error(`Tool "${key}" does not have an execute function`); + } + + return await tool.execute(input); + }, []); + + const getTool = useCallback((key: string): any => { + return toolsRef.current.get(key); + }, []); + + const getAllTools = useCallback((): Array<{ key: string; tool: any }> => { + return Array.from(toolsRef.current.entries()).map(([key, tool]) => ({ key, tool })); + }, []); + + const registry: ToolRegistry = useMemo( + () => ({ + getForTransport, + invoke, + getTool, + getAllTools, + register, + }), + [getForTransport, invoke, getTool, getAllTools, register], + ); + + const value = useMemo(() => ({ registry }), [registry]); + + return {children}; +} + +export function useToolRegistry(): ToolRegistry { + const context = useContext(ToolRegistryContext); + if (!context) { + throw new Error("useToolRegistry must be used within ToolRegistryProvider"); + } + return context.registry; +} diff --git a/apps/desktop2/src/routes/app/main/_layout.tsx b/apps/desktop2/src/routes/app/main/_layout.tsx index 5f23ad6c67..e2e7d9a0d7 100644 --- a/apps/desktop2/src/routes/app/main/_layout.tsx +++ b/apps/desktop2/src/routes/app/main/_layout.tsx @@ -1,8 +1,13 @@ import { createFileRoute, Outlet, useRouteContext } from "@tanstack/react-router"; import { useEffect } from "react"; -import { SearchProvider } from "../../../contexts/search"; +import { toolFactories } from "../../../chat/tools"; +import { useSearchEngine } from "../../../contexts/search/engine"; +import { SearchEngineProvider } from "../../../contexts/search/engine"; +import { SearchUIProvider } from "../../../contexts/search/ui"; import { ShellProvider } from "../../../contexts/shell"; +import { useToolRegistry } from "../../../contexts/tool"; +import { ToolRegistryProvider } from "../../../contexts/tool"; import { useTabs } from "../../../store/zustand/tabs"; import { id } from "../../../utils"; @@ -11,16 +16,38 @@ export const Route = createFileRoute("/app/main/_layout")({ }); function Component() { + const { persistedStore } = useRouteContext({ from: "__root__" }); + return ( - - - - + + + + + + + + + ); } +function ToolRegistration() { + const registry = useToolRegistry(); + const { search } = useSearchEngine(); + + useEffect(() => { + const deps = { search }; + + Object.entries(toolFactories).forEach(([key, factory]) => { + registry.register(key, factory(deps)); + }); + }, [registry, search]); + + return null; +} + // TOOD function NotSureAboutThis() { const { persistedStore, internalStore } = useRouteContext({ from: "__root__" });