Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using image generation models like Flux via Replicate #909

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Generated by Django 5.0.7 on 2024-09-12 05:43

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0060_merge_20240905_1828"),
]

operations = [
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai"), ("replicate", "Replicate")],
default="openai",
max_length=200,
),
),
]
14 changes: 14 additions & 0 deletions src/khoj/database/migrations/0062_merge_20240913_0222.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Generated by Django 5.0.8 on 2024-09-13 02:22

from typing import List

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("database", "0061_alter_chatmodeloptions_model_type"),
("database", "0061_alter_texttoimagemodelconfig_model_type"),
]

operations: List[str] = []
1 change: 1 addition & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
STABILITYAI = "stability-ai"
REPLICATE = "replicate"

model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
Expand Down
8 changes: 4 additions & 4 deletions src/khoj/processor/conversation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@
## --

image_generation_improve_prompt_base = """
You are a talented creator with the ability to describe images to compose in vivid, fine detail.
Use the provided context and user prompt to generate a more detailed prompt to create an image:
You are a talented media artist with the ability to describe images to compose in professional, fine detail.
Generate a vivid description of the image to be rendered using the provided context and user prompt below:

Today's Date: {current_date}
User's Location: {location}
Expand All @@ -145,10 +145,10 @@

User Prompt: "{query}"

Now generate an improved prompt describing the image to generate in vivid, fine detail.
Now generate an professional description of the image to generate in vivid, fine detail.
- Use today's date, user's location, user's notes and online references to weave in any context that will improve the image generation.
- Retain any important information and follow any instructions in the conversation log or user prompt.
- Add specific, fine position details to compose the image.
- Add specific, fine position details. Mention painting style, camera parameters to compose the image.
- Ensure your improved prompt is in prose format."""

image_generation_improve_prompt_dalle = PromptTemplate.from_template(
Expand Down
212 changes: 212 additions & 0 deletions src/khoj/processor/image/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import base64
import io
import logging
import time
from typing import Any, Callable, Dict, List, Optional

import openai
import requests

from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
from khoj.utils.rawconfig import LocationData

logger = logging.getLogger(__name__)


async def text_to_image(
message: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
):
status_code = 200
image = None
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3

text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
yield image_url or image, status_code, message, intent_type.value
return

text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"

if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}

# Generate a better image prompt
# Use the user's message, chat history, and other context
image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
)

if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
yield {ChatEvent.STATUS: event}

# Generate image using the configured model and API
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
try:
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
status_code = 502
yield image_url or image, status_code, message, intent_type.value
return

# Decide how to store the generated image
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")

yield image_url or image, status_code, image_prompt, intent_type.value


def generate_image_with_openai(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using OpenAI API"

# Get the API key from the user's configuration
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}

# Generate image using OpenAI API
OPENAI_IMAGE_GEN_STYLE = "vivid"
response = state.openai_client.images.generate(
prompt=improved_image_prompt,
model=text2image_model,
style=OPENAI_IMAGE_GEN_STYLE,
response_format="b64_json",
extra_headers=auth_header,
)

# Extract the base64 image from the response
image = response.data[0].b64_json
# Decode base64 png and convert it to webp for faster loading
return convert_image_to_webp(base64.b64decode(image))


def generate_image_with_stability(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using Stability AI"

# Call Stability AI API to generate image
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={
"prompt": improved_image_prompt,
"model": text2image_model,
"mode": "text-to-image",
"output_format": "png",
"aspect_ratio": "1:1",
},
)
# Convert png to webp for faster loading
return convert_image_to_webp(response.content)


def generate_image_with_replicate(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using Replicate API"

# Create image generation task on Replicate
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = {
"Authorization": f"Bearer {text_to_image_config.api_key}",
"Content-Type": "application/json",
}
json = {
"input": {
"prompt": improved_image_prompt,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "webp",
"output_quality": 100,
}
}
create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()

# Get status of image generation task
get_prediction_url = create_prediction["urls"]["get"]
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count = 1

# Poll the image generation task for completion status
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
time.sleep(2)
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count += 1

# Raise exception if the image generation task fails
if status != "succeeded":
if retry_count >= 10:
raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}")

# Get the generated image
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
return io.BytesIO(requests.get(image_url).content).getvalue()
2 changes: 1 addition & 1 deletion src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
Expand All @@ -44,7 +45,6 @@
is_query_empty,
is_ready_to_chat,
read_chat_stream,
text_to_image,
update_telemetry_state,
validate_conversation_config,
)
Expand Down
Loading
Loading