Skip to content

Commit

Permalink
[backend] agent chat fixes (#263)
Browse files Browse the repository at this point in the history
fixes
  • Loading branch information
scott-cohere authored Jun 21, 2024
1 parent 9184edf commit eb34bf8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class ToolName(StrEnum):
Wiki_Retriever_LangChain = "wikipedia"
Search_File = "search_file"
Read_File = "read_document"
Python_Interpreter = "python_interpreter"
Python_Interpreter = "toolkit_python_interpreter"
Calculator = "calculator"
Tavily_Internet_Search = "internet_search"
Tavily_Internet_Search = "web_search"


ALL_TOOLS = {
Expand Down
9 changes: 9 additions & 0 deletions src/backend/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ def process_chat(
status_code=404, detail=f"Agent with ID {agent_id} not found."
)

tool_names = [tool.name for tool in chat_request.tools]
if chat_request.tools:
for tool in chat_request.tools:
if tool.name not in agent.tools:
raise HTTPException(
status_code=400,
detail=f"Tool {tool.name} not found in agent {agent.id}",
)

# Set the agent settings in the chat request
chat_request.preamble = agent.preamble
chat_request.tools = [Tool(name=tool) for tool in agent.tools]
Expand Down
21 changes: 21 additions & 0 deletions src/backend/tests/routers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,27 @@ def test_streaming_chat_with_existing_conversation_from_other_agent(
}


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
def test_streaming_chat_with_tools_not_in_agent_tools(
session_client_chat: TestClient, session_chat: Session, user: User
):
agent = get_factory("Agent", session_chat).create(user_id=user.id, tools=[])
response = session_client_chat.post(
"/v1/chat-stream",
headers={
"User-Id": user.id,
"Deployment-Name": ModelDeploymentName.CoherePlatform,
},
params={"agent_id": agent.id},
json={"message": "Hello", "max_tokens": 10, "tools": [{"name": "web_search"}]},
)

assert response.status_code == 400
assert response.json() == {
"detail": f"Tool web_search not found in agent {agent.id}"
}


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
def test_streaming_existing_chat(
session_client_chat: TestClient, session_chat: Session, user: User
Expand Down

0 comments on commit eb34bf8

Please sign in to comment.