Skip to content

Commit

Permalink
use tokenization service to trim tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
tcm390 committed Dec 31, 2024
1 parent 1d08e17 commit 5687fe8
Showing 1 changed file with 82 additions and 53 deletions.
135 changes: 82 additions & 53 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
ServiceType,
SearchResponse,
ActionResponse,
ITokenizationService,
} from "./types.ts";
import { fal } from "@fal-ai/client";

Expand Down Expand Up @@ -171,7 +172,15 @@ export async function generateText({
elizaLogger.debug(
`Trimming context to max length of ${max_context_length} tokens.`
);
context = await trimTokens(context, max_context_length, "gpt-4o");
const tokenizationService = runtime.getService<ITokenizationService>(
ServiceType.TOKENIZATION
);

context = await tokenizationService.trimTokens(
context,
max_context_length,
model
);

let response: string;

Expand Down Expand Up @@ -905,9 +914,19 @@ export async function generateMessageResponse({
context: string;
modelClass: string;
}): Promise<Content> {
const max_context_length =
models[runtime.modelProvider].settings.maxInputTokens;
context = trimTokens(context, max_context_length, "gpt-4o");
const provider = runtime.modelProvider;
const model = models[provider].model[modelClass];
const max_context_length = models[provider].settings.maxInputTokens;

const tokenizationService = runtime.getService<ITokenizationService>(
ServiceType.TOKENIZATION
);

context = await tokenizationService.trimTokens(
context,
max_context_length,
model
);
let retryLength = 1000; // exponential backoff
while (true) {
try {
Expand Down Expand Up @@ -966,33 +985,35 @@ export const generateImage = async (
});

const apiKey =
runtime.imageModelProvider === runtime.modelProvider
? runtime.token
: (() => {
// First try to match the specific provider
switch (runtime.imageModelProvider) {
case ModelProviderName.HEURIST:
return runtime.getSetting("HEURIST_API_KEY");
case ModelProviderName.TOGETHER:
return runtime.getSetting("TOGETHER_API_KEY");
case ModelProviderName.FAL:
return runtime.getSetting("FAL_API_KEY");
case ModelProviderName.OPENAI:
return runtime.getSetting("OPENAI_API_KEY");
case ModelProviderName.VENICE:
return runtime.getSetting("VENICE_API_KEY");
case ModelProviderName.LIVEPEER:
return runtime.getSetting("LIVEPEER_GATEWAY_URL");
default:
// If no specific match, try the fallback chain
return (runtime.getSetting("HEURIST_API_KEY") ??
runtime.getSetting("TOGETHER_API_KEY") ??
runtime.getSetting("FAL_API_KEY") ??
runtime.getSetting("OPENAI_API_KEY") ??
runtime.getSetting("VENICE_API_KEY"))??
runtime.getSetting("LIVEPEER_GATEWAY_URL");
}
})();
runtime.imageModelProvider === runtime.modelProvider
? runtime.token
: (() => {
// First try to match the specific provider
switch (runtime.imageModelProvider) {
case ModelProviderName.HEURIST:
return runtime.getSetting("HEURIST_API_KEY");
case ModelProviderName.TOGETHER:
return runtime.getSetting("TOGETHER_API_KEY");
case ModelProviderName.FAL:
return runtime.getSetting("FAL_API_KEY");
case ModelProviderName.OPENAI:
return runtime.getSetting("OPENAI_API_KEY");
case ModelProviderName.VENICE:
return runtime.getSetting("VENICE_API_KEY");
case ModelProviderName.LIVEPEER:
return runtime.getSetting("LIVEPEER_GATEWAY_URL");
default:
// If no specific match, try the fallback chain
return (
runtime.getSetting("HEURIST_API_KEY") ??
runtime.getSetting("TOGETHER_API_KEY") ??
runtime.getSetting("FAL_API_KEY") ??
runtime.getSetting("OPENAI_API_KEY") ??
runtime.getSetting("VENICE_API_KEY") ??
runtime.getSetting("LIVEPEER_GATEWAY_URL")
);
}
})();
try {
if (runtime.imageModelProvider === ModelProviderName.HEURIST) {
const response = await fetch(
Expand Down Expand Up @@ -1182,28 +1203,31 @@ export const generateImage = async (
});

return { success: true, data: base64s };

} else if (runtime.imageModelProvider === ModelProviderName.LIVEPEER) {
if (!apiKey) {
throw new Error("Livepeer Gateway is not defined");
}
try {
const baseUrl = new URL(apiKey);
if (!baseUrl.protocol.startsWith('http')) {
if (!baseUrl.protocol.startsWith("http")) {
throw new Error("Invalid Livepeer Gateway URL protocol");
}
const response = await fetch(`${baseUrl.toString()}text-to-image`, {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify({
model_id: data.modelId || "ByteDance/SDXL-Lightning",
prompt: data.prompt,
width: data.width || 1024,
height: data.height || 1024
})
});
const response = await fetch(
`${baseUrl.toString()}text-to-image`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
model_id:
data.modelId || "ByteDance/SDXL-Lightning",
prompt: data.prompt,
width: data.width || 1024,
height: data.height || 1024,
}),
}
);
const result = await response.json();
if (!result.images?.length) {
throw new Error("No images generated");
Expand All @@ -1225,19 +1249,19 @@ export const generateImage = async (
}
const blob = await imageResponse.blob();
const arrayBuffer = await blob.arrayBuffer();
const base64 = Buffer.from(arrayBuffer).toString("base64");
const base64 =
Buffer.from(arrayBuffer).toString("base64");
return `data:image/jpeg;base64,${base64}`;
})
);
return {
success: true,
data: base64Images
data: base64Images,
};
} catch (error) {
console.error(error);
return { success: false, error: error };
}

} else {
let targetSize = `${data.width}x${data.height}`;
if (
Expand Down Expand Up @@ -1383,10 +1407,7 @@ export const generateObject = async ({
}

const provider = runtime.modelProvider;
const model = models[provider].model[modelClass] as TiktokenModel;
if (!model) {
throw new Error(`Unsupported model class: ${modelClass}`);
}
const model = models[provider].model[modelClass];
const temperature = models[provider].settings.temperature;
const frequency_penalty = models[provider].settings.frequency_penalty;
const presence_penalty = models[provider].settings.presence_penalty;
Expand All @@ -1395,7 +1416,15 @@ export const generateObject = async ({
const apiKey = runtime.token;

try {
context = trimTokens(context, max_context_length, model);
const tokenizationService = runtime.getService<ITokenizationService>(
ServiceType.TOKENIZATION
);

context = await tokenizationService.trimTokens(
context,
max_context_length,
model
);

const modelOptions: ModelSettings = {
prompt: context,
Expand Down

0 comments on commit 5687fe8

Please sign in to comment.