Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Google embedding support & update setup #550

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

KH_ENABLE_FIRST_SETUP = True
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
KH_OLLAMA_URL = config("KH_OLLAMA_URL", default="http://localhost:11434/v1/")

# App can be ran from anywhere and it's not trivial to decide where to store app data.
# So let's use the same directory as the flowsetting.py file.
Expand Down Expand Up @@ -162,7 +163,7 @@
KH_LLMS["ollama"] = {
"spec": {
"__type__": "kotaemon.llms.ChatOpenAI",
"base_url": "http://localhost:11434/v1/",
"base_url": KH_OLLAMA_URL,
"model": config("LOCAL_MODEL", default="llama3.1:8b"),
"api_key": "ollama",
},
Expand All @@ -171,7 +172,7 @@
KH_EMBEDDINGS["ollama"] = {
"spec": {
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": "http://localhost:11434/v1/",
"base_url": KH_OLLAMA_URL,
"model": config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"),
"api_key": "ollama",
},
Expand All @@ -195,11 +196,11 @@
},
"default": False,
}
KH_LLMS["gemini"] = {
KH_LLMS["google"] = {
"spec": {
"__type__": "kotaemon.llms.chats.LCGeminiChat",
"model_name": "gemini-1.5-pro",
"api_key": "your-key",
"model_name": "gemini-1.5-flash",
"api_key": config("GOOGLE_API_KEY", default="your-key"),
},
"default": False,
}
Expand Down Expand Up @@ -231,6 +232,13 @@
},
"default": False,
}
KH_EMBEDDINGS["google"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
"model": "models/text-embedding-004",
"google_api_key": config("GOOGLE_API_KEY", default="your-key"),
}
}
# KH_EMBEDDINGS["huggingface"] = {
# "spec": {
# "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings",
Expand Down
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .langchain_based import (
LCAzureOpenAIEmbeddings,
LCCohereEmbeddings,
LCGoogleEmbeddings,
LCHuggingFaceEmbeddings,
LCOpenAIEmbeddings,
)
Expand All @@ -18,6 +19,7 @@
"LCAzureOpenAIEmbeddings",
"LCCohereEmbeddings",
"LCHuggingFaceEmbeddings",
"LCGoogleEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"FastEmbedEmbeddings",
Expand Down
35 changes: 35 additions & 0 deletions libs/kotaemon/kotaemon/embeddings/langchain_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,38 @@ def _get_lc_class(self):
from langchain.embeddings import HuggingFaceBgeEmbeddings

return HuggingFaceBgeEmbeddings


class LCGoogleEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Google GenAI embedding, focusing on key parameters"""

google_api_key: str = Param(
help="API key (https://aistudio.google.com/app/apikey)",
default=None,
required=True,
)
model: str = Param(
help="Model name to use (https://ai.google.dev/gemini-api/docs/models/gemini#text-embedding-and-embedding)", # noqa
default="models/text-embedding-004",
required=True,
)

def __init__(
self,
model: str = "models/text-embedding-004",
google_api_key: Optional[str] = None,
**params,
):
super().__init__(
model=model,
google_api_key=google_api_key,
**params,
)

def _get_lc_class(self):
try:
from langchain_google_genai import GoogleGenerativeAIEmbeddings
except ImportError:
raise ImportError("Please install langchain-google-genai")

return GoogleGenerativeAIEmbeddings
2 changes: 2 additions & 0 deletions libs/ktem/ktem/embeddings/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def load_vendors(self):
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCGoogleEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
TeiEndpointEmbeddings,
Expand All @@ -68,6 +69,7 @@ def load_vendors(self):
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
LCGoogleEmbeddings,
TeiEndpointEmbeddings,
]

Expand Down
67 changes: 57 additions & 10 deletions libs/ktem/ktem/pages/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from theflow.settings import settings as flowsettings

KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
DEFAULT_OLLAMA_URL = "http://localhost:11434/api"
KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/")
DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api")
if DEFAULT_OLLAMA_URL.endswith("/"):
DEFAULT_OLLAMA_URL = DEFAULT_OLLAMA_URL[:-1]


DEMO_MESSAGE = (
Expand Down Expand Up @@ -55,8 +58,9 @@ def on_building_ui(self):
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
self.radio_model = gr.Radio(
[
("Cohere API (*free registration* available) - recommended", "cohere"),
("OpenAI API (for more advance models)", "openai"),
("Cohere API (*free registration*) - recommended", "cohere"),
("Google API (*free registration*)", "google"),
("OpenAI API (for GPT-based models)", "openai"),
("Local LLM (for completely *private RAG*)", "ollama"),
],
label="Select your model provider",
Expand Down Expand Up @@ -92,6 +96,18 @@ def on_building_ui(self):
show_label=False, placeholder="Cohere API Key"
)

with gr.Column(visible=False) as self.google_option:
gr.Markdown(
(
"#### Google API Key\n\n"
"(register your free API key "
"at https://aistudio.google.com/app/apikey)"
)
)
self.google_api_key = gr.Textbox(
show_label=False, placeholder="Google API Key"
)

with gr.Column(visible=False) as self.ollama_option:
gr.Markdown(
(
Expand Down Expand Up @@ -119,7 +135,12 @@ def on_register_events(self):
self.openai_api_key.submit,
],
fn=self.update_model,
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model],
inputs=[
self.cohere_api_key,
self.openai_api_key,
self.google_api_key,
self.radio_model,
],
outputs=[self.setup_log],
show_progress="hidden",
)
Expand Down Expand Up @@ -147,13 +168,19 @@ def on_register_events(self):
fn=self.switch_options_view,
inputs=[self.radio_model],
show_progress="hidden",
outputs=[self.cohere_option, self.openai_option, self.ollama_option],
outputs=[
self.cohere_option,
self.openai_option,
self.ollama_option,
self.google_option,
],
)

def update_model(
self,
cohere_api_key,
openai_api_key,
google_api_key,
radio_model_value,
):
# skip if KH_DEMO_MODE
Expand Down Expand Up @@ -221,12 +248,32 @@ def update_model(
},
default=True,
)
elif radio_model_value == "google":
if google_api_key:
llms.update(
name="google",
spec={
"__type__": "kotaemon.llms.chats.LCGeminiChat",
"model_name": "gemini-1.5-flash",
"api_key": google_api_key,
},
default=True,
)
embeddings.update(
name="google",
spec={
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
"model": "models/text-embedding-004",
"google_api_key": google_api_key,
},
default=True,
)
elif radio_model_value == "ollama":
llms.update(
name="ollama",
spec={
"__type__": "kotaemon.llms.ChatOpenAI",
"base_url": "http://localhost:11434/v1/",
"base_url": KH_OLLAMA_URL,
"model": "llama3.1:8b",
"api_key": "ollama",
},
Expand All @@ -236,7 +283,7 @@ def update_model(
name="ollama",
spec={
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": "http://localhost:11434/v1/",
"base_url": KH_OLLAMA_URL,
"model": "nomic-embed-text",
"api_key": "ollama",
},
Expand Down Expand Up @@ -270,7 +317,7 @@ def update_model(
yield log_content
except Exception as e:
log_content += (
"Make sure you have download and installed Ollama correctly."
"Make sure you have download and installed Ollama correctly. "
f"Got error: {str(e)}"
)
yield log_content
Expand Down Expand Up @@ -345,9 +392,9 @@ def update_default_settings(self, radio_model_value, default_settings):
return default_settings

def switch_options_view(self, radio_model_value):
components_visible = [gr.update(visible=False) for _ in range(3)]
components_visible = [gr.update(visible=False) for _ in range(4)]

values = ["cohere", "openai", "ollama", None]
values = ["cohere", "openai", "ollama", "google", None]
assert radio_model_value in values, f"Invalid value {radio_model_value}"

if radio_model_value is not None:
Expand Down
Loading