From 60b8aafe1d85315a78904de856ffdaf0a48e4329 Mon Sep 17 00:00:00 2001 From: Olasunkanmi Oyinlola Date: Sun, 9 Feb 2025 11:28:42 +0800 Subject: [PATCH 1/2] feat(gemini): Integrate tool selection and execution This commit introduces the following changes: - Updates the IToolConfig interface to align with expected tool behavior. - Adds database connection functionality to extension activation. - Implements tool retrieval for Gemini and enables tool execution during content generation. - Initializes the CodeBuddyToolProvider and ContextRetriever during Gemini LLM instantiation. - Modifies tool base class to support configuration. - Introduces a vector database search tool and factory. - Adds a tool provider and factory to register tools. - Adds database search tool and functionality. - Adds tool execution capability to the Gemini model. --- src/application/interfaces/agent.interface.ts | 5 +-- src/extension.ts | 2 +- src/llms/gemini/gemini.ts | 39 ++++++++-------- src/providers/tool.ts | 25 ++++++----- src/services/context-retriever.ts | 8 ++++ src/tools/base.ts | 6 +-- src/tools/database/search-tool.ts | 26 +++++------ src/utils/prompt.ts | 44 +------------------ 8 files changed, 57 insertions(+), 98 deletions(-) diff --git a/src/application/interfaces/agent.interface.ts b/src/application/interfaces/agent.interface.ts index 2b417b6..f7b19dc 100644 --- a/src/application/interfaces/agent.interface.ts +++ b/src/application/interfaces/agent.interface.ts @@ -28,8 +28,5 @@ export interface ICodeBuddyToolConfig { } export interface IToolConfig extends ICodeBuddyToolConfig { - createInstance: ( - config: ICodeBuddyToolConfig, - retriever?: any, - ) => CodeBuddyTool; + createInstance: (config: ICodeBuddyToolConfig, retriever?: any) => any; } diff --git a/src/extension.ts b/src/extension.ts index 92ab3db..d74964f 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -52,7 +52,7 @@ let agentEventEmmitter: EventEmitter; export async function activate(context: vscode.ExtensionContext) { try { Memory.getInstance(); - // await connectDB(); + await connectDB(); // const x = CodeRepository.getInstance(); // const apiKey = getGeminiAPIKey(); // const embeddingService = new EmbeddingService(apiKey); diff --git a/src/llms/gemini/gemini.ts b/src/llms/gemini/gemini.ts index 17d0013..c553f3b 100644 --- a/src/llms/gemini/gemini.ts +++ b/src/llms/gemini/gemini.ts @@ -5,17 +5,16 @@ import { GenerateContentResult, GenerativeModel, GoogleGenerativeAI, - Tool, } from "@google/generative-ai"; import * as vscode from "vscode"; import { Orchestrator } from "../../agents/orchestrator"; -import { ProcessInputResult } from "../../application/interfaces/agent.interface"; +import { COMMON } from "../../application/constant"; +import { Memory } from "../../memory/base"; +import { CodeBuddyToolProvider } from "../../providers/tool"; import { createPrompt } from "../../utils/prompt"; import { BaseLLM } from "../base"; import { GeminiModelResponseType, ILlmConfig } from "../interface"; import { IMessageInput, Message } from "../message"; -import { Memory } from "../../memory/base"; -import { COMMON } from "../../application/constant"; export class GeminiLLM extends BaseLLM @@ -33,6 +32,7 @@ export class GeminiLLM this.generativeAi = new GoogleGenerativeAI(this.config.apiKey); this.response = undefined; this.orchestrator = Orchestrator.getInstance(); + CodeBuddyToolProvider.initialize(); } static getInstance(config: ILlmConfig) { @@ -85,15 +85,13 @@ export class GeminiLLM } } - async generateContent( - userInput: string, - ): Promise> { + async generateContent(userInput: string): Promise { try { await this.buildChatHistory(userInput); const prompt = createPrompt(userInput); const contents = Memory.get(COMMON.GEMINI_CHAT_HISTORY); - const tools: Tool[] = []; - const model = this.getModel({ tools, systemInstruction: prompt }); + const tools: any = this.getTools(); + const model = this.getModel({ systemInstruction: prompt, tools }); const generateContentResponse: GenerateContentResult = await model.generateContent({ contents, @@ -103,18 +101,12 @@ export class GeminiLLM }); this.response = generateContentResponse; const { text, usageMetadata } = generateContentResponse.response; - const parsedResponse = this.orchestrator.parseResponse(text()); - const extractedQueries = parsedResponse.queries; - const extractedThought = parsedResponse.thought; const tokenCount = usageMetadata?.totalTokenCount ?? 0; - const result = { - queries: extractedQueries, - tokens: tokenCount, - prompt: userInput, - thought: extractedThought, - }; - this.orchestrator.publish("onQuery", JSON.stringify(result)); - return result; + this.orchestrator.publish( + "onQuery", + JSON.stringify("making function call"), + ); + return this.response; } catch (error: any) { this.orchestrator.publish("onError", error); vscode.window.showErrorMessage("Error processing user query"); @@ -169,6 +161,13 @@ export class GeminiLLM Memory.set(COMMON.GEMINI_CHAT_HISTORY, chatHistory); } + getTools() { + const tools = CodeBuddyToolProvider.getTools(); + return { + functionDeclarations: tools.map((t) => t.config()), + }; + } + public createSnapShot(data?: any): GeminiModelResponseType { return { ...this.response, ...data }; } diff --git a/src/providers/tool.ts b/src/providers/tool.ts index 8a68483..a4cfca3 100644 --- a/src/providers/tool.ts +++ b/src/providers/tool.ts @@ -3,23 +3,24 @@ import { ContextRetriever } from "../services/context-retriever"; import { CodeBuddyTool } from "../tools/base"; import { TOOL_CONFIGS } from "../tools/database/search-tool"; -type Retriever = Pick; - export class ToolFactory { private readonly tools: Map = new Map(); - constructor(private readonly contextRetriever: Retriever) { + private readonly contextRetriever: ContextRetriever; + constructor() { + this.contextRetriever = ContextRetriever.initialize(); for (const [name, { tool, useContextRetriever }] of Object.entries( TOOL_CONFIGS, )) { + const toolConfig = tool.prototype.config(); this.register({ - ...tool.prototype.config, + ...toolConfig, name, createInstance: useContextRetriever ? (_, contextRetriever) => { if (!contextRetriever) { throw new Error(`Context retriever is needed for ${name}`); } - return new tool(contextRetriever); + return new tool(this.contextRetriever); } : () => new tool(), }); @@ -31,9 +32,11 @@ export class ToolFactory { } getInstances(): CodeBuddyTool[] { - return Array.from(Object.values(this.tools)).map((tool) => + const y = this.tools.entries(); + const x = Array.from(y).map(([_, tool]) => tool.createInstance(tool, this.contextRetriever), ); + return x; } } @@ -42,15 +45,13 @@ export class CodeBuddyToolProvider { private static instance: CodeBuddyToolProvider | undefined; - private constructor(contextRetriever: Retriever) { - this.factory = new ToolFactory(contextRetriever); + private constructor() { + this.factory = new ToolFactory(); } - public static initialize(contextRetriever: Retriever) { + public static initialize() { if (!CodeBuddyToolProvider.instance) { - CodeBuddyToolProvider.instance = new CodeBuddyToolProvider( - contextRetriever, - ); + CodeBuddyToolProvider.instance = new CodeBuddyToolProvider(); } } diff --git a/src/services/context-retriever.ts b/src/services/context-retriever.ts index c9e7fec..566a861 100644 --- a/src/services/context-retriever.ts +++ b/src/services/context-retriever.ts @@ -9,6 +9,7 @@ export class ContextRetriever { private readonly embeddingService: EmbeddingService; private static readonly SEARCH_RESULT_COUNT = 2; private readonly logger: Logger; + private static instance: ContextRetriever; constructor() { this.codeRepository = CodeRepository.getInstance(); const geminiApiKey = getGeminiAPIKey(); @@ -16,6 +17,13 @@ export class ContextRetriever { this.logger = new Logger("ContextRetriever"); } + static initialize() { + if (!ContextRetriever.instance) { + ContextRetriever.instance = new ContextRetriever(); + } + return ContextRetriever.instance; + } + async retrieveContext(input: string): Promise { try { const embedding = await this.embeddingService.generateEmbedding(input); diff --git a/src/tools/base.ts b/src/tools/base.ts index 169b0d3..4f8d7b4 100644 --- a/src/tools/base.ts +++ b/src/tools/base.ts @@ -1,7 +1,7 @@ -import { ICodeBuddyToolConfig } from "../application/interfaces/agent.interface"; - export abstract class CodeBuddyTool { - constructor(public readonly config: ICodeBuddyToolConfig) {} + constructor() {} abstract execute(query: string): any; + + abstract config(): any; } diff --git a/src/tools/database/search-tool.ts b/src/tools/database/search-tool.ts index be1f8c1..f5fe4a3 100644 --- a/src/tools/database/search-tool.ts +++ b/src/tools/database/search-tool.ts @@ -1,15 +1,15 @@ -import { ContextRetriever } from "../../services/context-retriever"; -import { CodeBuddyTool } from "../base"; import { SchemaType } from "@google/generative-ai"; +import { ContextRetriever } from "../../services/context-retriever"; + +class SearchTool { + constructor(private readonly contextRetriever?: ContextRetriever) {} -class SearchTool extends CodeBuddyTool { - constructor( - private readonly contextRetriever?: Pick< - ContextRetriever, - "retrieveContext" - >, - ) { - super({ + public async execute(query: string) { + return await this.contextRetriever?.retrieveContext(query); + } + + config() { + return { name: "search_vector_db", description: "Perform a similarity search in the vector database based on user input", @@ -25,11 +25,7 @@ class SearchTool extends CodeBuddyTool { example: ["How was authentication implemented within this codebase"], required: ["query"], }, - }); - } - - public async execute(query: string) { - return await this.contextRetriever?.retrieveContext(query); + }; } } diff --git a/src/utils/prompt.ts b/src/utils/prompt.ts index 0175a49..18ddc48 100644 --- a/src/utils/prompt.ts +++ b/src/utils/prompt.ts @@ -1,4 +1,4 @@ -export const createPrompt = (query: string, thought?: string) => { +export const createPrompt = (query: string) => { return `You are an expert Information Retrieval Assistant. Transform user queries into precise keyword combinations with strategic reasoning and appropriate search operators. You have access these tools: @@ -13,49 +13,7 @@ export const createPrompt = (query: string, thought?: string) => { - Looking for external documentation - Checking latest features or updates - Finding general information - - When searching through the codebase, ALWAYS use the search_vector_db function first. - Only fall back to web search if you need external information. - -Core Rules: -1. Generate search queries that directly include appropriate operators -2. Keep base keywords minimal: 2-4 words preferred -3. Use exact match quotes for specific phrases that must stay together -4. Apply + operator for critical terms that must appear -5. Use - operator to exclude irrelevant or ambiguous terms -6. Add appropriate filters (filetype:, site:, lang:, loc:) when context suggests -7. Split queries only when necessary for distinctly different aspects -8. Preserve crucial qualifiers while removing fluff words -9. Make the query resistant to SEO manipulation - -Available Operators: -- "phrase" : exact match for phrases -- +term : must include term -- -term : exclude term -- filetype:pdf/doc : specific file type -- site:example.com : limit to specific site -- lang:xx : language filter (ISO 639-1 code) -- loc:xx : location filter (ISO 3166-1 code) -- intitle:term : term must be in title -- inbody:term : term must be in body text - -Examples with Strategic Reasoning: - -Input Query: Where is authentication handled in this codebase? - -Thought: This is a code structure and architecture query. The user is likely trying to understand how authentication is implemented within a specific codebase. User likely wants to call the vector database search function to retrieve relevant code snippets or file information - -Queries: [ "authentication middleware", "JWT implementation, "passport authentication setup" ] - -Input Query: Latest AWS Lambda features for serverless applications -Thought: This is a product research query focused on recent updates. User wants current information about specific technology features, likely for implementation purposes. User likely wants to call the websearch function to get the latest information. -Queries: [ - "aws lambda features site:aws.amazon.com intitle:2024", - "lambda serverless best practices +new -legacy" -] -Note Queries should always be in an array. Even if it is just one Query Now, process this query: Input Query: ${query} -Intention: ${thought} `; }; From 237f566381a4bd4f805daa3673028e53708f3d7a Mon Sep 17 00:00:00 2001 From: Olasunkanmi Oyinlola Date: Sun, 9 Feb 2025 11:29:02 +0800 Subject: [PATCH 2/2] Refactor(tool): Simplify tool instance creation - Refactors the getInstances method in ToolFactory to use Array.from(this.tools.values()) for cleaner code and better readability. --- src/providers/tool.ts | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/providers/tool.ts b/src/providers/tool.ts index a4cfca3..e8f1348 100644 --- a/src/providers/tool.ts +++ b/src/providers/tool.ts @@ -8,9 +8,7 @@ export class ToolFactory { private readonly contextRetriever: ContextRetriever; constructor() { this.contextRetriever = ContextRetriever.initialize(); - for (const [name, { tool, useContextRetriever }] of Object.entries( - TOOL_CONFIGS, - )) { + for (const [name, { tool, useContextRetriever }] of Object.entries(TOOL_CONFIGS)) { const toolConfig = tool.prototype.config(); this.register({ ...toolConfig, @@ -32,11 +30,7 @@ export class ToolFactory { } getInstances(): CodeBuddyTool[] { - const y = this.tools.entries(); - const x = Array.from(y).map(([_, tool]) => - tool.createInstance(tool, this.contextRetriever), - ); - return x; + return Array.from(this.tools.values()).map((tool) => tool.createInstance(tool, this.contextRetriever)); } }