diff --git a/packages/core/src/runtime.ts b/packages/core/src/runtime.ts index 646dc2b0ed..59145d5ec6 100644 --- a/packages/core/src/runtime.ts +++ b/packages/core/src/runtime.ts @@ -349,9 +349,9 @@ export class AgentRuntime implements IAgentRuntime { this.imageVisionModelProvider = this.character.imageVisionModelProvider ?? this.modelProvider; - elizaLogger.info("Selected model provider:", this.modelProvider); + // elizaLogger.info("Selected model provider:", this.modelProvider); duplicated log ln: 343 elizaLogger.info( - "Selected image model provider:", + "Selected image vision model provider:", this.imageVisionModelProvider ); diff --git a/packages/plugin-node/src/services/image.ts b/packages/plugin-node/src/services/image.ts index 56a59c9056..adffe10cb3 100644 --- a/packages/plugin-node/src/services/image.ts +++ b/packages/plugin-node/src/services/image.ts @@ -189,6 +189,51 @@ class OpenAIImageProvider implements ImageProvider { } } + +class GroqImageProvider implements ImageProvider { + constructor(private runtime: IAgentRuntime) {} + + async initialize(): Promise {} + + async describeImage( + imageData: Buffer, + mimeType: string + ): Promise<{ title: string; description: string }> { + const imageUrl = convertToBase64DataUrl(imageData, mimeType); + + const content = [ + { type: "text", text: IMAGE_DESCRIPTION_PROMPT }, + { type: "image_url", image_url: { url: imageUrl } }, + ]; + + const endpoint = + this.runtime.imageVisionModelProvider === ModelProviderName.GROQ + ? getEndpoint(this.runtime.imageVisionModelProvider) + : "https://api.groq.com/openai/v1/"; + + const response = await fetch(endpoint + "/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.runtime.getSetting("GROQ_API_KEY")}`, + }, + body: JSON.stringify({ + model: /*this.runtime.imageVisionModelName ||*/ "llama-3.2-90b-vision-preview", + messages: [{ role: "user", content }], + max_tokens: 1024, + }), + }); + + if (!response.ok) { + await handleApiError(response, "Groq"); + } + + const data = await response.json(); + return parseImageResponse(data.choices[0].message.content); + } +} + + class GoogleImageProvider implements ImageProvider { constructor(private runtime: IAgentRuntime) {} @@ -280,6 +325,12 @@ export class ImageDescriptionService ) { this.provider = new OpenAIImageProvider(this.runtime); elizaLogger.debug("Using openai for vision model"); + } else if ( + this.runtime.imageVisionModelProvider === + ModelProviderName.GROQ + ) { + this.provider = new GroqImageProvider(this.runtime); + elizaLogger.debug("Using Groq for vision model"); } else { elizaLogger.error( `Unsupported image vision model provider: ${this.runtime.imageVisionModelProvider}` @@ -291,6 +342,9 @@ export class ImageDescriptionService } else if (model === models[ModelProviderName.GOOGLE]) { this.provider = new GoogleImageProvider(this.runtime); elizaLogger.debug("Using google for vision model"); + } else if (model === models[ModelProviderName.GROQ]) { + this.provider = new GroqImageProvider(this.runtime); + elizaLogger.debug("Using groq for vision model"); } else { elizaLogger.debug("Using default openai for vision model"); this.provider = new OpenAIImageProvider(this.runtime);