Skip to content

Commit

Permalink
Merge pull request #8 from LyzrCore/wip-khush/refactor-chatbot
Browse files Browse the repository at this point in the history
Enhanced chatbot and QA Bot with retrievers (query fusion) and reranker (mmr reranker) and Switch default vector store to Weaviate
  • Loading branch information
patel-lyzr authored Jan 12, 2024
2 parents bc32b7c + 0a31c76 commit 079d01f
Show file tree
Hide file tree
Showing 13 changed files with 545 additions and 84 deletions.
2 changes: 2 additions & 0 deletions build/lib/lyzr/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lyzr.base.llms import LLM, get_model
from lyzr.base.service import LyzrService
from lyzr.base.vector_store import LyzrVectorStoreIndex
from lyzr.base.retrievers import LyzrRetriever
from lyzr.base.prompt import Prompt

__all__ = [
Expand All @@ -14,4 +15,5 @@
"get_model",
"read_file",
"describe_dataset",
"LyzrRetriever",
]
44 changes: 44 additions & 0 deletions build/lib/lyzr/base/retrievers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Optional, Sequence
from llama_index.retrievers import BaseRetriever
from llama_index.indices import VectorStoreIndex


def import_retriever_class(retriever_class_name: str):
module = __import__("llama_index.retrievers", fromlist=[retriever_class_name])
class_ = getattr(module, retriever_class_name)
return class_


class LyzrRetriever:
@staticmethod
def from_defaults(
retriever_type: str = "QueryFusionRetriever",
base_index: VectorStoreIndex = None,
**kwargs
) -> BaseRetriever:
RetrieverClass = import_retriever_class(retriever_type)

if retriever_type == "QueryFusionRetriever":
print("QueryFusionRetriever")
retriever = RetrieverClass(
retrievers=[
base_index.as_retriever(
vector_store_query_mode="mmr",
similarity_top_k=3,
vector_store_kwargs={"mmr_threshold": 0.1},
),
base_index.as_retriever(
vector_store_query_mode="mmr",
similarity_top_k=3,
vector_store_kwargs={"mmr_threshold": 0.1},
),
],
similarity_top_k=5,
num_queries=2,
use_async=False,
**kwargs
)
return retriever
else:
retriever = RetrieverClass(**kwargs)
return retriever
22 changes: 11 additions & 11 deletions build/lib/lyzr/base/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from llama_index import Document, ServiceContext, VectorStoreIndex, StorageContext
from llama_index.node_parser import SimpleNodeParser


def import_vector_store_class(vector_store_class_name: str):
module = __import__("llama_index.vector_stores", fromlist=[vector_store_class_name])
class_ = getattr(module, vector_store_class_name)
Expand All @@ -12,12 +13,11 @@ def import_vector_store_class(vector_store_class_name: str):
class LyzrVectorStoreIndex:
@staticmethod
def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
documents: Optional[Sequence[Document]] = None,
service_context: Optional[ServiceContext] = None,
**kwargs
vector_store_type: str = "LanceDBVectorStore",
documents: Optional[Sequence[Document]] = None,
service_context: Optional[ServiceContext] = None,
**kwargs
) -> VectorStoreIndex:

if documents is None and vector_store_type == "SimpleVectorStore":
raise ValueError("documents must be provided for SimpleVectorStore")

Expand All @@ -36,12 +36,12 @@ def from_defaults(
)
vector_store = vector_store_class(**kwargs)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

index = VectorStoreIndex.from_documents(
documents=documents,
storage_context=storage_context,
service_context=service_context,
show_progress=True,
)
documents=documents,
storage_context=storage_context,
service_context=service_context,
show_progress=True,
)

return index
137 changes: 125 additions & 12 deletions build/lib/lyzr/utils/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

from llama_index.chat_engine.types import BaseChatEngine, ChatMode
from llama_index.embeddings.utils import EmbedType
from llama_index.chat_engine import ContextChatEngine
from llama_index.memory import ChatMemoryBuffer

from lyzr.base.llm import LyzrLLMFactory
from lyzr.base.service import LyzrService
from lyzr.base.vector_store import LyzrVectorStoreIndex
from lyzr.base.retrievers import LyzrRetriever

from lyzr.utils.document_reading import (
read_pdf_as_documents,
read_docx_as_documents,
Expand All @@ -30,6 +34,7 @@ def pdf_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_pdf_as_documents(
input_dir=input_dir,
Expand All @@ -51,6 +56,12 @@ def pdf_chat_(
)
chat_engine_params = {} if chat_engine_params is None else chat_engine_params

retriever_params = (
{"retriever_type": "QueryFusionRetriever"}
if retriever_params is None
else retriever_params
)

llm = LyzrLLMFactory.from_defaults(**llm_params)
service_context = LyzrService.from_defaults(
llm=llm,
Expand All @@ -64,10 +75,22 @@ def pdf_chat_(
**vector_store_params, documents=documents, service_context=service_context
)

return vector_store_index.as_chat_engine(
**chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5
retriever = LyzrRetriever.from_defaults(
**retriever_params, base_index=vector_store_index
)

memory = ChatMemoryBuffer.from_defaults(token_limit=4000)

chat_engine = ContextChatEngine(
llm=llm,
memory=memory,
retriever=retriever,
prefix_messages=list(),
**chat_engine_params,
)

return chat_engine


def txt_chat_(
input_dir: Optional[str] = None,
Expand Down Expand Up @@ -104,6 +127,12 @@ def txt_chat_(
)
chat_engine_params = {} if chat_engine_params is None else chat_engine_params

retriever_params = (
{"retriever_type": "QueryFusionRetriever"}
if retriever_params is None
else retriever_params
)

llm = LyzrLLMFactory.from_defaults(**llm_params)
service_context = LyzrService.from_defaults(
llm=llm,
Expand All @@ -117,10 +146,22 @@ def txt_chat_(
**vector_store_params, documents=documents, service_context=service_context
)

return vector_store_index.as_chat_engine(
**chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5
retriever = LyzrRetriever.from_defaults(
**retriever_params, base_index=vector_store_index
)

memory = ChatMemoryBuffer.from_defaults(token_limit=4000)

chat_engine = ContextChatEngine(
llm=llm,
memory=memory,
retriever=retriever,
prefix_messages=list(),
**chat_engine_params,
)

return chat_engine


def docx_chat_(
input_dir: Optional[str] = None,
Expand Down Expand Up @@ -157,6 +198,12 @@ def docx_chat_(
)
chat_engine_params = {} if chat_engine_params is None else chat_engine_params

retriever_params = (
{"retriever_type": "QueryFusionRetriever"}
if retriever_params is None
else retriever_params
)

llm = LyzrLLMFactory.from_defaults(**llm_params)
service_context = LyzrService.from_defaults(
llm=llm,
Expand All @@ -170,10 +217,22 @@ def docx_chat_(
**vector_store_params, documents=documents, service_context=service_context
)

return vector_store_index.as_chat_engine(
**chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5
retriever = LyzrRetriever.from_defaults(
**retriever_params, base_index=vector_store_index
)

memory = ChatMemoryBuffer.from_defaults(token_limit=4000)

chat_engine = ContextChatEngine(
llm=llm,
memory=memory,
retriever=retriever,
prefix_messages=list(),
**chat_engine_params,
)

return chat_engine


def webpage_chat_(
url: str = None,
Expand All @@ -200,6 +259,12 @@ def webpage_chat_(
)
chat_engine_params = {} if chat_engine_params is None else chat_engine_params

retriever_params = (
{"retriever_type": "QueryFusionRetriever"}
if retriever_params is None
else retriever_params
)

llm = LyzrLLMFactory.from_defaults(**llm_params)
service_context = LyzrService.from_defaults(
llm=llm,
Expand All @@ -213,10 +278,22 @@ def webpage_chat_(
**vector_store_params, documents=documents, service_context=service_context
)

return vector_store_index.as_chat_engine(
**chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5
retriever = LyzrRetriever.from_defaults(
**retriever_params, base_index=vector_store_index
)

memory = ChatMemoryBuffer.from_defaults(token_limit=4000)

chat_engine = ContextChatEngine(
llm=llm,
memory=memory,
retriever=retriever,
prefix_messages=list(),
**chat_engine_params,
)

return chat_engine


def website_chat_(
url: str = None,
Expand All @@ -243,6 +320,12 @@ def website_chat_(
)
chat_engine_params = {} if chat_engine_params is None else chat_engine_params

retriever_params = (
{"retriever_type": "QueryFusionRetriever"}
if retriever_params is None
else retriever_params
)

llm = LyzrLLMFactory.from_defaults(**llm_params)
service_context = LyzrService.from_defaults(
llm=llm,
Expand All @@ -256,10 +339,22 @@ def website_chat_(
**vector_store_params, documents=documents, service_context=service_context
)

return vector_store_index.as_chat_engine(
**chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5
retriever = LyzrRetriever.from_defaults(
**retriever_params, base_index=vector_store_index
)

memory = ChatMemoryBuffer.from_defaults(token_limit=4000)

chat_engine = ContextChatEngine(
llm=llm,
memory=memory,
retriever=retriever,
prefix_messages=list(),
**chat_engine_params,
)

return chat_engine


def youtube_chat_(
urls: List[str] = None,
Expand All @@ -286,6 +381,12 @@ def youtube_chat_(
)
chat_engine_params = {} if chat_engine_params is None else chat_engine_params

retriever_params = (
{"retriever_type": "QueryFusionRetriever"}
if retriever_params is None
else retriever_params
)

llm = LyzrLLMFactory.from_defaults(**llm_params)
service_context = LyzrService.from_defaults(
llm=llm,
Expand All @@ -299,6 +400,18 @@ def youtube_chat_(
**vector_store_params, documents=documents, service_context=service_context
)

return vector_store_index.as_chat_engine(
**chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5
retriever = LyzrRetriever.from_defaults(
**retriever_params, base_index=vector_store_index
)

memory = ChatMemoryBuffer.from_defaults(token_limit=4000)

chat_engine = ContextChatEngine(
llm=llm,
memory=memory,
retriever=retriever,
prefix_messages=list(),
**chat_engine_params,
)

return chat_engine
Loading

0 comments on commit 079d01f

Please sign in to comment.