Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: embedding search for non-openai models #660

Merged
merged 2 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions packages/adapter-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ export * from "./sqliteTables.ts";
export * from "./sqlite_vec.ts";

import { DatabaseAdapter, IDatabaseCacheAdapter } from "@ai16z/eliza";
import { embeddingZeroVector } from "@ai16z/eliza";
import {
Account,
Actor,
Expand Down Expand Up @@ -222,7 +221,7 @@ export class SqliteDatabaseAdapter
memory.id ?? v4(),
tableName,
content,
new Float32Array(memory.embedding ?? embeddingZeroVector), // Store as Float32Array
new Float32Array(memory.embedding!), // Store as Float32Array
memory.userId,
memory.roomId,
memory.agentId,
Expand Down
6 changes: 3 additions & 3 deletions packages/client-discord/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { embeddingZeroVector } from "@ai16z/eliza";
import { getEmbeddingZeroVector } from "@ai16z/eliza";
import { Character, Client as ElizaClient, IAgentRuntime } from "@ai16z/eliza";
import { stringToUuid } from "@ai16z/eliza";
import { elizaLogger } from "@ai16z/eliza";
Expand Down Expand Up @@ -189,7 +189,7 @@ export class DiscordClient extends EventEmitter {
},
roomId,
createdAt: Date.now(),
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
});
}

Expand Down Expand Up @@ -259,7 +259,7 @@ export class DiscordClient extends EventEmitter {
},
roomId,
createdAt: Date.now(),
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
});
} catch (error) {
console.error("Error creating reaction removal message:", error);
Expand Down
5 changes: 2 additions & 3 deletions packages/client-discord/src/messages.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { composeContext } from "@ai16z/eliza";
import { generateMessageResponse, generateShouldRespond } from "@ai16z/eliza";
import { embeddingZeroVector } from "@ai16z/eliza";
import {
Content,
HandlerCallback,
Expand All @@ -15,7 +14,7 @@ import {
State,
UUID,
} from "@ai16z/eliza";
import { stringToUuid } from "@ai16z/eliza";
import { stringToUuid, getEmbeddingZeroVector } from "@ai16z/eliza";
import {
ChannelType,
Client,
Expand Down Expand Up @@ -268,7 +267,7 @@ export class MessageManager {
url: m.url,
},
roomId,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
createdAt: m.createdTimestamp,
};
memories.push(memory);
Expand Down
28 changes: 13 additions & 15 deletions packages/client-discord/src/voice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
UUID,
composeContext,
elizaLogger,
embeddingZeroVector,
getEmbeddingZeroVector,
generateMessageResponse,
stringToUuid,
generateShouldRespond,
Expand Down Expand Up @@ -500,17 +500,13 @@ export class VoiceManager extends EventEmitter {
}
};

new AudioMonitor(
audioStream,
10000000,
async (buffer) => {
if (!buffer) {
console.error("Received empty buffer");
return;
}
await processBuffer(buffer);
new AudioMonitor(audioStream, 10000000, async (buffer) => {
if (!buffer) {
console.error("Received empty buffer");
return;
}
);
await processBuffer(buffer);
});
}

private async processTranscription(
Expand All @@ -534,7 +530,7 @@ export class VoiceManager extends EventEmitter {

const transcriptionText = await this.runtime
.getService<ITranscriptionService>(ServiceType.TRANSCRIPTION)
.transcribe(wavBuffer);
.transcribe(wavBuffer);

function isValidTranscription(text: string): boolean {
if (!text || text.includes("[BLANK_AUDIO]")) return false;
Expand Down Expand Up @@ -614,7 +610,7 @@ export class VoiceManager extends EventEmitter {
},
userId: userIdUUID,
roomId,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
createdAt: Date.now(),
};

Expand Down Expand Up @@ -674,7 +670,7 @@ export class VoiceManager extends EventEmitter {
inReplyTo: memory.id,
},
roomId,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
};

if (responseMemory.content.text?.trim()) {
Expand All @@ -684,7 +680,9 @@ export class VoiceManager extends EventEmitter {
state = await this.runtime.updateRecentMessageState(state);

const responseStream = await this.runtime
.getService<ISpeechService>(ServiceType.SPEECH_GENERATION)
.getService<ISpeechService>(
ServiceType.SPEECH_GENERATION
)
.generate(this.runtime, content.text);

if (responseStream) {
Expand Down
3 changes: 1 addition & 2 deletions packages/client-farcaster/src/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ export async function sendCast({
cast,
memory: createCastMemory({
roomId,
agentId: runtime.agentId,
userId: runtime.agentId,
runtime,
cast,
}),
}));
Expand Down
3 changes: 1 addition & 2 deletions packages/client-farcaster/src/interactions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ export class FarcasterInteractionManager {
if (!castMemory) {
await this.runtime.messageManager.createMemory(
createCastMemory({
agentId: this.runtime.agentId,
roomId: memory.roomId,
userId: memory.userId,
runtime: this.runtime,
cast,
})
);
Expand Down
21 changes: 9 additions & 12 deletions packages/client-farcaster/src/memory.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { isCastAddMessage } from "@farcaster/hub-nodejs";
import {
elizaLogger,
embeddingZeroVector,
getEmbeddingZeroVector,
IAgentRuntime,
stringToUuid,
type Memory,
Expand All @@ -14,29 +14,27 @@ import { FarcasterClient } from "./client";

export function createCastMemory({
roomId,
agentId,
userId,
runtime,
cast,
}: {
roomId: UUID;
agentId: UUID;
userId: UUID;
runtime: IAgentRuntime;
cast: Cast;
}): Memory {
const inReplyTo = cast.message.data.castAddBody.parentCastId
? castUuid({
hash: toHex(cast.message.data.castAddBody.parentCastId.hash),
agentId,
agentId: runtime.agentId,
})
: undefined;

return {
id: castUuid({
hash: cast.id,
agentId,
agentId: runtime.agentId,
}),
agentId,
userId,
agentId: runtime.agentId,
userId: runtime.agentId,
content: {
text: cast.text,
source: "farcaster",
Expand All @@ -45,7 +43,7 @@ export function createCastMemory({
hash: cast.id,
},
roomId,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(runtime),
createdAt: cast.message.data.timestamp * 1000,
};
}
Expand Down Expand Up @@ -93,8 +91,7 @@ export async function buildConversationThread({
await runtime.messageManager.createMemory(
createCastMemory({
roomId,
agentId: runtime.agentId,
userId,
runtime,
cast: currentCast,
})
);
Expand Down
3 changes: 1 addition & 2 deletions packages/client-farcaster/src/post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ export class FarcasterPostManager {
await this.runtime.messageManager.createMemory(
createCastMemory({
roomId,
userId: this.runtime.agentId,
agentId: this.runtime.agentId,
runtime: this.runtime,
cast,
})
);
Expand Down
6 changes: 3 additions & 3 deletions packages/client-telegram/src/messageManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Message } from "@telegraf/types";
import { Context, Telegraf } from "telegraf";

import { composeContext, elizaLogger, ServiceType } from "@ai16z/eliza";
import { embeddingZeroVector } from "@ai16z/eliza";
import { getEmbeddingZeroVector } from "@ai16z/eliza";
import {
Content,
HandlerCallback,
Expand Down Expand Up @@ -405,7 +405,7 @@ export class MessageManager {
roomId,
content,
createdAt: message.date * 1000,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
};

// Create memory
Expand Down Expand Up @@ -468,7 +468,7 @@ export class MessageManager {
inReplyTo: messageId,
},
createdAt: sentMessage.date * 1000,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
};

// Set action to CONTINUE for all messages except the last one
Expand Down
8 changes: 4 additions & 4 deletions packages/client-twitter/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
Memory,
State,
UUID,
embeddingZeroVector,
getEmbeddingZeroVector,
elizaLogger,
stringToUuid,
} from "@ai16z/eliza";
Expand Down Expand Up @@ -420,7 +420,7 @@ export class ClientBase extends EventEmitter {
content: content,
agentId: this.runtime.agentId,
roomId,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
createdAt: tweet.timestamp * 1000,
});

Expand Down Expand Up @@ -533,7 +533,7 @@ export class ClientBase extends EventEmitter {
content: content,
agentId: this.runtime.agentId,
roomId,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
createdAt: tweet.timestamp * 1000,
});

Expand Down Expand Up @@ -575,7 +575,7 @@ export class ClientBase extends EventEmitter {
} else {
await this.runtime.messageManager.createMemory({
...message,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
});
}

Expand Down
33 changes: 20 additions & 13 deletions packages/client-twitter/src/interactions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import {
State,
stringToUuid,
elizaLogger,
getEmbeddingZeroVector,
} from "@ai16z/eliza";
import { ClientBase } from "./base";
import { buildConversationThread, sendTweet, wait } from "./utils.ts";
import { embeddingZeroVector } from "@ai16z/eliza";

export const twitterMessageHandlerTemplate =
`{{timeline}}
Expand Down Expand Up @@ -130,13 +130,20 @@ export class TwitterInteractionClient {
BigInt(tweet.id) > this.client.lastCheckedTweetId
) {
// Generate the tweetId UUID the same way it's done in handleTweet
const tweetId = stringToUuid(tweet.id + "-" + this.runtime.agentId);
const tweetId = stringToUuid(
tweet.id + "-" + this.runtime.agentId
);

// Check if we've already processed this tweet
const existingResponse = await this.runtime.messageManager.getMemoryById(tweetId);
const existingResponse =
await this.runtime.messageManager.getMemoryById(
tweetId
);

if (existingResponse) {
elizaLogger.log(`Already responded to tweet ${tweet.id}, skipping`);
elizaLogger.log(
`Already responded to tweet ${tweet.id}, skipping`
);
continue;
}
elizaLogger.log("New Tweet found", tweet.permanentUrl);
Expand Down Expand Up @@ -280,10 +287,10 @@ export class TwitterInteractionClient {
url: tweet.permanentUrl,
inReplyTo: tweet.inReplyToStatusId
? stringToUuid(
tweet.inReplyToStatusId +
"-" +
this.runtime.agentId
)
tweet.inReplyToStatusId +
"-" +
this.runtime.agentId
)
: undefined,
},
userId: userIdUUID,
Expand Down Expand Up @@ -447,10 +454,10 @@ export class TwitterInteractionClient {
url: currentTweet.permanentUrl,
inReplyTo: currentTweet.inReplyToStatusId
? stringToUuid(
currentTweet.inReplyToStatusId +
"-" +
this.runtime.agentId
)
currentTweet.inReplyToStatusId +
"-" +
this.runtime.agentId
)
: undefined,
},
createdAt: currentTweet.timestamp * 1000,
Expand All @@ -459,7 +466,7 @@ export class TwitterInteractionClient {
currentTweet.userId === this.twitterUserId
? this.runtime.agentId
: stringToUuid(currentTweet.userId),
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
});
}

Expand Down
8 changes: 4 additions & 4 deletions packages/client-twitter/src/post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Tweet } from "agent-twitter-client";
import {
composeContext,
generateText,
embeddingZeroVector,
getEmbeddingZeroVector,
IAgentRuntime,
ModelClass,
stringToUuid,
Expand Down Expand Up @@ -209,8 +209,8 @@ export class TwitterPostClient {
);
const body = await result.json();
if (!body?.data?.create_tweet?.tweet_results?.result) {
console.error("Error sending tweet; Bad response:", body);
return;
console.error("Error sending tweet; Bad response:", body);
return;
}
const tweetResult = body.data.create_tweet.tweet_results.result;

Expand Down Expand Up @@ -267,7 +267,7 @@ export class TwitterPostClient {
source: "twitter",
},
roomId,
embedding: embeddingZeroVector,
embedding: getEmbeddingZeroVector(this.runtime),
createdAt: tweet.timestamp * 1000,
});
} catch (error) {
Expand Down
Loading