diff --git a/core/.env.example b/core/.env.example index bf4cc7124ad..5d9069fef0c 100644 --- a/core/.env.example +++ b/core/.env.example @@ -28,6 +28,20 @@ XAI_API_KEY= XAI_MODEL= +#Set to Use for New OLLAMA provider +OLLAMA_SERVER_URL= #Leave blank for default localhost:11434 +OLLAMA_MODEL= +OLLAMA_EMBEDDING_MODEL= #default mxbai-embed-large +#To use custom model types for different tasks set these +SMALL_OLLAMA_MODEL= #default llama3.2 +MEDIUM_OLLAMA_MODEL= #default herems3 +LARGE_OLLAMA_MODEL= #default hermes3:70b + +#to still use the original LOCALLLAMA provider but with ollama +LOCAL_LLAMA_PROVIDER= #"OLLAMA" #Leave blank for LLAMA-CPP or add OLLAMA + + + # For asking Claude stuff ANTHROPIC_API_KEY= diff --git a/core/package.json b/core/package.json index ee3f10e7e61..71ef0070796 100644 --- a/core/package.json +++ b/core/package.json @@ -81,6 +81,7 @@ "@ai-sdk/google-vertex": "^0.0.42", "@ai-sdk/groq": "^0.0.3", "@ai-sdk/openai": "^0.0.70", + "ollama-ai-provider": "^0.16.1", "@anthropic-ai/sdk": "^0.30.1", "@cliqz/adblocker-playwright": "1.34.0", "@coral-xyz/anchor": "^0.30.1", diff --git a/core/src/core/embedding.ts b/core/src/core/embedding.ts index 1a5026768d4..1585044100e 100644 --- a/core/src/core/embedding.ts +++ b/core/src/core/embedding.ts @@ -10,8 +10,9 @@ export async function embed(runtime: IAgentRuntime, input: string) { // get the charcter, and handle by model type const model = models[runtime.character.settings.model]; - if (model !== ModelProvider.OPENAI) { + if (model !== ModelProvider.OPENAI && model !== ModelProvider.OLLAMA) { return await runtime.llamaService.getEmbeddingResponse(input); + //ollama supports embedding api so just use that } const embeddingModel = models[runtime.modelProvider].model.embedding; @@ -26,7 +27,8 @@ export async function embed(runtime: IAgentRuntime, input: string) { method: "POST", headers: { "Content-Type": "application/json", - Authorization: `Bearer ${runtime.token}`, + //Authorization: `Bearer ${runtime.token}`, + ...(runtime.modelProvider !== ModelProvider.OLLAMA && { Authorization: `Bearer ${runtime.token}` }), }, body: JSON.stringify({ input, @@ -36,7 +38,8 @@ export async function embed(runtime: IAgentRuntime, input: string) { }; try { const response = await fetch( - `${runtime.serverUrl}/embeddings`, + //`${runtime.serverUrl}/embeddings`, + `${runtime.serverUrl}${runtime.modelProvider === ModelProvider.OLLAMA ? '/v1' : ''}/embeddings`, requestOptions ); diff --git a/core/src/core/generation.ts b/core/src/core/generation.ts index c9eda52e7a7..516ee4675fa 100644 --- a/core/src/core/generation.ts +++ b/core/src/core/generation.ts @@ -16,6 +16,8 @@ import { generateText as aiGenerateText } from "ai"; import { createAnthropic } from "@ai-sdk/anthropic"; import { prettyConsole } from "../index.ts"; +import { createOllama } from 'ollama-ai-provider'; + /** * Send a message to the model for a text generateText - receive a string back and parse how you'd like * @param opts - The options for the generateText request. @@ -190,6 +192,30 @@ export async function generateText({ break; } + case ModelProvider.OLLAMA: { + console.log("Initializing Ollama model."); + + const ollamaProvider = createOllama({ + baseURL: models[provider].endpoint + "/api", + }) + const ollama = ollamaProvider(model); + + console.log('****** MODEL\n', model) + + const { text: ollamaResponse } = await aiGenerateText({ + model: ollama, + prompt: context, + temperature: temperature, + maxTokens: max_response_length, + frequencyPenalty: frequency_penalty, + presencePenalty: presence_penalty, + }); + + response = ollamaResponse; + } + console.log("Received response from Ollama model."); + break; + default: { const errorMessage = `Unsupported provider: ${provider}`; prettyConsole.error(errorMessage); diff --git a/core/src/core/models.ts b/core/src/core/models.ts index 3c8b275d4ed..be25220ba35 100644 --- a/core/src/core/models.ts +++ b/core/src/core/models.ts @@ -10,6 +10,7 @@ type Models = { [ModelProvider.GOOGLE]: Model; [ModelProvider.CLAUDE_VERTEX]: Model; [ModelProvider.REDPILL]: Model; + [ModelProvider.OLLAMA]: Model; // TODO: add OpenRouter - feel free to do this :) }; @@ -169,6 +170,23 @@ const models: Models = { [ModelClass.EMBEDDING]: "text-embedding-3-small", }, }, + [ModelProvider.OLLAMA]: { + settings: { + stop: [], + maxInputTokens: 128000, + maxOutputTokens: 8192, + frequency_penalty: 0.0, + presence_penalty: 0.0, + temperature: 0.3, + }, + endpoint: process.env.OLLAMA_SERVER_URL || "http://localhost:11434", + model: { + [ModelClass.SMALL]: process.env.SMALL_OLLAMA_MODEL || process.env.OLLAMA_MODEL || "llama3.2", + [ModelClass.MEDIUM]: process.env.MEDIUM_OLLAMA_MODEL ||process.env.OLLAMA_MODEL || "hermes3", + [ModelClass.LARGE]: process.env.LARGE_OLLAMA_MODEL || process.env.OLLAMA_MODEL || "hermes3:70b", + [ModelClass.EMBEDDING]: process.env.OLLAMA_EMBEDDING_MODEL || "mxbai-embed-large" + }, + }, }; export function getModel(provider: ModelProvider, type: ModelClass) { diff --git a/core/src/core/types.ts b/core/src/core/types.ts index b8b4a82d9db..672d1b1f290 100644 --- a/core/src/core/types.ts +++ b/core/src/core/types.ts @@ -107,7 +107,8 @@ export enum ModelProvider { LLAMALOCAL = "llama_local", GOOGLE = "google", CLAUDE_VERTEX = "claude_vertex", - REDPILL = "redpill" + REDPILL = "redpill", + OLLAMA = "ollama" } /** diff --git a/core/src/services/LlamaCppService.ts b/core/src/services/LlamaCppService.ts new file mode 100644 index 00000000000..e2c956f701f --- /dev/null +++ b/core/src/services/LlamaCppService.ts @@ -0,0 +1,411 @@ +import { fileURLToPath } from "url"; +import path from "path"; +import { + GbnfJsonSchema, + getLlama, + Llama, + LlamaContext, + LlamaContextSequence, + LlamaContextSequenceRepeatPenalty, + LlamaJsonSchemaGrammar, + LlamaModel, + Token, +} from "node-llama-cpp"; +import fs from "fs"; +import https from "https"; +import si from "systeminformation"; +import { wordsToPunish } from "./wordsToPunish.ts"; + +const __dirname = path.dirname(fileURLToPath(import.meta.url)); + +const jsonSchemaGrammar: Readonly<{ + type: string; + properties: { + user: { + type: string; + }; + content: { + type: string; + }; + }; +}> = { + type: "object", + properties: { + user: { + type: "string", + }, + content: { + type: "string", + }, + }, +}; + +interface QueuedMessage { + context: string; + temperature: number; + stop: string[]; + max_tokens: number; + frequency_penalty: number; + presence_penalty: number; + useGrammar: boolean; + resolve: (value: any | string | PromiseLike) => void; + reject: (reason?: any) => void; +} + +class LlamaCppService { + private static instance: LlamaCppService | null = null; + private llama: Llama | undefined; + private model: LlamaModel | undefined; + private modelPath: string; + private grammar: LlamaJsonSchemaGrammar | undefined; + private ctx: LlamaContext | undefined; + private sequence: LlamaContextSequence | undefined; + private modelUrl: string; + + private messageQueue: QueuedMessage[] = []; + private isProcessing: boolean = false; + private modelInitialized: boolean = false; + + private constructor() { + console.log("Constructing"); + this.llama = undefined; + this.model = undefined; + this.modelUrl = + "https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B-GGUF/resolve/main/Hermes-3-Llama-3.1-8B.Q8_0.gguf?download=true"; + const modelName = "model.gguf"; + console.log("modelName", modelName); + this.modelPath = path.join(__dirname, modelName); + try { + this.initializeModel(); + } catch (error) { + console.error("Error initializing model", error); + } + } + + public static getInstance(): LlamaCppService { + if (!LlamaCppService.instance) { + LlamaCppService.instance = new LlamaCppService(); + } + return LlamaCppService.instance; + } + + async initializeModel() { + try { + await this.checkModel(); + + const systemInfo = await si.graphics(); + const hasCUDA = systemInfo.controllers.some((controller) => + controller.vendor.toLowerCase().includes("nvidia") + ); + + if (hasCUDA) { + console.log("**** CUDA detected"); + } else { + console.log( + "**** No CUDA detected - local response will be slow" + ); + } + + this.llama = await getLlama({ + gpu: "cuda", + }); + console.log("Creating grammar"); + const grammar = new LlamaJsonSchemaGrammar( + this.llama, + jsonSchemaGrammar as GbnfJsonSchema + ); + this.grammar = grammar; + console.log("Loading model"); + console.log("this.modelPath", this.modelPath); + + this.model = await this.llama.loadModel({ + modelPath: this.modelPath, + }); + console.log("Model GPU support", this.llama.getGpuDeviceNames()); + console.log("Creating context"); + this.ctx = await this.model.createContext({ contextSize: 8192 }); + this.sequence = this.ctx.getSequence(); + + this.modelInitialized = true; + this.processQueue(); + } catch (error) { + console.error( + "Model initialization failed. Deleting model and retrying...", + error + ); + await this.deleteModel(); + await this.initializeModel(); + } + } + + async checkModel() { + console.log("Checking model"); + if (!fs.existsSync(this.modelPath)) { + console.log("this.modelPath", this.modelPath); + console.log("Model not found. Downloading..."); + + await new Promise((resolve, reject) => { + const file = fs.createWriteStream(this.modelPath); + let downloadedSize = 0; + + const downloadModel = (url: string) => { + https + .get(url, (response) => { + const isRedirect = + response.statusCode >= 300 && + response.statusCode < 400; + if (isRedirect) { + const redirectUrl = response.headers.location; + if (redirectUrl) { + console.log( + "Following redirect to:", + redirectUrl + ); + downloadModel(redirectUrl); + return; + } else { + console.error("Redirect URL not found"); + reject(new Error("Redirect URL not found")); + return; + } + } + + const totalSize = parseInt( + response.headers["content-length"] ?? "0", + 10 + ); + + response.on("data", (chunk) => { + downloadedSize += chunk.length; + file.write(chunk); + + // Log progress + const progress = ( + (downloadedSize / totalSize) * + 100 + ).toFixed(2); + process.stdout.write( + `Downloaded ${progress}%\r` + ); + }); + + response.on("end", () => { + file.end(); + console.log("\nModel downloaded successfully."); + resolve(); + }); + }) + .on("error", (err) => { + fs.unlink(this.modelPath, () => {}); // Delete the file async + console.error("Download failed:", err.message); + reject(err); + }); + }; + + downloadModel(this.modelUrl); + + file.on("error", (err) => { + fs.unlink(this.modelPath, () => {}); // Delete the file async + console.error("File write error:", err.message); + reject(err); + }); + }); + } else { + console.log("Model already exists."); + } + } + + async deleteModel() { + if (fs.existsSync(this.modelPath)) { + fs.unlinkSync(this.modelPath); + console.log("Model deleted."); + } + } + + async queueMessageCompletion( + context: string, + temperature: number, + stop: string[], + frequency_penalty: number, + presence_penalty: number, + max_tokens: number + ): Promise { + console.log("Queueing message generateText"); + return new Promise((resolve, reject) => { + this.messageQueue.push({ + context, + temperature, + stop, + frequency_penalty, + presence_penalty, + max_tokens, + useGrammar: true, + resolve, + reject, + }); + this.processQueue(); + }); + } + + async queueTextCompletion( + context: string, + temperature: number, + stop: string[], + frequency_penalty: number, + presence_penalty: number, + max_tokens: number + ): Promise { + return new Promise((resolve, reject) => { + this.messageQueue.push({ + context, + temperature, + stop, + frequency_penalty, + presence_penalty, + max_tokens, + useGrammar: false, + resolve, + reject, + }); + this.processQueue(); + }); + } + + private async processQueue() { + if ( + this.isProcessing || + this.messageQueue.length === 0 || + !this.modelInitialized + ) { + return; + } + + this.isProcessing = true; + + while (this.messageQueue.length > 0) { + const message = this.messageQueue.shift(); + if (message) { + try { + console.log("Processing message"); + const response = await this.getCompletionResponse( + message.context, + message.temperature, + message.stop, + message.frequency_penalty, + message.presence_penalty, + message.max_tokens, + message.useGrammar + ); + message.resolve(response); + } catch (error) { + message.reject(error); + } + } + } + + this.isProcessing = false; + } + + private async getCompletionResponse( + context: string, + temperature: number, + stop: string[], + frequency_penalty: number, + presence_penalty: number, + max_tokens: number, + useGrammar: boolean + ): Promise { + if (!this.sequence) { + throw new Error("Model not initialized."); + } + + const tokens = this.model!.tokenize(context); + + // tokenize the words to punish + const wordsToPunishTokens = wordsToPunish + .map((word) => this.model!.tokenize(word)) + .flat(); + + const repeatPenalty: LlamaContextSequenceRepeatPenalty = { + punishTokens: () => wordsToPunishTokens, + penalty: 1.2, + frequencyPenalty: frequency_penalty, + presencePenalty: presence_penalty, + }; + + const responseTokens: Token[] = []; + console.log("Evaluating tokens"); + for await (const token of this.sequence.evaluate(tokens, { + temperature: Number(temperature), + repeatPenalty: repeatPenalty, + grammarEvaluationState: useGrammar ? this.grammar : undefined, + yieldEogToken: false, + })) { + const current = this.model.detokenize([...responseTokens, token]); + if ([...stop].some((s) => current.includes(s))) { + console.log("Stop sequence found"); + break; + } + + responseTokens.push(token); + process.stdout.write(this.model!.detokenize([token])); + if (useGrammar) { + if (current.replaceAll("\n", "").includes("}```")) { + console.log("JSON block found"); + break; + } + } + if (responseTokens.length > max_tokens) { + console.log("Max tokens reached"); + break; + } + } + + const response = this.model!.detokenize(responseTokens); + + if (!response) { + throw new Error("Response is undefined"); + } + + if (useGrammar) { + // extract everything between ```json and ``` + let jsonString = response.match(/```json(.*?)```/s)?.[1].trim(); + if (!jsonString) { + // try parsing response as JSON + try { + jsonString = JSON.stringify(JSON.parse(response)); + console.log("parsedResponse", jsonString); + } catch { + throw new Error("JSON string not found"); + } + } + try { + const parsedResponse = JSON.parse(jsonString); + if (!parsedResponse) { + throw new Error("Parsed response is undefined"); + } + console.log("AI: " + parsedResponse.content); + await this.sequence.clearHistory(); + return parsedResponse; + } catch (error) { + console.error("Error parsing JSON:", error); + } + } else { + console.log("AI: " + response); + await this.sequence.clearHistory(); + return response; + } + } + + async getEmbeddingResponse(input: string): Promise { + if (!this.model) { + throw new Error("Model not initialized. Call initialize() first."); + } + + const embeddingContext = await this.model.createEmbeddingContext(); + const embedding = await embeddingContext.getEmbeddingFor(input); + return embedding?.vector ? [...embedding.vector] : undefined; + } +} + +export default LlamaCppService; diff --git a/core/src/services/OllamaService.ts b/core/src/services/OllamaService.ts new file mode 100644 index 00000000000..524c3471343 --- /dev/null +++ b/core/src/services/OllamaService.ts @@ -0,0 +1,217 @@ +import { OpenAI } from 'openai'; +import * as dotenv from 'dotenv'; +import { debuglog } from 'util'; + +// Create debug logger +const debug = debuglog('LLAMA'); + +process.on('uncaughtException', (err) => { + debug('Uncaught Exception:', err); + process.exit(1); +}); + +process.on('unhandledRejection', (reason, promise) => { + debug('Unhandled Rejection at:', promise, 'reason:', reason); +}); + +interface QueuedMessage { + context: string; + temperature: number; + stop: string[]; + max_tokens: number; + frequency_penalty: number; + presence_penalty: number; + useGrammar: boolean; + resolve: (value: any | string | PromiseLike) => void; + reject: (reason?: any) => void; +} + +class OllamaService { + private static instance: OllamaService | null = null; + private openai: OpenAI; + private modelName: string; + private embeddingModelName: string = process.env.OLLAMA_EMBEDDING_MODEL || 'nomic-embed-text'; + private messageQueue: QueuedMessage[] = []; + private isProcessing: boolean = false; + + private constructor() { + debug('Constructing OllamaService'); + dotenv.config(); + this.modelName = process.env.OLLAMA_MODEL || 'llama3.2'; + this.openai = new OpenAI({ + baseURL: process.env.OLLAMA_SERVER_URL || 'http://localhost:11434/v1', + apiKey: 'ollama', + dangerouslyAllowBrowser: true + }); + debug(`Using model: ${this.modelName}`); + debug('OpenAI client initialized'); + } + + public static getInstance(): OllamaService { + debug('Getting OllamaService instance'); + if (!OllamaService.instance) { + debug('Creating new instance'); + OllamaService.instance = new OllamaService(); + } + return OllamaService.instance; + } + + // Adding initializeModel method to satisfy IOllamaService interface + public async initializeModel(): Promise { + debug('Initializing model...'); + try { + // Placeholder for model setup if needed + debug(`Model ${this.modelName} initialized successfully.`); + } catch (error) { + debug('Error during model initialization:', error); + throw error; + } + } + + async queueMessageCompletion( + context: string, + temperature: number, + stop: string[], + frequency_penalty: number, + presence_penalty: number, + max_tokens: number + ): Promise { + debug('Queueing message completion'); + return new Promise((resolve, reject) => { + this.messageQueue.push({ + context, + temperature, + stop, + frequency_penalty, + presence_penalty, + max_tokens, + useGrammar: true, + resolve, + reject, + }); + this.processQueue(); + }); + } + + async queueTextCompletion( + context: string, + temperature: number, + stop: string[], + frequency_penalty: number, + presence_penalty: number, + max_tokens: number + ): Promise { + debug('Queueing text completion'); + return new Promise((resolve, reject) => { + this.messageQueue.push({ + context, + temperature, + stop, + frequency_penalty, + presence_penalty, + max_tokens, + useGrammar: false, + resolve, + reject, + }); + this.processQueue(); + }); + } + + private async processQueue() { + debug(`Processing queue: ${this.messageQueue.length} items`); + if (this.isProcessing || this.messageQueue.length === 0) { + return; + } + + this.isProcessing = true; + + while (this.messageQueue.length > 0) { + const message = this.messageQueue.shift(); + if (message) { + try { + const response = await this.getCompletionResponse( + message.context, + message.temperature, + message.stop, + message.frequency_penalty, + message.presence_penalty, + message.max_tokens, + message.useGrammar + ); + message.resolve(response); + } catch (error) { + debug('Queue processing error:', error); + message.reject(error); + } + } + } + + this.isProcessing = false; + } + + private async getCompletionResponse( + context: string, + temperature: number, + stop: string[], + frequency_penalty: number, + presence_penalty: number, + max_tokens: number, + useGrammar: boolean + ): Promise { + debug('Getting completion response'); + try { + const completion = await this.openai.chat.completions.create({ + model: this.modelName, + messages: [{ role: 'user', content: context }], + temperature, + max_tokens, + stop, + frequency_penalty, + presence_penalty, + }); + + const response = completion.choices[0].message.content; + + if (useGrammar && response) { + try { + let jsonResponse = JSON.parse(response); + return jsonResponse; + } catch { + const jsonMatch = response.match(/```json\s*([\s\S]*?)\s*```/); + if (jsonMatch) { + try { + return JSON.parse(jsonMatch[1]); + } catch { + throw new Error("Failed to parse JSON from response"); + } + } + throw new Error("No valid JSON found in response"); + } + } + + return response || ''; + } catch (error) { + debug('Completion error:', error); + throw error; + } + } + + async getEmbeddingResponse(input: string): Promise { + debug('Getting embedding response'); + try { + const embeddingResponse = await this.openai.embeddings.create({ + model: this.embeddingModelName, + input, + }); + + return embeddingResponse.data[0].embedding; + } catch (error) { + debug('Embedding error:', error); + return undefined; + } + } +} + +debug('OllamaService module loaded'); +export default OllamaService; diff --git a/core/src/services/llama.ts b/core/src/services/llama.ts index fb781b9352e..6a3b2f52615 100644 --- a/core/src/services/llama.ts +++ b/core/src/services/llama.ts @@ -1,411 +1,64 @@ -import { fileURLToPath } from "url"; -import path from "path"; -import { - GbnfJsonSchema, - getLlama, - Llama, - LlamaContext, - LlamaContextSequence, - LlamaContextSequenceRepeatPenalty, - LlamaJsonSchemaGrammar, - LlamaModel, - Token, -} from "node-llama-cpp"; -import fs from "fs"; -import https from "https"; -import si from "systeminformation"; -import { wordsToPunish } from "./wordsToPunish.ts"; +import { ILlamaService } from '../core/types.ts'; ///ILlamaService'; +import LlamaCppService from './LlamaCppService.ts'; +import OllamaService from './OllamaService.ts'; +import * as dotenv from 'dotenv'; -const __dirname = path.dirname(fileURLToPath(import.meta.url)); +dotenv.config(); -const jsonSchemaGrammar: Readonly<{ - type: string; - properties: { - user: { - type: string; - }; - content: { - type: string; - }; - }; -}> = { - type: "object", - properties: { - user: { - type: "string", - }, - content: { - type: "string", - }, - }, -}; - -interface QueuedMessage { - context: string; - temperature: number; - stop: string[]; - max_tokens: number; - frequency_penalty: number; - presence_penalty: number; - useGrammar: boolean; - resolve: (value: any | string | PromiseLike) => void; - reject: (reason?: any) => void; -} - -class LlamaService { +class LlamaService implements ILlamaService { private static instance: LlamaService | null = null; - private llama: Llama | undefined; - private model: LlamaModel | undefined; - private modelPath: string; - private grammar: LlamaJsonSchemaGrammar | undefined; - private ctx: LlamaContext | undefined; - private sequence: LlamaContextSequence | undefined; - private modelUrl: string; - - private messageQueue: QueuedMessage[] = []; - private isProcessing: boolean = false; - private modelInitialized: boolean = false; - + private delegate: ILlamaService; + private constructor() { - console.log("Constructing"); - this.llama = undefined; - this.model = undefined; - this.modelUrl = - "https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B-GGUF/resolve/main/Hermes-3-Llama-3.1-8B.Q8_0.gguf?download=true"; - const modelName = "model.gguf"; - console.log("modelName", modelName); - this.modelPath = path.join(__dirname, modelName); - try { - this.initializeModel(); - } catch (error) { - console.error("Error initializing model", error); - } + const provider = process.env.LOCAL_LLAMA_PROVIDER; + console.log("provider: ", provider) + if (provider === 'OLLAMA') { + this.delegate = OllamaService.getInstance(); + } else { + this.delegate = LlamaCppService.getInstance(); + } } - + public static getInstance(): LlamaService { - if (!LlamaService.instance) { - LlamaService.instance = new LlamaService(); - } - return LlamaService.instance; - } - - async initializeModel() { - try { - await this.checkModel(); - - const systemInfo = await si.graphics(); - const hasCUDA = systemInfo.controllers.some((controller) => - controller.vendor.toLowerCase().includes("nvidia") - ); - - if (hasCUDA) { - console.log("**** CUDA detected"); - } else { - console.log( - "**** No CUDA detected - local response will be slow" - ); - } - - this.llama = await getLlama({ - gpu: "cuda", - }); - console.log("Creating grammar"); - const grammar = new LlamaJsonSchemaGrammar( - this.llama, - jsonSchemaGrammar as GbnfJsonSchema - ); - this.grammar = grammar; - console.log("Loading model"); - console.log("this.modelPath", this.modelPath); - - this.model = await this.llama.loadModel({ - modelPath: this.modelPath, - }); - console.log("Model GPU support", this.llama.getGpuDeviceNames()); - console.log("Creating context"); - this.ctx = await this.model.createContext({ contextSize: 8192 }); - this.sequence = this.ctx.getSequence(); - - this.modelInitialized = true; - this.processQueue(); - } catch (error) { - console.error( - "Model initialization failed. Deleting model and retrying...", - error - ); - await this.deleteModel(); - await this.initializeModel(); - } + if (!LlamaService.instance) { + LlamaService.instance = new LlamaService(); + } + return LlamaService.instance; } - - async checkModel() { - console.log("Checking model"); - if (!fs.existsSync(this.modelPath)) { - console.log("this.modelPath", this.modelPath); - console.log("Model not found. Downloading..."); - - await new Promise((resolve, reject) => { - const file = fs.createWriteStream(this.modelPath); - let downloadedSize = 0; - - const downloadModel = (url: string) => { - https - .get(url, (response) => { - const isRedirect = - response.statusCode >= 300 && - response.statusCode < 400; - if (isRedirect) { - const redirectUrl = response.headers.location; - if (redirectUrl) { - console.log( - "Following redirect to:", - redirectUrl - ); - downloadModel(redirectUrl); - return; - } else { - console.error("Redirect URL not found"); - reject(new Error("Redirect URL not found")); - return; - } - } - - const totalSize = parseInt( - response.headers["content-length"] ?? "0", - 10 - ); - - response.on("data", (chunk) => { - downloadedSize += chunk.length; - file.write(chunk); - - // Log progress - const progress = ( - (downloadedSize / totalSize) * - 100 - ).toFixed(2); - process.stdout.write( - `Downloaded ${progress}%\r` - ); - }); - - response.on("end", () => { - file.end(); - console.log("\nModel downloaded successfully."); - resolve(); - }); - }) - .on("error", (err) => { - fs.unlink(this.modelPath, () => {}); // Delete the file async - console.error("Download failed:", err.message); - reject(err); - }); - }; - - downloadModel(this.modelUrl); - - file.on("error", (err) => { - fs.unlink(this.modelPath, () => {}); // Delete the file async - console.error("File write error:", err.message); - reject(err); - }); - }); - } else { - console.log("Model already exists."); - } + + async initializeModel(): Promise { + return this.delegate.initializeModel(); } - - async deleteModel() { - if (fs.existsSync(this.modelPath)) { - fs.unlinkSync(this.modelPath); - console.log("Model deleted."); - } - } - + async queueMessageCompletion( - context: string, - temperature: number, - stop: string[], - frequency_penalty: number, - presence_penalty: number, - max_tokens: number + context: string, + temperature: number, + stop: string[], + frequency_penalty: number, + presence_penalty: number, + max_tokens: number ): Promise { - console.log("Queueing message generateText"); - return new Promise((resolve, reject) => { - this.messageQueue.push({ - context, - temperature, - stop, - frequency_penalty, - presence_penalty, - max_tokens, - useGrammar: true, - resolve, - reject, - }); - this.processQueue(); - }); + return this.delegate.queueMessageCompletion( + context, temperature, stop, frequency_penalty, presence_penalty, max_tokens + ); } - + async queueTextCompletion( - context: string, - temperature: number, - stop: string[], - frequency_penalty: number, - presence_penalty: number, - max_tokens: number + context: string, + temperature: number, + stop: string[], + frequency_penalty: number, + presence_penalty: number, + max_tokens: number ): Promise { - return new Promise((resolve, reject) => { - this.messageQueue.push({ - context, - temperature, - stop, - frequency_penalty, - presence_penalty, - max_tokens, - useGrammar: false, - resolve, - reject, - }); - this.processQueue(); - }); + return this.delegate.queueTextCompletion( + context, temperature, stop, frequency_penalty, presence_penalty, max_tokens + ); } - - private async processQueue() { - if ( - this.isProcessing || - this.messageQueue.length === 0 || - !this.modelInitialized - ) { - return; - } - - this.isProcessing = true; - - while (this.messageQueue.length > 0) { - const message = this.messageQueue.shift(); - if (message) { - try { - console.log("Processing message"); - const response = await this.getCompletionResponse( - message.context, - message.temperature, - message.stop, - message.frequency_penalty, - message.presence_penalty, - message.max_tokens, - message.useGrammar - ); - message.resolve(response); - } catch (error) { - message.reject(error); - } - } - } - - this.isProcessing = false; - } - - private async getCompletionResponse( - context: string, - temperature: number, - stop: string[], - frequency_penalty: number, - presence_penalty: number, - max_tokens: number, - useGrammar: boolean - ): Promise { - if (!this.sequence) { - throw new Error("Model not initialized."); - } - - const tokens = this.model!.tokenize(context); - - // tokenize the words to punish - const wordsToPunishTokens = wordsToPunish - .map((word) => this.model!.tokenize(word)) - .flat(); - - const repeatPenalty: LlamaContextSequenceRepeatPenalty = { - punishTokens: () => wordsToPunishTokens, - penalty: 1.2, - frequencyPenalty: frequency_penalty, - presencePenalty: presence_penalty, - }; - - const responseTokens: Token[] = []; - console.log("Evaluating tokens"); - for await (const token of this.sequence.evaluate(tokens, { - temperature: Number(temperature), - repeatPenalty: repeatPenalty, - grammarEvaluationState: useGrammar ? this.grammar : undefined, - yieldEogToken: false, - })) { - const current = this.model.detokenize([...responseTokens, token]); - if ([...stop].some((s) => current.includes(s))) { - console.log("Stop sequence found"); - break; - } - - responseTokens.push(token); - process.stdout.write(this.model!.detokenize([token])); - if (useGrammar) { - if (current.replaceAll("\n", "").includes("}```")) { - console.log("JSON block found"); - break; - } - } - if (responseTokens.length > max_tokens) { - console.log("Max tokens reached"); - break; - } - } - - const response = this.model!.detokenize(responseTokens); - - if (!response) { - throw new Error("Response is undefined"); - } - - if (useGrammar) { - // extract everything between ```json and ``` - let jsonString = response.match(/```json(.*?)```/s)?.[1].trim(); - if (!jsonString) { - // try parsing response as JSON - try { - jsonString = JSON.stringify(JSON.parse(response)); - console.log("parsedResponse", jsonString); - } catch { - throw new Error("JSON string not found"); - } - } - try { - const parsedResponse = JSON.parse(jsonString); - if (!parsedResponse) { - throw new Error("Parsed response is undefined"); - } - console.log("AI: " + parsedResponse.content); - await this.sequence.clearHistory(); - return parsedResponse; - } catch (error) { - console.error("Error parsing JSON:", error); - } - } else { - console.log("AI: " + response); - await this.sequence.clearHistory(); - return response; - } - } - + async getEmbeddingResponse(input: string): Promise { - if (!this.model) { - throw new Error("Model not initialized. Call initialize() first."); - } - - const embeddingContext = await this.model.createEmbeddingContext(); - const embedding = await embeddingContext.getEmbeddingFor(input); - return embedding?.vector ? [...embedding.vector] : undefined; + return this.delegate.getEmbeddingResponse(input); } -} - -export default LlamaService; + } + + export default LlamaService; \ No newline at end of file