Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Modified default index naming in vector_store to use a unique identifier #9

Merged
merged 2 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions build/lib/lyzr/base/vector_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Optional, Sequence

import os
import uuid
import weaviate
from weaviate.embedded import EmbeddedOptions
from llama_index import Document, ServiceContext, VectorStoreIndex, StorageContext
from llama_index.node_parser import SimpleNodeParser

Expand All @@ -13,30 +17,43 @@ def import_vector_store_class(vector_store_class_name: str):
class LyzrVectorStoreIndex:
@staticmethod
def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
vector_store_type: str = "WeaviateVectorStore",
documents: Optional[Sequence[Document]] = None,
service_context: Optional[ServiceContext] = None,
**kwargs
) -> VectorStoreIndex:
if documents is None and vector_store_type == "SimpleVectorStore":
raise ValueError("documents must be provided for SimpleVectorStore")

vector_store_class = import_vector_store_class(vector_store_type)
VectorStoreClass = import_vector_store_class(vector_store_type)

if vector_store_type == "WeaviateVectorStore":
weaviate_client = weaviate.Client(
embedded_options=weaviate.embedded.EmbeddedOptions(),
additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]},
)
kwargs["weaviate_client"] = (
weaviate_client
if "weaviate_client" not in kwargs
else kwargs["weaviate_client"]
)
kwargs["index_name"] = (
f"DB_{uuid.uuid4().hex}" if "index_name" not in kwargs else kwargs["index_name"]
)

vector_store = VectorStoreClass(**kwargs)
else:
vector_store = VectorStoreClass(**kwargs)

if documents is None:
vector_store = vector_store_class(**kwargs)
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store, service_context=service_context
)
else:
if vector_store_type == "LanceDBVectorStore":
kwargs["uri"] = "./.lancedb" if "uri" not in kwargs else kwargs["uri"]
kwargs["table_name"] = (
"vectors" if "table_name" not in kwargs else kwargs["table_name"]
)
vector_store = vector_store_class(**kwargs)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
return index

storage_context = StorageContext.from_defaults(vector_store=vector_store)

if documents is not None:
index = VectorStoreIndex.from_documents(
documents=documents,
storage_context=storage_context,
Expand Down
12 changes: 12 additions & 0 deletions build/lib/lyzr/chatqa/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def pdf_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return pdf_chat_(
input_dir=input_dir,
Expand All @@ -57,6 +58,7 @@ def pdf_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -74,6 +76,7 @@ def docx_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return docx_chat_(
input_dir=input_dir,
Expand All @@ -89,6 +92,7 @@ def docx_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -106,6 +110,7 @@ def txt_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return txt_chat_(
input_dir=input_dir,
Expand All @@ -121,6 +126,7 @@ def txt_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -133,6 +139,7 @@ def webpage_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return webpage_chat_(
url=url,
Expand All @@ -143,6 +150,7 @@ def webpage_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -155,6 +163,7 @@ def website_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return website_chat_(
url=url,
Expand All @@ -165,6 +174,7 @@ def website_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -177,6 +187,7 @@ def youtube_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return youtube_chat_(
urls=urls,
Expand All @@ -187,4 +198,5 @@ def youtube_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)
12 changes: 12 additions & 0 deletions build/lib/lyzr/chatqa/qa_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def pdf_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return pdf_rag(
input_dir=input_dir,
Expand All @@ -56,6 +57,7 @@ def pdf_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -73,6 +75,7 @@ def docx_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return docx_rag(
input_dir=input_dir,
Expand All @@ -88,6 +91,7 @@ def docx_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -105,6 +109,7 @@ def txt_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return txt_rag(
input_dir=input_dir,
Expand All @@ -120,6 +125,7 @@ def txt_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -132,6 +138,7 @@ def webpage_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return webpage_rag(
url=url,
Expand All @@ -142,6 +149,7 @@ def webpage_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -154,6 +162,7 @@ def website_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return website_rag(
url=url,
Expand All @@ -164,6 +173,7 @@ def website_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -176,6 +186,7 @@ def youtube_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return youtube_rag(
urls=urls,
Expand All @@ -186,4 +197,5 @@ def youtube_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)
17 changes: 11 additions & 6 deletions build/lib/lyzr/utils/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def pdf_chat_(

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -106,6 +106,7 @@ def txt_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_txt_as_documents(
input_dir=input_dir,
Expand All @@ -118,7 +119,7 @@ def txt_chat_(

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -177,6 +178,7 @@ def docx_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_docx_as_documents(
input_dir=input_dir,
Expand All @@ -189,7 +191,7 @@ def docx_chat_(

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -243,14 +245,15 @@ def webpage_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_webpage_as_documents(
url=url,
)

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -304,14 +307,15 @@ def website_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_website_as_documents(
url=url,
)

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -365,14 +369,15 @@ def youtube_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_youtube_as_documents(
urls=urls,
)

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down
Loading