Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Refactor image interface and update to move llama cloud -> together provider #777

Merged
merged 6 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ export function getTokenForProvider(
settings.ETERNALAI_API_KEY
);
case ModelProviderName.LLAMACLOUD:
case ModelProviderName.TOGETHER:
return (
character.settings?.secrets?.LLAMACLOUD_API_KEY ||
settings.LLAMACLOUD_API_KEY ||
Expand Down
53 changes: 30 additions & 23 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,25 @@ export async function generateText({

// if runtime.getSetting("LLAMACLOUD_MODEL_LARGE") is true and modelProvider is LLAMACLOUD, then use the large model
if (
runtime.getSetting("LLAMACLOUD_MODEL_LARGE") &&
provider === ModelProviderName.LLAMACLOUD
(runtime.getSetting("LLAMACLOUD_MODEL_LARGE") &&
provider === ModelProviderName.LLAMACLOUD) ||
(runtime.getSetting("TOGETHER_MODEL_LARGE") &&
provider === ModelProviderName.TOGETHER)
) {
model = runtime.getSetting("LLAMACLOUD_MODEL_LARGE");
model =
runtime.getSetting("LLAMACLOUD_MODEL_LARGE") ||
runtime.getSetting("TOGETHER_MODEL_LARGE");
}

if (
runtime.getSetting("LLAMACLOUD_MODEL_SMALL") &&
provider === ModelProviderName.LLAMACLOUD
(runtime.getSetting("LLAMACLOUD_MODEL_SMALL") &&
provider === ModelProviderName.LLAMACLOUD) ||
(runtime.getSetting("TOGETHER_MODEL_SMALL") &&
provider === ModelProviderName.TOGETHER)
) {
model = runtime.getSetting("LLAMACLOUD_MODEL_SMALL");
model =
runtime.getSetting("LLAMACLOUD_MODEL_SMALL") ||
runtime.getSetting("TOGETHER_MODEL_SMALL");
}

elizaLogger.info("Selected model:", model);
Expand Down Expand Up @@ -120,7 +128,8 @@ export async function generateText({
case ModelProviderName.ETERNALAI:
case ModelProviderName.ALI_BAILIAN:
case ModelProviderName.VOLENGINE:
case ModelProviderName.LLAMACLOUD: {
case ModelProviderName.LLAMACLOUD:
case ModelProviderName.TOGETHER: {
elizaLogger.debug("Initializing OpenAI model.");
const openai = createOpenAI({ apiKey, baseURL: endpoint });

Expand Down Expand Up @@ -806,12 +815,6 @@ export const generateImage = async (
data?: string[];
error?: any;
}> => {
const { prompt, width, height } = data;
let { count } = data;
if (!count) {
count = 1;
}

const model = getModel(runtime.imageModelProvider, ModelClass.IMAGE);
const modelSettings = models[runtime.imageModelProvider].imageSettings;

Expand Down Expand Up @@ -866,16 +869,19 @@ export const generateImage = async (
const imageURL = await response.json();
return { success: true, data: [imageURL] };
} else if (
runtime.imageModelProvider === ModelProviderName.TOGETHER ||
// for backwards compat
runtime.imageModelProvider === ModelProviderName.LLAMACLOUD
) {
const together = new Together({ apiKey: apiKey as string });
// Fix: steps 4 is for schnell; 28 is for dev.
const response = await together.images.create({
model: "black-forest-labs/FLUX.1-schnell",
prompt,
width,
height,
prompt: data.prompt,
width: data.width,
height: data.height,
steps: modelSettings?.steps ?? 4,
n: count,
n: data.count,
});
const urls: string[] = [];
for (let i = 0; i < response.data.length; i++) {
Expand All @@ -902,11 +908,11 @@ export const generateImage = async (

// Prepare the input parameters according to their schema
const input = {
prompt: prompt,
prompt: data.prompt,
image_size: "square" as const,
num_inference_steps: modelSettings?.steps ?? 50,
guidance_scale: 3.5,
num_images: count,
guidance_scale: data.guidanceScale || 3.5,
num_images: data.count,
enable_safety_checker: true,
output_format: "png" as const,
seed: data.seed ?? 6252023,
Expand Down Expand Up @@ -945,7 +951,7 @@ export const generateImage = async (
const base64s = await Promise.all(base64Promises);
return { success: true, data: base64s };
} else {
let targetSize = `${width}x${height}`;
let targetSize = `${data.width}x${data.height}`;
if (
targetSize !== "1024x1024" &&
targetSize !== "1792x1024" &&
Expand All @@ -956,9 +962,9 @@ export const generateImage = async (
const openai = new OpenAI({ apiKey: apiKey as string });
const response = await openai.images.generate({
model,
prompt,
prompt: data.prompt,
size: targetSize as "1024x1024" | "1792x1024" | "1024x1792",
n: count,
n: data.count,
response_format: "b64_json",
});
const base64s = response.data.map(
Expand Down Expand Up @@ -1157,6 +1163,7 @@ export async function handleProvider(
case ModelProviderName.ALI_BAILIAN:
case ModelProviderName.VOLENGINE:
case ModelProviderName.LLAMACLOUD:
case ModelProviderName.TOGETHER:
return await handleOpenAI(options);
case ModelProviderName.ANTHROPIC:
return await handleAnthropic(options);
Expand Down
21 changes: 21 additions & 0 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,27 @@ export const models: Models = {
[ModelClass.IMAGE]: "black-forest-labs/FLUX.1-schnell",
},
},
[ModelProviderName.TOGETHER]: {
settings: {
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
repetition_penalty: 0.4,
temperature: 0.7,
},
imageSettings: {
steps: 4,
},
endpoint: "https://api.together.ai/v1",
model: {
[ModelClass.SMALL]: "meta-llama/Llama-3.2-3B-Instruct-Turbo",
[ModelClass.MEDIUM]: "meta-llama-3.1-8b-instruct",
[ModelClass.LARGE]: "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
[ModelClass.EMBEDDING]:
"togethercomputer/m2-bert-80M-32k-retrieval",
[ModelClass.IMAGE]: "black-forest-labs/FLUX.1-schnell",
},
},
[ModelProviderName.LLAMALOCAL]: {
settings: {
stop: ["<|eot_id|>", "<|eom_id|>"],
Expand Down
102 changes: 56 additions & 46 deletions packages/core/src/tests/generation.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { ModelProviderName, IAgentRuntime } from "../types";
import { models } from "../models";
import { generateText, generateTrueOrFalse, splitChunks, trimTokens } from "../generation";
import {
generateText,
generateTrueOrFalse,
splitChunks,
trimTokens,
} from "../generation";
import type { TiktokenModel } from "js-tiktoken";

// Mock the elizaLogger
Expand Down Expand Up @@ -42,6 +47,8 @@ describe("Generation", () => {
getSetting: vi.fn().mockImplementation((key: string) => {
if (key === "LLAMACLOUD_MODEL_LARGE") return false;
if (key === "LLAMACLOUD_MODEL_SMALL") return false;
if (key === "TOGETHER_MODEL_LARGE") return false;
if (key === "TOGETHER_MODEL_SMALL") return false;
return undefined;
}),
} as unknown as IAgentRuntime;
Expand Down Expand Up @@ -122,53 +129,56 @@ describe("Generation", () => {
});
});

describe("trimTokens", () => {
const model = "gpt-4" as TiktokenModel;

it("should return empty string for empty input", () => {
const result = trimTokens("", 100, model);
expect(result).toBe("");
});

it("should throw error for negative maxTokens", () => {
expect(() => trimTokens("test", -1, model)).toThrow("maxTokens must be positive");
});

it("should return unchanged text if within token limit", () => {
const shortText = "This is a short text";
const result = trimTokens(shortText, 10, model);
expect(result).toBe(shortText);
});

it("should truncate text to specified token limit", () => {
// Using a longer text that we know will exceed the token limit
const longText = "This is a much longer text that will definitely exceed our very small token limit and need to be truncated to fit within the specified constraints."
const result = trimTokens(longText, 5, model);

// The exact result will depend on the tokenizer, but we can verify:
// 1. Result is shorter than original
expect(result.length).toBeLessThan(longText.length);
// 2. Result is not empty
expect(result.length).toBeGreaterThan(0);
// 3. Result is a proper substring of the original text
expect(longText.includes(result)).toBe(true);
});

it("should handle non-ASCII characters", () => {
const unicodeText = "Hello 👋 World 🌍";
const result = trimTokens(unicodeText, 5, model);
expect(result.length).toBeGreaterThan(0);
});

it("should handle multiline text", () => {
const multilineText = `Line 1
describe("trimTokens", () => {
const model = "gpt-4" as TiktokenModel;

it("should return empty string for empty input", () => {
const result = trimTokens("", 100, model);
expect(result).toBe("");
});

it("should throw error for negative maxTokens", () => {
expect(() => trimTokens("test", -1, model)).toThrow(
"maxTokens must be positive"
);
});

it("should return unchanged text if within token limit", () => {
const shortText = "This is a short text";
const result = trimTokens(shortText, 10, model);
expect(result).toBe(shortText);
});

it("should truncate text to specified token limit", () => {
// Using a longer text that we know will exceed the token limit
const longText =
"This is a much longer text that will definitely exceed our very small token limit and need to be truncated to fit within the specified constraints.";
const result = trimTokens(longText, 5, model);

// The exact result will depend on the tokenizer, but we can verify:
// 1. Result is shorter than original
expect(result.length).toBeLessThan(longText.length);
// 2. Result is not empty
expect(result.length).toBeGreaterThan(0);
// 3. Result is a proper substring of the original text
expect(longText.includes(result)).toBe(true);
});

it("should handle non-ASCII characters", () => {
const unicodeText = "Hello 👋 World 🌍";
const result = trimTokens(unicodeText, 5, model);
expect(result.length).toBeGreaterThan(0);
});

it("should handle multiline text", () => {
const multilineText = `Line 1
Line 2
Line 3
Line 4
Line 5`;
const result = trimTokens(multilineText, 5, model);
expect(result.length).toBeGreaterThan(0);
expect(result.length).toBeLessThan(multilineText.length);
});
});
const result = trimTokens(multilineText, 5, model);
expect(result.length).toBeGreaterThan(0);
expect(result.length).toBeLessThan(multilineText.length);
});
});
});
2 changes: 2 additions & 0 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ export type Models = {
[ModelProviderName.GROK]: Model;
[ModelProviderName.GROQ]: Model;
[ModelProviderName.LLAMACLOUD]: Model;
[ModelProviderName.TOGETHER]: Model;
[ModelProviderName.LLAMALOCAL]: Model;
[ModelProviderName.GOOGLE]: Model;
[ModelProviderName.CLAUDE_VERTEX]: Model;
Expand All @@ -216,6 +217,7 @@ export enum ModelProviderName {
GROK = "grok",
GROQ = "groq",
LLAMACLOUD = "llama_cloud",
TOGETHER = "together",
LLAMALOCAL = "llama_local",
GOOGLE = "google",
CLAUDE_VERTEX = "claude_vertex",
Expand Down
34 changes: 29 additions & 5 deletions packages/plugin-image-generation/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { generateImage } from "@ai16z/eliza";

import fs from "fs";
import path from "path";
import { validateImageGenConfig } from "./enviroment";
import { validateImageGenConfig } from "./environment";

export function saveBase64Image(base64Data: string, filename: string): string {
// Create generatedImages directory if it doesn't exist
Expand Down Expand Up @@ -97,7 +97,17 @@ const imageGeneration: Action = {
runtime: IAgentRuntime,
message: Memory,
state: State,
options: any,
options: {
width?: number;
height?: number;
count?: number;
negativePrompt?: string;
numIterations?: number;
guidanceScale?: number;
seed?: number;
modelId?: string;
jobId?: string;
},
callback: HandlerCallback
) => {
elizaLogger.log("Composing state for message:", message);
Expand All @@ -116,9 +126,23 @@ const imageGeneration: Action = {
const images = await generateImage(
{
prompt: imagePrompt,
width: 1024,
height: 1024,
count: 1,
width: options.width || 1024,
height: options.height || 1024,
...(options.count != null ? { count: options.count || 1 } : {}),
...(options.negativePrompt != null
? { negativePrompt: options.negativePrompt }
: {}),
...(options.numIterations != null
? { numIterations: options.numIterations }
: {}),
...(options.guidanceScale != null
? { guidanceScale: options.guidanceScale }
: {}),
...(options.seed != null ? { seed: options.seed } : {}),
...(options.modelId != null
? { modelId: options.modelId }
: {}),
...(options.jobId != null ? { jobId: options.jobId } : {}),
},
runtime
);
Expand Down
Loading