Skip to content

Commit

Permalink
Merge pull request #1002 from khoj-ai/features/improve-multiple-outpu…
Browse files Browse the repository at this point in the history
…t-mode-generation

Improve handling of multiple output modes

- Use the generated descriptions / inferred queries to supply context to the model about what it's created and give a richer response
- Stop sending the generated image in user message. This seemed to be confusing the model more than helping.
- Collect generated assets in a structured objects to provide model context. This seems to help it follow instructions and separate responsibility better
- Also, rename the open ai converse method to converse_openai to follow patterns with other providers
  • Loading branch information
sabaimran authored Dec 11, 2024
2 parents 4bc5c13 + 2bb14c5 commit 01d000e
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 87 deletions.
6 changes: 2 additions & 4 deletions src/khoj/processor/conversation/anthropic/anthropic_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,9 @@ def converse_anthropic(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
program_execution_context: Optional[List[str]] = None,
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
):
"""
Expand Down Expand Up @@ -221,9 +220,8 @@ def converse_anthropic(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)

Expand Down
6 changes: 2 additions & 4 deletions src/khoj/processor/conversation/google/gemini_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,8 @@ def converse_gemini(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
tracer={},
):
Expand Down Expand Up @@ -232,9 +231,8 @@ def converse_gemini(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)

Expand Down
4 changes: 3 additions & 1 deletion src/khoj/processor/conversation/offline/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Union

import pyjson5
from langchain.schema import ChatMessage
Expand Down Expand Up @@ -166,6 +166,7 @@ def converse_offline(
query_files: str = None,
generated_files: List[FileAttachment] = None,
additional_context: List[str] = None,
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Expand Down Expand Up @@ -234,6 +235,7 @@ def converse_offline(
model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=additional_context,
)

Expand Down
8 changes: 3 additions & 5 deletions src/khoj/processor/conversation/openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def send_message_to_model(
)


def converse(
def converse_openai(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
Expand All @@ -157,9 +157,8 @@ def converse(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
tracer: dict = {},
):
Expand Down Expand Up @@ -223,9 +222,8 @@ def converse(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
Expand Down
54 changes: 33 additions & 21 deletions src/khoj/processor/conversation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,40 +178,41 @@
""".strip()
)

generated_image_attachment = PromptTemplate.from_template(
f"""
Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences.
""".strip()
)
generated_assets_context = PromptTemplate.from_template(
"""
Assets that you created have already been created to respond to the query. Below, there are references to the descriptions used to create the assets.
You can provide a summary of your reasoning from the information below or use it to respond to the original query.
generated_diagram_attachment = PromptTemplate.from_template(
f"""
I've successfully created a diagram based on the user's query. The diagram will automatically be shared with the user. I can follow-up with a general response or summary. Limit to 1-2 sentences.
Generated Assets:
{generated_assets}
Limit your response to 3 sentences max. Be succinct, clear, and informative.
""".strip()
)


## Diagram Generation
## --

improve_diagram_description_prompt = PromptTemplate.from_template(
"""
you are an architect working with a novice digital artist using a diagramming software.
You are an architect working with a novice digital artist using a diagramming software.
{personality_context}
you need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
- text
- rectangle
- ellipse
- line
- arrow
You need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
- Text
- Rectangle
- Ellipse
- Line
- Arrow
use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description.
Use these primitives to describe what sort of diagram the drawer should create. The artist must recreate the diagram every time, so include all relevant prior information in your description.
- include the full, exact description. the artist does not have much experience, so be precise.
- describe the layout.
- you can only use straight lines.
- use simple, concise language.
- keep it simple and easy to understand. the artist is easily distracted.
- Include the full, exact description. the artist does not have much experience, so be precise.
- Describe the layout.
- You can only use straight lines.
- Use simple, concise language.
- Keep it simple and easy to understand. the artist is easily distracted.
Today's Date: {current_date}
User's Location: {location}
Expand Down Expand Up @@ -337,6 +338,17 @@
""".strip()
)

failed_diagram_generation = PromptTemplate.from_template(
"""
You attempted to programmatically generate a diagram but failed due to a system issue. You are normally able to generate diagrams, but you encountered a system issue this time.
You can create an ASCII image of the diagram in response instead.
This is the diagram you attempted to make:
{attempted_diagram}
""".strip()
)

## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(
Expand Down
61 changes: 31 additions & 30 deletions src/khoj/processor/conversation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
merge_dicts,
)
from khoj.utils.rawconfig import FileAttachment
from khoj.utils.yaml import yaml_dump

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -381,9 +382,8 @@ def generate_chatml_messages_with_context(
model_type="",
context_message="",
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: str = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = [],
):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
Expand All @@ -403,11 +403,15 @@ def generate_chatml_messages_with_context(
message_context = ""
message_attached_files = ""

generated_assets = {}

chat_message = chat.get("message")
role = "user" if chat["by"] == "you" else "assistant"

# Legacy code to handle excalidraw diagrams prior to Dec 2024
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0]

if not is_none_or_empty(chat.get("context")):
references = "\n\n".join(
{
Expand All @@ -434,15 +438,23 @@ def generate_chatml_messages_with_context(
reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message)

if chat.get("images") and role == "assistant":
# Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message.
file_attachment_message = construct_structured_message(
message=prompts.generated_image_attachment.format(),
images=chat.get("images"),
model_type=model_type,
vision_enabled=vision_enabled,
if not is_none_or_empty(chat.get("images")) and role == "assistant":
generated_assets["image"] = {
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
}

if not is_none_or_empty(chat.get("excalidrawDiagram")) and role == "assistant":
generated_assets["diagram"] = {
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
}

if not is_none_or_empty(generated_assets):
chatml_messages.append(
ChatMessage(
content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_assets))}\n",
role="user",
)
)
chatml_messages.append(ChatMessage(content=file_attachment_message, role="user"))

message_content = construct_structured_message(
chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled
Expand All @@ -465,33 +477,22 @@ def generate_chatml_messages_with_context(
role="user",
)
)
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))

if generated_images:
messages.append(
ChatMessage(
content=construct_structured_message(
prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled
),
role="user",
)
)

if generated_files:
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
messages.append(ChatMessage(content=message_attached_files, role="assistant"))

if generated_excalidraw_diagram:
messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant"))
if not is_none_or_empty(generated_asset_results):
context_message += (
f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n"
)

if program_execution_context:
messages.append(
ChatMessage(
content=prompts.additional_program_context.format(context="\n".join(program_execution_context)),
role="assistant",
)
)
program_context_text = "\n".join(program_execution_context)
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"

if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))

if len(chatml_messages) > 0:
messages += chatml_messages
Expand Down
15 changes: 14 additions & 1 deletion src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
aget_user_name,
)
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
from khoj.processor.image.generate import text_to_image
Expand Down Expand Up @@ -765,6 +766,7 @@ def collect_telemetry():
researched_results = ""
online_results: Dict = dict()
code_results: Dict = dict()
generated_asset_results: Dict = dict()
## Extract Document References
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
Expand Down Expand Up @@ -1128,6 +1130,10 @@ def collect_telemetry():
else:
generated_images.append(generated_image)

generated_asset_results["images"] = {
"query": improved_image_prompt,
}

async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
Expand Down Expand Up @@ -1166,6 +1172,10 @@ def collect_telemetry():

generated_excalidraw_diagram = diagram_description

generated_asset_results["diagrams"] = {
"query": better_diagram_description_prompt,
}

async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
Expand All @@ -1176,7 +1186,9 @@ def collect_telemetry():
else:
error_message = "Failed to generate diagram. Please try again later."
program_execution_context.append(
f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt."
prompts.failed_diagram_generation.format(
attempted_diagram=better_diagram_description_prompt
)
)

async for result in send_event(ChatEvent.STATUS, error_message):
Expand Down Expand Up @@ -1209,6 +1221,7 @@ def collect_telemetry():
generated_files,
generated_excalidraw_diagram,
program_execution_context,
generated_asset_results,
tracer,
)

Expand Down
Loading

0 comments on commit 01d000e

Please sign in to comment.