|
50 | 50 | MAX_HEADER_VALUE_LENGTH = 4096 |
51 | 51 |
|
52 | 52 |
|
| 53 | +class PassthroughHeadersError(Exception): |
| 54 | + """Base class for passthrough headers-related errors. |
| 55 | +
|
| 56 | + Examples: |
| 57 | + >>> error = PassthroughHeadersError("Test error") |
| 58 | + >>> str(error) |
| 59 | + 'Test error' |
| 60 | + >>> isinstance(error, Exception) |
| 61 | + True |
| 62 | + """ |
| 63 | + |
| 64 | + |
53 | 65 | def sanitize_header_value(value: str, max_length: int = MAX_HEADER_VALUE_LENGTH) -> str: |
54 | 66 | """Sanitize header value for security. |
55 | 67 |
|
@@ -131,12 +143,15 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[ |
131 | 143 |
|
132 | 144 | Examples: |
133 | 145 | Feature disabled by default (secure by default): |
134 | | - >>> from unittest.mock import Mock |
135 | | - >>> mock_db = Mock() |
136 | | - >>> request_headers = {"x-tenant-id": "should-be-ignored"} |
137 | | - >>> base_headers = {"Content-Type": "application/json"} |
138 | | - >>> result = get_passthrough_headers(request_headers, base_headers, mock_db) |
139 | | - >>> result |
| 146 | + >>> from unittest.mock import Mock, patch |
| 147 | + >>> with patch(__name__ + ".settings") as mock_settings: |
| 148 | + ... mock_settings.enable_header_passthrough = False |
| 149 | + ... mock_settings.default_passthrough_headers = ["X-Tenant-Id"] |
| 150 | + ... mock_db = Mock() |
| 151 | + ... mock_db.query.return_value.first.return_value = None |
| 152 | + ... request_headers = {"x-tenant-id": "should-be-ignored"} |
| 153 | + ... base_headers = {"Content-Type": "application/json"} |
| 154 | + ... get_passthrough_headers(request_headers, base_headers, mock_db) |
140 | 155 | {'Content-Type': 'application/json'} |
141 | 156 |
|
142 | 157 | See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py |
@@ -213,3 +228,88 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[ |
213 | 228 |
|
214 | 229 | logger.debug(f"Final passthrough headers: {list(passthrough_headers.keys())}") |
215 | 230 | return passthrough_headers |
| 231 | + |
| 232 | + |
| 233 | +async def set_global_passthrough_headers(db: Session) -> None: |
| 234 | + """Set global passthrough headers in the database if not already configured. |
| 235 | +
|
| 236 | + This function checks if the global passthrough headers are already set in the |
| 237 | + GlobalConfig table. If not, it initializes them with the default headers from |
| 238 | + settings.default_passthrough_headers. |
| 239 | +
|
| 240 | + Args: |
| 241 | + db (Session): SQLAlchemy database session for querying and updating GlobalConfig. |
| 242 | +
|
| 243 | + Raises: |
| 244 | + PassthroughHeadersError: If unable to update passthrough headers in the database. |
| 245 | +
|
| 246 | + Examples: |
| 247 | + Successful insert of default headers: |
| 248 | + >>> import pytest |
| 249 | + >>> from unittest.mock import Mock, patch |
| 250 | + >>> @pytest.mark.asyncio |
| 251 | + ... @patch("mcpgateway.utils.passthrough_headers.settings") |
| 252 | + ... async def test_default_headers(mock_settings): |
| 253 | + ... mock_settings.enable_header_passthrough = True |
| 254 | + ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] |
| 255 | + ... mock_db = Mock() |
| 256 | + ... mock_db.query.return_value.first.return_value = None |
| 257 | + ... await set_global_passthrough_headers(mock_db) |
| 258 | + ... mock_db.add.assert_called_once() |
| 259 | + ... mock_db.commit.assert_called_once() |
| 260 | +
|
| 261 | + Database write failure: |
| 262 | + >>> import pytest |
| 263 | + >>> from unittest.mock import Mock, patch |
| 264 | + >>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError |
| 265 | + >>> @pytest.mark.asyncio |
| 266 | + ... @patch("mcpgateway.utils.passthrough_headers.settings") |
| 267 | + ... async def test_db_write_failure(mock_settings): |
| 268 | + ... mock_settings.enable_header_passthrough = True |
| 269 | + ... mock_db = Mock() |
| 270 | + ... mock_db.query.return_value.first.return_value = None |
| 271 | + ... mock_db.commit.side_effect = Exception("DB write failed") |
| 272 | + ... with pytest.raises(PassthroughHeadersError): |
| 273 | + ... await set_global_passthrough_headers(mock_db) |
| 274 | + ... mock_db.rollback.assert_called_once() |
| 275 | +
|
| 276 | + Config already exists (no DB write): |
| 277 | + >>> import pytest |
| 278 | + >>> from unittest.mock import Mock, patch |
| 279 | + >>> from mcpgateway.models import GlobalConfig |
| 280 | + >>> @pytest.mark.asyncio |
| 281 | + ... @patch("mcpgateway.utils.passthrough_headers.settings") |
| 282 | + ... async def test_existing_config(mock_settings): |
| 283 | + ... mock_settings.enable_header_passthrough = True |
| 284 | + ... mock_db = Mock() |
| 285 | + ... existing = Mock(spec=GlobalConfig) |
| 286 | + ... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"] |
| 287 | + ... mock_db.query.return_value.first.return_value = existing |
| 288 | + ... await set_global_passthrough_headers(mock_db) |
| 289 | + ... mock_db.add.assert_not_called() |
| 290 | + ... mock_db.commit.assert_not_called() |
| 291 | + ... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"] |
| 292 | +
|
| 293 | + Note: |
| 294 | + This function is typically called during application startup to ensure |
| 295 | + global configuration is in place before any gateway operations. |
| 296 | + """ |
| 297 | + global_config = db.query(GlobalConfig).first() |
| 298 | + |
| 299 | + if not global_config: |
| 300 | + config_headers = settings.default_passthrough_headers |
| 301 | + if config_headers: |
| 302 | + allowed_headers = [] |
| 303 | + for header_name in config_headers: |
| 304 | + # Validate header name |
| 305 | + if not validate_header_name(header_name): |
| 306 | + logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})") |
| 307 | + continue |
| 308 | + |
| 309 | + allowed_headers.append(header_name) |
| 310 | + try: |
| 311 | + db.add(GlobalConfig(passthrough_headers=allowed_headers)) |
| 312 | + db.commit() |
| 313 | + except Exception as e: |
| 314 | + db.rollback() |
| 315 | + raise PassthroughHeadersError(f"Failed to update passthrough headers: {str(e)}") |
0 commit comments