Skip to content

Commit

Permalink
feat: truncate embedding tokens (run-llama#918)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Yang <himself65@outlook.com>
  • Loading branch information
marcusschiesser and himself65 authored Jun 14, 2024
1 parent a51ed8d commit a44e54f
Show file tree
Hide file tree
Showing 24 changed files with 231 additions and 83 deletions.
5 changes: 5 additions & 0 deletions .changeset/giant-buses-breathe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

Truncate text to embed for OpenAI if it exceeds maxTokens
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ jobs:
- cloudflare-worker-agent
- nextjs-agent
- nextjs-edge-runtime
- waku-query-engine
# - waku-query-engine
runs-on: ubuntu-latest
name: Build Core Example (${{ matrix.packages }})
steps:
Expand Down
6 changes: 6 additions & 0 deletions packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ import {
BaseNode,
SimilarityType,
type BaseEmbedding,
type EmbeddingInfo,
type MessageContentDetail,
} from "llamaindex";

export class OpenAIEmbedding implements BaseEmbedding {
embedInfo?: EmbeddingInfo | undefined;
embedBatchSize = 512;

async getQueryEmbedding(query: MessageContentDetail) {
Expand Down Expand Up @@ -36,4 +38,8 @@ export class OpenAIEmbedding implements BaseEmbedding {
nodes.forEach((node) => (node.embedding = [0]));
return nodes;
}

truncateMaxTokens(input: string[]): string[] {
return input;
}
}
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"portkey-ai": "^0.1.16",
"rake-modified": "^1.0.8",
"string-strip-html": "^13.4.8",
"tiktoken": "^1.0.15",
"unpdf": "^0.10.1",
"wikipedia": "^2.1.2",
"wink-nlp": "^2.3.0"
Expand Down
11 changes: 6 additions & 5 deletions packages/core/src/ChatHistory.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { globalsHelper } from "./GlobalsHelper.js";
import { tokenizers, type Tokenizer } from "@llamaindex/env";
import type { SummaryPrompt } from "./Prompt.js";
import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js";
import { OpenAI } from "./llm/openai.js";
Expand Down Expand Up @@ -70,8 +70,7 @@ export class SummaryChatHistory extends ChatHistory {
* Tokenizer function that converts text to tokens,
* this is used to calculate the number of tokens in a message.
*/
tokenizer: (text: string) => Uint32Array =
globalsHelper.defaultTokenizer.encode;
tokenizer: Tokenizer;
tokensToSummarize: number;
messages: ChatMessage[];
summaryPrompt: SummaryPrompt;
Expand All @@ -89,6 +88,7 @@ export class SummaryChatHistory extends ChatHistory {
"LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM.",
);
}
this.tokenizer = init?.tokenizer ?? tokenizers.tokenizer();
this.tokensToSummarize =
this.llm.metadata.contextWindow - this.llm.metadata.maxTokens;
if (this.tokensToSummarize < this.llm.metadata.contextWindow * 0.25) {
Expand Down Expand Up @@ -116,7 +116,8 @@ export class SummaryChatHistory extends ChatHistory {
// remove oldest message until the chat history is short enough for the context window
messagesToSummarize.shift();
} while (
this.tokenizer(promptMessages[0].content).length > this.tokensToSummarize
this.tokenizer.encode(promptMessages[0].content).length >
this.tokensToSummarize
);

const response = await this.llm.chat({
Expand Down Expand Up @@ -195,7 +196,7 @@ export class SummaryChatHistory extends ChatHistory {
// get tokens of current request messages and the transient messages
const tokens = requestMessages.reduce(
(count, message) =>
count + this.tokenizer(extractText(message.content)).length,
count + this.tokenizer.encode(extractText(message.content)).length,
0,
);
if (tokens > this.tokensToSummarize) {
Expand Down
49 changes: 0 additions & 49 deletions packages/core/src/GlobalsHelper.ts

This file was deleted.

10 changes: 5 additions & 5 deletions packages/core/src/PromptHelper.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { globalsHelper } from "./GlobalsHelper.js";
import { tokenizers, type Tokenizer } from "@llamaindex/env";
import type { SimplePrompt } from "./Prompt.js";
import { SentenceSplitter } from "./TextSplitter.js";
import {
Expand Down Expand Up @@ -34,7 +34,7 @@ export class PromptHelper {
numOutput = DEFAULT_NUM_OUTPUTS;
chunkOverlapRatio = DEFAULT_CHUNK_OVERLAP_RATIO;
chunkSizeLimit?: number;
tokenizer: (text: string) => Uint32Array;
tokenizer: Tokenizer;
separator = " ";

// eslint-disable-next-line max-params
Expand All @@ -43,14 +43,14 @@ export class PromptHelper {
numOutput = DEFAULT_NUM_OUTPUTS,
chunkOverlapRatio = DEFAULT_CHUNK_OVERLAP_RATIO,
chunkSizeLimit?: number,
tokenizer?: (text: string) => Uint32Array,
tokenizer?: Tokenizer,
separator = " ",
) {
this.contextWindow = contextWindow;
this.numOutput = numOutput;
this.chunkOverlapRatio = chunkOverlapRatio;
this.chunkSizeLimit = chunkSizeLimit;
this.tokenizer = tokenizer || globalsHelper.tokenizer();
this.tokenizer = tokenizer ?? tokenizers.tokenizer();
this.separator = separator;
}

Expand All @@ -61,7 +61,7 @@ export class PromptHelper {
*/
private getAvailableContextSize(prompt: SimplePrompt) {
const emptyPromptText = getEmptyPromptTxt(prompt);
const promptTokens = this.tokenizer(emptyPromptText);
const promptTokens = this.tokenizer.encode(emptyPromptText);
const numPromptTokens = promptTokens.length;

return this.contextWindow - numPromptTokens - this.numOutput;
Expand Down
23 changes: 9 additions & 14 deletions packages/core/src/TextSplitter.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { EOL } from "@llamaindex/env";
import { EOL, tokenizers, type Tokenizer } from "@llamaindex/env";
// GitHub translated
import { globalsHelper } from "./GlobalsHelper.js";
import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE } from "./constants.js";

class TextSplit {
Expand Down Expand Up @@ -69,17 +68,15 @@ export class SentenceSplitter {
public chunkSize: number;
public chunkOverlap: number;

private tokenizer: any;
private tokenizerDecoder: any;
private tokenizer: Tokenizer;
private paragraphSeparator: string;
private chunkingTokenizerFn: (text: string) => string[];
private splitLongSentences: boolean;

constructor(options?: {
chunkSize?: number;
chunkOverlap?: number;
tokenizer?: any;
tokenizerDecoder?: any;
tokenizer?: Tokenizer;
paragraphSeparator?: string;
chunkingTokenizerFn?: (text: string) => string[];
splitLongSentences?: boolean;
Expand All @@ -88,7 +85,6 @@ export class SentenceSplitter {
chunkSize = DEFAULT_CHUNK_SIZE,
chunkOverlap = DEFAULT_CHUNK_OVERLAP,
tokenizer = null,
tokenizerDecoder = null,
paragraphSeparator = defaultParagraphSeparator,
chunkingTokenizerFn,
splitLongSentences = false,
Expand All @@ -102,9 +98,7 @@ export class SentenceSplitter {
this.chunkSize = chunkSize;
this.chunkOverlap = chunkOverlap;

this.tokenizer = tokenizer ?? globalsHelper.tokenizer();
this.tokenizerDecoder =
tokenizerDecoder ?? globalsHelper.tokenizerDecoder();
this.tokenizer = tokenizer ?? tokenizers.tokenizer();

this.paragraphSeparator = paragraphSeparator;
this.chunkingTokenizerFn = chunkingTokenizerFn ?? defaultSentenceTokenizer;
Expand All @@ -115,7 +109,8 @@ export class SentenceSplitter {
// get "effective" chunk size by removing the metadata
let effectiveChunkSize;
if (extraInfoStr != undefined) {
const numExtraTokens = this.tokenizer(`${extraInfoStr}\n\n`).length + 1;
const numExtraTokens =
this.tokenizer.encode(`${extraInfoStr}\n\n`).length + 1;
effectiveChunkSize = this.chunkSize - numExtraTokens;
if (effectiveChunkSize <= 0) {
throw new Error(
Expand Down Expand Up @@ -190,19 +185,19 @@ export class SentenceSplitter {
if (!this.splitLongSentences) {
return sentenceSplits.map((split) => ({
text: split,
numTokens: this.tokenizer(split).length,
numTokens: this.tokenizer.encode(split).length,
}));
}

const newSplits: SplitRep[] = [];
for (const split of sentenceSplits) {
const splitTokens = this.tokenizer(split);
const splitTokens = this.tokenizer.encode(split);
const splitLen = splitTokens.length;
if (splitLen <= effectiveChunkSize) {
newSplits.push({ text: split, numTokens: splitLen });
} else {
for (let i = 0; i < splitLen; i += effectiveChunkSize) {
const cur_split = this.tokenizerDecoder(
const cur_split = this.tokenizer.decode(
splitTokens.slice(i, i + effectiveChunkSize),
);
newSplits.push({ text: cur_split, numTokens: effectiveChunkSize });
Expand Down
23 changes: 20 additions & 3 deletions packages/core/src/embeddings/OpenAIEmbedding.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Tokenizers } from "@llamaindex/env";
import type { ClientOptions as OpenAIClientOptions } from "openai";
import type { AzureOpenAIConfig } from "../llm/azure.js";
import {
Expand All @@ -12,20 +13,25 @@ import { BaseEmbedding } from "./types.js";
export const ALL_OPENAI_EMBEDDING_MODELS = {
"text-embedding-ada-002": {
dimensions: 1536,
maxTokens: 8191,
maxTokens: 8192,
tokenizer: Tokenizers.CL100K_BASE,
},
"text-embedding-3-small": {
dimensions: 1536,
dimensionOptions: [512, 1536],
maxTokens: 8191,
maxTokens: 8192,
tokenizer: Tokenizers.CL100K_BASE,
},
"text-embedding-3-large": {
dimensions: 3072,
dimensionOptions: [256, 1024, 3072],
maxTokens: 8191,
maxTokens: 8192,
tokenizer: Tokenizers.CL100K_BASE,
},
};

type ModelKeys = keyof typeof ALL_OPENAI_EMBEDDING_MODELS;

export class OpenAIEmbedding extends BaseEmbedding {
/** embeddding model. defaults to "text-embedding-ada-002" */
model: string;
Expand Down Expand Up @@ -65,6 +71,14 @@ export class OpenAIEmbedding extends BaseEmbedding {
this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
this.additionalSessionOptions = init?.additionalSessionOptions;

// find metadata for model
const key = Object.keys(ALL_OPENAI_EMBEDDING_MODELS).find(
(key) => key === this.model,
) as ModelKeys | undefined;
if (key) {
this.embedInfo = ALL_OPENAI_EMBEDDING_MODELS[key];
}

if (init?.azure || shouldUseAzure()) {
const azureConfig = {
...getAzureConfigFromEnv({
Expand Down Expand Up @@ -102,6 +116,9 @@ export class OpenAIEmbedding extends BaseEmbedding {
* @param options
*/
private async getOpenAIEmbedding(input: string[]): Promise<number[][]> {
// TODO: ensure this for every sub class by calling it in the base class
input = this.truncateMaxTokens(input);

const { data } = await this.session.openai.embeddings.create({
model: this.model,
dimensions: this.dimensions, // only sent to OpenAI if set by user
Expand Down
20 changes: 20 additions & 0 deletions packages/core/src/embeddings/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { Tokenizers, tokenizers } from "@llamaindex/env";

export function truncateMaxTokens(
tokenizer: Tokenizers,
value: string,
maxTokens: number,
): string {
// the maximum number of tokens per one character is 2 (e.g. 爨)
if (value.length * 2 < maxTokens) return value;
const t = tokenizers.tokenizer(tokenizer);
let tokens = t.encode(value);
if (tokens.length > maxTokens) {
// truncate tokens
tokens = tokens.slice(0, maxTokens);
value = t.decode(tokens);
// if we truncate at an UTF-8 boundary (some characters have more than one token), tiktoken returns a � character - remove it
return value.replace("�", "");
}
return value;
}
21 changes: 21 additions & 0 deletions packages/core/src/embeddings/types.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import { type Tokenizers } from "@llamaindex/env";
import type { BaseNode } from "../Node.js";
import { MetadataMode } from "../Node.js";
import type { TransformComponent } from "../ingestion/types.js";
import type { MessageContentDetail } from "../llm/types.js";
import { extractSingleText } from "../llm/utils.js";
import { truncateMaxTokens } from "./tokenizer.js";
import { SimilarityType, similarity } from "./utils.js";

const DEFAULT_EMBED_BATCH_SIZE = 10;

type EmbedFunc<T> = (values: T[]) => Promise<Array<number[]>>;

export type EmbeddingInfo = {
dimensions?: number;
maxTokens?: number;
tokenizer?: Tokenizers;
};

export abstract class BaseEmbedding implements TransformComponent {
embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
embedInfo?: EmbeddingInfo;

similarity(
embedding1: number[],
Expand Down Expand Up @@ -77,6 +86,18 @@ export abstract class BaseEmbedding implements TransformComponent {

return nodes;
}

truncateMaxTokens(input: string[]): string[] {
return input.map((s) => {
// truncate to max tokens
if (!(this.embedInfo?.tokenizer && this.embedInfo?.maxTokens)) return s;
return truncateMaxTokens(
this.embedInfo.tokenizer,
s,
this.embedInfo.maxTokens,
);
});
}
}

export async function batchEmbeddings<T>(
Expand Down
1 change: 0 additions & 1 deletion packages/core/src/index.edge.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
export * from "./ChatHistory.js";
export * from "./GlobalsHelper.js";
export * from "./Node.js";
export * from "./OutputParser.js";
export * from "./Prompt.js";
Expand Down
Loading

0 comments on commit a44e54f

Please sign in to comment.