Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Anthropic models #760

Merged
merged 5 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ dependencies = [
"pytz ~= 2024.1",
"cron-descriptor == 1.4.3",
"django_apscheduler == 0.6.2",
"anthropic == 0.26.1",
]
dynamic = ["version"]

Expand Down
4 changes: 3 additions & 1 deletion src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,9 @@ def get_valid_conversation_config(user: KhojUser, conversation: Conversation):

return conversation_config

if conversation_config.model_type == "openai" and conversation_config.openai_config:
if (
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
) and conversation_config.openai_config:
return conversation_config

else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Generated by Django 4.2.10 on 2024-05-26 12:35

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0042_serverchatsettings"),
]

operations = [
migrations.AlterField(
model_name="chatmodeloptions",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("offline", "Offline"), ("anthropic", "Anthropic")],
default="offline",
max_length=200,
),
),
]
1 change: 1 addition & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
ANTHROPIC = "anthropic"

max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
Expand Down
Empty file.
204 changes: 204 additions & 0 deletions src/khoj/processor/conversation/anthropic/anthropic_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import json
import logging
import re
from datetime import datetime, timedelta
from typing import Dict, Optional

from langchain.schema import ChatMessage

from khoj.database.models import Agent
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData

logger = logging.getLogger(__name__)


def extract_questions_anthropic(
text,
model: Optional[str] = "claude-instant-1.2",
conversation_log={},
api_key=None,
temperature=0,
max_tokens=100,
location_data: LocationData = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"

# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
f'Q: {chat["intent"]["query"]}\nKhoj: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
]
)

# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)

system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
)

prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history,
text=text,
)

messages = [ChatMessage(content=prompt, role="user")]

response = anthropic_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
max_tokens=max_tokens,
)

# Extract, Clean Message from Claude's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Claude returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Claude: {questions}")
return questions


def anthropic_send_message_to_model(messages, api_key, model):
"""
Send message to model
"""
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter
system_prompt = None

if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)

# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
)


def converse_anthropic(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "claude-instant-1.2",
api_key: Optional[str] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
):
"""
Converse with user using Anthropic's Claude
"""
# Initialize Variables
current_date = datetime.now().strftime("%Y-%m-%d")
compiled_references = "\n\n".join({f"# {item}" for item in references})

conversation_primer = prompts.query_prompt.format(query=user_query)

if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name, bio=agent.personality, current_date=current_date
)
else:
system_prompt = prompts.personality.format(current_date=current_date)

if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}"

if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"

# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])

if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"

# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
conversation_log=conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
)

for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)

truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Claude: {truncated_messages}")

# Get Response from Claude
return anthropic_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=0,
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
)
116 changes: 116 additions & 0 deletions src/khoj/processor/conversation/anthropic/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
from threading import Thread
from typing import Dict, List

import anthropic
from tenacity import (
before_sleep_log,
retry,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)

from khoj.processor.conversation.utils import ThreadedGenerator

logger = logging.getLogger(__name__)

anthropic_clients: Dict[str, anthropic.Anthropic] = {}


DEFAULT_MAX_TOKENS_ANTHROPIC = 3000


@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
anthropic_clients[api_key] = client
else:
client = anthropic_clients[api_key]

formatted_messages = [{"role": message.role, "content": message.content} for message in messages]

aggregated_response = ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC

model_kwargs = model_kwargs or dict()
if system_prompt:
model_kwargs["system"] = system_prompt

with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
timeout=20,
max_tokens=max_tokens,
**(model_kwargs),
) as stream:
for text in stream.text_stream:
aggregated_response += text

return aggregated_response


@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def anthropic_chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
api_key,
system_prompt,
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
)
t.start()
return g


def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
):
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
anthropic_clients[api_key] = client
else:
client: anthropic.Anthropic = anthropic_clients[api_key]

formatted_messages: List[anthropic.types.MessageParam] = [
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
]

max_prompt_size = max_prompt_size or DEFAULT_MAX_TOKENS_ANTHROPIC

with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
system=system_prompt,
timeout=20,
max_tokens=max_prompt_size,
**(model_kwargs or dict()),
) as stream:
for text in stream.text_stream:
g.send(text)

g.close()
Loading
Loading