Skip to content

Commit

Permalink
Add support for Anthropic models (#760)
Browse files Browse the repository at this point in the history
* Add support for chatting with Anthropic's suite of models

- Had to use a custom class because there was enough nuance with how the anthropic SDK works that it would be better to simply separate out the logic. The extract questions flow needed modification of the system prompt in order to work as intended with the haiku model
  • Loading branch information
sabaimran authored May 26, 2024
1 parent e292296 commit 01cdc54
Show file tree
Hide file tree
Showing 10 changed files with 454 additions and 5 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ dependencies = [
"pytz ~= 2024.1",
"cron-descriptor == 1.4.3",
"django_apscheduler == 0.6.2",
"anthropic == 0.26.1",
]
dynamic = ["version"]

Expand Down
4 changes: 3 additions & 1 deletion src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,9 @@ def get_valid_conversation_config(user: KhojUser, conversation: Conversation):

return conversation_config

if conversation_config.model_type == "openai" and conversation_config.openai_config:
if (
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
) and conversation_config.openai_config:
return conversation_config

else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Generated by Django 4.2.10 on 2024-05-26 12:35

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0042_serverchatsettings"),
]

operations = [
migrations.AlterField(
model_name="chatmodeloptions",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("offline", "Offline"), ("anthropic", "Anthropic")],
default="offline",
max_length=200,
),
),
]
1 change: 1 addition & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
ANTHROPIC = "anthropic"

max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
Expand Down
Empty file.
204 changes: 204 additions & 0 deletions src/khoj/processor/conversation/anthropic/anthropic_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import json
import logging
import re
from datetime import datetime, timedelta
from typing import Dict, Optional

from langchain.schema import ChatMessage

from khoj.database.models import Agent
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData

logger = logging.getLogger(__name__)


def extract_questions_anthropic(
text,
model: Optional[str] = "claude-instant-1.2",
conversation_log={},
api_key=None,
temperature=0,
max_tokens=100,
location_data: LocationData = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"

# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
f'Q: {chat["intent"]["query"]}\nKhoj: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
]
)

# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)

system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
)

prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history,
text=text,
)

messages = [ChatMessage(content=prompt, role="user")]

response = anthropic_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
max_tokens=max_tokens,
)

# Extract, Clean Message from Claude's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Claude returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Claude: {questions}")
return questions


def anthropic_send_message_to_model(messages, api_key, model):
"""
Send message to model
"""
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter
system_prompt = None

if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)

# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
)


def converse_anthropic(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "claude-instant-1.2",
api_key: Optional[str] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
):
"""
Converse with user using Anthropic's Claude
"""
# Initialize Variables
current_date = datetime.now().strftime("%Y-%m-%d")
compiled_references = "\n\n".join({f"# {item}" for item in references})

conversation_primer = prompts.query_prompt.format(query=user_query)

if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name, bio=agent.personality, current_date=current_date
)
else:
system_prompt = prompts.personality.format(current_date=current_date)

if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}"

if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"

# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])

if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"

# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
conversation_log=conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
)

for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)

truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Claude: {truncated_messages}")

# Get Response from Claude
return anthropic_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=0,
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
)
116 changes: 116 additions & 0 deletions src/khoj/processor/conversation/anthropic/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
from threading import Thread
from typing import Dict, List

import anthropic
from tenacity import (
before_sleep_log,
retry,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)

from khoj.processor.conversation.utils import ThreadedGenerator

logger = logging.getLogger(__name__)

anthropic_clients: Dict[str, anthropic.Anthropic] = {}


DEFAULT_MAX_TOKENS_ANTHROPIC = 3000


@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
anthropic_clients[api_key] = client
else:
client = anthropic_clients[api_key]

formatted_messages = [{"role": message.role, "content": message.content} for message in messages]

aggregated_response = ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC

model_kwargs = model_kwargs or dict()
if system_prompt:
model_kwargs["system"] = system_prompt

with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
timeout=20,
max_tokens=max_tokens,
**(model_kwargs),
) as stream:
for text in stream.text_stream:
aggregated_response += text

return aggregated_response


@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def anthropic_chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
api_key,
system_prompt,
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
)
t.start()
return g


def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
):
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
anthropic_clients[api_key] = client
else:
client: anthropic.Anthropic = anthropic_clients[api_key]

formatted_messages: List[anthropic.types.MessageParam] = [
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
]

max_prompt_size = max_prompt_size or DEFAULT_MAX_TOKENS_ANTHROPIC

with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
system=system_prompt,
timeout=20,
max_tokens=max_prompt_size,
**(model_kwargs or dict()),
) as stream:
for text in stream.text_stream:
g.send(text)

g.close()
Loading

0 comments on commit 01cdc54

Please sign in to comment.