Skip to content

Commit

Permalink
Modularize chat models initialization with a reusable function
Browse files Browse the repository at this point in the history
The chat model initialize interaction flow is fairly similar across
the chat model providers.

This should simplify adding new chat model providers and reduce
chances of bugs in the interactive chat model initialization flow.
  • Loading branch information
debanjum committed Sep 21, 2024
1 parent 27b99dc commit f4827b3
Showing 1 changed file with 100 additions and 126 deletions.
226 changes: 100 additions & 126 deletions src/khoj/utils/initialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Tuple

from khoj.database.adapters import ConversationAdapters
from khoj.database.models import (
Expand Down Expand Up @@ -41,41 +42,18 @@ def _create_chat_configuration():
"🗣️ Configure chat models available to your server. You can always update these at /server/admin using your admin account"
)

# Set up OpenAI's online models
default_openai_api_key = os.getenv("OPENAI_API_KEY")
default_use_openai_model = {True: "y", False: "n"}[default_openai_api_key != None]
use_model_provider = default_use_openai_model if not interactive else input("Add OpenAI models? (y/n): ")
if use_model_provider == "y":
logger.info("️💬 Setting up your OpenAI configuration")
if interactive:
user_api_key = input(f"Enter your OpenAI API key (default: {default_openai_api_key}): ")
api_key = user_api_key if user_api_key != "" else default_openai_api_key
else:
api_key = default_openai_api_key
chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="OpenAI")

if interactive:
chat_model_names = input(
f"Enter the OpenAI chat models you want to use (default: {','.join(default_openai_chat_models)}): "
)
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_openai_chat_models
chat_models = [model.strip() for model in chat_models]
else:
chat_models = default_openai_chat_models

# Add OpenAI chat models
for chat_model in chat_models:
vision_enabled = chat_model in ["gpt-4o-mini", "gpt-4o"]
default_max_tokens = model_to_prompt_size.get(chat_model)
ChatModelOptions.objects.create(
chat_model=chat_model,
model_type=ChatModelOptions.ModelType.OPENAI,
max_prompt_size=default_max_tokens,
openai_config=chat_model_provider,
vision_enabled=vision_enabled,
)
# Set up OpenAI's online chat models
openai_configured, openai_provider = _setup_chat_model_provider(
ChatModelOptions.ModelType.OPENAI,
default_openai_chat_models,
default_api_key=os.getenv("OPENAI_API_KEY"),
vision_enabled=True,
is_offline=False,
interactive=interactive,
)

# Add OpenAI speech to text model
# Setup OpenAI speech to text model
if openai_configured:
default_speech2text_model = "whisper-1"
if interactive:
openai_speech2text_model = input(
Expand All @@ -88,7 +66,8 @@ def _create_chat_configuration():
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
)

# Add OpenAI text to image model
# Setup OpenAI text to image model
if openai_configured:
default_text_to_image_model = "dall-e-3"
if interactive:
openai_text_to_image_model = input(
Expand All @@ -98,107 +77,44 @@ def _create_chat_configuration():
else:
openai_text_to_image_model = default_text_to_image_model
TextToImageModelConfig.objects.create(
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
model_name=openai_text_to_image_model,
model_type=TextToImageModelConfig.ModelType.OPENAI,
openai_config=openai_provider,
)

# Set up Google's Gemini online chat models
default_gemini_api_key = os.getenv("GEMINI_API_KEY")
default_use_gemini_model = {True: "y", False: "n"}[default_gemini_api_key != None]
use_model_provider = default_use_gemini_model if not interactive else input("Add Google's chat models? (y/n): ")
if use_model_provider == "y":
logger.info("️💬 Setting up your Google Gemini configuration")
if interactive:
user_api_key = input(f"Enter your Gemini API key (default: {default_gemini_api_key}): ")
api_key = user_api_key if user_api_key != "" else default_gemini_api_key
else:
api_key = default_gemini_api_key
chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="Gemini")

if interactive:
chat_model_names = input(
f"Enter the Gemini chat models you want to use (default: {','.join(default_gemini_chat_models)}): "
)
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_gemini_chat_models
chat_models = [model.strip() for model in chat_models]
else:
chat_models = default_gemini_chat_models

# Add Gemini chat models
for chat_model in chat_models:
default_max_tokens = model_to_prompt_size.get(chat_model)
vision_enabled = False
ChatModelOptions.objects.create(
chat_model=chat_model,
model_type=ChatModelOptions.ModelType.GOOGLE,
max_prompt_size=default_max_tokens,
openai_config=chat_model_provider,
vision_enabled=False,
)
_setup_chat_model_provider(
ChatModelOptions.ModelType.GOOGLE,
default_gemini_chat_models,
default_api_key=os.getenv("GEMINI_API_KEY"),
vision_enabled=False,
is_offline=False,
interactive=interactive,
provider_name="Google Gemini",
)

# Set up Anthropic's online chat models
default_anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
default_use_anthropic_model = {True: "y", False: "n"}[default_anthropic_api_key != None]
use_model_provider = (
default_use_anthropic_model if not interactive else input("Add Anthropic's chat models? (y/n): ")
_setup_chat_model_provider(
ChatModelOptions.ModelType.ANTHROPIC,
default_anthropic_chat_models,
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
vision_enabled=False,
is_offline=False,
interactive=interactive,
)
if use_model_provider == "y":
logger.info("️💬 Setting up your Anthropic configuration")
if interactive:
user_api_key = input(f"Enter your Anthropic API key (default: {default_anthropic_api_key}): ")
api_key = user_api_key if user_api_key != "" else default_anthropic_api_key
else:
api_key = default_anthropic_api_key
chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="Anthropic")

if interactive:
chat_model_names = input(
f"Enter the Anthropic chat models you want to use (default: {','.join(default_anthropic_chat_models)}): "
)
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_anthropic_chat_models
chat_models = [model.strip() for model in chat_models]
else:
chat_models = default_anthropic_chat_models

# Add Anthropic chat models
for chat_model in chat_models:
vision_enabled = False
default_max_tokens = model_to_prompt_size.get(chat_model)
ChatModelOptions.objects.create(
chat_model=chat_model,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
max_prompt_size=default_max_tokens,
openai_config=chat_model_provider,
vision_enabled=False,
)

# Set up offline chat models
use_model_provider = "y" if not interactive else input("Add Offline chat models? (y/n): ")
if use_model_provider == "y":
logger.info("️💬 Setting up Offline chat models")

if interactive:
chat_model_names = input(
f"Enter the offline chat models you want to use. See HuggingFace for available GGUF models (default: {','.join(default_offline_chat_models)}): "
)
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_offline_chat_models
chat_models = [model.strip() for model in chat_models]
else:
chat_models = default_offline_chat_models

# Add chat models
for chat_model in chat_models:
default_max_tokens = model_to_prompt_size.get(chat_model)
default_tokenizer = model_to_tokenizer.get(chat_model)
ChatModelOptions.objects.create(
chat_model=chat_model,
model_type=ChatModelOptions.ModelType.OFFLINE,
max_prompt_size=default_max_tokens,
tokenizer=default_tokenizer,
)

chat_models_configured = ChatModelOptions.objects.count()
_setup_chat_model_provider(
ChatModelOptions.ModelType.OFFLINE,
default_offline_chat_models,
default_api_key=None,
vision_enabled=False,
is_offline=True,
interactive=interactive,
)

# Explicitly set default chat model
chat_models_configured = ChatModelOptions.objects.count()
if chat_models_configured > 0:
default_chat_model_name = ChatModelOptions.objects.first().chat_model
# If there are multiple chat models, ask the user to choose the default chat model
Expand Down Expand Up @@ -236,6 +152,64 @@ def _create_chat_configuration():

logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")

def _setup_chat_model_provider(
model_type: ChatModelOptions.ModelType,
default_chat_models: list,
default_api_key: str,
interactive: bool,
vision_enabled: bool = False,
is_offline: bool = False,
provider_name: str = None,
) -> Tuple[bool, OpenAIProcessorConversationConfig]:
supported_vision_models = ["gpt-4o-mini", "gpt-4o"]
provider_name = provider_name or model_type.name.capitalize()
default_use_model = {True: "y", False: "n"}[default_api_key is not None or is_offline]
use_model_provider = (
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ")
)

if use_model_provider != "y":
return False, None

logger.info(f"️💬 Setting up your {provider_name} chat configuration")

chat_model_provider = None
if not is_offline:
if interactive:
user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
api_key = user_api_key if user_api_key != "" else default_api_key
else:
api_key = default_api_key
chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name=provider_name)

if interactive:
chat_model_names = input(
f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): "
)
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models
chat_models = [model.strip() for model in chat_models]
else:
chat_models = default_chat_models

for chat_model in chat_models:
default_max_tokens = model_to_prompt_size.get(chat_model)
default_tokenizer = model_to_tokenizer.get(chat_model)
vision_enabled = vision_enabled and chat_model in supported_vision_models

chat_model_options = {
"chat_model": chat_model,
"model_type": model_type,
"max_prompt_size": default_max_tokens,
"vision_enabled": vision_enabled,
"tokenizer": default_tokenizer,
"openai_config": chat_model_provider,
}

ChatModelOptions.objects.create(**chat_model_options)

logger.info(f"🗣️ {provider_name} chat model configuration complete")
return True, chat_model_provider

admin_user = KhojUser.objects.filter(is_staff=True).first()
if admin_user is None:
while True:
Expand Down

0 comments on commit f4827b3

Please sign in to comment.