From eb34bf85fba766563a420fbd45c9d7c0f26dd88f Mon Sep 17 00:00:00 2001 From: Scott <146760070+scott-cohere@users.noreply.github.com> Date: Fri, 21 Jun 2024 11:52:33 -0400 Subject: [PATCH] [backend] agent chat fixes (#263) fixes --- src/backend/config/tools.py | 4 ++-- src/backend/services/chat.py | 9 +++++++++ src/backend/tests/routers/test_chat.py | 21 +++++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 5fde08cac2..97c4e5c761 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -29,9 +29,9 @@ class ToolName(StrEnum): Wiki_Retriever_LangChain = "wikipedia" Search_File = "search_file" Read_File = "read_document" - Python_Interpreter = "python_interpreter" + Python_Interpreter = "toolkit_python_interpreter" Calculator = "calculator" - Tavily_Internet_Search = "internet_search" + Tavily_Internet_Search = "web_search" ALL_TOOLS = { diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index 3b0306a54c..7956562acf 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -85,6 +85,15 @@ def process_chat( status_code=404, detail=f"Agent with ID {agent_id} not found." ) + tool_names = [tool.name for tool in chat_request.tools] + if chat_request.tools: + for tool in chat_request.tools: + if tool.name not in agent.tools: + raise HTTPException( + status_code=400, + detail=f"Tool {tool.name} not found in agent {agent.id}", + ) + # Set the agent settings in the chat request chat_request.preamble = agent.preamble chat_request.tools = [Tool(name=tool) for tool in agent.tools] diff --git a/src/backend/tests/routers/test_chat.py b/src/backend/tests/routers/test_chat.py index fc675f7c4e..b5c09658b6 100644 --- a/src/backend/tests/routers/test_chat.py +++ b/src/backend/tests/routers/test_chat.py @@ -166,6 +166,27 @@ def test_streaming_chat_with_existing_conversation_from_other_agent( } +@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") +def test_streaming_chat_with_tools_not_in_agent_tools( + session_client_chat: TestClient, session_chat: Session, user: User +): + agent = get_factory("Agent", session_chat).create(user_id=user.id, tools=[]) + response = session_client_chat.post( + "/v1/chat-stream", + headers={ + "User-Id": user.id, + "Deployment-Name": ModelDeploymentName.CoherePlatform, + }, + params={"agent_id": agent.id}, + json={"message": "Hello", "max_tokens": 10, "tools": [{"name": "web_search"}]}, + ) + + assert response.status_code == 400 + assert response.json() == { + "detail": f"Tool web_search not found in agent {agent.id}" + } + + @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_existing_chat( session_client_chat: TestClient, session_chat: Session, user: User