Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CONSVC-2063] test: extract middleware unit tests #106

Merged
merged 9 commits into from
Nov 16, 2022
28 changes: 28 additions & 0 deletions docs/dev/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ def test_with_suggestion_request(srequest: SuggestionRequestFixture) -> None:
result: list[BaseSuggestion] = await provider.query(request)
```

#### ScopeFixture, ReceiveMockFixture & SendMockFixture
For use when testing middleware, these fixtures initialize or mock the common Scope,
Receive and Send object dependencies.

_**Usage:**_
```python
def test_middleware(scope: Scope, receive_mock: Receive, send_mock: Send) -> None:
pass
````

## Integration Tests

Expand Down Expand Up @@ -89,6 +98,25 @@ def test_with_test_client_with_event(client_with_events: TestClient):
response: Response = client_with_events.get("/api/v1/endpoint")
```

#### RequestSummaryLogDataFixture
This fixture will extract the extra log data from a captured 'request.summary'
LogRecord for verification

_**Usage:**_
```python
def test_with_log_data(
caplog: LogCaptureFixture,
filter_caplog: FilterCaplogFixture,
extract_request_summary_log_data: LogDataFixture
):
records: list[LogRecord] = filter_caplog(caplog.records, "request.summary")
assert len(records) == 1

record: LogRecord = records[0]
log_data: dict[str, Any] = extract_request_summary_log_data(record)
assert log_data == expected_log_data
```

#### InjectProvidersFixture & ProvidersFixture
These fixture will setup and teardown given providers.

Expand Down
2 changes: 1 addition & 1 deletion merino/middleware/featureflags.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, app: ASGIApp) -> None:

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Insert session id before handing request"""
if scope["type"] != "http": # pragma: no cover
if scope["type"] != "http":
Trinaa marked this conversation as resolved.
Show resolved Hide resolved
await self.app(scope, receive, send)
return

Expand Down
53 changes: 9 additions & 44 deletions merino/middleware/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from datetime import datetime
from typing import Pattern

from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from merino.middleware import ScopeKey
from merino.util.log_data_creators import (
create_request_summary_log_data,
create_suggest_log_data,
)

# web.suggest.request is used for logs coming from the /suggest endpoint
suggest_request_logger = logging.getLogger("web.suggest.request")
Expand All @@ -24,61 +26,24 @@ class LoggingMiddleware:
"""An ASGI middleware for logging."""

def __init__(self, app: ASGIApp) -> None:
"""Initilize."""
"""Initialize."""
Trinaa marked this conversation as resolved.
Show resolved Hide resolved
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
"""Log requests."""
if scope["type"] != "http": # pragma: no cover
if scope["type"] != "http":
await self.app(scope, receive, send)
return

async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
request = Request(scope=scope)
dt: datetime = datetime.fromtimestamp(time.time())
if PATTERN.match(request.url.path):
location = scope[ScopeKey.GEOLOCATION]
ua = scope[ScopeKey.USER_AGENT]
data = {
"sensitive": True,
"path": request.url.path,
"method": request.method,
"query": request.query_params.get("q"),
"errno": 0,
"code": message["status"],
"time": datetime.fromtimestamp(time.time()).isoformat(),
# Provided by the asgi-correlation-id middleware.
"rid": Headers(scope=message)["X-Request-ID"],
"session_id": request.query_params.get("sid"),
"sequence_no": int(seq)
if (seq := request.query_params.get("seq"))
else None,
"country": location.country,
"region": location.region,
"city": location.city,
"dma": location.dma,
"client_variants": request.query_params.get(
"client_variants", ""
),
"requested_providers": request.query_params.get(
"providers", ""
),
"browser": ua.browser,
"os_family": ua.os_family,
"form_factor": ua.form_factor,
}
data = create_suggest_log_data(request, message, dt)
suggest_request_logger.info("", extra=data)
else:
data = {
"agent": request.headers.get("User-Agent"),
"path": request.url.path,
"method": request.method,
"lang": request.headers.get("Accept-Language"),
"querystring": dict(request.query_params),
"errno": 0,
"code": message["status"],
"time": datetime.fromtimestamp(time.time()).isoformat(),
}
data = create_request_summary_log_data(request, message, dt)
logger.info("", extra=data)

await send(message)
Expand Down
2 changes: 1 addition & 1 deletion merino/middleware/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, app: ASGIApp) -> None:

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Wrap the request with metrics."""
if scope["type"] != "http": # pragma: no cover
if scope["type"] != "http":
await self.app(scope, receive, send)
return

Expand Down
2 changes: 1 addition & 1 deletion merino/middleware/user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Parse user agent information through "User-Agent" and store the result
to `scope`.
"""
if scope["type"] != "http": # pragma: no cover
if scope["type"] != "http":
await self.app(scope, receive, send)
return

Trinaa marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
73 changes: 73 additions & 0 deletions merino/util/log_data_creators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""A utility module for lag data creation"""
from datetime import datetime
from typing import Any

from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.types import Message

from merino.middleware import ScopeKey
from merino.middleware.geolocation import Location
from merino.middleware.user_agent import UserAgent


def create_request_summary_log_data(
request: Request, message: Message, dt: datetime
) -> dict[str, Any]:
"""Create log data for API endpoints."""
general_data = {
"errno": 0,
"time": dt.isoformat(),
}

request_data = {
"agent": request.headers.get("User-Agent"),
"path": request.url.path,
"method": request.method,
"lang": request.headers.get("Accept-Language"),
"querystring": dict(request.query_params),
"code": message["status"],
}

return {**general_data, **request_data}


def create_suggest_log_data(
request: Request, message: Message, dt: datetime
) -> dict[str, Any]:
"""Create log data for the suggest API endpoint."""
general_data = {
"sensitive": True,
"errno": 0,
"time": dt.isoformat(),
}

request_data = {
"path": request.url.path,
"method": request.method,
"query": request.query_params.get("q"),
"code": message["status"],
# Provided by the asgi-correlation-id middleware.
"rid": Headers(scope=message)["X-Request-ID"],
"session_id": request.query_params.get("sid"),
"sequence_no": int(seq) if (seq := request.query_params.get("seq")) else None,
"client_variants": request.query_params.get("client_variants", ""),
"requested_providers": request.query_params.get("providers", ""),
}

location: Location = request.scope[ScopeKey.GEOLOCATION]
location_data = {
"country": location.country,
"region": location.region,
"city": location.city,
"dma": location.dma,
}

user_agent: UserAgent = request.scope[ScopeKey.USER_AGENT]
user_agent_data = {
"browser": user_agent.browser,
"os_family": user_agent.os_family,
"form_factor": user_agent.form_factor,
}

return {**general_data, **request_data, **location_data, **user_agent_data}
Loading