diff --git a/apps/desktop/src/utils/tag-generation.ts b/apps/desktop/src/utils/tag-generation.ts index 4d136f6342..ed02b93cf0 100644 --- a/apps/desktop/src/utils/tag-generation.ts +++ b/apps/desktop/src/utils/tag-generation.ts @@ -6,63 +6,68 @@ import { commands as templateCommands } from "@hypr/plugin-template"; import { generateText, localProviderName, modelProvider } from "@hypr/utils/ai"; export async function generateTagsForSession(sessionId: string): Promise { - const { type: connectionType } = await connectorCommands.getLlmConnection(); + try { + const { type: connectionType } = await connectorCommands.getLlmConnection(); - const config = await dbCommands.getConfig(); - const session = await dbCommands.getSession({ id: sessionId }); - if (!session) { - throw new Error("Session not found"); - } + const config = await dbCommands.getConfig(); + const session = await dbCommands.getSession({ id: sessionId }); + if (!session) { + throw new Error("Session not found"); + } - const historicalTags = await dbCommands.listAllTags(); - const currentTags = await dbCommands.listSessionTags(sessionId); + const historicalTags = await dbCommands.listAllTags(); + const currentTags = await dbCommands.listSessionTags(sessionId); - const extractHashtags = (text: string): string[] => { - const hashtagRegex = /#(\w+)/g; - return Array.from(text.matchAll(hashtagRegex), match => match[1]); - }; + const extractHashtags = (text: string): string[] => { + const hashtagRegex = /#(\w+)/g; + return Array.from(text.matchAll(hashtagRegex), match => match[1]); + }; - const existingHashtags = extractHashtags(session.raw_memo_html); + const existingHashtags = extractHashtags(session.raw_memo_html); - const systemPrompt = await templateCommands.render( - "suggest_tags.system", - { config, type: connectionType }, - ); + const systemPrompt = await templateCommands.render( + "suggest_tags.system", + { config, type: connectionType }, + ); - const userPrompt = await templateCommands.render( - "suggest_tags.user", - { - title: session.title, - content: session.raw_memo_html, - existing_hashtags: existingHashtags, - formal_tags: currentTags.map(t => t.name), - historical_tags: historicalTags.slice(0, 20).map(t => t.name), - }, - ); + const userPrompt = await templateCommands.render( + "suggest_tags.user", + { + title: session.title, + content: session.raw_memo_html, + existing_hashtags: existingHashtags, + formal_tags: currentTags.map(t => t.name), + historical_tags: historicalTags.slice(0, 20).map(t => t.name), + }, + ); - const provider = await modelProvider(); - const model = provider.languageModel("defaultModel"); + const provider = await modelProvider(); + const model = provider.languageModel("defaultModel"); - const result = await generateText({ - model, - messages: [ - { role: "system", content: systemPrompt }, - { role: "user", content: userPrompt }, - ], - providerOptions: { - [localProviderName]: { - metadata: { - grammar: "tags", + const result = await generateText({ + model, + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userPrompt }, + ], + providerOptions: { + [localProviderName]: { + metadata: { + grammar: "tags", + }, }, }, - }, - }); + }); - const schema = z.preprocess( - (val) => (typeof val === "string" ? JSON.parse(val) : val), - z.array(z.string().min(1)).min(1).max(5), - ); + const schema = z.preprocess( + (val) => (typeof val === "string" ? JSON.parse(val) : val), + z.array(z.string().min(1)).min(1).max(5), + ); - const parsed = schema.safeParse(result.text); - return parsed.success ? parsed.data : []; + const parsed = schema.safeParse(result.text); + return parsed.success ? parsed.data : []; + } catch (error) { + console.error("Tag generation failed:", error); + return []; + } } diff --git a/crates/gbnf/assets/tags.gbnf b/crates/gbnf/assets/tags.gbnf index 6638c24bdb..b17423cc1c 100644 --- a/crates/gbnf/assets/tags.gbnf +++ b/crates/gbnf/assets/tags.gbnf @@ -1,3 +1,3 @@ -root ::= "[" "'" word "'" ("," ws "'" word "'")* "]" +root ::= "[" "\"" word "\"" ("," ws "\"" word "\"")* "]" word ::= [a-zA-Z0-9_-]+ ws ::= " "* \ No newline at end of file diff --git a/crates/gbnf/src/lib.rs b/crates/gbnf/src/lib.rs index d93e80ceb7..a9f8a3c622 100644 --- a/crates/gbnf/src/lib.rs +++ b/crates/gbnf/src/lib.rs @@ -47,8 +47,8 @@ mod tests { let gbnf = gbnf_validator::Validator::new().unwrap(); for (input, expected) in vec![ - ("['meeting', 'summary']", true), - ("['meeting', 'summary', '']", false), + ("[\"meeting\", \"summary\"]", true), + ("[\"meeting\", \"summary\", \"\"]", false), ] { let result = gbnf.validate(TAGS, input).unwrap(); assert_eq!(result, expected, "failed: {}", input);