Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve knowledge embeddings #472

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,12 @@ const summarizeAction = {
const model = models[runtime.character.settings.model];
const chunkSize = model.settings.maxContextLength - 1000;

const chunks = await splitChunks(formattedMemories, chunkSize, 0);
const chunks = await splitChunks(
formattedMemories,
chunkSize,
"gpt-4o-mini",
0
);

const datestr = new Date().toUTCString().replace(/:/g, "-");

Expand Down
5 changes: 3 additions & 2 deletions packages/client-discord/src/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -430,13 +430,13 @@ export class MessageManager {
await this.runtime.messageManager.createMemory(memory);
}

let state = (await this.runtime.composeState(userMessage, {
let state = await this.runtime.composeState(userMessage, {
discordClient: this.client,
discordMessage: message,
agentName:
this.runtime.character.name ||
this.client.user?.displayName,
})) as State;
});

if (!canSendMessage(message.channel).canSend) {
return elizaLogger.warn(
Expand Down Expand Up @@ -649,6 +649,7 @@ export class MessageManager {
message: DiscordMessage
): Promise<{ processedContent: string; attachments: Media[] }> {
let processedContent = message.content;

let attachments: Media[] = [];

// Process code blocks in the message content
Expand Down
45 changes: 3 additions & 42 deletions packages/client-github/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ import {
AgentRuntime,
Client,
IAgentRuntime,
Content,
Memory,
knowledge,
stringToUuid,
embeddingZeroVector,
splitChunks,
embed,
} from "@ai16z/eliza";
import { validateGithubConfig } from "./enviroment";

Expand Down Expand Up @@ -112,11 +108,8 @@ export class GitHubClient {
relativePath
);

const memory: Memory = {
await knowledge.set(this.runtime, {
id: knowledgeId,
agentId: this.runtime.agentId,
userId: this.runtime.agentId,
roomId: this.runtime.agentId,
content: {
text: content,
hash: contentHash,
Expand All @@ -128,39 +121,7 @@ export class GitHubClient {
owner: this.config.owner,
},
},
embedding: embeddingZeroVector,
};

await this.runtime.documentsManager.createMemory(memory);

// Only split if content exceeds 4000 characters
const fragments =
content.length > 4000
? await splitChunks(content, 2000, 200)
: [content];

for (const fragment of fragments) {
// Skip empty fragments
if (!fragment.trim()) continue;

// Add file path context to the fragment before embedding
const fragmentWithPath = `File: ${relativePath}\n\n${fragment}`;
const embedding = await embed(this.runtime, fragmentWithPath);

await this.runtime.knowledgeManager.createMemory({
// We namespace the knowledge base uuid to avoid id
// collision with the document above.
id: stringToUuid(knowledgeId + fragment),
roomId: this.runtime.agentId,
agentId: this.runtime.agentId,
userId: this.runtime.agentId,
content: {
source: knowledgeId,
text: fragment,
},
embedding,
});
}
});
}
}

Expand Down
26 changes: 15 additions & 11 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -463,34 +463,38 @@ export async function generateShouldRespond({
* Splits content into chunks of specified size with optional overlapping bleed sections
* @param content - The text content to split into chunks
* @param chunkSize - The maximum size of each chunk in tokens
* @param bleed - Number of characters to overlap between chunks (default: 100)
* @param model - The model name to use for tokenization (default: runtime.model)
* @param bleed - Number of characters to overlap between chunks (default: 100)
* @returns Promise resolving to array of text chunks with bleed sections
*/
export async function splitChunks(
content: string,
chunkSize: number,
model: string,
bleed: number = 100
): Promise<string[]> {
const encoding = encoding_for_model("gpt-4o-mini");

const encoding = encoding_for_model(model as TiktokenModel);
const tokens = encoding.encode(content);
const chunks: string[] = [];
const textDecoder = new TextDecoder();

for (let i = 0; i < tokens.length; i += chunkSize) {
const chunk = tokens.slice(i, i + chunkSize);
const decodedChunk = textDecoder.decode(encoding.decode(chunk));
let chunk = tokens.slice(i, i + chunkSize);

// Append bleed characters from the previous chunk
const startBleed = i > 0 ? content.slice(i - bleed, i) : "";
if (i > 0) {
chunk = new Uint32Array([...tokens.slice(i - bleed, i), ...chunk]);
}

// Append bleed characters from the next chunk
const endBleed =
i + chunkSize < tokens.length
? content.slice(i + chunkSize, i + chunkSize + bleed)
: "";
if (i + chunkSize < tokens.length) {
chunk = new Uint32Array([
...chunk,
...tokens.slice(i + chunkSize, i + chunkSize + bleed),
]);
}

chunks.push(startBleed + decodedChunk + endBleed);
chunks.push(textDecoder.decode(encoding.decode(chunk)));
}

return chunks;
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ export * from "./parsing.ts";
export * from "./uuid.ts";
export * from "./enviroment.ts";
export * from "./cache.ts";
export { default as knowledge } from "./knowledge.ts";
129 changes: 129 additions & 0 deletions packages/core/src/knowledge.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import { UUID } from "crypto";

import { AgentRuntime } from "./runtime.ts";
import { embed } from "./embedding.ts";
import { Content, ModelClass, type Memory } from "./types.ts";
import { stringToUuid } from "./uuid.ts";
import { embeddingZeroVector } from "./memory.ts";
import { splitChunks } from "./generation.ts";
import { models } from "./models.ts";
import elizaLogger from "./logger.ts";

async function get(runtime: AgentRuntime, message: Memory): Promise<string[]> {
const processed = preprocess(message.content.text);
elizaLogger.log(`Querying knowledge for: ${processed}`);
const embedding = await embed(runtime, processed);
const fragments = await runtime.knowledgeManager.searchMemoriesByEmbedding(
embedding,
{
roomId: message.agentId,
agentId: message.agentId,
count: 3,
match_threshold: 0.1,
}
);

const uniqueSources = [
...new Set(
fragments.map((memory) => {
elizaLogger.log(
`Matched fragment: ${memory.content.text} with similarity: ${message.similarity}`
);
return memory.content.source;
})
),
];

const knowledgeDocuments = await Promise.all(
uniqueSources.map((source) =>
runtime.documentsManager.getMemoryById(source as UUID)
)
);

const knowledge = knowledgeDocuments
.filter((memory) => memory !== null)
.map((memory) => memory.content.text);
return knowledge;
}

export type KnowledgeItem = {
id: UUID;
content: Content;
};

async function set(runtime: AgentRuntime, item: KnowledgeItem) {
await runtime.documentsManager.createMemory({
embedding: embeddingZeroVector,
id: item.id,
agentId: runtime.agentId,
roomId: runtime.agentId,
userId: runtime.agentId,
createdAt: Date.now(),
content: item.content,
});

const preprocessed = preprocess(item.content.text);
const fragments = await splitChunks(
preprocessed,
10,
models[runtime.character.modelProvider].model?.[ModelClass.EMBEDDING],
5
);

for (const fragment of fragments) {
const embedding = await embed(runtime, fragment);
await runtime.knowledgeManager.createMemory({
// We namespace the knowledge base uuid to avoid id
// collision with the document above.
id: stringToUuid(item.id + fragment),
roomId: runtime.agentId,
agentId: runtime.agentId,
userId: runtime.agentId,
createdAt: Date.now(),
content: {
source: item.id,
text: fragment,
},
embedding,
});
}
}

export function preprocess(content: string): string {
return (
content
// Remove code blocks and their content
.replace(/```[\s\S]*?```/g, "")
// Remove inline code
.replace(/`.*?`/g, "")
// Convert headers to plain text with emphasis
.replace(/#{1,6}\s*(.*)/g, "$1")
// Remove image links but keep alt text
.replace(/!\[(.*?)\]\(.*?\)/g, "$1")
// Remove links but keep text
.replace(/\[(.*?)\]\(.*?\)/g, "$1")
// Remove HTML tags
.replace(/<[^>]*>/g, "")
// Remove horizontal rules
.replace(/^\s*[-*_]{3,}\s*$/gm, "")
// Remove comments
.replace(/\/\*[\s\S]*?\*\//g, "")
.replace(/\/\/.*/g, "")
// Normalize whitespace
.replace(/\s+/g, " ")
// Remove multiple newlines
.replace(/\n{3,}/g, "\n\n")
// strip all special characters
.replace(/[^a-zA-Z0-9\s]/g, "")
// Remove Discord mentions
.replace(/<@!?\d+>/g, "")
.trim()
.toLowerCase()
);
}

export default {
get,
set,
process,
};
Loading