Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/tomguluson92/eliza into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Dec 2, 2024
2 parents ba21ec8 + a747ae6 commit 2a9010e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
25 changes: 11 additions & 14 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -806,11 +806,6 @@ export const generateImage = async (
data?: string[];
error?: any;
}> => {
const { prompt, width, height } = data;
let { count } = data;
if (!count) {
count = 1;
}

const model = getModel(runtime.imageModelProvider, ModelClass.IMAGE);
const modelSettings = models[runtime.imageModelProvider].imageSettings;
Expand Down Expand Up @@ -866,16 +861,18 @@ export const generateImage = async (
const imageURL = await response.json();
return { success: true, data: [imageURL] };
} else if (
// TODO: Fix LLAMACLOUD -> Together?
runtime.imageModelProvider === ModelProviderName.LLAMACLOUD
) {
const together = new Together({ apiKey: apiKey as string });
// Fix: steps 4 is for schnell; 28 is for dev.
const response = await together.images.create({
model: "black-forest-labs/FLUX.1-schnell",
prompt,
width,
height,
data.prompt,
data.width,
data.height,
steps: modelSettings?.steps ?? 4,
n: count,
n: data.count,
});
const urls: string[] = [];
for (let i = 0; i < response.data.length; i++) {
Expand All @@ -902,11 +899,11 @@ export const generateImage = async (

// Prepare the input parameters according to their schema
const input = {
prompt: prompt,
prompt: data.prompt,
image_size: "square" as const,
num_inference_steps: modelSettings?.steps ?? 50,
guidance_scale: 3.5,
num_images: count,
guidance_scale: data.guidanceScale || 3.5,
num_images: data.count,
enable_safety_checker: true,
output_format: "png" as const,
seed: data.seed ?? 6252023,
Expand Down Expand Up @@ -956,9 +953,9 @@ export const generateImage = async (
const openai = new OpenAI({ apiKey: apiKey as string });
const response = await openai.images.generate({
model,
prompt,
data.prompt,
size: targetSize as "1024x1024" | "1792x1024" | "1024x1792",
n: count,
n: data.count,
response_format: "b64_json",
});
const base64s = response.data.map(
Expand Down
26 changes: 21 additions & 5 deletions packages/plugin-image-generation/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { generateImage } from "@ai16z/eliza";

import fs from "fs";
import path from "path";
import { validateImageGenConfig } from "./enviroment";
import { validateImageGenConfig } from "./environment";

export function saveBase64Image(base64Data: string, filename: string): string {
// Create generatedImages directory if it doesn't exist
Expand Down Expand Up @@ -97,7 +97,17 @@ const imageGeneration: Action = {
runtime: IAgentRuntime,
message: Memory,
state: State,
options: any,
options: {
width?: number;
height?: number;
count?: number;
negativePrompt?: string;
numIterations?: number;
guidanceScale?: number;
seed?: number;
modelId?: string;
jobId?: string;
},
callback: HandlerCallback
) => {
elizaLogger.log("Composing state for message:", message);
Expand All @@ -116,9 +126,15 @@ const imageGeneration: Action = {
const images = await generateImage(
{
prompt: imagePrompt,
width: 1024,
height: 1024,
count: 1,
...(options.width !== undefined ? { width: options.width || 1024 } : {}),
...(options.height !== undefined ? { height: options.height || 1024 } : {}),
...(options.count !== undefined ? { count: options.count || 1 } : {}),
...(options.negativePrompt !== undefined ? { negativePrompt: options.negativePrompt } : {}),
...(options.numIterations !== undefined ? { numIterations: options.numIterations } : {}),
...(options.guidanceScale !== undefined ? { guidanceScale: options.guidanceScale } : {}),
...(options.seed !== undefined ? { seed: options.seed } : {}),
...(options.modelId !== undefined ? { modelId: options.modelId } : {}),
...(options.jobId !== undefined ? { jobId: options.jobId } : {})
},
runtime
);
Expand Down

0 comments on commit 2a9010e

Please sign in to comment.