Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tools: unify retrievers/functions and add file tools #164

Merged
merged 20 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"
pypdf = "^4.2.0"

[tool.poetry.group.dev]
optional = true
Expand Down
30 changes: 30 additions & 0 deletions src/backend/alembic/versions/78b159bee0a6_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""empty message

Revision ID: 78b159bee0a6
Revises: c15b848babe3
Create Date: 2024-06-03 12:46:37.439991

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "78b159bee0a6"
down_revision: Union[str, None] = "c15b848babe3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("file_content", sa.String(), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("files", "file_content")
# ### end Alembic commands ###
277 changes: 167 additions & 110 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import logging
from typing import Any
from typing import Any, Dict, Generator, List

from fastapi import HTTPException

from backend.chat.base import BaseChat
from backend.chat.collate import combine_documents
from backend.chat.custom.utils import get_deployment
from backend.chat.enums import StreamEvent
from backend.config.tools import AVAILABLE_TOOLS, ToolName
from backend.crud.file import get_files_by_conversation_id
from backend.model_deployments.base import BaseDeployment
from backend.schemas.chat import ChatMessage
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.tool import Category, Tool
from backend.services.logger import get_logger
Expand All @@ -31,140 +34,194 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
"""
# Choose the deployment model - validation already performed by request validator
deployment_model = get_deployment(kwargs.get("deployment_name"), **kwargs)
self.logger.info(f"Using deployment {deployment_model.__class__.__name__}")
print(f"Using deployment {deployment_model.__class__.__name__}")

if len(chat_request.tools) > 0 and len(chat_request.documents) > 0:
raise HTTPException(
status_code=400, detail="Both tools and documents cannot be provided."
)

# Handles managed tools.
# If a direct answer is generated instead of tool calls, the chat will not be called again
# Instead, the direct answer will be returned from the stream
should_call_chat = True
if kwargs.get("managed_tools", True):
# Generate Search Queries
chat_history = [message.to_dict() for message in chat_request.chat_history]

function_tools: list[Tool] = []
for tool in chat_request.tools:
available_tool = AVAILABLE_TOOLS.get(tool.name)
if available_tool and available_tool.category == Category.Function:
function_tools.append(Tool(**available_tool.model_dump()))

if len(function_tools) > 0:
tool_results = self.get_tool_results(
chat_request.message, function_tools, deployment_model
)
stream = self.handle_managed_tools(chat_request, deployment_model, **kwargs)

chat_request.tools = None
if kwargs.get("stream", True) is True:
return deployment_model.invoke_chat_stream(
chat_request,
tool_results=tool_results,
)
for event, generated_direct_answer in stream:
if generated_direct_answer:
should_call_chat = False
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
else:
return deployment_model.invoke_chat(
chat_request,
tool_results=tool_results,
)
chat_request = event

queries = deployment_model.invoke_search_queries(
chat_request.message, chat_history
)
self.logger.info(f"Search queries generated: {queries}")

# Fetch Documents
retrievers = self.get_retrievers(
kwargs.get("file_paths", []), [tool.name for tool in chat_request.tools]
)
self.logger.info(
f"Using retrievers: {[retriever.__class__.__name__ for retriever in retrievers]}"
)

# No search queries were generated but retrievers were selected, use user message as query
if len(queries) == 0 and len(retrievers) > 0:
queries = [chat_request.message]

all_documents = {}
# TODO: call in parallel and error handling
# TODO: merge with regular function tools after multihop implemented
for retriever in retrievers:
for query in queries:
parameters = {"query": query}
all_documents.setdefault(query, []).extend(
retriever.call(parameters)
)

# Collate Documents
documents = combine_documents(all_documents, deployment_model)
chat_request.documents = documents
chat_request.tools = []
break

# Generate Response
if kwargs.get("stream", True) is True:
return deployment_model.invoke_chat_stream(chat_request)
if should_call_chat:
if kwargs.get("stream", True) is True:
for event in deployment_model.invoke_chat_stream(chat_request):
yield event
else:
for event in deployment_model.invoke_chat(chat_request):
yield event
else:
return deployment_model.invoke_chat(chat_request)
for event in stream:
for e in event:
if type(e) == bool:
continue
yield e

def handle_managed_tools(
self,
chat_request: CohereChatRequest,
deployment_model: BaseDeployment,
**kwargs: Any,
) -> Generator[Any, None, None]:
"""
This function handles the managed tools.

def get_retrievers(
self, file_paths: list[str], req_tools: list[ToolName]
) -> list[Any]:
Args:
chat_request (CohereChatRequest): The chat request
deployment_model (BaseDeployment): The deployment model
**kwargs (Any): The keyword arguments

Returns:
Generator[Any, None, None]: The tool results or the chat response, and a boolean indicating if a direct answer was generated
"""
tools = [
Tool(**AVAILABLE_TOOLS.get(tool.name).model_dump())
for tool in chat_request.tools
if AVAILABLE_TOOLS.get(tool.name)
]

if not tools:
yield chat_request, False

for event, should_return in self.get_tool_results(
chat_request.message,
chat_request.chat_history,
tools,
kwargs.get("conversation_id"),
deployment_model,
kwargs,
):
if should_return:
yield event, True
else:
chat_request.tool_results = event
chat_request.tools = tools
return chat_request, False

def get_tool_results(
self,
message: str,
chat_history: List[Dict[str, str]],
tools: list[Tool],
conversation_id: str,
deployment_model: BaseDeployment,
kwargs: Any,
) -> Any:
tool_results = []
"""
Get retrievers for the required tools.
Invokes the tools and returns the results. If no tools calls are generated, it returns the chat response
as a direct answer.

Args:
file_paths (list[str]): File paths.
req_tools (list[str]): Required tools.
message (str): The message to be processed
chat_history (List[Dict[str, str]]): The chat history
tools (list[Tool]): The tools to be invoked
conversation_id (str): The conversation ID
deployment_model (BaseDeployment): The deployment model
kwargs (Any): The keyword arguments

Returns:
list[Any]: Retriever implementations.
Any: The tool results or the chat response, and a boolean indicating if a direct answer was generated

"""
retrievers = []

# If no tools are required, return an empty list
if not req_tools:
return retrievers

# Iterate through the required tools and check if they are available
# If so, add the implementation to the list of retrievers
# If not, raise an HTTPException
for tool_name in req_tools:
tool = AVAILABLE_TOOLS.get(tool_name)
if tool is None:
raise HTTPException(
status_code=404, detail=f"Tool {tool_name} not found."
)

# Check if the tool is visible, if not, skip it
if not tool.is_visible:
continue
# If the tool is Read_File or SearchFile, add the available files to the chat history
# so that the model knows what files are available
tool_names = [tool.name for tool in tools]
if ToolName.Read_File in tool_names or ToolName.Search_File in tool_names:
chat_history = self.add_files_to_chat_history(
chat_history,
conversation_id,
kwargs.get("session"),
kwargs.get("user_id"),
)

if tool.category == Category.FileLoader and file_paths is not None:
for file_path in file_paths:
retrievers.append(tool.implementation(file_path, **tool.kwargs))
elif tool.category != Category.FileLoader:
retrievers.append(tool.implementation(**tool.kwargs))
print(f"Invoking tools: {tools}")
stream = deployment_model.invoke_tools(
message, tools, chat_history=chat_history
)

return retrievers
# Invoke tools can return a direct answer or a stream of events with the tool calls
# If the second event is a tool call, the tools are invoked, and the results are returned
# Otherwise, the chat response is returned as a direct answer
stream_start_event = next(stream)

def get_tool_results(
self, message: str, tools: list[Tool], model: BaseDeployment
) -> list[dict[str, Any]]:
tool_results = []
tools_to_use = model.invoke_tools(message, tools)
if stream_start_event is None:
yield [], False

tool_calls = tools_to_use.tool_calls if tools_to_use.tool_calls else []
for tool_call in tool_calls:
tool = AVAILABLE_TOOLS.get(tool_call.name)
if not tool:
logging.warning(f"Couldn't find tool {tool_call.name}")
continue
second_event = next(stream)

outputs = tool.implementation().call(
parameters=tool_call.parameters,
)
if second_event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION:
lusmoura marked this conversation as resolved.
Show resolved Hide resolved
tool_calls = second_event["tool_calls"]

# If the tool returns a list of outputs, append each output to the tool_results list
# Otherwise, append the single output to the tool_results list
outputs = outputs if isinstance(outputs, list) else [outputs]
for output in outputs:
tool_results.append({"call": tool_call, "outputs": [output]})
print(f"Using tools: {tool_calls}")

return tool_results
# TODO: parallelize tool calls
for tool_call in tool_calls:
tool = AVAILABLE_TOOLS.get(tool_call.name)
if not tool:
logging.warning(f"Couldn't find tool {tool_call.name}")
continue

outputs = tool.implementation().call(
parameters=tool_call.parameters,
session=kwargs.get("session"),
model_deployment=deployment_model,
user_id=kwargs.get("user_id"),
)

# If the tool returns a list of outputs, append each output to the tool_results list
# Otherwise, append the single output to the tool_results list
outputs = outputs if isinstance(outputs, list) else [outputs]
for output in outputs:
tool_results.append({"call": tool_call, "outputs": [output]})

print(f"Tool results: {tool_results}")
yield tool_results, False

else:
yield stream_start_event, True
yield second_event, True
for event in stream:
yield event, True

def add_files_to_chat_history(
self,
chat_history: List[Dict[str, str]],
conversation_id: str,
session: Any,
user_id: str,
) -> List[Dict[str, str]]:
if session is None or conversation_id is None or len(conversation_id) == 0:
return chat_history

available_files = get_files_by_conversation_id(
session, conversation_id, user_id
)
files_message = "The user uploaded the following attachments:\n"

for file in available_files:
word_count = len(file.file_content.split())

# Use the first 25 words as the document preview in the preamble
num_words = min(25, word_count)
preview = " ".join(file.file_content.split()[:num_words])

files_message += f"Filename: {file.file_name}\nWord Count: {word_count} Preview: {preview}\n\n"

chat_history.append(ChatMessage(message=files_message, role="SYSTEM"))
return chat_history
Loading
Loading