Skip to content

Commit

Permalink
[backend] enforce agent update with user-id (#246)
Browse files Browse the repository at this point in the history
* updates

* remove client changes

* remove logs

* use better header user id check

* fix validators

* typo
  • Loading branch information
scott-cohere authored Jun 20, 2024
1 parent c22c659 commit c460c2a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 16 deletions.
11 changes: 5 additions & 6 deletions src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def get_agent_by_id(

if not agent:
raise HTTPException(
status_code=404,
status_code=400,
detail=f"Agent with ID: {agent_id} not found.",
)

Expand Down Expand Up @@ -127,11 +127,10 @@ async def update_agent(
HTTPException: If the agent with the given ID is not found.
"""
agent = agent_crud.get_agent_by_id(session, agent_id)

if not agent:
raise HTTPException(
status_code=404,
detail=f"Agent with ID: {agent_id} not found.",
status_code=400,
detail=f"Agent with ID {agent_id} not found.",
)

try:
Expand Down Expand Up @@ -164,8 +163,8 @@ async def delete_agent(

if not agent:
raise HTTPException(
status_code=404,
detail=f"Agent with ID: {agent_id} not found.",
status_code=400,
detail=f"Agent with ID {agent_id} not found.",
)

try:
Expand Down
19 changes: 17 additions & 2 deletions src/backend/services/request_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from backend.crud import agent as agent_crud
from backend.crud import conversation as conversation_crud
from backend.database_models.database import DBSessionDep
from backend.services.auth.utils import get_header_user_id


def validate_user_header(request: Request):
Expand Down Expand Up @@ -188,7 +189,7 @@ async def validate_create_agent_request(session: DBSessionDep, request: Request)
)


async def validate_update_agent_request(request: Request):
async def validate_update_agent_request(session: DBSessionDep, request: Request):
"""
Validate that the update agent request has valid tools, deployments, and compatible models.
Expand All @@ -198,8 +199,22 @@ async def validate_update_agent_request(request: Request):
Raises:
HTTPException: If the request does not have the appropriate values in the body
"""
body = await request.json()
agent_id = request.path_params.get("agent_id")
if not agent_id:
raise HTTPException(status_code=400, detail="Agent ID is required.")

agent = agent_crud.get_agent_by_id(session, agent_id)
if not agent:
raise HTTPException(
status_code=400, detail=f"Agent with ID {agent_id} not found."
)

if agent.user_id != get_header_user_id(request):
raise HTTPException(
status_code=401, detail=f"Agent with ID {agent_id} does not belong to user."
)

body = await request.json()
# Validate tools
tools = body.get("tools")
if tools:
Expand Down
42 changes: 34 additions & 8 deletions src/backend/tests/routers/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_get_agent(session_client: TestClient, session: Session) -> None:

def test_get_nonexistent_agent(session_client: TestClient, session: Session) -> None:
response = session_client.get("/v1/agents/456", headers={"User-Id": "123"})
assert response.status_code == 404
assert response.status_code == 400
assert response.json() == {"detail": "Agent with ID: 456 not found."}


Expand All @@ -255,6 +255,7 @@ def test_update_agent(session_client: TestClient, session: Session) -> None:
temperature=0.5,
model="command-r-plus",
deployment=ModelDeploymentName.CoherePlatform,
user_id="123",
)

request_json = {
Expand All @@ -268,7 +269,9 @@ def test_update_agent(session_client: TestClient, session: Session) -> None:
}

response = session_client.put(
f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "123"}
f"/v1/agents/{agent.id}",
json=request_json,
headers={"User-Id": "123"},
)

assert response.status_code == 200
Expand All @@ -292,6 +295,7 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N
model="command-r-plus",
deployment=ModelDeploymentName.CoherePlatform,
tools=[ToolName.Calculator],
user_id="123",
)

request_json = {
Expand All @@ -300,7 +304,9 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N
}

response = session_client.put(
f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "123"}
f"/v1/agents/{agent.id}",
json=request_json,
headers={"User-Id": "123"},
)
assert response.status_code == 200
updated_agent = response.json()
Expand All @@ -321,8 +327,23 @@ def test_update_nonexistent_agent(session_client: TestClient, session: Session)
response = session_client.put(
"/v1/agents/456", json=request_json, headers={"User-Id": "123"}
)
assert response.status_code == 404
assert response.json() == {"detail": "Agent with ID: 456 not found."}
assert response.status_code == 400
assert response.json() == {"detail": "Agent with ID 456 not found."}


def test_update_agent_wrong_user(session_client: TestClient, session: Session) -> None:
agent = get_factory("Agent", session).create(user_id="123")
request_json = {
"name": "updated name",
}

response = session_client.put(
f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "456"}
)
assert response.status_code == 401
assert response.json() == {
"detail": f"Agent with ID {agent.id} does not belong to user."
}


def test_update_agent_invalid_model(
Expand All @@ -336,6 +357,7 @@ def test_update_agent_invalid_model(
temperature=0.5,
model="command-r-plus",
deployment=ModelDeploymentName.CoherePlatform,
user_id="123",
)

request_json = {
Expand Down Expand Up @@ -363,6 +385,7 @@ def test_update_agent_invalid_deployment(
temperature=0.5,
model="command-r-plus",
deployment=ModelDeploymentName.CoherePlatform,
user_id="123",
)

request_json = {
Expand Down Expand Up @@ -390,6 +413,7 @@ def test_update_agent_model_without_deployment(
temperature=0.5,
model="command-r-plus",
deployment=ModelDeploymentName.CoherePlatform,
user_id="123",
)

request_json = {
Expand All @@ -416,6 +440,7 @@ def test_update_agent_deployment_without_model(
temperature=0.5,
model="command-r-plus",
deployment=ModelDeploymentName.CoherePlatform,
user_id="123",
)

request_json = {
Expand All @@ -442,6 +467,7 @@ def test_update_agent_invalid_tool(
temperature=0.5,
model="command-r-plus",
deployment=ModelDeploymentName.CoherePlatform,
user_id="123",
)

request_json = {
Expand All @@ -458,7 +484,7 @@ def test_update_agent_invalid_tool(


def test_delete_agent(session_client: TestClient, session: Session) -> None:
agent = get_factory("Agent", session).create()
agent = get_factory("Agent", session).create(user_id="123")
response = session_client.delete(
f"/v1/agents/{agent.id}", headers={"User-Id": "123"}
)
Expand All @@ -473,5 +499,5 @@ def test_fail_delete_nonexistent_agent(
session_client: TestClient, session: Session
) -> None:
response = session_client.delete("/v1/agents/456", headers={"User-Id": "123"})
assert response.status_code == 404
assert response.json() == {"detail": "Agent with ID: 456 not found."}
assert response.status_code == 400
assert response.json() == {"detail": "Agent with ID 456 not found."}

0 comments on commit c460c2a

Please sign in to comment.