Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add custom fetch logic for agent #1010

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ export const wait = (minTime: number = 1000, maxTime: number = 3000) => {
return new Promise((resolve) => setTimeout(resolve, waitTime));
};

const logFetch = async (url: string, options: any) => {
elizaLogger.info(`Fetching ${url}`);
elizaLogger.info(options);
return fetch(url, options);
};

export function parseArguments(): {
character?: string;
characters?: string;
Expand Down Expand Up @@ -473,6 +479,7 @@ export async function createAgent(
services: [],
managers: [],
cacheManager: cache,
fetch: logFetch,
});
}

Expand Down
138 changes: 95 additions & 43 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,47 +78,68 @@ export async function generateText({

// allow character.json settings => secrets to override models
// FIXME: add MODEL_MEDIUM support
switch(provider) {
switch (provider) {
// if runtime.getSetting("LLAMACLOUD_MODEL_LARGE") is true and modelProvider is LLAMACLOUD, then use the large model
case ModelProviderName.LLAMACLOUD: {
switch(modelClass) {
case ModelClass.LARGE: {
model = runtime.getSetting("LLAMACLOUD_MODEL_LARGE") || model;
}
break;
case ModelClass.SMALL: {
model = runtime.getSetting("LLAMACLOUD_MODEL_SMALL") || model;
case ModelProviderName.LLAMACLOUD:
{
switch (modelClass) {
case ModelClass.LARGE:
{
model =
runtime.getSetting("LLAMACLOUD_MODEL_LARGE") ||
model;
}
break;
case ModelClass.SMALL:
{
model =
runtime.getSetting("LLAMACLOUD_MODEL_SMALL") ||
model;
}
break;
}
break;
}
}
break;
case ModelProviderName.TOGETHER: {
switch(modelClass) {
case ModelClass.LARGE: {
model = runtime.getSetting("TOGETHER_MODEL_LARGE") || model;
}
break;
case ModelClass.SMALL: {
model = runtime.getSetting("TOGETHER_MODEL_SMALL") || model;
break;
case ModelProviderName.TOGETHER:
{
switch (modelClass) {
case ModelClass.LARGE:
{
model =
runtime.getSetting("TOGETHER_MODEL_LARGE") ||
model;
}
break;
case ModelClass.SMALL:
{
model =
runtime.getSetting("TOGETHER_MODEL_SMALL") ||
model;
}
break;
}
break;
}
}
break;
case ModelProviderName.OPENROUTER: {
switch(modelClass) {
case ModelClass.LARGE: {
model = runtime.getSetting("LARGE_OPENROUTER_MODEL") || model;
}
break;
case ModelClass.SMALL: {
model = runtime.getSetting("SMALL_OPENROUTER_MODEL") || model;
break;
case ModelProviderName.OPENROUTER:
{
switch (modelClass) {
case ModelClass.LARGE:
{
model =
runtime.getSetting("LARGE_OPENROUTER_MODEL") ||
model;
}
break;
case ModelClass.SMALL:
{
model =
runtime.getSetting("SMALL_OPENROUTER_MODEL") ||
model;
}
break;
}
break;
}
}
break;
break;
}

elizaLogger.info("Selected model:", model);
Expand Down Expand Up @@ -155,7 +176,11 @@ export async function generateText({
case ModelProviderName.HYPERBOLIC:
case ModelProviderName.TOGETHER: {
elizaLogger.debug("Initializing OpenAI model.");
const openai = createOpenAI({ apiKey, baseURL: endpoint });
const openai = createOpenAI({
apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: openaiResponse } = await aiGenerateText({
model: openai.languageModel(model),
Expand All @@ -176,7 +201,9 @@ export async function generateText({
}

case ModelProviderName.GOOGLE: {
const google = createGoogleGenerativeAI();
const google = createGoogleGenerativeAI({
fetch: runtime.fetch,
});

const { text: googleResponse } = await aiGenerateText({
model: google(model),
Expand All @@ -199,7 +226,10 @@ export async function generateText({
case ModelProviderName.ANTHROPIC: {
elizaLogger.debug("Initializing Anthropic model.");

const anthropic = createAnthropic({ apiKey });
const anthropic = createAnthropic({
apiKey,
fetch: runtime.fetch,
});

const { text: anthropicResponse } = await aiGenerateText({
model: anthropic.languageModel(model),
Expand All @@ -222,7 +252,10 @@ export async function generateText({
case ModelProviderName.CLAUDE_VERTEX: {
elizaLogger.debug("Initializing Claude Vertex model.");

const anthropic = createAnthropic({ apiKey });
const anthropic = createAnthropic({
apiKey,
fetch: runtime.fetch,
});

const { text: anthropicResponse } = await aiGenerateText({
model: anthropic.languageModel(model),
Expand All @@ -246,7 +279,11 @@ export async function generateText({

case ModelProviderName.GROK: {
elizaLogger.debug("Initializing Grok model.");
const grok = createOpenAI({ apiKey, baseURL: endpoint });
const grok = createOpenAI({
apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: grokResponse } = await aiGenerateText({
model: grok.languageModel(model, {
Expand All @@ -269,7 +306,7 @@ export async function generateText({
}

case ModelProviderName.GROQ: {
const groq = createGroq({ apiKey });
const groq = createGroq({ apiKey, fetch: runtime.fetch });

const { text: groqResponse } = await aiGenerateText({
model: groq.languageModel(model),
Expand Down Expand Up @@ -316,7 +353,11 @@ export async function generateText({
case ModelProviderName.REDPILL: {
elizaLogger.debug("Initializing RedPill model.");
const serverUrl = models[provider].endpoint;
const openai = createOpenAI({ apiKey, baseURL: serverUrl });
const openai = createOpenAI({
apiKey,
baseURL: serverUrl,
fetch: runtime.fetch,
});

const { text: redpillResponse } = await aiGenerateText({
model: openai.languageModel(model),
Expand All @@ -339,7 +380,11 @@ export async function generateText({
case ModelProviderName.OPENROUTER: {
elizaLogger.debug("Initializing OpenRouter model.");
const serverUrl = models[provider].endpoint;
const openrouter = createOpenAI({ apiKey, baseURL: serverUrl });
const openrouter = createOpenAI({
apiKey,
baseURL: serverUrl,
fetch: runtime.fetch,
});

const { text: openrouterResponse } = await aiGenerateText({
model: openrouter.languageModel(model),
Expand All @@ -365,6 +410,7 @@ export async function generateText({

const ollamaProvider = createOllama({
baseURL: models[provider].endpoint + "/api",
fetch: runtime.fetch,
});
const ollama = ollamaProvider(model);

Expand All @@ -389,6 +435,7 @@ export async function generateText({
const heurist = createOpenAI({
apiKey: apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: heuristResponse } = await aiGenerateText({
Expand Down Expand Up @@ -434,7 +481,11 @@ export async function generateText({

elizaLogger.debug("Using GAIANET model with baseURL:", baseURL);

const openai = createOpenAI({ apiKey, baseURL: endpoint });
const openai = createOpenAI({
apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: openaiResponse } = await aiGenerateText({
model: openai.languageModel(model),
Expand All @@ -459,6 +510,7 @@ export async function generateText({
const galadriel = createOpenAI({
apiKey: apiKey,
baseURL: endpoint,
fetch: runtime.fetch,
});

const { text: galadrielResponse } = await aiGenerateText({
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,8 @@ export interface IAgentRuntime {
evaluators: Evaluator[];
plugins: Plugin[];

fetch?: typeof fetch | null;

messageManager: IMemoryManager;
descriptionManager: IMemoryManager;
documentsManager: IMemoryManager;
Expand Down
Loading