Skip to content

Commit

Permalink
Have Khoj dynamically select conversation command(s) in chat (#641)
Browse files Browse the repository at this point in the history
* Have Khoj dynamically select which conversation command(s) are to be used in the chat flow
- Intercept the commands if in default mode, and have Khoj dynamically guess which tools would be the most relevant for answering the user's query
* Remove conditional for default to enter online search mode
* Add multiple-tool examples in the prompt, make prompt for tools more specific to info collection
  • Loading branch information
sabaimran authored Feb 11, 2024
1 parent 69344a6 commit a3eb17b
Show file tree
Hide file tree
Showing 10 changed files with 371 additions and 61 deletions.
2 changes: 1 addition & 1 deletion src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def save_conversation(
conversation_id: int = None,
user_message: str = None,
):
slug = user_message.strip()[:200] if not is_none_or_empty(user_message) else None
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
else:
Expand Down
23 changes: 10 additions & 13 deletions src/khoj/processor/conversation/offline/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def converse_offline(
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_command=ConversationCommand.Default,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]:
Expand All @@ -148,27 +148,24 @@ def converse_offline(
# Initialize Variables
compiled_references_message = "\n\n".join({f"{item}" for item in references})

conversation_primer = prompts.query_prompt.format(query=user_query)

# Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:

if ConversationCommand.Online in conversation_commands:
simplified_online_results = online_results.copy()
for result in online_results:
if online_results[result].get("extracted_content"):
simplified_online_results[result] = online_results[result]["extracted_content"]

conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(simplified_online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
conversation_primer = user_query
else:
conversation_primer = prompts.notes_conversation_gpt4all.format(
query=user_query, references=compiled_references_message
)
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
if ConversationCommand.Notes in conversation_commands:
conversation_primer = f"{prompts.notes_conversation_gpt4all.format(references=compiled_references_message)}\n{conversation_primer}"

# Setup Prompt with Primer or Conversation History
current_date = datetime.now().strftime("%Y-%m-%d")
Expand Down
21 changes: 10 additions & 11 deletions src/khoj/processor/conversation/openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def converse(
api_key: Optional[str] = None,
temperature: float = 0.2,
completion_func=None,
conversation_command=ConversationCommand.Default,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
):
Expand All @@ -133,26 +133,25 @@ def converse(
current_date = datetime.now().strftime("%Y-%m-%d")
compiled_references = "\n\n".join({f"# {item}" for item in references})

conversation_primer = prompts.query_prompt.format(query=user_query)

# Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:

if ConversationCommand.Online in conversation_commands:
simplified_online_results = online_results.copy()
for result in online_results:
if online_results[result].get("extracted_content"):
simplified_online_results[result] = online_results[result]["extracted_content"]

conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(simplified_online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
conversation_primer = prompts.general_conversation.format(query=user_query)
else:
conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
if ConversationCommand.Notes in conversation_commands:
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n{conversation_primer}"

# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
Expand Down
61 changes: 60 additions & 1 deletion src/khoj/processor/conversation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@
"""
User's Notes:
{references}
Question: {query}
""".strip()
)

Expand All @@ -139,7 +138,13 @@
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
Information from the internet: {online_results}
""".strip()
)

## Query prompt
## --
query_prompt = PromptTemplate.from_template(
"""
Query: {query}""".strip()
)

Expand Down Expand Up @@ -285,6 +290,60 @@
""".strip()
)

pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, a smart and helpful personal assistant. You have access to a variety of data sources to help you answer the user's question. You can use the data sources listed below to collect more relevant information. You can use any combination of these data sources to answer the user's question. Tell me which data sources you would like to use to answer the user's question.
{tools}
Here are some example responses:
Example 1:
Chat History:
User: I'm thinking of moving to a new city. I'm trying to decide between New York and San Francisco.
AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene.
Q: What is the population of each of those cities?
Khoj: ["online"]
Example 2:
Chat History:
User: I've been having a hard time at work. I'm thinking of quitting.
AI: I'm sorry to hear that. It's important to take care of your mental health. Have you considered talking to your manager about your concerns?
Q: What are the best ways to quit a job?
Khoj: ["general"]
Example 3:
Chat History:
User: I'm thinking of my next vacation idea. Ideally, I want to see something new and exciting.
AI: Excellent! Taking a vacation is a great way to relax and recharge.
Q: Where did Grandma grow up?
Khoj: ["notes"]
Example 4:
Chat History:
Q: I want to make chocolate cake. What was my recipe?
Khoj: ["notes"]
Example 5:
Chat History:
Q: What's the latest news with the first company I worked for?
Khoj: ["notes", "online"]
Now it's your turn to pick the tools you would like to use to answer the user's question. Provide your response as a list of strings.
Chat History:
{chat_history}
Q: {query}
A:
""".strip()
)

online_search_conversation_subqueries = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant. You are tasked with constructing **up to three** search queries for Google to answer the user's question.
Expand Down
4 changes: 2 additions & 2 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,15 @@ async def extract_references_and_questions(
q: str,
n: int,
d: float,
conversation_type: ConversationCommand = ConversationCommand.Default,
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
):
user = request.user.object if request.user.is_authenticated else None

# Initialize Variables
compiled_references: List[Any] = []
inferred_queries: List[str] = []

if conversation_type == ConversationCommand.General or conversation_type == ConversationCommand.Online:
if not ConversationCommand.Notes in conversation_commands:
return compiled_references, inferred_queries, q

if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
Expand Down
50 changes: 27 additions & 23 deletions src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CommonQueryParams,
ConversationCommandRateLimiter,
agenerate_chat_response,
aget_relevant_information_sources,
get_conversation_command,
is_ready_to_chat,
text_to_image,
Expand Down Expand Up @@ -207,7 +208,7 @@ async def set_conversation_title(
)


@api_chat.get("", response_class=Response)
@api_chat.get("/", response_class=Response)
@requires(["authenticated"])
async def chat(
request: Request,
Expand All @@ -229,41 +230,44 @@ async def chat(
q = unquote(q)

await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
conversation_commands = [get_conversation_command(query=q, any_references=True)]

await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)

q = q.replace(f"/{conversation_command.value}", "").strip()
if conversation_commands == [ConversationCommand.Help]:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)

meta_log = (
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
).conversation_log

if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)

for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()

compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands
)
online_results: Dict = dict()

if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General

elif conversation_command == ConversationCommand.Help:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)

elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)

elif conversation_command == ConversationCommand.Online:
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)

if ConversationCommand.Online in conversation_commands:
try:
online_results = await search_with_google(defiltered_query, meta_log)
except ValueError as e:
Expand All @@ -272,12 +276,12 @@ async def chat(
media_type="text/event-stream",
status_code=200,
)
elif conversation_command == ConversationCommand.Image:
elif conversation_commands == [ConversationCommand.Image]:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_command.value},
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(q, meta_log)
Expand Down Expand Up @@ -308,13 +312,13 @@ async def chat(
compiled_references,
online_results,
inferred_queries,
conversation_command,
conversation_commands,
user,
request.user.client_app,
conversation_id,
)

chat_metadata.update({"conversation_command": conversation_command.value})
chat_metadata.update({"conversation_command": ",".join([cmd.value for cmd in conversation_commands])})

update_telemetry_state(
request=request,
Expand Down
Loading

0 comments on commit a3eb17b

Please sign in to comment.