diff --git a/src/stac_auth_proxy/utils/middleware.py b/src/stac_auth_proxy/utils/middleware.py index d8d966b..756fd7d 100644 --- a/src/stac_auth_proxy/utils/middleware.py +++ b/src/stac_auth_proxy/utils/middleware.py @@ -18,6 +18,10 @@ class JsonResponseMiddleware(ABC): app: ASGIApp + # Expected data type for JSON responses. Only responses matching this type will be transformed. + # If None, all JSON responses will be transformed regardless of type. + expected_data_type: Optional[type] = dict + @abstractmethod def should_transform_response( self, request: Request, scope: Scope @@ -97,8 +101,21 @@ async def transform_response(message: Message) -> None: ) await response(scope, receive, send) return - transformed = self.transform_json(data, request=request) - body = json.dumps(transformed).encode() + + if self.expected_data_type is None or isinstance( + data, self.expected_data_type + ): + transformed = self.transform_json(data, request=request) + body = json.dumps(transformed).encode() + else: + logger.warning( + "Received JSON response with unexpected data type %r from upstream server (%r %r), " + "skipping transformation (expected: %r)", + type(data).__name__, + request.method, + request.url, + self.expected_data_type.__name__, + ) # Update content-length header headers["content-length"] = str(len(body)) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 2d114c9..221e1c8 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,7 +1,9 @@ """Tests for middleware utilities.""" from typing import Any +from unittest.mock import patch +import pytest from fastapi import FastAPI, Response from starlette.datastructures import Headers from starlette.requests import Request @@ -17,6 +19,7 @@ class ExampleJsonResponseMiddleware(JsonResponseMiddleware): def __init__(self, app: ASGIApp): """Initialize the middleware.""" self.app = app + # Use default expected_data_type (dict) def should_transform_response(self, request: Request, scope: Scope) -> bool: """Transform JSON responses based on content type.""" @@ -24,11 +27,65 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool: def transform_json(self, data: Any, request: Request) -> Any: """Add a test field to the response.""" - if isinstance(data, dict): - data["transformed"] = True + data["transformed"] = True return data +class ExampleStringJsonResponseMiddleware(JsonResponseMiddleware): + """Example implementation that expects string JSON responses.""" + + def __init__(self, app: ASGIApp): + """Initialize the middleware.""" + self.app = app + self.expected_data_type = str + + def should_transform_response(self, request: Request, scope: Scope) -> bool: + """Transform JSON responses based on content type.""" + return Headers(scope=scope).get("content-type", "") == "application/json" + + def transform_json(self, data: Any, request: Request) -> Any: + """Transform string responses by adding a prefix.""" + if isinstance(data, str): + return f"transformed: {data}" + return data + + +class ExampleListJsonResponseMiddleware(JsonResponseMiddleware): + """Example implementation that expects list JSON responses.""" + + def __init__(self, app: ASGIApp): + """Initialize the middleware.""" + self.app = app + self.expected_data_type = list + + def should_transform_response(self, request: Request, scope: Scope) -> bool: + """Transform JSON responses based on content type.""" + return Headers(scope=scope).get("content-type", "") == "application/json" + + def transform_json(self, data: Any, request: Request) -> Any: + """Transform list responses by adding a new item.""" + if isinstance(data, list): + return data + ["transformed"] + return data + + +class ExampleAnyJsonResponseMiddleware(JsonResponseMiddleware): + """Example implementation that transforms any JSON response type.""" + + def __init__(self, app: ASGIApp): + """Initialize the middleware.""" + self.app = app + self.expected_data_type = None # Transform any JSON type + + def should_transform_response(self, request: Request, scope: Scope) -> bool: + """Transform JSON responses based on content type.""" + return Headers(scope=scope).get("content-type", "") == "application/json" + + def transform_json(self, data: Any, request: Request) -> Any: + """Transform any JSON response by wrapping it.""" + return {"transformed": True, "data": data} + + def test_json_response_middleware(): """Test that JSON responses are properly transformed.""" app = FastAPI() @@ -119,3 +176,131 @@ async def test_endpoint(): assert response.headers["content-type"] == "application/json" data = response.json() assert data == {"error": "Received invalid JSON from upstream server"} + + +@pytest.mark.parametrize( + "content,expected_data", + [ + ('"hello world"', "hello world"), + ('[1, 2, 3, "test"]', [1, 2, 3, "test"]), + ("42", 42), + ("true", True), + ("null", None), + ], +) +def test_json_response_middleware_non_dict_json(content, expected_data): + """Test that non-dict JSON responses are not transformed by default middleware.""" + app = FastAPI() + app.add_middleware(ExampleJsonResponseMiddleware) + + @app.get("/test") + async def test_endpoint(): + return Response(content=content, media_type="application/json") + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + data = response.json() + assert data == expected_data # Should remain unchanged + + +@pytest.mark.parametrize( + "middleware_class, test_data, expected_result, should_transform", + [ + # String middleware tests + ( + ExampleStringJsonResponseMiddleware, + "this is a string", + "transformed: this is a string", + True, + ), + ( + ExampleStringJsonResponseMiddleware, + {"message": "not a string"}, + {"message": "not a string"}, + False, + ), + # List middleware tests + ( + ExampleListJsonResponseMiddleware, + [1, 2, 3], + [1, 2, 3, "transformed"], + True, + ), + ( + ExampleListJsonResponseMiddleware, + "not a list", + "not a list", + False, + ), + # Dict middleware tests (default) + ( + ExampleJsonResponseMiddleware, + {"message": "test"}, + {"message": "test", "transformed": True}, + True, + ), + ( + ExampleJsonResponseMiddleware, + "not a dict", + "not a dict", + False, + ), + ], +) +def test_json_response_middleware_type_specific( + middleware_class, test_data, expected_result, should_transform +): + """Test that middleware transforms only expected data types.""" + with patch.object( + middleware_class, "transform_json", return_value=expected_result + ) as mock_method: + app = FastAPI() + app.add_middleware(middleware_class) + + @app.get("/test") + async def test_endpoint(): + return test_data + + client = TestClient(app) + response = client.get("/test") + + data = response.json() + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + assert mock_method.call_count == (1 if should_transform else 0) + if should_transform: + assert mock_method.call_args[0][0] == test_data + assert data == expected_result + + +@pytest.mark.parametrize( + "test_data", + [ + {"message": "test"}, + "hello world", + [1, 2, 3], + 42, + True, + None, + ], +) +def test_json_response_middleware_expected_none_type(test_data): + """Test that middleware with expected_data_type=None transforms all JSON response types.""" + app = FastAPI() + app.add_middleware(ExampleAnyJsonResponseMiddleware) + + @app.get("/test") + async def test_endpoint(): + return test_data + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + data = response.json() + + # Verify the simplified transformation behavior + assert data["transformed"] is True + assert data["data"] == test_data