Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tools: unify retrievers/functions and add file tools #164

Merged
merged 20 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"
pypdf = "^4.2.0"

[tool.poetry.group.dev]
optional = true
Expand Down
30 changes: 30 additions & 0 deletions src/backend/alembic/versions/78b159bee0a6_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""empty message

Revision ID: 78b159bee0a6
Revises: c15b848babe3
Create Date: 2024-06-03 12:46:37.439991

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "78b159bee0a6"
down_revision: Union[str, None] = "c15b848babe3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("file_content", sa.String(), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("files", "file_content")
# ### end Alembic commands ###
126 changes: 87 additions & 39 deletions src/backend/chat/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,126 @@

from backend.model_deployments.base import BaseDeployment

RELEVANCE_THRESHOLD = 0.5


def combine_documents(
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
documents: Dict[str, List[Dict[str, Any]]],
tool_results: List[Dict[str, Any]],
model: BaseDeployment,
) -> List[Dict[str, Any]]:
"""
Combines documents from different retrievers and reranks them.

Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
tool_results (List[Dict[str, Any]]): List of tool_results from different retrievers.
Each tool_result contains a ToolCall and a list of Outputs.
model (BaseDeployment): Model deployment.

Returns:
List[Dict[str, Any]]: List of combined documents.
"""
reranked_documents = rerank(documents, model)
return interleave(reranked_documents)
return rerank_and_chunk(tool_results, model)


def rerank(
documents_by_query: Dict[str, List[Dict[str, Any]]], model: BaseDeployment
def rerank_and_chunk(
tool_resuls: List[Dict[str, Any]], model: BaseDeployment
) -> Dict[str, List[Dict[str, Any]]]:
"""
Takes a dictionary from queries of lists of documents and
internally rerank the documents for each query e.g:
Takes a list of tool_results and internally reranks the documents for each query, if there's one e.g:
[{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [{"q1":[2 , 3, 1],"q2": [4, 6, 5]]

Args:
documents_by_query (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
tool_resuls (List[Dict[str, Any]]): List of tool_results from different retrievers.
Each tool_result contains a ToolCall and a list of Outputs.
model (BaseDeployment): Model deployment.

Returns:
Dict[str, List[Dict[str, Any]]]: Dictionary from queries of lists of reranked documents.
"""
# If rerank is not enabled return documents as is:
if not model.rerank_enabled:
return documents_by_query
return tool_resuls
lusmoura marked this conversation as resolved.
Show resolved Hide resolved

# rerank the documents by each query
all_rerank_docs = {}
for query, documents in documents_by_query.items():
# Only rerank on text of document
# TODO handle no text in document
docs_to_rerank = [doc["text"] for doc in documents]
reranked_results = {}
for tool_result in tool_resuls:
tool_call = tool_result["call"]

# If no documents to rerank, continue to the next query
if not docs_to_rerank:
if not tool_call.parameters.get("query") and not tool_call.parameters.get(
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
"search_query"
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
):
continue

res = model.invoke_rerank(query=query, documents=docs_to_rerank)
# Sort the results by relevance score
res.results.sort(key=lambda x: x.relevance_score, reverse=True)
# Map the results back to the original documents
all_rerank_docs[query] = [documents[r.index] for r in res.results]
query = tool_call.parameters.get("query") or tool_call.parameters.get(
"search_query"
)
lusmoura marked this conversation as resolved.
Show resolved Hide resolved

return all_rerank_docs
chunked_outputs = []
for output in tool_result["outputs"]:
text = output.get("text")
if not text:
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
continue
# create dict with all the existing keys, but replace the text with the chunked text
chunks = chunk(text)
chunked_outputs.extend([dict(output, text=chunk) for chunk in chunks])

# If no documents to rerank, continue to the next query
if not chunked_outputs:
continue

def interleave(documents: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""
Takes a dictionary from queries of lists of documents and interleaves them
for example [{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [1, 4, 2, 5, 3, 6]
res = model.invoke_rerank(query=query, documents=chunked_outputs)

Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
# Sort the results by relevance score
res.results.sort(key=lambda x: x.relevance_score, reverse=True)

Returns:
List[Dict[str, Any]]: List of interleaved documents.
"""
return [
y
for x in zip_longest(*documents.values(), fillvalue=None)
for y in x
if y is not None
]
# Map the results back to the original documents
# Merges the results with the same tool call and parameters
tool_call_hashable = str(tool_call)
if tool_call_hashable not in reranked_results.keys():
reranked_results[tool_call_hashable] = {"call": tool_call, "outputs": []}

reranked_results[tool_call_hashable]["outputs"].extend(
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
[
chunked_outputs[r.index]
for r in res.results
if r.relevance_score > RELEVANCE_THRESHOLD
]
)

return list(reranked_results.values())


def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300):
if compact_mode:
content = content.replace("\n", " ")

chunks = []
current_chunk = ""
words = content.split()
word_count = 0

for word in words:
if word_count + len(word.split()) > hard_word_cut_off:
# If adding the next word exceeds the hard limit, finalize the current chunk
chunks.append(current_chunk)
current_chunk = ""
word_count = 0

if word_count + len(word.split()) > soft_word_cut_off and word.endswith("."):
# If adding the next word exceeds the soft limit and the word ends with a period, finalize the current chunk
current_chunk += " " + word
chunks.append(current_chunk.strip())
current_chunk = ""
word_count = 0
else:
# Add the word to the current chunk
if current_chunk == "":
current_chunk = word
else:
current_chunk += " " + word
word_count += len(word.split())

# Add any remaining content as the last chunk
if current_chunk != "":
chunks.append(current_chunk.strip())

return chunks
Loading
Loading