Skip to content

Commit 165eb0f

Browse files
authored
Store passthrough headers config in database (IBM#736)
* Store passthrough config in db Signed-off-by: Madhav Kandukuri <madhav165@gmail.com> * Fix failing test Signed-off-by: Madhav Kandukuri <madhav165@gmail.com> * Add tests and fix doctest Signed-off-by: Madhav Kandukuri <madhav165@gmail.com> * Fix test Signed-off-by: Madhav Kandukuri <madhav165@gmail.com> * Fix doctest Signed-off-by: Madhav Kandukuri <madhav165@gmail.com>
1 parent 65e7ee4 commit 165eb0f

File tree

4 files changed

+205
-21
lines changed

4 files changed

+205
-21
lines changed

mcpgateway/admin.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from mcpgateway.services.tool_service import ToolError, ToolNotFoundError, ToolService
7272
from mcpgateway.utils.create_jwt_token import get_jwt_token
7373
from mcpgateway.utils.error_formatter import ErrorFormatter
74+
from mcpgateway.utils.passthrough_headers import PassthroughHeadersError
7475
from mcpgateway.utils.retry_manager import ResilientHttpClient
7576
from mcpgateway.utils.verify_credentials import require_auth, require_basic_auth
7677

@@ -176,14 +177,12 @@ async def wrapper(*args, request: Request = None, **kwargs):
176177
@admin_router.get("/config/passthrough-headers", response_model=GlobalConfigRead)
177178
@rate_limit(requests_per_minute=30) # Lower limit for config endpoints
178179
async def get_global_passthrough_headers(
179-
request: Request, # pylint: disable=unused-argument
180180
db: Session = Depends(get_db),
181181
_user: str = Depends(require_auth),
182182
) -> GlobalConfigRead:
183183
"""Get the global passthrough headers configuration.
184184
185185
Args:
186-
request: HTTP request object
187186
db: Database session
188187
_user: Authenticated user
189188
@@ -201,9 +200,11 @@ async def get_global_passthrough_headers(
201200
True
202201
"""
203202
config = db.query(GlobalConfig).first()
204-
if not config:
205-
config = GlobalConfig()
206-
return GlobalConfigRead(passthrough_headers=config.passthrough_headers)
203+
if config:
204+
passthrough_headers = config.passthrough_headers
205+
else:
206+
passthrough_headers = []
207+
return GlobalConfigRead(passthrough_headers=passthrough_headers)
207208

208209

209210
@admin_router.put("/config/passthrough-headers", response_model=GlobalConfigRead)
@@ -222,6 +223,9 @@ async def update_global_passthrough_headers(
222223
db: Database session
223224
_user: Authenticated user
224225
226+
Raises:
227+
HTTPException: If there is a conflict or validation error
228+
225229
Returns:
226230
GlobalConfigRead: The updated configuration
227231
@@ -235,14 +239,25 @@ async def update_global_passthrough_headers(
235239
>>> inspect.iscoroutinefunction(update_global_passthrough_headers)
236240
True
237241
"""
238-
config = db.query(GlobalConfig).first()
239-
if not config:
240-
config = GlobalConfig(passthrough_headers=config_update.passthrough_headers)
241-
db.add(config)
242-
else:
243-
config.passthrough_headers = config_update.passthrough_headers
244-
db.commit()
245-
return GlobalConfigRead(passthrough_headers=config.passthrough_headers)
242+
try:
243+
config = db.query(GlobalConfig).first()
244+
if not config:
245+
config = GlobalConfig(passthrough_headers=config_update.passthrough_headers)
246+
db.add(config)
247+
else:
248+
config.passthrough_headers = config_update.passthrough_headers
249+
db.commit()
250+
return GlobalConfigRead(passthrough_headers=config.passthrough_headers)
251+
except Exception as e:
252+
if isinstance(e, IntegrityError):
253+
db.rollback()
254+
raise HTTPException(status_code=409, detail="Passthrough headers conflict")
255+
if isinstance(e, ValidationError):
256+
db.rollback()
257+
raise HTTPException(status_code=422, detail="Invalid passthrough headers format")
258+
if isinstance(e, PassthroughHeadersError):
259+
db.rollback()
260+
raise HTTPException(status_code=500, detail=str(e))
246261

247262

248263
@admin_router.get("/servers", response_model=List[ServerRead])

mcpgateway/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
from mcpgateway.transports.streamablehttp_transport import SessionManagerWrapper, streamable_http_auth
9898
from mcpgateway.utils.db_isready import wait_for_db_ready
9999
from mcpgateway.utils.error_formatter import ErrorFormatter
100+
from mcpgateway.utils.passthrough_headers import set_global_passthrough_headers
100101
from mcpgateway.utils.redis_isready import wait_for_redis_ready
101102
from mcpgateway.utils.retry_manager import ResilientHttpClient
102103
from mcpgateway.utils.verify_credentials import require_auth, require_auth_override, verify_jwt_token
@@ -191,6 +192,14 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
191192
await plugin_manager.initialize()
192193
logger.info(f"Plugin manager initialized with {plugin_manager.plugin_count} plugins")
193194

195+
if settings.enable_header_passthrough:
196+
db_gen = get_db()
197+
db = next(db_gen) # pylint: disable=stop-iteration-return
198+
try:
199+
await set_global_passthrough_headers(db)
200+
finally:
201+
db.close()
202+
194203
await tool_service.initialize()
195204
await resource_service.initialize()
196205
await prompt_service.initialize()

mcpgateway/utils/passthrough_headers.py

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@
5050
MAX_HEADER_VALUE_LENGTH = 4096
5151

5252

53+
class PassthroughHeadersError(Exception):
54+
"""Base class for passthrough headers-related errors.
55+
56+
Examples:
57+
>>> error = PassthroughHeadersError("Test error")
58+
>>> str(error)
59+
'Test error'
60+
>>> isinstance(error, Exception)
61+
True
62+
"""
63+
64+
5365
def sanitize_header_value(value: str, max_length: int = MAX_HEADER_VALUE_LENGTH) -> str:
5466
"""Sanitize header value for security.
5567
@@ -131,12 +143,15 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[
131143
132144
Examples:
133145
Feature disabled by default (secure by default):
134-
>>> from unittest.mock import Mock
135-
>>> mock_db = Mock()
136-
>>> request_headers = {"x-tenant-id": "should-be-ignored"}
137-
>>> base_headers = {"Content-Type": "application/json"}
138-
>>> result = get_passthrough_headers(request_headers, base_headers, mock_db)
139-
>>> result
146+
>>> from unittest.mock import Mock, patch
147+
>>> with patch(__name__ + ".settings") as mock_settings:
148+
... mock_settings.enable_header_passthrough = False
149+
... mock_settings.default_passthrough_headers = ["X-Tenant-Id"]
150+
... mock_db = Mock()
151+
... mock_db.query.return_value.first.return_value = None
152+
... request_headers = {"x-tenant-id": "should-be-ignored"}
153+
... base_headers = {"Content-Type": "application/json"}
154+
... get_passthrough_headers(request_headers, base_headers, mock_db)
140155
{'Content-Type': 'application/json'}
141156
142157
See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py
@@ -213,3 +228,88 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[
213228

214229
logger.debug(f"Final passthrough headers: {list(passthrough_headers.keys())}")
215230
return passthrough_headers
231+
232+
233+
async def set_global_passthrough_headers(db: Session) -> None:
234+
"""Set global passthrough headers in the database if not already configured.
235+
236+
This function checks if the global passthrough headers are already set in the
237+
GlobalConfig table. If not, it initializes them with the default headers from
238+
settings.default_passthrough_headers.
239+
240+
Args:
241+
db (Session): SQLAlchemy database session for querying and updating GlobalConfig.
242+
243+
Raises:
244+
PassthroughHeadersError: If unable to update passthrough headers in the database.
245+
246+
Examples:
247+
Successful insert of default headers:
248+
>>> import pytest
249+
>>> from unittest.mock import Mock, patch
250+
>>> @pytest.mark.asyncio
251+
... @patch("mcpgateway.utils.passthrough_headers.settings")
252+
... async def test_default_headers(mock_settings):
253+
... mock_settings.enable_header_passthrough = True
254+
... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"]
255+
... mock_db = Mock()
256+
... mock_db.query.return_value.first.return_value = None
257+
... await set_global_passthrough_headers(mock_db)
258+
... mock_db.add.assert_called_once()
259+
... mock_db.commit.assert_called_once()
260+
261+
Database write failure:
262+
>>> import pytest
263+
>>> from unittest.mock import Mock, patch
264+
>>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError
265+
>>> @pytest.mark.asyncio
266+
... @patch("mcpgateway.utils.passthrough_headers.settings")
267+
... async def test_db_write_failure(mock_settings):
268+
... mock_settings.enable_header_passthrough = True
269+
... mock_db = Mock()
270+
... mock_db.query.return_value.first.return_value = None
271+
... mock_db.commit.side_effect = Exception("DB write failed")
272+
... with pytest.raises(PassthroughHeadersError):
273+
... await set_global_passthrough_headers(mock_db)
274+
... mock_db.rollback.assert_called_once()
275+
276+
Config already exists (no DB write):
277+
>>> import pytest
278+
>>> from unittest.mock import Mock, patch
279+
>>> from mcpgateway.models import GlobalConfig
280+
>>> @pytest.mark.asyncio
281+
... @patch("mcpgateway.utils.passthrough_headers.settings")
282+
... async def test_existing_config(mock_settings):
283+
... mock_settings.enable_header_passthrough = True
284+
... mock_db = Mock()
285+
... existing = Mock(spec=GlobalConfig)
286+
... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"]
287+
... mock_db.query.return_value.first.return_value = existing
288+
... await set_global_passthrough_headers(mock_db)
289+
... mock_db.add.assert_not_called()
290+
... mock_db.commit.assert_not_called()
291+
... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"]
292+
293+
Note:
294+
This function is typically called during application startup to ensure
295+
global configuration is in place before any gateway operations.
296+
"""
297+
global_config = db.query(GlobalConfig).first()
298+
299+
if not global_config:
300+
config_headers = settings.default_passthrough_headers
301+
if config_headers:
302+
allowed_headers = []
303+
for header_name in config_headers:
304+
# Validate header name
305+
if not validate_header_name(header_name):
306+
logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})")
307+
continue
308+
309+
allowed_headers.append(header_name)
310+
try:
311+
db.add(GlobalConfig(passthrough_headers=allowed_headers))
312+
db.commit()
313+
except Exception as e:
314+
db.rollback()
315+
raise PassthroughHeadersError(f"Failed to update passthrough headers: {str(e)}")

tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# Standard
1414
import logging
1515
from unittest.mock import Mock, patch
16+
import pytest
1617

1718
# First-Party
1819
from mcpgateway.db import Gateway as DbGateway
1920
from mcpgateway.db import GlobalConfig
20-
from mcpgateway.utils.passthrough_headers import get_passthrough_headers
21+
from mcpgateway.utils.passthrough_headers import get_passthrough_headers, set_global_passthrough_headers, PassthroughHeadersError
2122

2223

2324
class TestPassthroughHeaders:
@@ -121,8 +122,11 @@ def test_authorization_conflict_bearer_auth(self, mock_settings, caplog):
121122
# Check warning was logged
122123
assert any("Skipping Authorization header passthrough due to bearer auth" in record.message for record in caplog.records)
123124

124-
def test_feature_disabled_by_default(self):
125+
@patch("mcpgateway.utils.passthrough_headers.settings")
126+
def test_feature_disabled_by_default(self, mock_settings):
125127
"""Test that feature is disabled by default."""
128+
mock_settings.enable_header_passthrough = False
129+
126130
mock_db = Mock()
127131
request_headers = {"x-tenant-id": "test"}
128132
base_headers = {"Content-Type": "application/json"}
@@ -154,3 +158,59 @@ def test_case_insensitive_header_matching(self, mock_settings):
154158
# Headers should preserve config case in output keys
155159
expected = {"X-Tenant-ID": "mixed-case-value", "Authorization": "bearer lowercase-header"}
156160
assert result == expected
161+
162+
@pytest.mark.asyncio
163+
@patch("mcpgateway.utils.passthrough_headers.settings")
164+
async def test_set_global_passthrough_headers_default(self, mock_settings):
165+
mock_settings.enable_header_passthrough = True
166+
mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"]
167+
168+
mock_db = Mock()
169+
mock_db.query.return_value.first.return_value = None # Simulate no config in DB
170+
171+
# Act
172+
await set_global_passthrough_headers(mock_db)
173+
174+
# Assert
175+
mock_db.add.assert_called_once()
176+
added_config = mock_db.add.call_args[0][0]
177+
assert added_config.passthrough_headers == ["X-Tenant-Id", "X-Trace-Id"]
178+
179+
mock_db.commit.assert_called_once()
180+
181+
182+
@pytest.mark.asyncio
183+
@patch("mcpgateway.utils.passthrough_headers.settings")
184+
async def test_set_global_passthrough_headers_invalid_config(self, mock_settings):
185+
"""Should raise PassthroughHeadersError when config is invalid."""
186+
mock_settings.enable_header_passthrough = True
187+
188+
mock_db = Mock()
189+
mock_db.query.return_value.first.return_value = None
190+
mock_db.commit.side_effect = Exception("DB write failed")
191+
192+
with pytest.raises(PassthroughHeadersError) as exc_info:
193+
await set_global_passthrough_headers(mock_db)
194+
195+
assert "DB write failed" in str(exc_info.value) or str(exc_info.value)
196+
mock_db.rollback.assert_called_once()
197+
198+
@pytest.mark.asyncio
199+
@patch("mcpgateway.utils.passthrough_headers.settings")
200+
async def test_set_global_passthrough_headers_existing_config(self, mock_settings):
201+
"""Should raise PassthroughHeadersError when config is invalid."""
202+
mock_settings.enable_header_passthrough = True
203+
204+
mock_db = Mock()
205+
mock_global_config = Mock(spec=GlobalConfig)
206+
mock_global_config.passthrough_headers = ["X-Tenant-ID", "Authorization"]
207+
mock_db.query.return_value.first.return_value = mock_global_config
208+
209+
await set_global_passthrough_headers(mock_db)
210+
211+
mock_db.add.assert_not_called()
212+
mock_db.commit.assert_not_called()
213+
214+
# Ensure existing config is not modified
215+
assert mock_global_config.passthrough_headers == ["X-Tenant-ID", "Authorization"]
216+
mock_db.rollback.assert_not_called()

0 commit comments

Comments
 (0)