Skip to content

Commit 55df373

Browse files
Fixing test cases for pooling
Signed-off-by: Mohan Lakshmaiah <mohalaks@in.ibm.com>
1 parent ae2383b commit 55df373

File tree

5 files changed

+53
-31
lines changed

5 files changed

+53
-31
lines changed

mcpgateway/cache/session_registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,12 @@ def _db_cleanup() -> int:
12351235
transport = session_data['transport']
12361236
try:
12371237
if not await transport.is_connected():
1238-
await self.remove_session(session_id)
1238+
if pooled:
1239+
# For pooled sessions, remove from registry but don't disconnect
1240+
await self.remove_session_from_registry_only(session_id)
1241+
else:
1242+
# For non-pooled sessions, full removal with disconnect
1243+
await self.remove_session(session_id)
12391244
continue
12401245

12411246
# Refresh session in database

mcpgateway/main.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3967,13 +3967,13 @@ async def websocket_endpoint(websocket: WebSocket):
39673967
Args:
39683968
websocket: The WebSocket connection instance.
39693969
"""
3970+
transport = None
3971+
proxy_user = None
3972+
token = None
39703973
try:
39713974
# Authenticate WebSocket connection
39723975
if settings.mcp_client_auth_enabled or settings.auth_required:
39733976
# Extract auth from query params or headers
3974-
token = None
3975-
proxy_user = None
3976-
39773977
# Try to get token from query parameter
39783978
if "token" in websocket.query_params:
39793979
token = websocket.query_params["token"]
@@ -4001,9 +4001,6 @@ async def websocket_endpoint(websocket: WebSocket):
40014001
await websocket.close(code=1008, reason="Invalid authentication")
40024002
return
40034003

4004-
await websocket.accept()
4005-
logger.info("WebSocket connection accepted")
4006-
40074004
# Identify user and server for pooling key
40084005
user_id = proxy_user or "anonymous"
40094006
server_id = websocket.query_params.get("server_id", "default-server")
@@ -4013,21 +4010,16 @@ async def websocket_endpoint(websocket: WebSocket):
40134010
transport = None
40144011
if await should_use_session_pooling(server_id):
40154012
# Use existing or create pooled session
4016-
# Note: WebSocket transport needs the actual WebSocket object, so pooling works differently
40174013
transport = WebSocketTransport(websocket, pooled=True, pool_key=f"{user_id}:{server_id}")
40184014
await transport.connect()
40194015
await session_registry.add_session(transport.session_id, transport, pooled=True)
4020-
logger.info(
4021-
f"Created pooled WebSocket session for user={user_id}, server={server_id}, session={transport.session_id}"
4022-
)
4016+
logger.info(f"Created pooled WebSocket session for user={user_id}, server={server_id}, session={transport.session_id}")
40234017
else:
40244018
# Fallback: create new transport
40254019
transport = WebSocketTransport(websocket)
40264020
await transport.connect()
40274021
await session_registry.add_session(transport.session_id, transport)
4028-
logger.info(
4029-
f"Created new WebSocket session for user={user_id}, server={server_id}, session={transport.session_id}"
4030-
)
4022+
logger.info(f"Created new WebSocket session for user={user_id}, server={server_id}, session={transport.session_id}")
40314023

40324024
while True:
40334025
try:
@@ -4058,10 +4050,8 @@ async def websocket_endpoint(websocket: WebSocket):
40584050
break
40594051
except WebSocketDisconnect:
40604052
logger.info("WebSocket disconnected")
4061-
# Cleanup pooled session if needed
40624053
if transport and hasattr(transport, '_pooled') and transport._pooled:
4063-
# For pooled sessions, we don't immediately remove from registry
4064-
# They get cleaned up by the pool's background task
4054+
# For pooled sessions, we don't immediately remove from registry. They get cleaned up by the pool's background task
40654055
pass
40664056
else:
40674057
# For non-pooled sessions, remove from registry

mcpgateway/services/server_service.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import asyncio
1616
from datetime import datetime, timezone
1717
from typing import Any, AsyncGenerator, Dict, List, Optional
18+
import builtins
1819

1920
# Third-Party
2021
import httpx
@@ -51,7 +52,7 @@ class ServerNotFoundError(ServerError):
5152
"""Raised when a requested server is not found."""
5253

5354

54-
class PermissionError(ServerError):
55+
class PermissionError(builtins.PermissionError, ServerError):
5556
"""Raised when a user does not have permission to perform an action on a server."""
5657

5758

@@ -647,7 +648,7 @@ async def get_server(self, db: Session, server_id: str) -> ServerRead:
647648
raise ServerNotFoundError(f"Server not found: {server_id}")
648649

649650
try:
650-
effective_strategy = await self.get_session_strategy(db, server_id)
651+
effective_strategy = await self.get_session_strategy(db, server_id, server=server)
651652
logger.debug(f"Server {server_id} effective session strategy: {effective_strategy}")
652653
except Exception as e:
653654
logger.warning(f"Could not determine session strategy for server {server_id}: {e}")
@@ -998,13 +999,14 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[
998999
if not server:
9991000
raise ServerNotFoundError(f"Server not found: {server_id}")
10001001

1001-
# Check ownership if user_email provided
1002+
# Always perform ownership check if user_email is provided
10021003
if user_email:
10031004
# First-Party
10041005
from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
10051006

10061007
permission_service = PermissionService(db)
1007-
if not await permission_service.check_resource_ownership(user_email, server):
1008+
can_delete = await permission_service.check_resource_ownership(user_email, server)
1009+
if not can_delete:
10081010
raise PermissionError("Only the owner can delete this server")
10091011

10101012
server_info = {"id": server.id, "name": server.name}
@@ -1147,7 +1149,7 @@ async def _notify_server_deleted(self, server_info: Dict[str, Any]) -> None:
11471149
}
11481150
await self._publish_event(event)
11491151

1150-
async def get_session_strategy(self, db: Session, server_id: str) -> str:
1152+
async def get_session_strategy(self, db: Session, server_id: str, server: Optional[DbServer] = None) -> str:
11511153
"""Determine effective session strategy for server.
11521154
11531155
This method resolves the session strategy for a specific server, taking into account:
@@ -1180,7 +1182,12 @@ async def get_session_strategy(self, db: Session, server_id: str) -> str:
11801182
>>> result == settings.session_pool_strategy
11811183
True
11821184
"""
1183-
server = db.get(DbServer, server_id)
1185+
# server = db.get(DbServer, server_id)
1186+
# if not server:
1187+
# raise ServerNotFoundError(f"Server not found: {server_id}")
1188+
# Allow callers to pass an already-loaded server object to avoid repeated DB lookups.
1189+
if server is None:
1190+
server = db.get(DbServer, server_id)
11841191
if not server:
11851192
raise ServerNotFoundError(f"Server not found: {server_id}")
11861193

mcpgateway/transports/websocket_transport.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ async def connect(self) -> None:
128128
>>> mock_ws.accept.called
129129
True
130130
"""
131-
await self._websocket.accept()
132-
self._connected = True
131+
if not self._connected:
132+
await self._websocket.accept()
133+
self._connected = True
133134

134135
# Start ping task
135136
if settings.websocket_ping_interval > 0:

tests/unit/mcpgateway/services/test_resource_ownership.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,37 @@ async def test_delete_server_non_owner_denied(self, server_service, mock_db_sess
202202
mock_server = MagicMock(spec=Server)
203203
mock_server.id = "server-1"
204204
mock_server.owner_email = "owner@example.com"
205-
205+
mock_server.team_id = None
206+
mock_server.visibility = "private"
207+
mock_server.name = "Test Server"
208+
mock_server.session_pooling_strategy = "inherit"
209+
206210
mock_db_session.get.return_value = mock_server
211+
mock_db_session.rollback = MagicMock()
212+
mock_db_session.commit = MagicMock()
213+
mock_db_session.delete = MagicMock()
207214

208215
with patch("mcpgateway.services.permission_service.PermissionService") as mock_perm_service_class:
209216
mock_perm_service = mock_perm_service_class.return_value
210217
mock_perm_service.check_resource_ownership = AsyncMock(return_value=False)
211-
212-
with pytest.raises(PermissionError, match="Only the owner can delete this server"):
213-
await server_service.delete_server(mock_db_session, "server-1", user_email="other@example.com")
214-
218+
server_service._notify_server_deleted = AsyncMock()
219+
220+
try:
221+
with pytest.raises(PermissionError, match="Only the owner can delete this server"):
222+
await server_service.delete_server(mock_db_session, "server-1", user_email="other@example.com")
223+
except AssertionError as e:
224+
# This will help us understand if we're getting a different error message
225+
print(f"Test failed because: {e}")
226+
raise
227+
except Exception as e:
228+
print(f"Unexpected error: {e}")
229+
raise
230+
231+
# Verify the expectations
215232
mock_db_session.delete.assert_not_called()
216-
233+
mock_db_session.rollback.assert_called_once()
234+
mock_perm_service.check_resource_ownership.assert_called_once_with("other@example.com", mock_server)
235+
mock_db_session.commit.assert_not_called()
217236

218237
class TestToolServiceOwnership:
219238
"""Test ownership checks in ToolService delete/update methods."""

0 commit comments

Comments
 (0)