Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tools: unify retrievers/functions and add file tools #164

Merged
merged 20 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"
pypdf = "^4.2.0"
pyjwt = "^2.8.0"

[tool.poetry.group.dev]
Expand Down
30 changes: 30 additions & 0 deletions src/backend/alembic/versions/78b159bee0a6_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""empty message

Revision ID: 78b159bee0a6
Revises: c15b848babe3
Create Date: 2024-06-03 12:46:37.439991

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "78b159bee0a6"
down_revision: Union[str, None] = "c15b848babe3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("file_content", sa.String(), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("files", "file_content")
# ### end Alembic commands ###
154 changes: 99 additions & 55 deletions src/backend/chat/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,122 @@

from backend.model_deployments.base import BaseDeployment


def combine_documents(
documents: Dict[str, List[Dict[str, Any]]],
model: BaseDeployment,
) -> List[Dict[str, Any]]:
"""
Combines documents from different retrievers and reranks them.

Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
model (BaseDeployment): Model deployment.

Returns:
List[Dict[str, Any]]: List of combined documents.
"""
reranked_documents = rerank(documents, model)
return interleave(reranked_documents)
RELEVANCE_THRESHOLD = 0.5


def rerank(
documents_by_query: Dict[str, List[Dict[str, Any]]], model: BaseDeployment
) -> Dict[str, List[Dict[str, Any]]]:
def rerank_and_chunk(
tool_results: List[Dict[str, Any]], model: BaseDeployment
) -> List[Dict[str, Any]]:
"""
Takes a dictionary from queries of lists of documents and
internally rerank the documents for each query e.g:
Takes a list of tool_results and internally reranks the documents for each query, if there's one e.g:
[{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [{"q1":[2 , 3, 1],"q2": [4, 6, 5]]

Args:
documents_by_query (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
tool_results (List[Dict[str, Any]]): List of tool_results from different retrievers.
Each tool_result contains a ToolCall and a list of Outputs.
model (BaseDeployment): Model deployment.

Returns:
Dict[str, List[Dict[str, Any]]]: Dictionary from queries of lists of reranked documents.
List[Dict[str, Any]]: List of reranked and combined documents.
"""
# If rerank is not enabled return documents as is:
if not model.rerank_enabled:
return documents_by_query
return tool_results

# Merge all the documents with the same tool call and parameters
unified_tool_results = {}
for tool_result in tool_results:
tool_call = tool_result["call"]
tool_call_hashable = str(tool_call)

if tool_call_hashable not in unified_tool_results.keys():
unified_tool_results[tool_call_hashable] = {
"call": tool_call,
"outputs": [],
}

unified_tool_results[tool_call_hashable]["outputs"].extend(
tool_result["outputs"]
)

# Rerank the documents for each query
reranked_results = {}
for tool_call_hashable, tool_result in unified_tool_results.items():
tool_call = tool_result["call"]
query = tool_call.parameters.get("query") or tool_call.parameters.get(
"search_query"
)
lusmoura marked this conversation as resolved.
Show resolved Hide resolved

# Only rerank if there is a query
if not query:
reranked_results[tool_call_hashable] = tool_result
continue

chunked_outputs = []
for output in tool_result["outputs"]:
text = output.get("text")

# rerank the documents by each query
all_rerank_docs = {}
for query, documents in documents_by_query.items():
# Only rerank on text of document
# TODO handle no text in document
docs_to_rerank = [doc["text"] for doc in documents]
if not text:
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
chunked_outputs.append([output])
continue

chunks = chunk(text)
chunked_outputs.extend([dict(output, text=chunk) for chunk in chunks])

# If no documents to rerank, continue to the next query
if not docs_to_rerank:
if not chunked_outputs:
continue

res = model.invoke_rerank(query=query, documents=docs_to_rerank)
res = model.invoke_rerank(query=query, documents=chunked_outputs)

# Sort the results by relevance score
res.results.sort(key=lambda x: x.relevance_score, reverse=True)
# Map the results back to the original documents
all_rerank_docs[query] = [documents[r.index] for r in res.results]

return all_rerank_docs


def interleave(documents: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""
Takes a dictionary from queries of lists of documents and interleaves them
for example [{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [1, 4, 2, 5, 3, 6]

Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.

Returns:
List[Dict[str, Any]]: List of interleaved documents.
"""
return [
y
for x in zip_longest(*documents.values(), fillvalue=None)
for y in x
if y is not None
]
# Map the results back to the original documents
reranked_results[tool_call_hashable] = {
"call": tool_call,
"outputs": [
chunked_outputs[r.index]
for r in res.results
if r.relevance_score > RELEVANCE_THRESHOLD
],
}

return list(reranked_results.values())


def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300):
if compact_mode:
content = content.replace("\n", " ")

chunks = []
current_chunk = ""
words = content.split()
word_count = 0

for word in words:
if word_count + len(word.split()) > hard_word_cut_off:
# If adding the next word exceeds the hard limit, finalize the current chunk
chunks.append(current_chunk)
current_chunk = ""
word_count = 0

if word_count + len(word.split()) > soft_word_cut_off and word.endswith("."):
# If adding the next word exceeds the soft limit and the word ends with a period, finalize the current chunk
current_chunk += " " + word
chunks.append(current_chunk.strip())
current_chunk = ""
word_count = 0
else:
# Add the word to the current chunk
if current_chunk == "":
current_chunk = word
else:
current_chunk += " " + word
word_count += len(word.split())

# Add any remaining content as the last chunk
if current_chunk != "":
chunks.append(current_chunk.strip())

return chunks
Loading
Loading