Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tools: unify retrievers/functions and add file tools #164

Merged
merged 20 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"
pypdf = "^4.2.0"

[tool.poetry.group.dev]
optional = true
Expand Down
30 changes: 30 additions & 0 deletions src/backend/alembic/versions/78b159bee0a6_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""empty message

Revision ID: 78b159bee0a6
Revises: c15b848babe3
Create Date: 2024-06-03 12:46:37.439991

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "78b159bee0a6"
down_revision: Union[str, None] = "c15b848babe3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("file_content", sa.String(), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("files", "file_content")
# ### end Alembic commands ###
143 changes: 103 additions & 40 deletions src/backend/chat/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,141 @@

from backend.model_deployments.base import BaseDeployment

RELEVANCE_THRESHOLD = 0.5


def combine_documents(
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
documents: Dict[str, List[Dict[str, Any]]],
tool_results: List[Dict[str, Any]],
model: BaseDeployment,
) -> List[Dict[str, Any]]:
"""
Combines documents from different retrievers and reranks them.

Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
tool_results (List[Dict[str, Any]]): List of tool_results from different retrievers.
Each tool_result contains a ToolCall and a list of Outputs.
model (BaseDeployment): Model deployment.

Returns:
List[Dict[str, Any]]: List of combined documents.
"""
reranked_documents = rerank(documents, model)
return interleave(reranked_documents)
return rerank_and_chunk(tool_results, model)


def rerank(
documents_by_query: Dict[str, List[Dict[str, Any]]], model: BaseDeployment
def rerank_and_chunk(
tool_results: List[Dict[str, Any]], model: BaseDeployment
) -> Dict[str, List[Dict[str, Any]]]:
"""
Takes a dictionary from queries of lists of documents and
internally rerank the documents for each query e.g:
Takes a list of tool_results and internally reranks the documents for each query, if there's one e.g:
[{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [{"q1":[2 , 3, 1],"q2": [4, 6, 5]]

Args:
documents_by_query (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
tool_results (List[Dict[str, Any]]): List of tool_results from different retrievers.
Each tool_result contains a ToolCall and a list of Outputs.
model (BaseDeployment): Model deployment.

Returns:
Dict[str, List[Dict[str, Any]]]: Dictionary from queries of lists of reranked documents.
"""
# If rerank is not enabled return documents as is:
if not model.rerank_enabled:
return documents_by_query
return tool_results

reranked_results = {}
non_reranked_results = {}
for tool_result in tool_results:
tool_call = tool_result["call"]

# Only rerank if there is a query
if not tool_call.parameters.get("query") and not tool_call.parameters.get(
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
"search_query"
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
):
tool_call_hashable = str(tool_call)
if tool_call_hashable not in non_reranked_results.keys():
non_reranked_results[tool_call_hashable] = {
"call": tool_call,
"outputs": [],
}

non_reranked_results[tool_call_hashable]["outputs"].extend(
tool_result["outputs"]
)
continue

# rerank the documents by each query
all_rerank_docs = {}
for query, documents in documents_by_query.items():
# Only rerank on text of document
# TODO handle no text in document
docs_to_rerank = [doc["text"] for doc in documents]
query = tool_call.parameters.get("query") or tool_call.parameters.get(
"search_query"
)
lusmoura marked this conversation as resolved.
Show resolved Hide resolved

# If no documents to rerank, continue to the next query
if not docs_to_rerank:
continue
chunked_outputs = []
for output in tool_result["outputs"]:
text = output.get("text")

res = model.invoke_rerank(query=query, documents=docs_to_rerank)
# Sort the results by relevance score
res.results.sort(key=lambda x: x.relevance_score, reverse=True)
# Map the results back to the original documents
all_rerank_docs[query] = [documents[r.index] for r in res.results]
if not text:
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
chunked_outputs.append([output])
continue

return all_rerank_docs
chunks = chunk(text)
chunked_outputs.extend([dict(output, text=chunk) for chunk in chunks])

# If no documents to rerank, continue to the next query
if not chunked_outputs:
continue

def interleave(documents: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""
Takes a dictionary from queries of lists of documents and interleaves them
for example [{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [1, 4, 2, 5, 3, 6]
res = model.invoke_rerank(query=query, documents=chunked_outputs)

Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
# Sort the results by relevance score
res.results.sort(key=lambda x: x.relevance_score, reverse=True)

Returns:
List[Dict[str, Any]]: List of interleaved documents.
"""
return [
y
for x in zip_longest(*documents.values(), fillvalue=None)
for y in x
if y is not None
]
# Map the results back to the original documents
# Merges the results with the same tool call and parameters
tool_call_hashable = str(tool_call)
if tool_call_hashable not in reranked_results.keys():
reranked_results[tool_call_hashable] = {"call": tool_call, "outputs": []}

reranked_results[tool_call_hashable]["outputs"].extend(
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
[
chunked_outputs[r.index]
for r in res.results
if r.relevance_score > RELEVANCE_THRESHOLD
]
)

# Return the reranked results followed by the non-reranked results
return list(reranked_results.values()) + list(non_reranked_results.values())


def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300):
if compact_mode:
content = content.replace("\n", " ")

chunks = []
current_chunk = ""
words = content.split()
word_count = 0

for word in words:
if word_count + len(word.split()) > hard_word_cut_off:
# If adding the next word exceeds the hard limit, finalize the current chunk
chunks.append(current_chunk)
current_chunk = ""
word_count = 0

if word_count + len(word.split()) > soft_word_cut_off and word.endswith("."):
# If adding the next word exceeds the soft limit and the word ends with a period, finalize the current chunk
current_chunk += " " + word
chunks.append(current_chunk.strip())
current_chunk = ""
word_count = 0
else:
# Add the word to the current chunk
if current_chunk == "":
current_chunk = word
else:
current_chunk += " " + word
word_count += len(word.split())

# Add any remaining content as the last chunk
if current_chunk != "":
chunks.append(current_chunk.strip())

return chunks
Loading
Loading