Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support for Non-OpenAI Models in Token Trimming #1605

Merged
merged 36 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a1fe889
add tokenization service
tcm390 Dec 31, 2024
0eaae1f
add tokenization service
tcm390 Dec 31, 2024
1d08e17
add tokenization service
tcm390 Dec 31, 2024
5687fe8
use tokenization service to trim tokens
tcm390 Dec 31, 2024
ddb1be5
remove trimtokens function
tcm390 Dec 31, 2024
12c7ad1
remove trimtokens test
tcm390 Dec 31, 2024
c3d188a
use tokenization service to trim token
tcm390 Dec 31, 2024
92077ff
use tokenization service to trim token
tcm390 Dec 31, 2024
9402599
use tokenization service to trim token
tcm390 Dec 31, 2024
b74194c
use tokenization service to trim token
tcm390 Dec 31, 2024
fc638e8
use tokenization service to trim token
tcm390 Dec 31, 2024
af97657
tokenizer setings
tcm390 Dec 31, 2024
61a55c7
Merge branch 'develop' into tcm-trimTokens
tcm390 Dec 31, 2024
d9f56a9
Merge branch 'develop' into tcm-trimTokens
monilpat Dec 31, 2024
b465438
Merge branch 'develop' into tcm-trimTokens
lalalune Jan 1, 2025
f31e9c4
Merge branch 'develop' into tcm-trimTokens
odilitime Jan 2, 2025
34acac4
Merge branch 'develop' into tcm-trimTokens
tcm390 Jan 3, 2025
9ee9f0a
Merge branch 'develop' into tcm-trimTokens
odilitime Jan 3, 2025
c9f0b0a
Merge branch 'develop' into tcm-trimTokens
odilitime Jan 3, 2025
edcac59
chore: pnpm lock file
shakkernerd Jan 3, 2025
7688f1c
Default to TikToken truncation using the gpt-4o-mini model if tokeniz…
tcm390 Jan 3, 2025
7079353
Remove 'model' parameter from trimTokens function; allow tokenizer mo…
tcm390 Jan 3, 2025
e97a948
Remove model parameter
tcm390 Jan 3, 2025
70d90ab
use 4o as default model
tcm390 Jan 3, 2025
12a1acc
Merge branch 'develop' into tcm-trimTokens
tcm390 Jan 3, 2025
be5319a
use elizaLogger
tcm390 Jan 3, 2025
e2dbeb3
move trimTokens to core
tcm390 Jan 3, 2025
5555f85
clean code
tcm390 Jan 3, 2025
3bd3622
restore test
tcm390 Jan 3, 2025
7c45e2a
clean code
tcm390 Jan 3, 2025
2365b54
remove tokenizer service
tcm390 Jan 3, 2025
930f91e
fall back if unsupported type
tcm390 Jan 3, 2025
8afb612
Merge branch 'develop' into tcm-trimTokens
tcm390 Jan 3, 2025
94caa7e
Move encoding into try block to handle potential errors during model …
tcm390 Jan 3, 2025
226cd65
feat: add JsDoc to trimTokens function
shakkernerd Jan 3, 2025
616ca1f
feat: add validation to trimTokens
shakkernerd Jan 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,7 @@ CRONOSZKEVM_PRIVATE_KEY=

# Fuel Ecosystem (FuelVM)
FUEL_WALLET_PRIVATE_KEY=

# Tokenizer Settings
TOKENIZER_MODEL= # Specify the tokenizer model to be used.
TOKENIZER_TYPE= # Options: tiktoken (for OpenAI models) or auto (AutoTokenizer from Hugging Face for non-OpenAI models). Default: tiktoken.
13 changes: 6 additions & 7 deletions packages/client-discord/src/actions/chat_with_attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,16 @@ const summarizeAction = {

state.attachmentsWithText = attachmentsWithText;
state.objective = objective;

const template = await trimTokens(
summarizationTemplate,
chunkSize + 500,
runtime
);
const context = composeContext({
state,
// make sure it fits, we can pad the tokens a bit
// Get the model's tokenizer based on the current model being used
template: trimTokens(
summarizationTemplate,
chunkSize + 500,
(model.model[ModelClass.SMALL] ||
"gpt-4o-mini") as TiktokenModel // Use the same model as generation; Fallback if no SMALL model configured
),
template,
});

const summary = await generateText({
Expand Down
11 changes: 6 additions & 5 deletions packages/client-discord/src/actions/summarize_conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,15 @@ const summarizeAction = {
const chunk = chunks[i];
state.currentSummary = currentSummary;
state.currentChunk = chunk;
const template = await trimTokens(
summarizationTemplate,
chunkSize + 500,
runtime
);
const context = composeContext({
state,
// make sure it fits, we can pad the tokens a bit
template: trimTokens(
summarizationTemplate,
chunkSize + 500,
"gpt-4o-mini"
),
template,
});

const summary = await generateText({
Expand Down
2 changes: 1 addition & 1 deletion packages/client-discord/src/attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async function generateSummary(
text: string
): Promise<{ title: string; description: string }> {
// make sure text is under 128k characters
text = trimTokens(text, 100000, "gpt-4o-mini"); // TODO: clean this up
text = await trimTokens(text, 100000, runtime);

const prompt = `Please generate a concise summary for the following text:
Expand Down
2 changes: 1 addition & 1 deletion packages/client-discord/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export async function generateSummary(
text: string
): Promise<{ title: string; description: string }> {
// make sure text is under 128k characters
text = trimTokens(text, 100000, "gpt-4o-mini"); // TODO: clean this up
text = await trimTokens(text, 100000, runtime);

const prompt = `Please generate a concise summary for the following text:
Expand Down
11 changes: 6 additions & 5 deletions packages/client-slack/src/actions/chat_with_attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,14 @@ const summarizeAction: Action = {
currentState.attachmentsWithText = attachmentsWithText;
currentState.objective = objective;

const template = await trimTokens(
summarizationTemplate,
chunkSize + 500,
runtime
);
const context = composeContext({
state: currentState,
template: trimTokens(
summarizationTemplate,
chunkSize + 500,
"gpt-4o-mini"
),
template,
});

const summary = await generateText({
Expand Down
12 changes: 7 additions & 5 deletions packages/client-slack/src/actions/summarize_conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,15 @@ const summarizeAction: Action = {
currentState.currentSummary = currentSummary;
currentState.currentChunk = chunk;

const template = await trimTokens(
summarizationTemplate,
chunkSize + 500,
runtime
);

const context = composeContext({
state: currentState,
template: trimTokens(
summarizationTemplate,
chunkSize + 500,
"gpt-4o-mini"
),
template,
});

const summary = await generateText({
Expand Down
2 changes: 1 addition & 1 deletion packages/client-slack/src/attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async function generateSummary(
runtime: IAgentRuntime,
text: string
): Promise<{ title: string; description: string }> {
text = trimTokens(text, 100000, "gpt-4o-mini");
text = await trimTokens(text, 100000, runtime);

const prompt = `Please generate a concise summary for the following text:
Expand Down
169 changes: 120 additions & 49 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { Buffer } from "buffer";
import { createOllama } from "ollama-ai-provider";
import OpenAI from "openai";
import { encodingForModel, TiktokenModel } from "js-tiktoken";
import { AutoTokenizer } from "@huggingface/transformers";
import Together from "together-ai";
import { ZodSchema } from "zod";
import { elizaLogger } from "./index.ts";
Expand All @@ -37,13 +38,122 @@ import {
SearchResponse,
ActionResponse,
TelemetrySettings,
TokenizerType,
} from "./types.ts";
import { fal } from "@fal-ai/client";
import { tavily } from "@tavily/core";

type Tool = CoreTool<any, any>;
type StepResult = AIStepResult<any>;

/**
* Trims the provided text context to a specified token limit using a tokenizer model and type.
*
* The function dynamically determines the truncation method based on the tokenizer settings
* provided by the runtime. If no tokenizer settings are defined, it defaults to using the
* TikToken truncation method with the "gpt-4o" model.
*
* @async
* @function trimTokens
* @param {string} context - The text to be tokenized and trimmed.
* @param {number} maxTokens - The maximum number of tokens allowed after truncation.
* @param {IAgentRuntime} runtime - The runtime interface providing tokenizer settings.
*
* @returns {Promise<string>} A promise that resolves to the trimmed text.
*
* @throws {Error} Throws an error if the runtime settings are invalid or missing required fields.
*
* @example
* const trimmedText = await trimTokens("This is an example text", 50, runtime);
* console.log(trimmedText); // Output will be a truncated version of the input text.
*/
export async function trimTokens(
context: string,
maxTokens: number,
runtime: IAgentRuntime
) {
if (!context) return "";
if (maxTokens <= 0) throw new Error("maxTokens must be positive");

const tokenizerModel = runtime.getSetting("TOKENIZER_MODEL");
const tokenizerType = runtime.getSetting("TOKENIZER_TYPE");

if (!tokenizerModel || !tokenizerType) {
// Default to TikToken truncation using the "gpt-4o" model if tokenizer settings are not defined
return truncateTiktoken("gpt-4o", context, maxTokens);
}

// Choose the truncation method based on tokenizer type
if (tokenizerType === TokenizerType.Auto) {
return truncateAuto(tokenizerModel, context, maxTokens);
}

if (tokenizerType === TokenizerType.TikToken) {
return truncateTiktoken(
tokenizerModel as TiktokenModel,
context,
maxTokens
);
}

elizaLogger.warn(`Unsupported tokenizer type: ${tokenizerType}`);
return truncateTiktoken("gpt-4o", context, maxTokens);
}

async function truncateAuto(
modelPath: string,
context: string,
maxTokens: number
) {
try {
const tokenizer = await AutoTokenizer.from_pretrained(modelPath);
const tokens = tokenizer.encode(context);

// If already within limits, return unchanged
if (tokens.length <= maxTokens) {
return context;
}

// Keep the most recent tokens by slicing from the end
const truncatedTokens = tokens.slice(-maxTokens);

// Decode back to text - js-tiktoken decode() returns a string directly
return tokenizer.decode(truncatedTokens);
} catch (error) {
elizaLogger.error("Error in trimTokens:", error);
// Return truncated string if tokenization fails
return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token
}
}

async function truncateTiktoken(
model: TiktokenModel,
context: string,
maxTokens: number
) {
try {
const encoding = encodingForModel(model);

// Encode the text into tokens
const tokens = encoding.encode(context);

// If already within limits, return unchanged
if (tokens.length <= maxTokens) {
return context;
}

// Keep the most recent tokens by slicing from the end
const truncatedTokens = tokens.slice(-maxTokens);

// Decode back to text - js-tiktoken decode() returns a string directly
return encoding.decode(truncatedTokens);
} catch (error) {
elizaLogger.error("Error in trimTokens:", error);
// Return truncated string if tokenization fails
return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token
}
}

/**
* Send a message to the model for a text generateText - receive a string back and parse how you'd like
* @param opts - The options for the generateText request.
Expand Down Expand Up @@ -187,7 +297,8 @@ export async function generateText({
elizaLogger.debug(
`Trimming context to max length of ${max_context_length} tokens.`
);
context = trimTokens(context, max_context_length, "gpt-4o");

context = await trimTokens(context, max_context_length, runtime);

let response: string;

Expand Down Expand Up @@ -653,45 +764,6 @@ export async function generateText({
}
}

/**
* Truncate the context to the maximum length allowed by the model.
* @param context The text to truncate
* @param maxTokens Maximum number of tokens to keep
* @param model The tokenizer model to use
* @returns The truncated text
*/
export function trimTokens(
context: string,
maxTokens: number,
model: TiktokenModel
): string {
if (!context) return "";
if (maxTokens <= 0) throw new Error("maxTokens must be positive");

// Get the tokenizer for the model
const encoding = encodingForModel(model);

try {
// Encode the text into tokens
const tokens = encoding.encode(context);

// If already within limits, return unchanged
if (tokens.length <= maxTokens) {
return context;
}

// Keep the most recent tokens by slicing from the end
const truncatedTokens = tokens.slice(-maxTokens);

// Decode back to text - js-tiktoken decode() returns a string directly
return encoding.decode(truncatedTokens);
} catch (error) {
console.error("Error in trimTokens:", error);
// Return truncated string if tokenization fails
return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token
}
}

/**
* Sends a message to the model to determine if it should respond to the given context.
* @param opts - The options for the generateText request
Expand Down Expand Up @@ -973,9 +1045,10 @@ export async function generateMessageResponse({
context: string;
modelClass: string;
}): Promise<Content> {
const max_context_length =
models[runtime.modelProvider].settings.maxInputTokens;
context = trimTokens(context, max_context_length, "gpt-4o");
const provider = runtime.modelProvider;
const max_context_length = models[provider].settings.maxInputTokens;

context = await trimTokens(context, max_context_length, runtime);
let retryLength = 1000; // exponential backoff
while (true) {
try {
Expand Down Expand Up @@ -1443,20 +1516,18 @@ export const generateObject = async ({
}

const provider = runtime.modelProvider;
const model = models[provider].model[modelClass] as TiktokenModel;
if (!model) {
throw new Error(`Unsupported model class: ${modelClass}`);
}
const model = models[provider].model[modelClass];
const temperature = models[provider].settings.temperature;
const frequency_penalty = models[provider].settings.frequency_penalty;
const presence_penalty = models[provider].settings.presence_penalty;
const max_context_length = models[provider].settings.maxInputTokens;
const max_response_length = models[provider].settings.maxOutputTokens;
const experimental_telemetry = models[provider].settings.experimental_telemetry;
const experimental_telemetry =
models[provider].settings.experimental_telemetry;
const apiKey = runtime.token;

try {
context = trimTokens(context, max_context_length, model);
context = await trimTokens(context, max_context_length, runtime);

const modelOptions: ModelSettings = {
prompt: context,
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,6 @@ export type Character = {
/** Image model provider to use, if different from modelProvider */
imageModelProvider?: ModelProviderName;


/** Image Vision model provider to use, if different from modelProvider */
imageVisionModelProvider?: ModelProviderName;

Expand Down Expand Up @@ -1319,6 +1318,11 @@ export interface ISlackService extends Service {
client: any;
}

export enum TokenizerType {
Auto = "auto",
TikToken = "tiktoken",
}

export enum TranscriptionProvider {
OpenAI = "openai",
Deepgram = "deepgram",
Expand Down
2 changes: 2 additions & 0 deletions packages/core/tsup.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@ export default defineConfig({
"https",
// Add other modules you want to externalize
"@tavily/core",
"onnxruntime-node",
"sharp",
],
});
2 changes: 1 addition & 1 deletion packages/plugin-node/src/services/browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async function generateSummary(
text: string
): Promise<{ title: string; description: string }> {
// make sure text is under 128k characters
text = trimTokens(text, 100000, "gpt-4o-mini"); // TODO: clean this up
text = await trimTokens(text, 100000, runtime);

const prompt = `Please generate a concise summary for the following text:
Expand Down
Loading
Loading