Skip to content

Commit b90da29

Browse files
authored
Merge pull request #1459 from MohanLaksh/test_cases
Test cases
2 parents f96b701 + 4f0640a commit b90da29

File tree

8 files changed

+1594
-0
lines changed

8 files changed

+1594
-0
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# -*- coding: utf-8 -*-
2+
"""Location: ./tests/unit/mcpgateway/instrumentation/test_sqlalchemy.py
3+
Copyright 2025
4+
SPDX-License-Identifier: Apache-2.0
5+
Authors: Mihai Criveti
6+
7+
Unit tests for sqlalchemy instrumentation.
8+
"""
9+
import pytest
10+
import threading
11+
import queue
12+
import time
13+
from unittest.mock import MagicMock, patch
14+
15+
import mcpgateway.instrumentation.sqlalchemy as sa
16+
17+
18+
@pytest.fixture(autouse=True)
19+
def reset_globals():
20+
sa._query_tracking.clear()
21+
sa._instrumentation_context = threading.local()
22+
sa._span_queue = queue.Queue(maxsize=2)
23+
sa._shutdown_event.clear()
24+
sa._span_writer_thread = None
25+
yield
26+
sa._query_tracking.clear()
27+
28+
29+
def test_before_cursor_execute_stores_tracking():
30+
conn = MagicMock()
31+
sa._before_cursor_execute(conn, None, "SELECT * FROM test", {"id": 1}, None, False)
32+
conn_id = id(conn)
33+
assert conn_id in sa._query_tracking
34+
tracking = sa._query_tracking[conn_id]
35+
assert tracking["statement"] == "SELECT * FROM test"
36+
assert "start_time" in tracking
37+
38+
39+
def test_after_cursor_execute_no_tracking():
40+
conn = MagicMock()
41+
sa._after_cursor_execute(conn, None, "SELECT * FROM test", None, None, False)
42+
# Should not raise or enqueue anything
43+
assert sa._span_queue.empty()
44+
45+
46+
def test_after_cursor_execute_inside_span_creation_skips():
47+
conn = MagicMock()
48+
conn_id = id(conn)
49+
sa._query_tracking[conn_id] = {"start_time": time.time(), "statement": "SELECT 1", "parameters": None, "executemany": False}
50+
sa._instrumentation_context.inside_span_creation = True
51+
sa._after_cursor_execute(conn, MagicMock(), "SELECT 1", None, None, False)
52+
assert sa._span_queue.empty()
53+
54+
55+
def test_after_cursor_execute_observability_table_skips(caplog):
56+
import logging
57+
caplog.set_level(logging.DEBUG)
58+
sa.logger.setLevel(logging.DEBUG)
59+
sa.logger.propagate = True
60+
conn = MagicMock()
61+
conn_id = id(conn)
62+
sa._query_tracking[conn_id] = {"start_time": time.time(), "statement": "SELECT * FROM observability_spans", "parameters": None, "executemany": False}
63+
sa._after_cursor_execute(conn, MagicMock(), "SELECT * FROM observability_spans", None, None, False)
64+
assert "Skipping instrumentation" in caplog.text
65+
66+
67+
def test_after_cursor_execute_with_trace_id_calls_create_query_span():
68+
conn = MagicMock()
69+
conn.info = {"trace_id": "abc123"}
70+
conn_id = id(conn)
71+
sa._query_tracking[conn_id] = {"start_time": time.time(), "statement": "SELECT * FROM users", "parameters": None, "executemany": False}
72+
with patch.object(sa, "_create_query_span") as mock_create:
73+
sa._after_cursor_execute(conn, MagicMock(rowcount=5), "SELECT * FROM users", None, None, False)
74+
mock_create.assert_called_once()
75+
args, kwargs = mock_create.call_args
76+
assert kwargs["trace_id"] == "abc123"
77+
78+
79+
def test_after_cursor_execute_without_trace_id_logs_debug(caplog):
80+
import logging
81+
caplog.set_level(logging.DEBUG)
82+
sa.logger.setLevel(logging.DEBUG)
83+
sa.logger.propagate = True
84+
conn = MagicMock()
85+
conn.info = {}
86+
conn_id = id(conn)
87+
sa._query_tracking[conn_id] = {"start_time": time.time(), "statement": "SELECT * FROM users", "parameters": None, "executemany": False}
88+
sa._after_cursor_execute(conn, MagicMock(rowcount=5), "SELECT * FROM users", None, None, False)
89+
assert "Query executed without trace context" in caplog.text
90+
91+
92+
def test_create_query_span_enqueues_successfully(caplog):
93+
import logging
94+
caplog.set_level(logging.DEBUG)
95+
sa.logger.setLevel(logging.DEBUG)
96+
sa.logger.propagate = True
97+
sa._create_query_span("trace123", "SELECT * FROM test", 10.0, 5, False)
98+
assert not sa._span_queue.empty()
99+
assert "Enqueued span" in caplog.text
100+
101+
102+
def test_create_query_span_queue_full_warns(caplog):
103+
sa._span_queue = queue.Queue(maxsize=1)
104+
sa._span_queue.put({"dummy": "data"})
105+
sa._create_query_span("trace123", "SELECT * FROM test", 10.0, 5, False)
106+
assert "Span queue is full" in caplog.text
107+
108+
109+
def test_create_query_span_exception_does_not_raise(caplog):
110+
import logging
111+
caplog.set_level(logging.DEBUG)
112+
sa.logger.setLevel(logging.DEBUG)
113+
sa.logger.propagate = True
114+
with patch("mcpgateway.instrumentation.sqlalchemy._span_queue.put_nowait", side_effect=Exception("fail")):
115+
sa._create_query_span("trace123", "SELECT * FROM test", 10.0, 5, False)
116+
assert "Failed to enqueue query span" in caplog.text
117+
118+
119+
def test_write_span_to_db_success():
120+
span_data = {
121+
"trace_id": "t1",
122+
"name": "db.query.select",
123+
"kind": "client",
124+
"resource_type": "database",
125+
"resource_name": "SELECT",
126+
"start_attributes": {},
127+
"end_attributes": {},
128+
"status": "ok",
129+
"duration_ms": 10.0,
130+
"row_count": 1,
131+
}
132+
mock_service = MagicMock()
133+
mock_db = MagicMock()
134+
mock_span = MagicMock()
135+
mock_db.query().filter_by().first.return_value = mock_span
136+
with patch("mcpgateway.services.observability_service.ObservabilityService", return_value=mock_service), \
137+
patch("mcpgateway.db.SessionLocal", return_value=mock_db), \
138+
patch("mcpgateway.db.ObservabilitySpan", MagicMock()):
139+
sa._write_span_to_db(span_data)
140+
mock_service.start_span.assert_called_once()
141+
mock_service.end_span.assert_called_once()
142+
mock_db.commit.assert_called_once()
143+
144+
145+
def test_write_span_to_db_exception_logs_warning(caplog):
146+
with patch("mcpgateway.services.observability_service.ObservabilityService", side_effect=Exception("fail")):
147+
sa._write_span_to_db({})
148+
assert "Failed to write query span" in caplog.text
149+
150+
151+
def test_span_writer_worker_processes_queue(monkeypatch):
152+
span_data = {"trace_id": "t1", "name": "db.query.select", "kind": "client", "resource_type": "database", "resource_name": "SELECT", "start_attributes": {}, "end_attributes": {}, "status": "ok", "duration_ms": 10.0}
153+
sa._span_queue.put(span_data)
154+
mock_write = MagicMock()
155+
monkeypatch.setattr(sa, "_write_span_to_db", mock_write)
156+
thread = threading.Thread(target=lambda: (time.sleep(0.1), sa._shutdown_event.set()))
157+
thread.start()
158+
sa._span_writer_worker()
159+
mock_write.assert_called_once()
160+
161+
162+
def test_instrument_sqlalchemy_starts_thread_and_registers_events():
163+
engine = MagicMock()
164+
with patch("mcpgateway.instrumentation.sqlalchemy.event.listen") as mock_listen, \
165+
patch("mcpgateway.instrumentation.sqlalchemy.threading.Thread") as mock_thread:
166+
mock_thread.return_value = MagicMock(is_alive=lambda: False)
167+
sa.instrument_sqlalchemy(engine)
168+
assert mock_listen.call_count == 2
169+
mock_thread.assert_called_once()
170+
171+
172+
def test_attach_trace_to_session_sets_trace_id():
173+
session = MagicMock()
174+
connection = MagicMock()
175+
connection.info = {}
176+
session.bind = True
177+
session.connection.return_value = connection
178+
sa.attach_trace_to_session(session, "trace123")
179+
assert connection.info["trace_id"] == "trace123"
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# -*- coding: utf-8 -*-
2+
"""Location: ./tests/unit/mcpgateway/middleware/test_auth_middleware.py
3+
Copyright 2025
4+
SPDX-License-Identifier: Apache-2.0
5+
Authors: Mihai Criveti
6+
7+
Unit tests for auth middleware.
8+
"""
9+
10+
import pytest
11+
from unittest.mock import AsyncMock, MagicMock, patch
12+
from starlette.requests import Request
13+
from starlette.responses import Response
14+
from mcpgateway.middleware.auth_middleware import AuthContextMiddleware
15+
16+
17+
@pytest.mark.asyncio
18+
async def test_health_and_static_paths_skipped(monkeypatch):
19+
"""Ensure middleware skips health and static paths."""
20+
middleware = AuthContextMiddleware(app=AsyncMock())
21+
call_next = AsyncMock(return_value=Response("ok"))
22+
23+
for path in ["/health", "/healthz", "/ready", "/metrics", "/static/logo.png"]:
24+
request = MagicMock(spec=Request)
25+
request.url.path = path
26+
response = await middleware.dispatch(request, call_next)
27+
call_next.assert_awaited_once_with(request)
28+
assert response.status_code == 200
29+
call_next.reset_mock()
30+
31+
32+
@pytest.mark.asyncio
33+
async def test_no_token_continues(monkeypatch):
34+
"""If no token found, request continues without user context."""
35+
middleware = AuthContextMiddleware(app=AsyncMock())
36+
call_next = AsyncMock(return_value=Response("ok"))
37+
request = MagicMock(spec=Request)
38+
request.url.path = "/api/data"
39+
request.cookies = {}
40+
request.headers = {}
41+
42+
response = await middleware.dispatch(request, call_next)
43+
call_next.assert_awaited_once_with(request)
44+
assert response.status_code == 200
45+
# request.state is a MagicMock, so user may exist as mock attribute
46+
# Instead, ensure user was never set explicitly
47+
# Ensure user attribute was not explicitly set (MagicMock defaults to having attributes)
48+
assert "user" not in request.state.__dict__
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_token_from_cookie(monkeypatch):
53+
"""Token extracted from cookie triggers authentication."""
54+
middleware = AuthContextMiddleware(app=AsyncMock())
55+
call_next = AsyncMock(return_value=Response("ok"))
56+
request = MagicMock(spec=Request)
57+
request.url.path = "/api/data"
58+
request.cookies = {"jwt_token": "cookie_token"}
59+
request.headers = {}
60+
61+
mock_user = MagicMock()
62+
mock_user.email = "user@example.com"
63+
64+
with patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=MagicMock()) as mock_session, \
65+
patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(return_value=mock_user)):
66+
response = await middleware.dispatch(request, call_next)
67+
68+
call_next.assert_awaited_once_with(request)
69+
assert response.status_code == 200
70+
assert request.state.user.email == "user@example.com"
71+
mock_session.return_value.close.assert_called_once()
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_token_from_header(monkeypatch):
76+
"""Token extracted from Authorization header triggers authentication."""
77+
middleware = AuthContextMiddleware(app=AsyncMock())
78+
call_next = AsyncMock(return_value=Response("ok"))
79+
request = MagicMock(spec=Request)
80+
request.url.path = "/api/data"
81+
request.cookies = {}
82+
request.headers = {"authorization": "Bearer header_token"}
83+
84+
mock_user = MagicMock()
85+
mock_user.email = "header@example.com"
86+
87+
with patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=MagicMock()) as mock_session, \
88+
patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(return_value=mock_user)):
89+
response = await middleware.dispatch(request, call_next)
90+
91+
call_next.assert_awaited_once_with(request)
92+
assert response.status_code == 200
93+
assert request.state.user.email == "header@example.com"
94+
mock_session.return_value.close.assert_called_once()
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_authentication_failure(monkeypatch):
99+
"""Authentication failure should log and continue without user context."""
100+
middleware = AuthContextMiddleware(app=AsyncMock())
101+
call_next = AsyncMock(return_value=Response("ok"))
102+
request = MagicMock(spec=Request)
103+
request.url.path = "/api/data"
104+
request.cookies = {"jwt_token": "bad_token"}
105+
request.headers = {}
106+
107+
with patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=MagicMock()) as mock_session, \
108+
patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=Exception("Invalid token"))), \
109+
patch("mcpgateway.middleware.auth_middleware.logger") as mock_logger:
110+
response = await middleware.dispatch(request, call_next)
111+
112+
call_next.assert_awaited_once_with(request)
113+
assert response.status_code == 200
114+
# Ensure user attribute was not explicitly set (MagicMock defaults to having attributes)
115+
assert "user" not in request.state.__dict__
116+
# Verify log message contains failure text
117+
logged_messages = [args[0] for args, _ in mock_logger.info.call_args_list]
118+
assert any("✗ Auth context extraction failed" in msg for msg in logged_messages)
119+
mock_session.return_value.close.assert_called_once()
120+
121+
122+
@pytest.mark.asyncio
123+
async def test_db_close_exception(monkeypatch):
124+
"""Ensure db.close exceptions are logged but do not break flow."""
125+
middleware = AuthContextMiddleware(app=AsyncMock())
126+
call_next = AsyncMock(return_value=Response("ok"))
127+
request = MagicMock(spec=Request)
128+
request.url.path = "/api/data"
129+
request.cookies = {"jwt_token": "token"}
130+
request.headers = {}
131+
132+
mock_user = MagicMock()
133+
mock_user.email = "user@example.com"
134+
mock_db = MagicMock()
135+
mock_db.close.side_effect = Exception("close error")
136+
137+
with patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=mock_db), \
138+
patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(return_value=mock_user)), \
139+
patch("mcpgateway.middleware.auth_middleware.logger") as mock_logger:
140+
response = await middleware.dispatch(request, call_next)
141+
142+
call_next.assert_awaited_once_with(request)
143+
assert response.status_code == 200
144+
# Verify log message contains close error text
145+
logged_debugs = [args[0] for args, _ in mock_logger.debug.call_args_list]
146+
assert any("Failed to close database session" in msg for msg in logged_debugs)

0 commit comments

Comments
 (0)