From 0729b24eb400382c547dce985da9b9c367663dc3 Mon Sep 17 00:00:00 2001 From: patel Date: Thu, 11 Jan 2024 01:46:37 +0530 Subject: [PATCH 1/4] Enhanced chatbot with retrievers (query fusion) and reranker (mmr reranker) --- build/lib/lyzr/base/__init__.py | 2 + build/lib/lyzr/base/retrievers.py | 44 +++++++++ build/lib/lyzr/base/vector_store.py | 22 ++--- build/lib/lyzr/utils/chat_utils.py | 27 +++++- lyzr/base/__init__.py | 2 + lyzr/base/retrievers.py | 44 +++++++++ lyzr/base/vector_store.py | 22 ++--- lyzr/utils/chat_utils.py | 137 +++++++++++++++++++++++++--- 8 files changed, 264 insertions(+), 36 deletions(-) create mode 100644 build/lib/lyzr/base/retrievers.py create mode 100644 lyzr/base/retrievers.py diff --git a/build/lib/lyzr/base/__init__.py b/build/lib/lyzr/base/__init__.py index 5d2313c..b5e8499 100644 --- a/build/lib/lyzr/base/__init__.py +++ b/build/lib/lyzr/base/__init__.py @@ -3,6 +3,7 @@ from lyzr.base.llms import LLM, get_model from lyzr.base.service import LyzrService from lyzr.base.vector_store import LyzrVectorStoreIndex +from lyzr.base.retrievers import LyzrRetriever from lyzr.base.prompt import Prompt __all__ = [ @@ -14,4 +15,5 @@ "get_model", "read_file", "describe_dataset", + "LyzrRetriever", ] diff --git a/build/lib/lyzr/base/retrievers.py b/build/lib/lyzr/base/retrievers.py new file mode 100644 index 0000000..f1c8be6 --- /dev/null +++ b/build/lib/lyzr/base/retrievers.py @@ -0,0 +1,44 @@ +from typing import Optional, Sequence +from llama_index.retrievers import BaseRetriever +from llama_index.indices import VectorStoreIndex + + +def import_retriever_class(retriever_class_name: str): + module = __import__("llama_index.retrievers", fromlist=[retriever_class_name]) + class_ = getattr(module, retriever_class_name) + return class_ + + +class LyzrRetriever: + @staticmethod + def from_defaults( + retriever_type: str = "QueryFusionRetriever", + base_index: VectorStoreIndex = None, + **kwargs + ) -> BaseRetriever: + RetrieverClass = import_retriever_class(retriever_type) + + if retriever_type == "QueryFusionRetriever": + print("QueryFusionRetriever") + retriever = RetrieverClass( + retrievers=[ + base_index.as_retriever( + vector_store_query_mode="mmr", + similarity_top_k=3, + vector_store_kwargs={"mmr_threshold": 0.1}, + ), + base_index.as_retriever( + vector_store_query_mode="mmr", + similarity_top_k=3, + vector_store_kwargs={"mmr_threshold": 0.1}, + ), + ], + similarity_top_k=5, + num_queries=2, + use_async=False, + **kwargs + ) + return retriever + else: + retriever = RetrieverClass(**kwargs) + return retriever diff --git a/build/lib/lyzr/base/vector_store.py b/build/lib/lyzr/base/vector_store.py index a112123..1fca590 100644 --- a/build/lib/lyzr/base/vector_store.py +++ b/build/lib/lyzr/base/vector_store.py @@ -3,6 +3,7 @@ from llama_index import Document, ServiceContext, VectorStoreIndex, StorageContext from llama_index.node_parser import SimpleNodeParser + def import_vector_store_class(vector_store_class_name: str): module = __import__("llama_index.vector_stores", fromlist=[vector_store_class_name]) class_ = getattr(module, vector_store_class_name) @@ -12,12 +13,11 @@ def import_vector_store_class(vector_store_class_name: str): class LyzrVectorStoreIndex: @staticmethod def from_defaults( - vector_store_type: str = "LanceDBVectorStore", - documents: Optional[Sequence[Document]] = None, - service_context: Optional[ServiceContext] = None, - **kwargs + vector_store_type: str = "LanceDBVectorStore", + documents: Optional[Sequence[Document]] = None, + service_context: Optional[ServiceContext] = None, + **kwargs ) -> VectorStoreIndex: - if documents is None and vector_store_type == "SimpleVectorStore": raise ValueError("documents must be provided for SimpleVectorStore") @@ -36,12 +36,12 @@ def from_defaults( ) vector_store = vector_store_class(**kwargs) storage_context = StorageContext.from_defaults(vector_store=vector_store) - + index = VectorStoreIndex.from_documents( - documents=documents, - storage_context=storage_context, - service_context=service_context, - show_progress=True, - ) + documents=documents, + storage_context=storage_context, + service_context=service_context, + show_progress=True, + ) return index diff --git a/build/lib/lyzr/utils/chat_utils.py b/build/lib/lyzr/utils/chat_utils.py index bf26f7c..19ead30 100644 --- a/build/lib/lyzr/utils/chat_utils.py +++ b/build/lib/lyzr/utils/chat_utils.py @@ -2,10 +2,14 @@ from llama_index.chat_engine.types import BaseChatEngine, ChatMode from llama_index.embeddings.utils import EmbedType +from llama_index.chat_engine import ContextChatEngine +from llama_index.memory import ChatMemoryBuffer from lyzr.base.llm import LyzrLLMFactory from lyzr.base.service import LyzrService from lyzr.base.vector_store import LyzrVectorStoreIndex +from lyzr.base.retrievers import LyzrRetriever + from lyzr.utils.document_reading import ( read_pdf_as_documents, read_docx_as_documents, @@ -30,6 +34,7 @@ def pdf_chat_( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: documents = read_pdf_as_documents( input_dir=input_dir, @@ -51,6 +56,12 @@ def pdf_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -64,10 +75,22 @@ def pdf_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, + ) + + return chat_engine + def txt_chat_( input_dir: Optional[str] = None, diff --git a/lyzr/base/__init__.py b/lyzr/base/__init__.py index 5d2313c..b5e8499 100644 --- a/lyzr/base/__init__.py +++ b/lyzr/base/__init__.py @@ -3,6 +3,7 @@ from lyzr.base.llms import LLM, get_model from lyzr.base.service import LyzrService from lyzr.base.vector_store import LyzrVectorStoreIndex +from lyzr.base.retrievers import LyzrRetriever from lyzr.base.prompt import Prompt __all__ = [ @@ -14,4 +15,5 @@ "get_model", "read_file", "describe_dataset", + "LyzrRetriever", ] diff --git a/lyzr/base/retrievers.py b/lyzr/base/retrievers.py new file mode 100644 index 0000000..f1c8be6 --- /dev/null +++ b/lyzr/base/retrievers.py @@ -0,0 +1,44 @@ +from typing import Optional, Sequence +from llama_index.retrievers import BaseRetriever +from llama_index.indices import VectorStoreIndex + + +def import_retriever_class(retriever_class_name: str): + module = __import__("llama_index.retrievers", fromlist=[retriever_class_name]) + class_ = getattr(module, retriever_class_name) + return class_ + + +class LyzrRetriever: + @staticmethod + def from_defaults( + retriever_type: str = "QueryFusionRetriever", + base_index: VectorStoreIndex = None, + **kwargs + ) -> BaseRetriever: + RetrieverClass = import_retriever_class(retriever_type) + + if retriever_type == "QueryFusionRetriever": + print("QueryFusionRetriever") + retriever = RetrieverClass( + retrievers=[ + base_index.as_retriever( + vector_store_query_mode="mmr", + similarity_top_k=3, + vector_store_kwargs={"mmr_threshold": 0.1}, + ), + base_index.as_retriever( + vector_store_query_mode="mmr", + similarity_top_k=3, + vector_store_kwargs={"mmr_threshold": 0.1}, + ), + ], + similarity_top_k=5, + num_queries=2, + use_async=False, + **kwargs + ) + return retriever + else: + retriever = RetrieverClass(**kwargs) + return retriever diff --git a/lyzr/base/vector_store.py b/lyzr/base/vector_store.py index a112123..1fca590 100644 --- a/lyzr/base/vector_store.py +++ b/lyzr/base/vector_store.py @@ -3,6 +3,7 @@ from llama_index import Document, ServiceContext, VectorStoreIndex, StorageContext from llama_index.node_parser import SimpleNodeParser + def import_vector_store_class(vector_store_class_name: str): module = __import__("llama_index.vector_stores", fromlist=[vector_store_class_name]) class_ = getattr(module, vector_store_class_name) @@ -12,12 +13,11 @@ def import_vector_store_class(vector_store_class_name: str): class LyzrVectorStoreIndex: @staticmethod def from_defaults( - vector_store_type: str = "LanceDBVectorStore", - documents: Optional[Sequence[Document]] = None, - service_context: Optional[ServiceContext] = None, - **kwargs + vector_store_type: str = "LanceDBVectorStore", + documents: Optional[Sequence[Document]] = None, + service_context: Optional[ServiceContext] = None, + **kwargs ) -> VectorStoreIndex: - if documents is None and vector_store_type == "SimpleVectorStore": raise ValueError("documents must be provided for SimpleVectorStore") @@ -36,12 +36,12 @@ def from_defaults( ) vector_store = vector_store_class(**kwargs) storage_context = StorageContext.from_defaults(vector_store=vector_store) - + index = VectorStoreIndex.from_documents( - documents=documents, - storage_context=storage_context, - service_context=service_context, - show_progress=True, - ) + documents=documents, + storage_context=storage_context, + service_context=service_context, + show_progress=True, + ) return index diff --git a/lyzr/utils/chat_utils.py b/lyzr/utils/chat_utils.py index bf26f7c..86971b1 100644 --- a/lyzr/utils/chat_utils.py +++ b/lyzr/utils/chat_utils.py @@ -2,10 +2,14 @@ from llama_index.chat_engine.types import BaseChatEngine, ChatMode from llama_index.embeddings.utils import EmbedType +from llama_index.chat_engine import ContextChatEngine +from llama_index.memory import ChatMemoryBuffer from lyzr.base.llm import LyzrLLMFactory from lyzr.base.service import LyzrService from lyzr.base.vector_store import LyzrVectorStoreIndex +from lyzr.base.retrievers import LyzrRetriever + from lyzr.utils.document_reading import ( read_pdf_as_documents, read_docx_as_documents, @@ -30,6 +34,7 @@ def pdf_chat_( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: documents = read_pdf_as_documents( input_dir=input_dir, @@ -51,6 +56,12 @@ def pdf_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -64,10 +75,22 @@ def pdf_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, + ) + + return chat_engine + def txt_chat_( input_dir: Optional[str] = None, @@ -104,6 +127,12 @@ def txt_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -117,10 +146,22 @@ def txt_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, ) + return chat_engine + def docx_chat_( input_dir: Optional[str] = None, @@ -157,6 +198,12 @@ def docx_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -170,10 +217,22 @@ def docx_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, + ) + + return chat_engine + def webpage_chat_( url: str = None, @@ -200,6 +259,12 @@ def webpage_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -213,10 +278,22 @@ def webpage_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, ) + return chat_engine + def website_chat_( url: str = None, @@ -243,6 +320,12 @@ def website_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -256,10 +339,22 @@ def website_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, + ) + + return chat_engine + def youtube_chat_( urls: List[str] = None, @@ -286,6 +381,12 @@ def youtube_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -299,6 +400,18 @@ def youtube_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) + + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, + ) + + return chat_engine From 56726630fbebc504449080243579aba4e2faced9 Mon Sep 17 00:00:00 2001 From: patel Date: Thu, 11 Jan 2024 16:54:38 +0530 Subject: [PATCH 2/4] Enhanced QAbot with retrievers and rerankers, fixed Chatbot params --- build/lib/lyzr/utils/chat_utils.py | 110 ++++++++++++++++++++++++++--- build/lib/lyzr/utils/rag_utils.py | 23 ++++-- lyzr/chatqa/chatbot.py | 12 ++++ lyzr/chatqa/qa_bot.py | 12 ++++ lyzr/utils/chat_utils.py | 5 ++ lyzr/utils/rag_utils.py | 102 +++++++++++++++++++++++--- 6 files changed, 240 insertions(+), 24 deletions(-) diff --git a/build/lib/lyzr/utils/chat_utils.py b/build/lib/lyzr/utils/chat_utils.py index 19ead30..86971b1 100644 --- a/build/lib/lyzr/utils/chat_utils.py +++ b/build/lib/lyzr/utils/chat_utils.py @@ -127,6 +127,12 @@ def txt_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -140,10 +146,22 @@ def txt_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, + ) + + return chat_engine + def docx_chat_( input_dir: Optional[str] = None, @@ -180,6 +198,12 @@ def docx_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -193,10 +217,22 @@ def docx_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, + ) + + return chat_engine + def webpage_chat_( url: str = None, @@ -223,6 +259,12 @@ def webpage_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -236,10 +278,22 @@ def webpage_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, ) + return chat_engine + def website_chat_( url: str = None, @@ -266,6 +320,12 @@ def website_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -279,10 +339,22 @@ def website_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, ) + return chat_engine + def youtube_chat_( urls: List[str] = None, @@ -309,6 +381,12 @@ def youtube_chat_( ) chat_engine_params = {} if chat_engine_params is None else chat_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( llm=llm, @@ -322,6 +400,18 @@ def youtube_chat_( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_chat_engine( - **chat_engine_params, chat_mode=ChatMode.CONTEXT, similarity_top_k=5 + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) + + memory = ChatMemoryBuffer.from_defaults(token_limit=4000) + + chat_engine = ContextChatEngine( + llm=llm, + memory=memory, + retriever=retriever, + prefix_messages=list(), + **chat_engine_params, + ) + + return chat_engine diff --git a/build/lib/lyzr/utils/rag_utils.py b/build/lib/lyzr/utils/rag_utils.py index c26aa33..cd69156 100644 --- a/build/lib/lyzr/utils/rag_utils.py +++ b/build/lib/lyzr/utils/rag_utils.py @@ -2,8 +2,10 @@ from llama_index.embeddings.utils import EmbedType from llama_index.indices.query.base import BaseQueryEngine +from llama_index.query_engine import RetrieverQueryEngine from lyzr.base.llm import LyzrLLMFactory +from lyzr.base.retrievers import LyzrRetriever from lyzr.base.service import LyzrService from lyzr.base.vector_store import LyzrVectorStoreIndex from lyzr.utils.document_reading import ( @@ -30,6 +32,7 @@ def pdf_rag( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: documents = read_pdf_as_documents( input_dir=input_dir, @@ -49,7 +52,13 @@ def pdf_rag( service_context_params = ( {} if service_context_params is None else service_context_params ) - query_engine_params = {} if query_engine_params is None else query_engine_params + chat_engine_params = {} if chat_engine_params is None else chat_engine_params + + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) llm = LyzrLLMFactory.from_defaults(**llm_params) service_context = LyzrService.from_defaults( @@ -61,12 +70,16 @@ def pdf_rag( ) vector_store_index = LyzrVectorStoreIndex.from_defaults( - **vector_store_params, - documents=documents, - service_context=service_context, + **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_query_engine(**query_engine_params, similarity_top_k=5) + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + query_engine = RetrieverQueryEngine.from_args(retriever, query_engine_params) + + return query_engine def txt_rag( diff --git a/lyzr/chatqa/chatbot.py b/lyzr/chatqa/chatbot.py index ba30750..2caed78 100644 --- a/lyzr/chatqa/chatbot.py +++ b/lyzr/chatqa/chatbot.py @@ -42,6 +42,7 @@ def pdf_chat( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: return pdf_chat_( input_dir=input_dir, @@ -57,6 +58,7 @@ def pdf_chat( vector_store_params=vector_store_params, service_context_params=service_context_params, chat_engine_params=chat_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -74,6 +76,7 @@ def docx_chat( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: return docx_chat_( input_dir=input_dir, @@ -89,6 +92,7 @@ def docx_chat( vector_store_params=vector_store_params, service_context_params=service_context_params, chat_engine_params=chat_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -106,6 +110,7 @@ def txt_chat( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: return txt_chat_( input_dir=input_dir, @@ -121,6 +126,7 @@ def txt_chat( vector_store_params=vector_store_params, service_context_params=service_context_params, chat_engine_params=chat_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -133,6 +139,7 @@ def webpage_chat( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: return webpage_chat_( url=url, @@ -143,6 +150,7 @@ def webpage_chat( vector_store_params=vector_store_params, service_context_params=service_context_params, chat_engine_params=chat_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -155,6 +163,7 @@ def website_chat( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: return website_chat_( url=url, @@ -165,6 +174,7 @@ def website_chat( vector_store_params=vector_store_params, service_context_params=service_context_params, chat_engine_params=chat_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -177,6 +187,7 @@ def youtube_chat( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: return youtube_chat_( urls=urls, @@ -187,4 +198,5 @@ def youtube_chat( vector_store_params=vector_store_params, service_context_params=service_context_params, chat_engine_params=chat_engine_params, + retriever_params=retriever_params, ) diff --git a/lyzr/chatqa/qa_bot.py b/lyzr/chatqa/qa_bot.py index 14a117a..2b2f145 100644 --- a/lyzr/chatqa/qa_bot.py +++ b/lyzr/chatqa/qa_bot.py @@ -41,6 +41,7 @@ def pdf_qa( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: return pdf_rag( input_dir=input_dir, @@ -56,6 +57,7 @@ def pdf_qa( vector_store_params=vector_store_params, service_context_params=service_context_params, query_engine_params=query_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -73,6 +75,7 @@ def docx_qa( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: return docx_rag( input_dir=input_dir, @@ -88,6 +91,7 @@ def docx_qa( vector_store_params=vector_store_params, service_context_params=service_context_params, query_engine_params=query_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -105,6 +109,7 @@ def txt_qa( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: return txt_rag( input_dir=input_dir, @@ -120,6 +125,7 @@ def txt_qa( vector_store_params=vector_store_params, service_context_params=service_context_params, query_engine_params=query_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -132,6 +138,7 @@ def webpage_qa( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: return webpage_rag( url=url, @@ -142,6 +149,7 @@ def webpage_qa( vector_store_params=vector_store_params, service_context_params=service_context_params, query_engine_params=query_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -154,6 +162,7 @@ def website_qa( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: return website_rag( url=url, @@ -164,6 +173,7 @@ def website_qa( vector_store_params=vector_store_params, service_context_params=service_context_params, query_engine_params=query_engine_params, + retriever_params=retriever_params, ) @staticmethod @@ -176,6 +186,7 @@ def youtube_qa( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: return youtube_rag( urls=urls, @@ -186,4 +197,5 @@ def youtube_qa( vector_store_params=vector_store_params, service_context_params=service_context_params, query_engine_params=query_engine_params, + retriever_params=retriever_params, ) diff --git a/lyzr/utils/chat_utils.py b/lyzr/utils/chat_utils.py index 86971b1..7a24e37 100644 --- a/lyzr/utils/chat_utils.py +++ b/lyzr/utils/chat_utils.py @@ -106,6 +106,7 @@ def txt_chat_( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: documents = read_txt_as_documents( input_dir=input_dir, @@ -177,6 +178,7 @@ def docx_chat_( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: documents = read_docx_as_documents( input_dir=input_dir, @@ -243,6 +245,7 @@ def webpage_chat_( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: documents = read_webpage_as_documents( url=url, @@ -304,6 +307,7 @@ def website_chat_( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: documents = read_website_as_documents( url=url, @@ -365,6 +369,7 @@ def youtube_chat_( vector_store_params: dict = None, service_context_params: dict = None, chat_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseChatEngine: documents = read_youtube_as_documents( urls=urls, diff --git a/lyzr/utils/rag_utils.py b/lyzr/utils/rag_utils.py index c26aa33..cb75ecd 100644 --- a/lyzr/utils/rag_utils.py +++ b/lyzr/utils/rag_utils.py @@ -2,8 +2,10 @@ from llama_index.embeddings.utils import EmbedType from llama_index.indices.query.base import BaseQueryEngine +from llama_index.query_engine import RetrieverQueryEngine from lyzr.base.llm import LyzrLLMFactory +from lyzr.base.retrievers import LyzrRetriever from lyzr.base.service import LyzrService from lyzr.base.vector_store import LyzrVectorStoreIndex from lyzr.utils.document_reading import ( @@ -30,6 +32,7 @@ def pdf_rag( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: documents = read_pdf_as_documents( input_dir=input_dir, @@ -51,7 +54,14 @@ def pdf_rag( ) query_engine_params = {} if query_engine_params is None else query_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) + service_context = LyzrService.from_defaults( llm=llm, embed_model=embed_model, @@ -61,12 +71,16 @@ def pdf_rag( ) vector_store_index = LyzrVectorStoreIndex.from_defaults( - **vector_store_params, - documents=documents, - service_context=service_context, + **vector_store_params, documents=documents, service_context=service_context + ) + + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index ) - return vector_store_index.as_query_engine(**query_engine_params, similarity_top_k=5) + query_engine = RetrieverQueryEngine.from_args(retriever, query_engine_params) + + return query_engine def txt_rag( @@ -83,6 +97,7 @@ def txt_rag( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: documents = read_txt_as_documents( input_dir=input_dir, @@ -104,7 +119,14 @@ def txt_rag( ) query_engine_params = {} if query_engine_params is None else query_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) + service_context = LyzrService.from_defaults( llm=llm, embed_model=embed_model, @@ -117,7 +139,13 @@ def txt_rag( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_query_engine(**query_engine_params, similarity_top_k=5) + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + query_engine = RetrieverQueryEngine.from_args(retriever, query_engine_params) + + return query_engine def docx_rag( @@ -134,6 +162,7 @@ def docx_rag( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: documents = read_docx_as_documents( input_dir=input_dir, @@ -155,7 +184,14 @@ def docx_rag( ) query_engine_params = {} if query_engine_params is None else query_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) + service_context = LyzrService.from_defaults( llm=llm, embed_model=embed_model, @@ -168,7 +204,13 @@ def docx_rag( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_query_engine(**query_engine_params, similarity_top_k=5) + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + query_engine = RetrieverQueryEngine.from_args(retriever, query_engine_params) + + return query_engine def webpage_rag( @@ -180,6 +222,7 @@ def webpage_rag( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: documents = read_webpage_as_documents( url=url, @@ -196,7 +239,14 @@ def webpage_rag( ) query_engine_params = {} if query_engine_params is None else query_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) + service_context = LyzrService.from_defaults( llm=llm, embed_model=embed_model, @@ -209,7 +259,13 @@ def webpage_rag( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_query_engine(**query_engine_params, similarity_top_k=5) + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + query_engine = RetrieverQueryEngine.from_args(retriever, query_engine_params) + + return query_engine def website_rag( @@ -221,6 +277,7 @@ def website_rag( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: documents = read_website_as_documents( url=url, @@ -237,7 +294,14 @@ def website_rag( ) query_engine_params = {} if query_engine_params is None else query_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) + service_context = LyzrService.from_defaults( llm=llm, embed_model=embed_model, @@ -250,7 +314,13 @@ def website_rag( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_query_engine(**query_engine_params, similarity_top_k=5) + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + query_engine = RetrieverQueryEngine.from_args(retriever, query_engine_params) + + return query_engine def youtube_rag( @@ -262,6 +332,7 @@ def youtube_rag( vector_store_params: dict = None, service_context_params: dict = None, query_engine_params: dict = None, + retriever_params: dict = None, ) -> BaseQueryEngine: documents = read_youtube_as_documents( urls=urls, @@ -278,7 +349,14 @@ def youtube_rag( ) query_engine_params = {} if query_engine_params is None else query_engine_params + retriever_params = ( + {"retriever_type": "QueryFusionRetriever"} + if retriever_params is None + else retriever_params + ) + llm = LyzrLLMFactory.from_defaults(**llm_params) + service_context = LyzrService.from_defaults( llm=llm, embed_model=embed_model, @@ -291,4 +369,10 @@ def youtube_rag( **vector_store_params, documents=documents, service_context=service_context ) - return vector_store_index.as_query_engine(**query_engine_params, similarity_top_k=5) + retriever = LyzrRetriever.from_defaults( + **retriever_params, base_index=vector_store_index + ) + + query_engine = RetrieverQueryEngine.from_args(retriever, query_engine_params) + + return query_engine From 35440e7fe7f5617588d04d01c6f88148711452a2 Mon Sep 17 00:00:00 2001 From: patel Date: Thu, 11 Jan 2024 17:02:27 +0530 Subject: [PATCH 3/4] Switch default vector store to Weaviate --- lyzr/base/vector_store.py | 39 ++++++++++++++++++++++++++++----------- lyzr/utils/chat_utils.py | 12 ++++++------ lyzr/utils/rag_utils.py | 12 ++++++------ setup.py | 2 +- 4 files changed, 41 insertions(+), 24 deletions(-) diff --git a/lyzr/base/vector_store.py b/lyzr/base/vector_store.py index 1fca590..6ad72bd 100644 --- a/lyzr/base/vector_store.py +++ b/lyzr/base/vector_store.py @@ -1,5 +1,9 @@ from typing import Optional, Sequence +import os +import uuid +import weaviate +from weaviate.embedded import EmbeddedOptions from llama_index import Document, ServiceContext, VectorStoreIndex, StorageContext from llama_index.node_parser import SimpleNodeParser @@ -13,7 +17,7 @@ def import_vector_store_class(vector_store_class_name: str): class LyzrVectorStoreIndex: @staticmethod def from_defaults( - vector_store_type: str = "LanceDBVectorStore", + vector_store_type: str = "WeaviateVectorStore", documents: Optional[Sequence[Document]] = None, service_context: Optional[ServiceContext] = None, **kwargs @@ -21,22 +25,35 @@ def from_defaults( if documents is None and vector_store_type == "SimpleVectorStore": raise ValueError("documents must be provided for SimpleVectorStore") - vector_store_class = import_vector_store_class(vector_store_type) + VectorStoreClass = import_vector_store_class(vector_store_type) + + if vector_store_type == "WeaviateVectorStore": + weaviate_client = weaviate.Client( + embedded_options=weaviate.embedded.EmbeddedOptions(), + additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]}, + ) + kwargs["weaviate_client"] = ( + weaviate_client + if "weaviate_client" not in kwargs + else kwargs["weaviate_client"] + ) + kwargs["index_name"] = ( + uuid if "index_name" not in kwargs else kwargs["index_name"] + ) + + vector_store = VectorStoreClass(**kwargs) + else: + vector_store = VectorStoreClass(**kwargs) if documents is None: - vector_store = vector_store_class(**kwargs) index = VectorStoreIndex.from_vector_store( vector_store=vector_store, service_context=service_context ) - else: - if vector_store_type == "LanceDBVectorStore": - kwargs["uri"] = "./.lancedb" if "uri" not in kwargs else kwargs["uri"] - kwargs["table_name"] = ( - "vectors" if "table_name" not in kwargs else kwargs["table_name"] - ) - vector_store = vector_store_class(**kwargs) - storage_context = StorageContext.from_defaults(vector_store=vector_store) + return index + + storage_context = StorageContext.from_defaults(vector_store=vector_store) + if documents is not None: index = VectorStoreIndex.from_documents( documents=documents, storage_context=storage_context, diff --git a/lyzr/utils/chat_utils.py b/lyzr/utils/chat_utils.py index 7a24e37..50d4e82 100644 --- a/lyzr/utils/chat_utils.py +++ b/lyzr/utils/chat_utils.py @@ -47,7 +47,7 @@ def pdf_chat_( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -119,7 +119,7 @@ def txt_chat_( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -191,7 +191,7 @@ def docx_chat_( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -253,7 +253,7 @@ def webpage_chat_( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -315,7 +315,7 @@ def website_chat_( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -377,7 +377,7 @@ def youtube_chat_( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) diff --git a/lyzr/utils/rag_utils.py b/lyzr/utils/rag_utils.py index cb75ecd..28b59e2 100644 --- a/lyzr/utils/rag_utils.py +++ b/lyzr/utils/rag_utils.py @@ -45,7 +45,7 @@ def pdf_rag( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -110,7 +110,7 @@ def txt_rag( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -175,7 +175,7 @@ def docx_rag( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -230,7 +230,7 @@ def webpage_rag( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -285,7 +285,7 @@ def website_rag( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) @@ -340,7 +340,7 @@ def youtube_rag( llm_params = {} if llm_params is None else llm_params vector_store_params = ( - {"vector_store_type": "LanceDBVectorStore"} + {"vector_store_type": "WeaviateVectorStore"} if vector_store_params is None else vector_store_params ) diff --git a/setup.py b/setup.py index daf9acd..e33fe3b 100644 --- a/setup.py +++ b/setup.py @@ -24,9 +24,9 @@ "llama-index==0.9.4", "langchain==0.0.339", "python-dotenv>=1.0.0", - "lancedb==0.3.3", "beautifulsoup4==4.12.2", "pandas==2.0.2", "matplotlib==3.8.2", + "weaviate-client==3.25.3", ], ) From 0a31c76bafffd0f75e4b184befc2374b08e6f186 Mon Sep 17 00:00:00 2001 From: patel Date: Fri, 12 Jan 2024 14:14:04 +0530 Subject: [PATCH 4/4] bump version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e33fe3b..aae7724 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="lyzr", - version="0.1.21", + version="0.1.22", author="lyzr", description="", long_description=open("README.md").read(),