From 5687fe8b02c0472e7cdeab706fb7e15e1b8d0d69 Mon Sep 17 00:00:00 2001 From: Ting Chien Meng Date: Tue, 31 Dec 2024 07:13:35 -0500 Subject: [PATCH] use tokenization service to trim tokens --- packages/core/src/generation.ts | 135 +++++++++++++++++++------------- 1 file changed, 82 insertions(+), 53 deletions(-) diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index 67ed1b664a..53cef21bc3 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -34,6 +34,7 @@ import { ServiceType, SearchResponse, ActionResponse, + ITokenizationService, } from "./types.ts"; import { fal } from "@fal-ai/client"; @@ -171,7 +172,15 @@ export async function generateText({ elizaLogger.debug( `Trimming context to max length of ${max_context_length} tokens.` ); - context = await trimTokens(context, max_context_length, "gpt-4o"); + const tokenizationService = runtime.getService( + ServiceType.TOKENIZATION + ); + + context = await tokenizationService.trimTokens( + context, + max_context_length, + model + ); let response: string; @@ -905,9 +914,19 @@ export async function generateMessageResponse({ context: string; modelClass: string; }): Promise { - const max_context_length = - models[runtime.modelProvider].settings.maxInputTokens; - context = trimTokens(context, max_context_length, "gpt-4o"); + const provider = runtime.modelProvider; + const model = models[provider].model[modelClass]; + const max_context_length = models[provider].settings.maxInputTokens; + + const tokenizationService = runtime.getService( + ServiceType.TOKENIZATION + ); + + context = await tokenizationService.trimTokens( + context, + max_context_length, + model + ); let retryLength = 1000; // exponential backoff while (true) { try { @@ -966,33 +985,35 @@ export const generateImage = async ( }); const apiKey = - runtime.imageModelProvider === runtime.modelProvider - ? runtime.token - : (() => { - // First try to match the specific provider - switch (runtime.imageModelProvider) { - case ModelProviderName.HEURIST: - return runtime.getSetting("HEURIST_API_KEY"); - case ModelProviderName.TOGETHER: - return runtime.getSetting("TOGETHER_API_KEY"); - case ModelProviderName.FAL: - return runtime.getSetting("FAL_API_KEY"); - case ModelProviderName.OPENAI: - return runtime.getSetting("OPENAI_API_KEY"); - case ModelProviderName.VENICE: - return runtime.getSetting("VENICE_API_KEY"); - case ModelProviderName.LIVEPEER: - return runtime.getSetting("LIVEPEER_GATEWAY_URL"); - default: - // If no specific match, try the fallback chain - return (runtime.getSetting("HEURIST_API_KEY") ?? - runtime.getSetting("TOGETHER_API_KEY") ?? - runtime.getSetting("FAL_API_KEY") ?? - runtime.getSetting("OPENAI_API_KEY") ?? - runtime.getSetting("VENICE_API_KEY"))?? - runtime.getSetting("LIVEPEER_GATEWAY_URL"); - } - })(); + runtime.imageModelProvider === runtime.modelProvider + ? runtime.token + : (() => { + // First try to match the specific provider + switch (runtime.imageModelProvider) { + case ModelProviderName.HEURIST: + return runtime.getSetting("HEURIST_API_KEY"); + case ModelProviderName.TOGETHER: + return runtime.getSetting("TOGETHER_API_KEY"); + case ModelProviderName.FAL: + return runtime.getSetting("FAL_API_KEY"); + case ModelProviderName.OPENAI: + return runtime.getSetting("OPENAI_API_KEY"); + case ModelProviderName.VENICE: + return runtime.getSetting("VENICE_API_KEY"); + case ModelProviderName.LIVEPEER: + return runtime.getSetting("LIVEPEER_GATEWAY_URL"); + default: + // If no specific match, try the fallback chain + return ( + runtime.getSetting("HEURIST_API_KEY") ?? + runtime.getSetting("TOGETHER_API_KEY") ?? + runtime.getSetting("FAL_API_KEY") ?? + runtime.getSetting("OPENAI_API_KEY") ?? + runtime.getSetting("VENICE_API_KEY") ?? + runtime.getSetting("LIVEPEER_GATEWAY_URL") + ); + } + })(); try { if (runtime.imageModelProvider === ModelProviderName.HEURIST) { const response = await fetch( @@ -1182,28 +1203,31 @@ export const generateImage = async ( }); return { success: true, data: base64s }; - } else if (runtime.imageModelProvider === ModelProviderName.LIVEPEER) { if (!apiKey) { throw new Error("Livepeer Gateway is not defined"); } try { const baseUrl = new URL(apiKey); - if (!baseUrl.protocol.startsWith('http')) { + if (!baseUrl.protocol.startsWith("http")) { throw new Error("Invalid Livepeer Gateway URL protocol"); } - const response = await fetch(`${baseUrl.toString()}text-to-image`, { - method: "POST", - headers: { - "Content-Type": "application/json" - }, - body: JSON.stringify({ - model_id: data.modelId || "ByteDance/SDXL-Lightning", - prompt: data.prompt, - width: data.width || 1024, - height: data.height || 1024 - }) - }); + const response = await fetch( + `${baseUrl.toString()}text-to-image`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model_id: + data.modelId || "ByteDance/SDXL-Lightning", + prompt: data.prompt, + width: data.width || 1024, + height: data.height || 1024, + }), + } + ); const result = await response.json(); if (!result.images?.length) { throw new Error("No images generated"); @@ -1225,19 +1249,19 @@ export const generateImage = async ( } const blob = await imageResponse.blob(); const arrayBuffer = await blob.arrayBuffer(); - const base64 = Buffer.from(arrayBuffer).toString("base64"); + const base64 = + Buffer.from(arrayBuffer).toString("base64"); return `data:image/jpeg;base64,${base64}`; }) ); return { success: true, - data: base64Images + data: base64Images, }; } catch (error) { console.error(error); return { success: false, error: error }; } - } else { let targetSize = `${data.width}x${data.height}`; if ( @@ -1383,10 +1407,7 @@ export const generateObject = async ({ } const provider = runtime.modelProvider; - const model = models[provider].model[modelClass] as TiktokenModel; - if (!model) { - throw new Error(`Unsupported model class: ${modelClass}`); - } + const model = models[provider].model[modelClass]; const temperature = models[provider].settings.temperature; const frequency_penalty = models[provider].settings.frequency_penalty; const presence_penalty = models[provider].settings.presence_penalty; @@ -1395,7 +1416,15 @@ export const generateObject = async ({ const apiKey = runtime.token; try { - context = trimTokens(context, max_context_length, model); + const tokenizationService = runtime.getService( + ServiceType.TOKENIZATION + ); + + context = await tokenizationService.trimTokens( + context, + max_context_length, + model + ); const modelOptions: ModelSettings = { prompt: context,