Skip to content

Commit

Permalink
chunk all docs
Browse files Browse the repository at this point in the history
  • Loading branch information
lusmoura committed Jun 5, 2024
1 parent 84baa28 commit 79f34b9
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 109 deletions.
122 changes: 81 additions & 41 deletions src/backend/chat/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,118 @@

from backend.model_deployments.base import BaseDeployment

RELEVANCE_THRESHOLD = 0.5

def combine_documents(
documents: Dict[str, List[Dict[str, Any]]],
tool_results: List[Dict[str, Any]],
model: BaseDeployment,
) -> List[Dict[str, Any]]:
"""
Combines documents from different retrievers and reranks them.
Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
tool_results (List[Dict[str, Any]]): List of tool_results from different retrievers.
Each tool_result contains a ToolCall and a list of Outputs.
model (BaseDeployment): Model deployment.
Returns:
List[Dict[str, Any]]: List of combined documents.
"""
reranked_documents = rerank(documents, model)
return interleave(reranked_documents)
return rerank_and_chunk(tool_results, model)


def rerank(
documents_by_query: Dict[str, List[Dict[str, Any]]], model: BaseDeployment
def rerank_and_chunk(
tool_resuls: List[Dict[str, Any]],
model: BaseDeployment
) -> Dict[str, List[Dict[str, Any]]]:
"""
Takes a dictionary from queries of lists of documents and
internally rerank the documents for each query e.g:
Takes a list of tool_results and internally reranks the documents for each query, if there's one e.g:
[{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [{"q1":[2 , 3, 1],"q2": [4, 6, 5]]
Args:
documents_by_query (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
tool_resuls (List[Dict[str, Any]]): List of tool_results from different retrievers.
Each tool_result contains a ToolCall and a list of Outputs.
model (BaseDeployment): Model deployment.
Returns:
Dict[str, List[Dict[str, Any]]]: Dictionary from queries of lists of reranked documents.
"""
# If rerank is not enabled return documents as is:
if not model.rerank_enabled:
return documents_by_query

# rerank the documents by each query
all_rerank_docs = {}
for query, documents in documents_by_query.items():
# Only rerank on text of document
# TODO handle no text in document
docs_to_rerank = [doc["text"] for doc in documents]
return tool_resuls

reranked_results = {}
for tool_result in tool_resuls:
tool_call = tool_result["call"]

if not tool_call.parameters.get("query") and not tool_call.parameters.get("search_query"):
continue

query = tool_call.parameters.get("query") or tool_call.parameters.get("search_query")

chunked_outputs = []
for output in tool_result["outputs"]:
text = output.get("text")
if not text:
continue
# create dict with all the existing keys, but replace the text with the chunked text
chunks = chunk(text)
chunked_outputs.extend([dict(output, text=chunk) for chunk in chunks])

# If no documents to rerank, continue to the next query
if not docs_to_rerank:
if not chunked_outputs:
continue

res = model.invoke_rerank(query=query, documents=docs_to_rerank)
res = model.invoke_rerank(query=query, documents=chunked_outputs)

# Sort the results by relevance score
res.results.sort(key=lambda x: x.relevance_score, reverse=True)

# Map the results back to the original documents
all_rerank_docs[query] = [documents[r.index] for r in res.results]

return all_rerank_docs


def interleave(documents: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""
Takes a dictionary from queries of lists of documents and interleaves them
for example [{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [1, 4, 2, 5, 3, 6]
Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
Returns:
List[Dict[str, Any]]: List of interleaved documents.
"""
return [
y
for x in zip_longest(*documents.values(), fillvalue=None)
for y in x
if y is not None
]
# Merges the results with the same tool call and parameters
tool_call_hashable = str(tool_call)
if tool_call_hashable not in reranked_results.keys():
reranked_results[tool_call_hashable] = {"call": tool_call, "outputs": []}

reranked_results[tool_call_hashable]["outputs"].extend([chunked_outputs[r.index] for r in res.results if r.relevance_score > RELEVANCE_THRESHOLD])

return list(reranked_results.values())


def chunk(
content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300
):
if compact_mode:
content = content.replace("\n", " ")

chunks = []
current_chunk = ""
words = content.split()
word_count = 0

for word in words:
if word_count + len(word.split()) > hard_word_cut_off:
# If adding the next word exceeds the hard limit, finalize the current chunk
chunks.append(current_chunk)
current_chunk = ""
word_count = 0

if word_count + len(word.split()) > soft_word_cut_off and word.endswith("."):
# If adding the next word exceeds the soft limit and the word ends with a period, finalize the current chunk
current_chunk += " " + word
chunks.append(current_chunk.strip())
current_chunk = ""
word_count = 0
else:
# Add the word to the current chunk
if current_chunk == "":
current_chunk = word
else:
current_chunk += " " + word
word_count += len(word.split())

# Add any remaining content as the last chunk
if current_chunk != "":
chunks.append(current_chunk.strip())

return chunks
11 changes: 6 additions & 5 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
"""
# Choose the deployment model - validation already performed by request validator
deployment_model = get_deployment(kwargs.get("deployment_name"), **kwargs)
print(f"Using deployment {deployment_model.__class__.__name__}")
self.logger.info(f"Using deployment {deployment_model.__class__.__name__}")

if len(chat_request.tools) > 0 and len(chat_request.documents) > 0:
raise HTTPException(
Expand Down Expand Up @@ -121,7 +121,6 @@ def get_tool_results(
deployment_model: BaseDeployment,
kwargs: Any,
) -> Any:
tool_results = []
"""
Invokes the tools and returns the results. If no tools calls are generated, it returns the chat response
as a direct answer.
Expand All @@ -138,6 +137,7 @@ def get_tool_results(
Any: The tool results or the chat response, and a boolean indicating if a direct answer was generated
"""
tool_results = []

# If the tool is Read_File or SearchFile, add the available files to the chat history
# so that the model knows what files are available
Expand All @@ -150,7 +150,7 @@ def get_tool_results(
kwargs.get("user_id"),
)

print(f"Invoking tools: {tools}")
self.logger.info(f"Invoking tools: {tools}")
stream = deployment_model.invoke_tools(
message, tools, chat_history=chat_history
)
Expand All @@ -168,7 +168,7 @@ def get_tool_results(
if second_event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION:
tool_calls = second_event["tool_calls"]

print(f"Using tools: {tool_calls}")
self.logger.info(f"Tool calls: {tool_calls}")

# TODO: parallelize tool calls
for tool_call in tool_calls:
Expand All @@ -190,7 +190,8 @@ def get_tool_results(
for output in outputs:
tool_results.append({"call": tool_call, "outputs": [output]})

print(f"Tool results: {tool_results}")
self.logger.info(f"Tool results: {tool_results}")
tool_results = combine_documents(tool_results, deployment_model)
yield tool_results, False

else:
Expand Down
1 change: 0 additions & 1 deletion src/backend/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ def generate_chat_stream(

stream_event = None
for event in model_deployment_stream:
print(f"Event: {event}")
if event["event_type"] == StreamEvent.STREAM_START:
stream_event = StreamStart.model_validate(event)
response_message.generation_id = event["generation_id"]
Expand Down
73 changes: 11 additions & 62 deletions src/backend/tools/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pypdf import PdfReader

import backend.crud.file as file_crud
from backend.chat.collate import combine_documents
from backend.tools.base import BaseTool


Expand Down Expand Up @@ -35,18 +34,14 @@ def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]:
return []

file = files[0]
chunks = chunk_document(file.file_content)
result = [
return [
{
"text": chunk,
"text": file.file_content,
"title": file.file_name,
"url": file.file_path,
}
for chunk in chunks
]

return result


class SearchFileTool(BaseTool):
"""
Expand Down Expand Up @@ -77,25 +72,17 @@ def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]:
if not files:
return []

files_dicts = []
results = []
for file in files:
chunks = chunk_document(file.file_content)
for chunk in chunks:
files_dicts.append(
{
"text": chunk,
"title": file.file_name,
"url": file.file_path,
}
)

# Combine and rerank the documents
result = combine_documents({query: files_dicts}, model_deployment)

# return top results
num_chunks = min(len(result), self.MAX_NUM_CHUNKS)
return result[:num_chunks]
results.append(
{
"text": file.file_content,
"title": file.file_name,
"url": file.file_path,
}
)

return results

def get_file_content(file_path):
# Currently only supports PDF files
Expand All @@ -106,41 +93,3 @@ def get_file_content(file_path):

return text


def chunk_document(
content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300
):
if compact_mode:
content = content.replace("\n", " ")

chunks = []
current_chunk = ""
words = content.split()
word_count = 0

for word in words:
if word_count + len(word.split()) > hard_word_cut_off:
# If adding the next word exceeds the hard limit, finalize the current chunk
chunks.append(current_chunk)
current_chunk = ""
word_count = 0

if word_count + len(word.split()) > soft_word_cut_off and word.endswith("."):
# If adding the next word exceeds the soft limit and the word ends with a period, finalize the current chunk
current_chunk += " " + word
chunks.append(current_chunk.strip())
current_chunk = ""
word_count = 0
else:
# Add the word to the current chunk
if current_chunk == "":
current_chunk = word
else:
current_chunk += " " + word
word_count += len(word.split())

# Add any remaining content as the last chunk
if current_chunk != "":
chunks.append(current_chunk.strip())

return chunks

0 comments on commit 79f34b9

Please sign in to comment.