Skip to content

Commit

Permalink
Merge pull request #14 from Azure-Samples/add-keyword-search
Browse files Browse the repository at this point in the history
Add keyword search
  • Loading branch information
john0isaac authored Jun 9, 2024
2 parents 16d00a5 + 8f2f372 commit 86a4adc
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 95 deletions.
3 changes: 3 additions & 0 deletions src/quartapp/approaches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
from langchain_core.documents import Document
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from pymongo.collection import Collection


class ApproachesBase(ABC):
Expand All @@ -11,10 +12,12 @@ def __init__(
vector_store: AzureCosmosDBVectorSearch,
embedding: AzureOpenAIEmbeddings,
chat: AzureChatOpenAI,
data_collection: Collection | None,
):
self._vector_store = vector_store
self._embedding = embedding
self._chat = chat
self._data_collection = data_collection

@abstractmethod
async def run(
Expand Down
20 changes: 20 additions & 0 deletions src/quartapp/approaches/keyword.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from langchain_core.documents import Document

from quartapp.approaches.base import ApproachesBase


class KeyWord(ApproachesBase):
async def run(
self, messages: list, temperature: float, limit: int, score_threshold: float
) -> tuple[list[Document], str]:
if messages and self._data_collection:
query = messages[-1]["content"]
keyword_response = self._data_collection.find({"$text": {"$search": query}}).limit(limit)
documents_list: list[Document] = []
if keyword_response:
for document in keyword_response:
documents_list.append(
Document(page_content=document["textContent"], metadata={"source": document["source"]})
)
return documents_list, documents_list[0].page_content
return [], ""
22 changes: 21 additions & 1 deletion src/quartapp/approaches/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@
from pydantic.v1 import SecretStr
from pymongo.collection import Collection

from quartapp.approaches.keyword import KeyWord
from quartapp.approaches.rag import RAG
from quartapp.approaches.utils import chat_api, embeddings_api, setup_users_collection, vector_store_api
from quartapp.approaches.utils import (
chat_api,
embeddings_api,
setup_data_collection,
setup_users_collection,
vector_store_api,
)
from quartapp.approaches.vector import Vector


Expand All @@ -29,13 +36,15 @@ def __init__(
index_name: str,
vector_store_api: AzureCosmosDBVectorSearch,
users_collection: Collection,
data_collection: Collection | None,
):
self._connection_string = connection_string
self._database_name = database_name
self._collection_name = collection_name
self._index_name = index_name
self._vector_store_api = vector_store_api
self._users_collection = users_collection
self._data_collection = data_collection


class Setup(ABC):
Expand Down Expand Up @@ -80,15 +89,26 @@ def __init__(
embedding=self._openai_setup._embeddings_api,
),
users_collection=setup_users_collection(connection_string=connection_string, database_name=database_name),
data_collection=setup_data_collection(
connection_string=connection_string, database_name=database_name, collection_name=collection_name
),
)

self.vector_search = Vector(
vector_store=self._database_setup._vector_store_api,
embedding=self._openai_setup._embeddings_api,
chat=self._openai_setup._chat_api,
data_collection=self._database_setup._data_collection,
)
self.rag = RAG(
vector_store=self._database_setup._vector_store_api,
embedding=self._openai_setup._embeddings_api,
chat=self._openai_setup._chat_api,
data_collection=self._database_setup._data_collection,
)
self.keyword = KeyWord(
vector_store=self._database_setup._vector_store_api,
embedding=self._openai_setup._embeddings_api,
chat=self._openai_setup._chat_api,
data_collection=self._database_setup._data_collection,
)
12 changes: 12 additions & 0 deletions src/quartapp/approaches/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic.v1 import SecretStr
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.errors import ServerSelectionTimeoutError


def embeddings_api(
Expand Down Expand Up @@ -48,3 +49,14 @@ def setup_users_collection(connection_string: str, database_name: str) -> Collec
db = mongo_client[database_name]
collection: Collection = db["Users"]
return collection


def setup_data_collection(connection_string: str, database_name: str, collection_name: str) -> Collection | None:
try:
mongo_client: MongoClient = MongoClient(connection_string, serverSelectionTimeoutMS=1000)
db = mongo_client[database_name]
collection: Collection = db[collection_name]
collection.create_index({"textContent": "text"}, name="search_text_index")
return collection
except ServerSelectionTimeoutError:
return None
135 changes: 55 additions & 80 deletions src/quartapp/config.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,65 @@
import json
import os
from typing import Any
from uuid import uuid4

from pydantic.v1 import SecretStr
from pymongo.errors import (
ConfigurationError,
InvalidName,
InvalidOperation,
OperationFailure,
)

from quartapp.approaches.schemas import Context, DataPoint, JSONDataPoint, Message, RetrievalResponse, Thought
from quartapp.approaches.setup import Setup


class AppConfig:
def __init__(self) -> None:
openai_embeddings_model = os.getenv("AZURE_OPENAI_EMBEDDINGS_MODEL_NAME", "text-embedding-ada-002")
openai_embeddings_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME", "text-embedding")
openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL_NAME", "gpt-35-turbo")
openai_chat_deployment = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "chat-gpt")
connection_string = os.getenv("AZURE_COSMOS_CONNECTION_STRING", "<YOUR-COSMOS-DB-CONNECTION-STRING>")
database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME", "<COSMOS-DB-NEW-UNIQUE-DATABASE-NAME>")
collection_name = os.getenv("AZURE_COSMOS_COLLECTION_NAME", "<COSMOS-DB-NEW-UNIQUE-DATABASE-NAME>")
index_name = os.getenv("AZURE_COSMOS_INDEX_NAME", "<COSMOS-DB-NEW-UNIQUE-INDEX-NAME>")
api_key = SecretStr(os.getenv("AZURE_OPENAI_API_KEY", "<YOUR-DEPLOYMENT-KEY>"))
api_version = os.getenv("OPENAI_API_VERSION", "2023-09-15-preview")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "https://<YOUR-OPENAI-DEPLOYMENT-NAME>.openai.azure.com/")
self.setup = Setup(
openai_embeddings_model=openai_embeddings_model,
openai_embeddings_deployment=openai_embeddings_deployment,
openai_chat_model=openai_chat_model,
openai_chat_deployment=openai_chat_deployment,
connection_string=connection_string,
database_name=database_name,
collection_name=collection_name,
index_name=index_name,
api_key=api_key,
api_version=api_version,
azure_endpoint=azure_endpoint,
from quartapp.config_base import AppConfigBase


class AppConfig(AppConfigBase):

async def run_keyword(
self, session_state: str | None, messages: list, temperature: float, limit: int, score_threshold: float
) -> RetrievalResponse:
keyword_response, answer = await self.setup.keyword.run(messages, temperature, limit, score_threshold)

new_session_state: str = session_state if session_state else str(uuid4())

if keyword_response is None or len(keyword_response) == 0:
return RetrievalResponse(
session_state=new_session_state,
context=Context(DataPoint([JSONDataPoint()]), [Thought()]),
delta={"role": "assistant"},
message=Message(content="No results found", role="assistant"),
)
top_result = json.loads(answer)

message_content = f"""
Name: {top_result.get('name')}
Description: {top_result.get('description')}
Price: {top_result.get('price')}
Category: {top_result.get('category')}
Collection: {self.setup._database_setup._collection_name}
"""

data_points: DataPoint = DataPoint(json=[])
thoughts: list[Thought] = []

thoughts.append(Thought(description=keyword_response[0].metadata.get("source"), title="Source"))

for res in keyword_response:
raw_data = json.loads(res.page_content)
json_data_point: JSONDataPoint = JSONDataPoint()
json_data_point.name = raw_data.get("name")
json_data_point.description = raw_data.get("description")
json_data_point.price = raw_data.get("price")
json_data_point.category = raw_data.get("category")
json_data_point.collection = self.setup._database_setup._collection_name
data_points.json.append(json_data_point)

context: Context = Context(data_points=data_points, thoughts=thoughts)

delta: dict[str, Any] = {"role": "assistant"}
message: Message = Message(content=message_content, role="assistant")

self.add_to_cosmos(
old_messages=messages,
new_message=message.to_dict(),
session_state=session_state,
new_session_state=new_session_state,
)

def add_to_cosmos(
self, old_messages: list, new_message: dict, session_state: str | None, new_session_state: str
) -> bool:
is_first_message: bool = True if not session_state else False
if is_first_message:
try:
if len(old_messages) == 0 or len(new_message) == 0 or len(new_session_state) == 0:
raise IndexError
old_messages.append(new_message)
self.setup._database_setup._users_collection.insert_one(
{"_id": new_session_state, "messages": old_messages}
)
return True
except (AttributeError, ConfigurationError, InvalidName, InvalidOperation, OperationFailure, IndexError):
return False
else:
try:
if len(old_messages) == 0 or len(new_message) == 0 or len(new_session_state) == 0:
raise IndexError
self.setup._database_setup._users_collection.update_one(
{"_id": new_session_state}, {"$push": {"messages": old_messages[-1]}}
)
self.setup._database_setup._users_collection.update_one(
{"_id": new_session_state}, {"$push": {"messages": new_message}}
)
return True
except (AttributeError, ConfigurationError, InvalidName, InvalidOperation, OperationFailure, IndexError):
return False
return RetrievalResponse(context, delta, message, new_session_state)

async def run_vector(
self, session_state: str | None, messages: list, temperature: float, limit: int, score_threshold: float
Expand Down Expand Up @@ -175,18 +165,3 @@ async def run_rag(
)

return RetrievalResponse(context, delta, message, new_session_state)

async def run_keyword(
self, session_state: str | None, messages: list, temperature: float, limit: int, score_threshold: float
) -> RetrievalResponse:
keyword_response = None

new_session_state: str = session_state if session_state else str(uuid4())

if keyword_response is None or len(keyword_response) == 0:
return RetrievalResponse(
session_state=new_session_state,
context=Context(DataPoint([JSONDataPoint()]), [Thought()]),
delta={"role": "assistant"},
message=Message(content="No results found", role="assistant"),
)
85 changes: 85 additions & 0 deletions src/quartapp/config_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
from abc import ABC, abstractmethod

from pydantic.v1 import SecretStr
from pymongo.errors import (
ConfigurationError,
InvalidName,
InvalidOperation,
OperationFailure,
)

from quartapp.approaches.schemas import RetrievalResponse
from quartapp.approaches.setup import Setup


class AppConfigBase(ABC):
def __init__(self) -> None:
openai_embeddings_model = os.getenv("AZURE_OPENAI_EMBEDDINGS_MODEL_NAME", "text-embedding-ada-002")
openai_embeddings_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME", "text-embedding")
openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL_NAME", "gpt-35-turbo")
openai_chat_deployment = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "chat-gpt")
connection_string = os.getenv("AZURE_COSMOS_CONNECTION_STRING", "<YOUR-COSMOS-DB-CONNECTION-STRING>")
database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME", "<COSMOS-DB-NEW-UNIQUE-DATABASE-NAME>")
collection_name = os.getenv("AZURE_COSMOS_COLLECTION_NAME", "<COSMOS-DB-NEW-UNIQUE-DATABASE-NAME>")
index_name = os.getenv("AZURE_COSMOS_INDEX_NAME", "<COSMOS-DB-NEW-UNIQUE-INDEX-NAME>")
api_key = SecretStr(os.getenv("AZURE_OPENAI_API_KEY", "<YOUR-DEPLOYMENT-KEY>"))
api_version = os.getenv("OPENAI_API_VERSION", "2023-09-15-preview")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "https://<YOUR-OPENAI-DEPLOYMENT-NAME>.openai.azure.com/")
self.setup = Setup(
openai_embeddings_model=openai_embeddings_model,
openai_embeddings_deployment=openai_embeddings_deployment,
openai_chat_model=openai_chat_model,
openai_chat_deployment=openai_chat_deployment,
connection_string=connection_string,
database_name=database_name,
collection_name=collection_name,
index_name=index_name,
api_key=api_key,
api_version=api_version,
azure_endpoint=azure_endpoint,
)

def add_to_cosmos(
self, old_messages: list, new_message: dict, session_state: str | None, new_session_state: str
) -> bool:
is_first_message: bool = True if not session_state else False
if is_first_message:
try:
if len(old_messages) == 0 or len(new_message) == 0 or len(new_session_state) == 0:
raise IndexError
old_messages.append(new_message)
self.setup._database_setup._users_collection.insert_one(
{"_id": new_session_state, "messages": old_messages}
)
return True
except (AttributeError, ConfigurationError, InvalidName, InvalidOperation, OperationFailure, IndexError):
return False
else:
try:
if len(old_messages) == 0 or len(new_message) == 0 or len(new_session_state) == 0:
raise IndexError
self.setup._database_setup._users_collection.update_one(
{"_id": new_session_state}, {"$push": {"messages": old_messages[-1]}}
)
self.setup._database_setup._users_collection.update_one(
{"_id": new_session_state}, {"$push": {"messages": new_message}}
)
return True
except (AttributeError, ConfigurationError, InvalidName, InvalidOperation, OperationFailure, IndexError):
return False

@abstractmethod
async def run_vector(
self, session_state: str | None, messages: list, temperature: float, limit: int, score_threshold: float
) -> RetrievalResponse: ...

@abstractmethod
async def run_rag(
self, session_state: str | None, messages: list, temperature: float, limit: int, score_threshold: float
) -> RetrievalResponse: ...

@abstractmethod
async def run_keyword(
self, session_state: str | None, messages: list, temperature: float, limit: int, score_threshold: float
) -> RetrievalResponse: ...
Loading

0 comments on commit 86a4adc

Please sign in to comment.