Skip to content

Commit

Permalink
Merge pull request #1605 from elizaOS/tcm-trimTokens
Browse files Browse the repository at this point in the history
fix: Support for Non-OpenAI Models in Token Trimming
  • Loading branch information
shakkernerd authored Jan 3, 2025
2 parents be15f56 + 616ca1f commit bf6ef96
Show file tree
Hide file tree
Showing 13 changed files with 656 additions and 510 deletions.
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,7 @@ CRONOSZKEVM_PRIVATE_KEY=

# Fuel Ecosystem (FuelVM)
FUEL_WALLET_PRIVATE_KEY=

# Tokenizer Settings
TOKENIZER_MODEL= # Specify the tokenizer model to be used.
TOKENIZER_TYPE= # Options: tiktoken (for OpenAI models) or auto (AutoTokenizer from Hugging Face for non-OpenAI models). Default: tiktoken.
13 changes: 6 additions & 7 deletions packages/client-discord/src/actions/chat_with_attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,16 @@ const summarizeAction = {

state.attachmentsWithText = attachmentsWithText;
state.objective = objective;

const template = await trimTokens(
summarizationTemplate,
chunkSize + 500,
runtime
);
const context = composeContext({
state,
// make sure it fits, we can pad the tokens a bit
// Get the model's tokenizer based on the current model being used
template: trimTokens(
summarizationTemplate,
chunkSize + 500,
(model.model[ModelClass.SMALL] ||
"gpt-4o-mini") as TiktokenModel // Use the same model as generation; Fallback if no SMALL model configured
),
template,
});

const summary = await generateText({
Expand Down
11 changes: 6 additions & 5 deletions packages/client-discord/src/actions/summarize_conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,15 @@ const summarizeAction = {
const chunk = chunks[i];
state.currentSummary = currentSummary;
state.currentChunk = chunk;
const template = await trimTokens(
summarizationTemplate,
chunkSize + 500,
runtime
);
const context = composeContext({
state,
// make sure it fits, we can pad the tokens a bit
template: trimTokens(
summarizationTemplate,
chunkSize + 500,
"gpt-4o-mini"
),
template,
});

const summary = await generateText({
Expand Down
2 changes: 1 addition & 1 deletion packages/client-discord/src/attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async function generateSummary(
text: string
): Promise<{ title: string; description: string }> {
// make sure text is under 128k characters
text = trimTokens(text, 100000, "gpt-4o-mini"); // TODO: clean this up
text = await trimTokens(text, 100000, runtime);

const prompt = `Please generate a concise summary for the following text:
Expand Down
2 changes: 1 addition & 1 deletion packages/client-discord/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export async function generateSummary(
text: string
): Promise<{ title: string; description: string }> {
// make sure text is under 128k characters
text = trimTokens(text, 100000, "gpt-4o-mini"); // TODO: clean this up
text = await trimTokens(text, 100000, runtime);

const prompt = `Please generate a concise summary for the following text:
Expand Down
11 changes: 6 additions & 5 deletions packages/client-slack/src/actions/chat_with_attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,14 @@ const summarizeAction: Action = {
currentState.attachmentsWithText = attachmentsWithText;
currentState.objective = objective;

const template = await trimTokens(
summarizationTemplate,
chunkSize + 500,
runtime
);
const context = composeContext({
state: currentState,
template: trimTokens(
summarizationTemplate,
chunkSize + 500,
"gpt-4o-mini"
),
template,
});

const summary = await generateText({
Expand Down
12 changes: 7 additions & 5 deletions packages/client-slack/src/actions/summarize_conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,15 @@ const summarizeAction: Action = {
currentState.currentSummary = currentSummary;
currentState.currentChunk = chunk;

const template = await trimTokens(
summarizationTemplate,
chunkSize + 500,
runtime
);

const context = composeContext({
state: currentState,
template: trimTokens(
summarizationTemplate,
chunkSize + 500,
"gpt-4o-mini"
),
template,
});

const summary = await generateText({
Expand Down
2 changes: 1 addition & 1 deletion packages/client-slack/src/attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async function generateSummary(
runtime: IAgentRuntime,
text: string
): Promise<{ title: string; description: string }> {
text = trimTokens(text, 100000, "gpt-4o-mini");
text = await trimTokens(text, 100000, runtime);

const prompt = `Please generate a concise summary for the following text:
Expand Down
169 changes: 120 additions & 49 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { Buffer } from "buffer";
import { createOllama } from "ollama-ai-provider";
import OpenAI from "openai";
import { encodingForModel, TiktokenModel } from "js-tiktoken";
import { AutoTokenizer } from "@huggingface/transformers";
import Together from "together-ai";
import { ZodSchema } from "zod";
import { elizaLogger } from "./index.ts";
Expand All @@ -37,13 +38,122 @@ import {
SearchResponse,
ActionResponse,
TelemetrySettings,
TokenizerType,
} from "./types.ts";
import { fal } from "@fal-ai/client";
import { tavily } from "@tavily/core";

type Tool = CoreTool<any, any>;
type StepResult = AIStepResult<any>;

/**
* Trims the provided text context to a specified token limit using a tokenizer model and type.
*
* The function dynamically determines the truncation method based on the tokenizer settings
* provided by the runtime. If no tokenizer settings are defined, it defaults to using the
* TikToken truncation method with the "gpt-4o" model.
*
* @async
* @function trimTokens
* @param {string} context - The text to be tokenized and trimmed.
* @param {number} maxTokens - The maximum number of tokens allowed after truncation.
* @param {IAgentRuntime} runtime - The runtime interface providing tokenizer settings.
*
* @returns {Promise<string>} A promise that resolves to the trimmed text.
*
* @throws {Error} Throws an error if the runtime settings are invalid or missing required fields.
*
* @example
* const trimmedText = await trimTokens("This is an example text", 50, runtime);
* console.log(trimmedText); // Output will be a truncated version of the input text.
*/
export async function trimTokens(
context: string,
maxTokens: number,
runtime: IAgentRuntime
) {
if (!context) return "";
if (maxTokens <= 0) throw new Error("maxTokens must be positive");

const tokenizerModel = runtime.getSetting("TOKENIZER_MODEL");
const tokenizerType = runtime.getSetting("TOKENIZER_TYPE");

if (!tokenizerModel || !tokenizerType) {
// Default to TikToken truncation using the "gpt-4o" model if tokenizer settings are not defined
return truncateTiktoken("gpt-4o", context, maxTokens);
}

// Choose the truncation method based on tokenizer type
if (tokenizerType === TokenizerType.Auto) {
return truncateAuto(tokenizerModel, context, maxTokens);
}

if (tokenizerType === TokenizerType.TikToken) {
return truncateTiktoken(
tokenizerModel as TiktokenModel,
context,
maxTokens
);
}

elizaLogger.warn(`Unsupported tokenizer type: ${tokenizerType}`);
return truncateTiktoken("gpt-4o", context, maxTokens);
}

async function truncateAuto(
modelPath: string,
context: string,
maxTokens: number
) {
try {
const tokenizer = await AutoTokenizer.from_pretrained(modelPath);
const tokens = tokenizer.encode(context);

// If already within limits, return unchanged
if (tokens.length <= maxTokens) {
return context;
}

// Keep the most recent tokens by slicing from the end
const truncatedTokens = tokens.slice(-maxTokens);

// Decode back to text - js-tiktoken decode() returns a string directly
return tokenizer.decode(truncatedTokens);
} catch (error) {
elizaLogger.error("Error in trimTokens:", error);
// Return truncated string if tokenization fails
return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token
}
}

async function truncateTiktoken(
model: TiktokenModel,
context: string,
maxTokens: number
) {
try {
const encoding = encodingForModel(model);

// Encode the text into tokens
const tokens = encoding.encode(context);

// If already within limits, return unchanged
if (tokens.length <= maxTokens) {
return context;
}

// Keep the most recent tokens by slicing from the end
const truncatedTokens = tokens.slice(-maxTokens);

// Decode back to text - js-tiktoken decode() returns a string directly
return encoding.decode(truncatedTokens);
} catch (error) {
elizaLogger.error("Error in trimTokens:", error);
// Return truncated string if tokenization fails
return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token
}
}

/**
* Send a message to the model for a text generateText - receive a string back and parse how you'd like
* @param opts - The options for the generateText request.
Expand Down Expand Up @@ -187,7 +297,8 @@ export async function generateText({
elizaLogger.debug(
`Trimming context to max length of ${max_context_length} tokens.`
);
context = trimTokens(context, max_context_length, "gpt-4o");

context = await trimTokens(context, max_context_length, runtime);

let response: string;

Expand Down Expand Up @@ -653,45 +764,6 @@ export async function generateText({
}
}

/**
* Truncate the context to the maximum length allowed by the model.
* @param context The text to truncate
* @param maxTokens Maximum number of tokens to keep
* @param model The tokenizer model to use
* @returns The truncated text
*/
export function trimTokens(
context: string,
maxTokens: number,
model: TiktokenModel
): string {
if (!context) return "";
if (maxTokens <= 0) throw new Error("maxTokens must be positive");

// Get the tokenizer for the model
const encoding = encodingForModel(model);

try {
// Encode the text into tokens
const tokens = encoding.encode(context);

// If already within limits, return unchanged
if (tokens.length <= maxTokens) {
return context;
}

// Keep the most recent tokens by slicing from the end
const truncatedTokens = tokens.slice(-maxTokens);

// Decode back to text - js-tiktoken decode() returns a string directly
return encoding.decode(truncatedTokens);
} catch (error) {
console.error("Error in trimTokens:", error);
// Return truncated string if tokenization fails
return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token
}
}

/**
* Sends a message to the model to determine if it should respond to the given context.
* @param opts - The options for the generateText request
Expand Down Expand Up @@ -973,9 +1045,10 @@ export async function generateMessageResponse({
context: string;
modelClass: string;
}): Promise<Content> {
const max_context_length =
models[runtime.modelProvider].settings.maxInputTokens;
context = trimTokens(context, max_context_length, "gpt-4o");
const provider = runtime.modelProvider;
const max_context_length = models[provider].settings.maxInputTokens;

context = await trimTokens(context, max_context_length, runtime);
let retryLength = 1000; // exponential backoff
while (true) {
try {
Expand Down Expand Up @@ -1443,20 +1516,18 @@ export const generateObject = async ({
}

const provider = runtime.modelProvider;
const model = models[provider].model[modelClass] as TiktokenModel;
if (!model) {
throw new Error(`Unsupported model class: ${modelClass}`);
}
const model = models[provider].model[modelClass];
const temperature = models[provider].settings.temperature;
const frequency_penalty = models[provider].settings.frequency_penalty;
const presence_penalty = models[provider].settings.presence_penalty;
const max_context_length = models[provider].settings.maxInputTokens;
const max_response_length = models[provider].settings.maxOutputTokens;
const experimental_telemetry = models[provider].settings.experimental_telemetry;
const experimental_telemetry =
models[provider].settings.experimental_telemetry;
const apiKey = runtime.token;

try {
context = trimTokens(context, max_context_length, model);
context = await trimTokens(context, max_context_length, runtime);

const modelOptions: ModelSettings = {
prompt: context,
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,6 @@ export type Character = {
/** Image model provider to use, if different from modelProvider */
imageModelProvider?: ModelProviderName;


/** Image Vision model provider to use, if different from modelProvider */
imageVisionModelProvider?: ModelProviderName;

Expand Down Expand Up @@ -1319,6 +1318,11 @@ export interface ISlackService extends Service {
client: any;
}

export enum TokenizerType {
Auto = "auto",
TikToken = "tiktoken",
}

export enum TranscriptionProvider {
OpenAI = "openai",
Deepgram = "deepgram",
Expand Down
2 changes: 2 additions & 0 deletions packages/core/tsup.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@ export default defineConfig({
"https",
// Add other modules you want to externalize
"@tavily/core",
"onnxruntime-node",
"sharp",
],
});
2 changes: 1 addition & 1 deletion packages/plugin-node/src/services/browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async function generateSummary(
text: string
): Promise<{ title: string; description: string }> {
// make sure text is under 128k characters
text = trimTokens(text, 100000, "gpt-4o-mini"); // TODO: clean this up
text = await trimTokens(text, 100000, runtime);

const prompt = `Please generate a concise summary for the following text:
Expand Down
Loading

0 comments on commit bf6ef96

Please sign in to comment.