Skip to content

Commit

Permalink
Vector search (#424)
Browse files Browse the repository at this point in the history
Vector search
  • Loading branch information
pablocastro authored Jul 18, 2023
1 parent 6bfb2cc commit 85791db
Show file tree
Hide file tree
Showing 18 changed files with 297 additions and 92 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ It will look like the following:
1. Run `azd env set AZURE_OPENAI_RESOURCE_GROUP {Name of existing resource group that OpenAI service is provisioned to}`
1. Run `azd env set AZURE_OPENAI_CHATGPT_DEPLOYMENT {Name of existing ChatGPT deployment}`. Only needed if your ChatGPT deployment is not the default 'chat'.
1. Run `azd env set AZURE_OPENAI_GPT_DEPLOYMENT {Name of existing GPT deployment}`. Only needed if your ChatGPT deployment is not the default 'davinci'.
1. Run `azd env set AZURE_OPENAI_EMB_DEPLOYMENT {Name of existing GPT embedding deployment}`. Only needed if your embeddings deployment is not the default 'embedding'.
1. Run `azd up`

> NOTE: You can also use existing Search and Storage Accounts. See `./infra/main.parameters.json` for list of environment variables to pass to `azd env set` to configure those existing resources.
Expand Down
16 changes: 11 additions & 5 deletions app/backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AZURE_OPENAI_GPT_DEPLOYMENT = os.environ.get("AZURE_OPENAI_GPT_DEPLOYMENT") or "davinci"
AZURE_OPENAI_CHATGPT_DEPLOYMENT = os.environ.get("AZURE_OPENAI_CHATGPT_DEPLOYMENT") or "chat"
AZURE_OPENAI_CHATGPT_MODEL = os.environ.get("AZURE_OPENAI_CHATGPT_MODEL") or "gpt-35-turbo"
AZURE_OPENAI_EMB_DEPLOYMENT = os.environ.get("AZURE_OPENAI_EMB_DEPLOYMENT") or "embedding"

KB_FIELDS_CONTENT = os.environ.get("KB_FIELDS_CONTENT") or "content"
KB_FIELDS_CATEGORY = os.environ.get("KB_FIELDS_CATEGORY") or "category"
Expand All @@ -31,7 +32,7 @@
# just use 'az login' locally, and managed identity when deployed on Azure). If you need to use keys, use separate AzureKeyCredential instances with the
# keys for each service
# If you encounter a blocking error during a DefaultAzureCredntial resolution, you can exclude the problematic credential by using a parameter (ex. exclude_shared_token_cache_credential=True)
azure_credential = DefaultAzureCredential()
azure_credential = DefaultAzureCredential(exclude_shared_token_cache_credential = True)

# Used by the OpenAI SDK
openai.api_type = "azure"
Expand All @@ -56,13 +57,18 @@
# Various approaches to integrate GPT and external knowledge, most applications will use a single one of these patterns
# or some derivative, here we include several for exploration purposes
ask_approaches = {
"rtr": RetrieveThenReadApproach(search_client, AZURE_OPENAI_CHATGPT_DEPLOYMENT, AZURE_OPENAI_CHATGPT_MODEL, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT),
"rrr": ReadRetrieveReadApproach(search_client, AZURE_OPENAI_GPT_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT),
"rda": ReadDecomposeAsk(search_client, AZURE_OPENAI_GPT_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT)
"rtr": RetrieveThenReadApproach(search_client, AZURE_OPENAI_CHATGPT_DEPLOYMENT, AZURE_OPENAI_CHATGPT_MODEL, AZURE_OPENAI_EMB_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT),
"rrr": ReadRetrieveReadApproach(search_client, AZURE_OPENAI_GPT_DEPLOYMENT, AZURE_OPENAI_EMB_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT),
"rda": ReadDecomposeAsk(search_client, AZURE_OPENAI_GPT_DEPLOYMENT, AZURE_OPENAI_EMB_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT)
}

chat_approaches = {
"rrr": ChatReadRetrieveReadApproach(search_client, AZURE_OPENAI_CHATGPT_DEPLOYMENT, AZURE_OPENAI_CHATGPT_MODEL, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT)
"rrr": ChatReadRetrieveReadApproach(search_client,
AZURE_OPENAI_CHATGPT_DEPLOYMENT,
AZURE_OPENAI_CHATGPT_MODEL,
AZURE_OPENAI_EMB_DEPLOYMENT,
KB_FIELDS_SOURCEPAGE,
KB_FIELDS_CONTENT)
}

app = Flask(__name__)
Expand Down
69 changes: 45 additions & 24 deletions app/backend/approaches/chatreadretrieveread.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,23 @@ class ChatReadRetrieveReadApproach(Approach):
"""
system_message_chat_conversation = """Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.
For tabular information return it as an html table. Do not return markdown format.
For tabular information return it as an html table. Do not return markdown format. If the question is not in English, answer in the language used in the question.
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, e.g. [info1.txt]. Don't combine sources, list each source separately, e.g. [info1.txt][info2.pdf].
{follow_up_questions_prompt}
{injected_prompt}
"""
follow_up_questions_prompt_content = """Generate three very brief follow-up questions that the user would likely ask next about their healthcare plan and employee handbook.
Use double angle brackets to reference the questions, e.g. <<Are there exclusions for prescriptions?>>.
Try not to repeat questions that have already been asked.
Only generate questions and do not generate any text before or after the questions, such as 'Next Questions'"""
Use double angle brackets to reference the questions, e.g. <<Are there exclusions for prescriptions?>>.
Try not to repeat questions that have already been asked.
Only generate questions and do not generate any text before or after the questions, such as 'Next Questions'"""

query_prompt_template = """Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching in a knowledge base about employee healthcare plans and the employee handbook.
Generate a search query based on the conversation and the new question.
Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms.
Do not include any text inside [] or <<>> in the search query terms.
Do not include any special characters like '+'.
If the question is not in English, translate the question to English before generating the search query.
Search Query:
Generate a search query based on the conversation and the new question.
Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms.
Do not include any text inside [] or <<>> in the search query terms.
Do not include any special characters like '+'.
If the question is not in English, translate the question to English before generating the search query.
If you cannot generate a search query, return just the number 0.
"""
query_prompt_few_shots = [
{'role' : USER, 'content' : 'What are my health plans?' },
Expand All @@ -49,16 +48,19 @@ class ChatReadRetrieveReadApproach(Approach):
{'role' : ASSISTANT, 'content' : 'Health plan cardio coverage' }
]

def __init__(self, search_client: SearchClient, chatgpt_deployment: str, chatgpt_model: str, sourcepage_field: str, content_field: str):
def __init__(self, search_client: SearchClient, chatgpt_deployment: str, chatgpt_model: str, embedding_deployment: str, sourcepage_field: str, content_field: str):
self.search_client = search_client
self.chatgpt_deployment = chatgpt_deployment
self.chatgpt_model = chatgpt_model
self.embedding_deployment = embedding_deployment
self.sourcepage_field = sourcepage_field
self.content_field = content_field
self.chatgpt_token_limit = get_token_limit(chatgpt_model)

def run(self, history: Sequence[dict[str, str]], overrides: dict[str, Any]) -> Any:
use_semantic_captions = True if overrides.get("semantic_captions") else False
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top") or 3
exclude_category = overrides.get("exclude_category") or None
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
Expand All @@ -83,20 +85,42 @@ def run(self, history: Sequence[dict[str, str]], overrides: dict[str, Any]) -> A
max_tokens=32,
n=1)

q = chat_completion.choices[0].message.content
query_text = chat_completion.choices[0].message.content
if query_text.strip() == "0":
query_text = history[-1]["user"] # Use the last user input if we failed to generate a better query

# STEP 2: Retrieve relevant documents from the search index with the GPT optimized query
if overrides.get("semantic_ranker"):
r = self.search_client.search(q,

# If retrieval mode includes vectors, compute an embedding for the query
if has_vector:
query_vector = openai.Embedding.create(engine=self.embedding_deployment, input=query_text)["data"][0]["embedding"]
else:
query_vector = None

# Only keep the text query if the retrieval mode uses text, otherwise drop it
if not has_text:
query_text = None

# Use semantic L2 reranker if requested and if retrieval mode is text or hybrid (vectors + text)
if overrides.get("semantic_ranker") and has_text:
r = self.search_client.search(query_text,
filter=filter,
query_type=QueryType.SEMANTIC,
query_language="en-us",
query_speller="lexicon",
semantic_configuration_name="default",
top=top,
query_caption="extractive|highlight-false" if use_semantic_captions else None)
query_caption="extractive|highlight-false" if use_semantic_captions else None,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="embedding" if query_vector else None)
else:
r = self.search_client.search(q, filter=filter, top=top)
r = self.search_client.search(query_text,
filter=filter,
top=top,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="embedding" if query_vector else None)
if use_semantic_captions:
results = [doc[self.sourcepage_field] + ": " + nonewlines(" . ".join([c.text for c in doc['@search.captions']])) for doc in r]
else:
Expand All @@ -116,14 +140,11 @@ def run(self, history: Sequence[dict[str, str]], overrides: dict[str, Any]) -> A
else:
system_message = prompt_override.format(follow_up_questions_prompt=follow_up_questions_prompt)

# latest conversation
user_content = history[-1]["user"] + " \nSources:" + content

messages = self.get_messages_from_history(
system_message,
system_message + "\n\nSources:\n" + content,
self.chatgpt_model,
history,
user_content,
history[-1]["user"],
max_tokens=self.chatgpt_token_limit)

chat_completion = openai.ChatCompletion.create(
Expand All @@ -138,7 +159,7 @@ def run(self, history: Sequence[dict[str, str]], overrides: dict[str, Any]) -> A

msg_to_display = '\n\n'.join([str(message) for message in messages])

return {"data_points": results, "answer": chat_content, "thoughts": f"Searched for:<br>{q}<br><br>Conversations:<br>" + msg_to_display.replace('\n', '<br>')}
return {"data_points": results, "answer": chat_content, "thoughts": f"Searched for:<br>{query_text}<br><br>Conversations:<br>" + msg_to_display.replace('\n', '<br>')}

def get_messages_from_history(self, system_prompt: str, model_id: str, history: Sequence[dict[str, str]], user_conv: str, few_shots = [], max_tokens: int = 4096) -> []:
message_builder = MessageBuilder(system_prompt, model_id)
Expand Down
37 changes: 29 additions & 8 deletions app/backend/approaches/readdecomposeask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,50 @@
from typing import Any, List, Optional

class ReadDecomposeAsk(Approach):
def __init__(self, search_client: SearchClient, openai_deployment: str, sourcepage_field: str, content_field: str):
def __init__(self, search_client: SearchClient, openai_deployment: str, embedding_deployment: str, sourcepage_field: str, content_field: str):
self.search_client = search_client
self.openai_deployment = openai_deployment
self.embedding_deployment = embedding_deployment
self.sourcepage_field = sourcepage_field
self.content_field = content_field

def search(self, q: str, overrides: dict[str, Any]) -> str:
use_semantic_captions = True if overrides.get("semantic_captions") else False
def search(self, query_text: str, overrides: dict[str, Any]) -> str:
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top") or 3
exclude_category = overrides.get("exclude_category") or None
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None

if overrides.get("semantic_ranker"):
r = self.search_client.search(q,
# If retrieval mode includes vectors, compute an embedding for the query
if has_vector:
query_vector = openai.Embedding.create(engine=self.embedding_deployment, input=query_text)["data"][0]["embedding"]
else:
query_vector = None

# Only keep the text query if the retrieval mode uses text, otherwise drop it
if not has_text:
query_text = None

if overrides.get("semantic_ranker") and has_text:
r = self.search_client.search(query_text,
filter=filter,
query_type=QueryType.SEMANTIC,
query_language="en-us",
query_speller="lexicon",
semantic_configuration_name="default",
top = top,
query_caption="extractive|highlight-false" if use_semantic_captions else None)
top=top,
query_caption="extractive|highlight-false" if use_semantic_captions else None,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="embedding" if query_vector else None)
else:
r = self.search_client.search(q, filter=filter, top=top)
r = self.search_client.search(query_text,
filter=filter,
top=top,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="embedding" if query_vector else None)
if use_semantic_captions:
self.results = [doc[self.sourcepage_field] + ":" + nonewlines(" . ".join([c.text for c in doc['@search.captions'] ])) for doc in r]
else:
Expand Down
36 changes: 29 additions & 7 deletions app/backend/approaches/readretrieveread.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,51 @@ class ReadRetrieveReadApproach(Approach):

CognitiveSearchToolDescription = "useful for searching the Microsoft employee benefits information such as healthcare plans, retirement plans, etc."

def __init__(self, search_client: SearchClient, openai_deployment: str, sourcepage_field: str, content_field: str):
def __init__(self, search_client: SearchClient, openai_deployment: str, embedding_deployment: str, sourcepage_field: str, content_field: str):
self.search_client = search_client
self.openai_deployment = openai_deployment
self.embedding_deployment = embedding_deployment
self.sourcepage_field = sourcepage_field
self.content_field = content_field

def retrieve(self, q: str, overrides: dict[str, Any]) -> Any:
use_semantic_captions = True if overrides.get("semantic_captions") else False
def retrieve(self, query_text: str, overrides: dict[str, Any]) -> Any:
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top") or 3
exclude_category = overrides.get("exclude_category") or None
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None

if overrides.get("semantic_ranker"):
r = self.search_client.search(q,
# If retrieval mode includes vectors, compute an embedding for the query
if has_vector:
query_vector = openai.Embedding.create(engine=self.embedding_deployment, input=query_text)["data"][0]["embedding"]
else:
query_vector = None

# Only keep the text query if the retrieval mode uses text, otherwise drop it
if not has_text:
query_text = None

# Use semantic ranker if requested and if retrieval mode is text or hybrid (vectors + text)
if overrides.get("semantic_ranker") and has_text:
r = self.search_client.search(query_text,
filter=filter,
query_type=QueryType.SEMANTIC,
query_language="en-us",
query_speller="lexicon",
semantic_configuration_name="default",
top = top,
query_caption="extractive|highlight-false" if use_semantic_captions else None)
query_caption="extractive|highlight-false" if use_semantic_captions else None,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="embedding" if query_vector else None)
else:
r = self.search_client.search(q, filter=filter, top=top)
r = self.search_client.search(query_text,
filter=filter,
top=top,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="embedding" if query_vector else None)
if use_semantic_captions:
self.results = [doc[self.sourcepage_field] + ":" + nonewlines(" -.- ".join([c.text for c in doc['@search.captions']])) for doc in r]
else:
Expand Down
Loading

0 comments on commit 85791db

Please sign in to comment.