From 6c8007e23b2e51434e6874a338b9ede4db28f503 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 10 Dec 2024 16:54:21 -0800 Subject: [PATCH] Improve handling of multiple output modes - Use the generated descriptions / inferred queries to supply context to the model about what it's created and give a richer response - Stop sending the generated image in user message. This seemed to be confusing the model more than helping. - Also, rename the open ai converse method to converse_openai to follow patterns with other providers --- .../conversation/anthropic/anthropic_chat.py | 6 +- .../conversation/google/gemini_chat.py | 6 +- .../conversation/offline/chat_model.py | 4 +- src/khoj/processor/conversation/openai/gpt.py | 8 +-- src/khoj/processor/conversation/prompts.py | 54 +++++++++------- src/khoj/processor/conversation/utils.py | 61 ++++++++++--------- src/khoj/routers/api_chat.py | 15 ++++- src/khoj/routers/helpers.py | 20 +++--- tests/test_online_chat_actors.py | 24 ++++---- 9 files changed, 111 insertions(+), 87 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 688514ca8..fa5ff9d82 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -157,10 +157,9 @@ def converse_anthropic( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, - generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, - generated_excalidraw_diagram: Optional[str] = None, program_execution_context: Optional[List[str]] = None, + generated_asset_results: Dict[str, Dict] = {}, tracer: dict = {}, ): """ @@ -221,9 +220,8 @@ def converse_anthropic( vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.ANTHROPIC, query_files=query_files, - generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, - generated_images=generated_images, + generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, ) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index ad10acda8..3567efed4 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -167,9 +167,8 @@ def converse_gemini( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, - generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, - generated_excalidraw_diagram: Optional[str] = None, + generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, tracer={}, ): @@ -232,9 +231,8 @@ def converse_gemini( vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.GOOGLE, query_files=query_files, - generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, - generated_images=generated_images, + generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, ) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index d81c194cc..2091d0a97 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -3,7 +3,7 @@ import os from datetime import datetime, timedelta from threading import Thread -from typing import Any, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union import pyjson5 from langchain.schema import ChatMessage @@ -166,6 +166,7 @@ def converse_offline( query_files: str = None, generated_files: List[FileAttachment] = None, additional_context: List[str] = None, + generated_asset_results: Dict[str, Dict] = {}, tracer: dict = {}, ) -> Union[ThreadedGenerator, Iterator[str]]: """ @@ -234,6 +235,7 @@ def converse_offline( model_type=ChatModelOptions.ModelType.OFFLINE, query_files=query_files, generated_files=generated_files, + generated_asset_results=generated_asset_results, program_execution_context=additional_context, ) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c8faf25e7..83e6d0dfa 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -137,7 +137,7 @@ def send_message_to_model( ) -def converse( +def converse_openai( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -157,9 +157,8 @@ def converse( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, - generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, - generated_excalidraw_diagram: Optional[str] = None, + generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, tracer: dict = {}, ): @@ -223,9 +222,8 @@ def converse( vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.OPENAI, query_files=query_files, - generated_excalidraw_diagram=generated_excalidraw_diagram, generated_files=generated_files, - generated_images=generated_images, + generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, ) logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 11b50d9cf..f173b9f40 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -178,40 +178,41 @@ """.strip() ) -generated_image_attachment = PromptTemplate.from_template( - f""" -Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences. -""".strip() -) +generated_assets_context = PromptTemplate.from_template( + """ +Assets that you created have already been created to respond to the query. Below, there are references to the descriptions used to create the assets. +You can provide a summary of your reasoning from the information below or use it to respond to the original query. -generated_diagram_attachment = PromptTemplate.from_template( - f""" -I've successfully created a diagram based on the user's query. The diagram will automatically be shared with the user. I can follow-up with a general response or summary. Limit to 1-2 sentences. +Generated Assets: +{generated_assets} + +Limit your response to 3 sentences max. Be succinct, clear, and informative. """.strip() ) + ## Diagram Generation ## -- improve_diagram_description_prompt = PromptTemplate.from_template( """ -you are an architect working with a novice digital artist using a diagramming software. +You are an architect working with a novice digital artist using a diagramming software. {personality_context} -you need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like -- text -- rectangle -- ellipse -- line -- arrow +You need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like +- Text +- Rectangle +- Ellipse +- Line +- Arrow -use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description. +Use these primitives to describe what sort of diagram the drawer should create. The artist must recreate the diagram every time, so include all relevant prior information in your description. -- include the full, exact description. the artist does not have much experience, so be precise. -- describe the layout. -- you can only use straight lines. -- use simple, concise language. -- keep it simple and easy to understand. the artist is easily distracted. +- Include the full, exact description. the artist does not have much experience, so be precise. +- Describe the layout. +- You can only use straight lines. +- Use simple, concise language. +- Keep it simple and easy to understand. the artist is easily distracted. Today's Date: {current_date} User's Location: {location} @@ -337,6 +338,17 @@ """.strip() ) +failed_diagram_generation = PromptTemplate.from_template( + """ +You attempted to programmatically generate a diagram but failed due to a system issue. You are normally able to generate diagrams, but you encountered a system issue this time. + +You can create an ASCII image of the diagram in response instead. + +This is the diagram you attempted to make: +{attempted_diagram} +""".strip() +) + ## Online Search Conversation ## -- online_search_conversation = PromptTemplate.from_template( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 12fc736f9..32f755b38 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -40,6 +40,7 @@ merge_dicts, ) from khoj.utils.rawconfig import FileAttachment +from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) @@ -380,9 +381,8 @@ def generate_chatml_messages_with_context( model_type="", context_message="", query_files: str = None, - generated_images: Optional[list[str]] = None, generated_files: List[FileAttachment] = None, - generated_excalidraw_diagram: str = None, + generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = [], ): """Generate chat messages with appropriate context from previous conversation to send to the chat model""" @@ -402,11 +402,15 @@ def generate_chatml_messages_with_context( message_context = "" message_attached_files = "" + generated_assets = {} + chat_message = chat.get("message") role = "user" if chat["by"] == "you" else "assistant" + # Legacy code to handle excalidraw diagrams prior to Dec 2024 if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): chat_message = chat["intent"].get("inferred-queries")[0] + if not is_none_or_empty(chat.get("context")): references = "\n\n".join( { @@ -433,15 +437,23 @@ def generate_chatml_messages_with_context( reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) - if chat.get("images") and role == "assistant": - # Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message. - file_attachment_message = construct_structured_message( - message=prompts.generated_image_attachment.format(), - images=chat.get("images"), - model_type=model_type, - vision_enabled=vision_enabled, + if not is_none_or_empty(chat.get("images")) and role == "assistant": + generated_assets["image"] = { + "query": chat.get("intent", {}).get("inferred-queries", [user_message])[0], + } + + if not is_none_or_empty(chat.get("excalidrawDiagram")) and role == "assistant": + generated_assets["diagram"] = { + "query": chat.get("intent", {}).get("inferred-queries", [user_message])[0], + } + + if not is_none_or_empty(generated_assets): + chatml_messages.append( + ChatMessage( + content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_assets))}\n", + role="user", + ) ) - chatml_messages.append(ChatMessage(content=file_attachment_message, role="user")) message_content = construct_structured_message( chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled @@ -464,33 +476,22 @@ def generate_chatml_messages_with_context( role="user", ) ) - if not is_none_or_empty(context_message): - messages.append(ChatMessage(content=context_message, role="user")) - - if generated_images: - messages.append( - ChatMessage( - content=construct_structured_message( - prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled - ), - role="user", - ) - ) if generated_files: message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files}) messages.append(ChatMessage(content=message_attached_files, role="assistant")) - if generated_excalidraw_diagram: - messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant")) + if not is_none_or_empty(generated_asset_results): + context_message += ( + f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n" + ) if program_execution_context: - messages.append( - ChatMessage( - content=prompts.additional_program_context.format(context="\n".join(program_execution_context)), - role="assistant", - ) - ) + program_context_text = "\n".join(program_execution_context) + context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n" + + if not is_none_or_empty(context_message): + messages.append(ChatMessage(content=context_message, role="user")) if len(chatml_messages) > 0: messages += chatml_messages diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 68d77475c..eea291d32 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -23,6 +23,7 @@ aget_user_name, ) from khoj.database.models import Agent, KhojUser +from khoj.processor.conversation import prompts from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log from khoj.processor.image.generate import text_to_image @@ -765,6 +766,7 @@ def collect_telemetry(): researched_results = "" online_results: Dict = dict() code_results: Dict = dict() + generated_asset_results: Dict = dict() ## Extract Document References compiled_references: List[Any] = [] inferred_queries: List[Any] = [] @@ -1128,6 +1130,10 @@ def collect_telemetry(): else: generated_images.append(generated_image) + generated_asset_results["images"] = { + "query": improved_image_prompt, + } + async for result in send_event( ChatEvent.GENERATED_ASSETS, { @@ -1166,6 +1172,10 @@ def collect_telemetry(): generated_excalidraw_diagram = diagram_description + generated_asset_results["diagrams"] = { + "query": better_diagram_description_prompt, + } + async for result in send_event( ChatEvent.GENERATED_ASSETS, { @@ -1176,7 +1186,9 @@ def collect_telemetry(): else: error_message = "Failed to generate diagram. Please try again later." program_execution_context.append( - f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt." + prompts.failed_diagram_generation.format( + attempted_diagram=better_diagram_description_prompt + ) ) async for result in send_event(ChatEvent.STATUS, error_message): @@ -1209,6 +1221,7 @@ def collect_telemetry(): generated_files, generated_excalidraw_diagram, program_execution_context, + generated_asset_results, tracer, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3ec701b7b..36a8d0085 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -88,7 +88,10 @@ converse_offline, send_message_to_model_offline, ) -from khoj.processor.conversation.openai.gpt import converse, send_message_to_model +from khoj.processor.conversation.openai.gpt import ( + converse_openai, + send_message_to_model, +) from khoj.processor.conversation.utils import ( ChatEvent, ThreadedGenerator, @@ -751,7 +754,7 @@ async def generate_excalidraw_diagram( ) except Exception as e: logger.error(f"Error generating Excalidraw diagram for {user.email}: {e}", exc_info=True) - yield None, None + yield better_diagram_description_prompt, None return scratchpad = excalidraw_diagram_description.get("scratchpad") @@ -1189,6 +1192,7 @@ def generate_chat_response( raw_generated_files: List[FileAttachment] = [], generated_excalidraw_diagram: str = None, program_execution_context: List[str] = [], + generated_asset_results: Dict[str, Dict] = {}, tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables @@ -1251,6 +1255,7 @@ def generate_chat_response( agent=agent, query_files=query_files, generated_files=raw_generated_files, + generated_asset_results=generated_asset_results, tracer=tracer, ) @@ -1258,7 +1263,7 @@ def generate_chat_response( openai_chat_config = conversation_config.ai_model_api api_key = openai_chat_config.api_key chat_model = conversation_config.chat_model - chat_response = converse( + chat_response = converse_openai( compiled_references, query_to_run, query_images=query_images, @@ -1278,8 +1283,7 @@ def generate_chat_response( vision_available=vision_available, query_files=query_files, generated_files=raw_generated_files, - generated_images=generated_images, - generated_excalidraw_diagram=generated_excalidraw_diagram, + generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, tracer=tracer, ) @@ -1305,8 +1309,7 @@ def generate_chat_response( vision_available=vision_available, query_files=query_files, generated_files=raw_generated_files, - generated_images=generated_images, - generated_excalidraw_diagram=generated_excalidraw_diagram, + generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, tracer=tracer, ) @@ -1331,8 +1334,7 @@ def generate_chat_response( vision_available=vision_available, query_files=query_files, generated_files=raw_generated_files, - generated_images=generated_images, - generated_excalidraw_diagram=generated_excalidraw_diagram, + generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, tracer=tracer, ) diff --git a/tests/test_online_chat_actors.py b/tests/test_online_chat_actors.py index 3a6db6f85..f873be450 100644 --- a/tests/test_online_chat_actors.py +++ b/tests/test_online_chat_actors.py @@ -4,7 +4,7 @@ import pytest from freezegun import freeze_time -from khoj.processor.conversation.openai.gpt import converse, extract_questions +from khoj.processor.conversation.openai.gpt import converse_openai, extract_questions from khoj.processor.conversation.utils import message_to_log from khoj.routers.helpers import ( aget_data_sources_and_output_format, @@ -158,7 +158,7 @@ def test_generate_search_query_using_question_and_answer_from_chat_history(): @pytest.mark.chatquality def test_chat_with_no_chat_history_or_retrieved_content(): # Act - response_gen = converse( + response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="Hello, my name is Testatron. Who are you?", api_key=api_key, @@ -183,7 +183,7 @@ def test_answer_from_chat_history_and_no_content(): ] # Act - response_gen = converse( + response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="What is my name?", conversation_log=populate_chat_history(message_list), @@ -214,7 +214,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(): ] # Act - response_gen = converse( + response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", conversation_log=populate_chat_history(message_list), @@ -239,7 +239,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(): ] # Act - response_gen = converse( + response_gen = converse_openai( references=[ {"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"} ], # Assume context retrieved from notes for the user_query @@ -265,7 +265,7 @@ def test_refuse_answering_unanswerable_question(): ] # Act - response_gen = converse( + response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", conversation_log=populate_chat_history(message_list), @@ -318,7 +318,7 @@ def test_answer_requires_current_date_awareness(): ] # Act - response_gen = converse( + response_gen = converse_openai( references=context, # Assume context retrieved from notes for the user_query user_query="What did I have for Dinner today?", api_key=api_key, @@ -362,7 +362,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(): ] # Act - response_gen = converse( + response_gen = converse_openai( references=context, # Assume context retrieved from notes for the user_query user_query="How much did I spend on dining this year?", api_key=api_key, @@ -386,7 +386,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(): ] # Act - response_gen = converse( + response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="Write a haiku about unit testing in 3 lines. Do not say anything else", conversation_log=populate_chat_history(message_list), @@ -426,7 +426,7 @@ def test_ask_for_clarification_if_not_enough_context_in_question(): ] # Act - response_gen = converse( + response_gen = converse_openai( references=context, # Assume context retrieved from notes for the user_query user_query="How many kids does my older sister have?", api_key=api_key, @@ -459,13 +459,13 @@ def test_agent_prompt_should_be_used(openai_agent): expected_responses = ["9.50", "9.5"] # Act - response_gen = converse( + response_gen = converse_openai( references=context, # Assume context retrieved from notes for the user_query user_query="What did I buy?", api_key=api_key, ) no_agent_response = "".join([response_chunk for response_chunk in response_gen]) - response_gen = converse( + response_gen = converse_openai( references=context, # Assume context retrieved from notes for the user_query user_query="What did I buy?", api_key=api_key,