Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix recorded path in mounted Starlette apps #80

Merged
merged 1 commit into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apitally/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def get_path(request: Request) -> Optional[str]:
for route in request.app.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return route.path
return request.scope.get("root_path", "") + route.path
return None

def get_consumer(self, request: Request) -> Optional[ApitallyConsumer]:
Expand Down
86 changes: 49 additions & 37 deletions tests/test_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def app(request: FixtureRequest, module_mocker: MockerFixture) -> Starlett
def get_starlette_app() -> Starlette:
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.routing import Mount, Route

from apitally.starlette import ApitallyConsumer, ApitallyMiddleware, RequestLoggingConfig

Expand Down Expand Up @@ -79,16 +79,22 @@ def task_func_with_error():
def identify_consumer(request: Request) -> Optional[ApitallyConsumer]:
return ApitallyConsumer("test", name="Test")

routes = [
Route("/foo/", foo),
Route("/foo/{bar}/", foo_bar),
Route("/bar/", bar, methods=["POST"]),
Route("/baz/", baz, methods=["POST"]),
Route("/val/", val),
Route("/stream/", stream),
Route("/task/", task, methods=["POST"]),
]
app = Starlette(routes=routes)
sub_app = Starlette(
routes=[
Route("/foo", foo),
Route("/foo/{bar}", foo_bar),
Route("/bar", bar, methods=["POST"]),
Route("/baz", baz, methods=["POST"]),
Route("/val", val),
]
)
app = Starlette(
routes=[
Mount("/api", sub_app),
Route("/stream", stream),
Route("/task", task, methods=["POST"]),
]
)
app.add_middleware(
ApitallyMiddleware,
client_id=CLIENT_ID,
Expand All @@ -104,7 +110,7 @@ def identify_consumer(request: Request) -> Optional[ApitallyConsumer]:


def get_fastapi_app() -> Starlette:
from fastapi import FastAPI, Query
from fastapi import APIRouter, FastAPI, Query
from fastapi.responses import PlainTextResponse, StreamingResponse

from apitally.fastapi import ApitallyConsumer, ApitallyMiddleware, RequestLoggingConfig
Expand All @@ -125,43 +131,47 @@ def identify_consumer(request: Request) -> Optional[ApitallyConsumer]:
identify_consumer_callback=identify_consumer,
)

@app.get("/foo/")
router = APIRouter()

@router.get("/foo")
def foo():
return "foo"

@app.get("/foo/{bar}/")
@router.get("/foo/{bar}")
def foo_bar(bar: str):
return PlainTextResponse(f"foo: {bar}")

@app.post("/bar/")
@router.post("/bar")
async def bar(request: Request):
body = await request.body()
return PlainTextResponse("bar: " + body.decode())

@app.post("/baz/")
@router.post("/baz")
def baz():
raise ValueError("baz")

@app.get("/val/")
@router.get("/val")
def val(foo: int = Query()):
return "val"

@app.get("/stream/")
@app.get("/stream")
def stream():
def stream_response():
yield b"foo"
yield b"bar"

return StreamingResponse(stream_response())

@app.post("/task/")
@app.post("/task")
def task(background_tasks: BackgroundTasks):
def task_func_with_error():
raise ValueError("task")

background_tasks.add_task(task_func_with_error)
return "ok"

app.include_router(router, prefix="/api")

return app


Expand All @@ -171,23 +181,23 @@ def test_middleware_requests_ok(app: Starlette, mocker: MockerFixture):
mock = mocker.patch("apitally.client.requests.RequestCounter.add_request")
client = TestClient(app)

response = client.get("/foo/")
response = client.get("/api/foo/")
assert response.status_code == 200
mock.assert_called_once()
assert mock.call_args is not None
assert mock.call_args.kwargs["consumer"] == "test"
assert mock.call_args.kwargs["method"] == "GET"
assert mock.call_args.kwargs["path"] == "/foo/"
assert mock.call_args.kwargs["path"] == "/api/foo"
assert mock.call_args.kwargs["status_code"] == 200
assert mock.call_args.kwargs["response_time"] > 0

response = client.get("/foo/123/")
response = client.get("/api/foo/123/")
assert response.status_code == 200
assert mock.call_count == 2
assert mock.call_args is not None
assert mock.call_args.kwargs["path"] == "/foo/{bar}/"
assert mock.call_args.kwargs["path"] == "/api/foo/{bar}"

response = client.post("/bar/")
response = client.post("/api/bar/")
assert response.status_code == 200
assert mock.call_count == 3
assert mock.call_args is not None
Expand All @@ -207,12 +217,12 @@ def test_middleware_requests_error(app: Starlette, mocker: MockerFixture):
mock2 = mocker.patch("apitally.client.server_errors.ServerErrorCounter.add_server_error")
client = TestClient(app, raise_server_exceptions=False)

response = client.post("/baz/")
response = client.post("/api/baz")
assert response.status_code == 500
mock1.assert_called_once()
assert mock1.call_args is not None
assert mock1.call_args.kwargs["method"] == "POST"
assert mock1.call_args.kwargs["path"] == "/baz/"
assert mock1.call_args.kwargs["path"] == "/api/baz"
assert mock1.call_args.kwargs["status_code"] == 500
assert mock1.call_args.kwargs["response_time"] > 0

Expand All @@ -222,7 +232,7 @@ def test_middleware_requests_error(app: Starlette, mocker: MockerFixture):
assert isinstance(exception, ValueError)

# Throws a ValueError in a background task, but returns 200
response = client.post("/task/")
response = client.post("/task")
assert response.status_code == 200
assert mock1.call_count == 2
assert mock1.call_args is not None
Expand All @@ -236,7 +246,7 @@ def test_middleware_requests_unhandled(app: Starlette, mocker: MockerFixture):
mock = mocker.patch("apitally.client.requests.RequestCounter.add_request")
client = TestClient(app)

response = client.post("/xxx/")
response = client.post("/xxx")
assert response.status_code == 404
mock.assert_not_called()

Expand All @@ -248,15 +258,15 @@ def test_middleware_validation_error(app: Starlette, mocker: MockerFixture):
client = TestClient(app)

# Validation error as foo must be an integer
response = client.get("/val?foo=bar")
response = client.get("/api/val?foo=bar")
assert response.status_code == 422

# FastAPI only
if response.headers["Content-Type"] == "application/json":
mock.assert_called_once()
assert mock.call_args is not None
assert mock.call_args.kwargs["method"] == "GET"
assert mock.call_args.kwargs["path"] == "/val/"
assert mock.call_args.kwargs["path"] == "/api/val"
assert len(mock.call_args.kwargs["detail"]) == 1
assert mock.call_args.kwargs["detail"][0]["loc"] == ["query", "foo"]

Expand All @@ -269,13 +279,13 @@ def test_middleware_request_logging(app: Starlette, mocker: MockerFixture):
mock = mocker.patch("apitally.client.request_logging.RequestLogger.log_request")
client = TestClient(app)

response = client.get("/foo/123/?foo=bar", headers={"Test-Header": "test"})
response = client.get("/api/foo/123?foo=bar", headers={"Test-Header": "test"})
assert response.status_code == 200
mock.assert_called_once()
assert mock.call_args is not None
assert mock.call_args.kwargs["request"]["method"] == "GET"
assert mock.call_args.kwargs["request"]["path"] == "/foo/{bar}/"
assert mock.call_args.kwargs["request"]["url"] == "http://testserver/foo/123/?foo=bar"
assert mock.call_args.kwargs["request"]["path"] == "/api/foo/{bar}"
assert mock.call_args.kwargs["request"]["url"] == "http://testserver/api/foo/123?foo=bar"
assert ("test-header", "test") in mock.call_args.kwargs["request"]["headers"]
assert mock.call_args.kwargs["request"]["consumer"] == "test"
assert mock.call_args.kwargs["response"]["status_code"] == 200
Expand All @@ -284,18 +294,18 @@ def test_middleware_request_logging(app: Starlette, mocker: MockerFixture):
assert mock.call_args.kwargs["response"]["size"] > 0
assert mock.call_args.kwargs["response"]["body"] == b"foo: 123"

response = client.post("/bar/", content=b"foo")
response = client.post("/api/bar", content=b"foo")
assert response.status_code == 200
assert mock.call_count == 2
assert mock.call_args is not None
assert mock.call_args.kwargs["request"]["method"] == "POST"
assert mock.call_args.kwargs["request"]["path"] == "/bar/"
assert mock.call_args.kwargs["request"]["url"] == "http://testserver/bar/"
assert mock.call_args.kwargs["request"]["path"] == "/api/bar"
assert mock.call_args.kwargs["request"]["url"] == "http://testserver/api/bar"
assert mock.call_args.kwargs["request"]["body"] == b"foo"
assert mock.call_args.kwargs["response"]["body"] == b"bar: foo"

mocker.patch("apitally.starlette.MAX_BODY_SIZE", 2)
response = client.post("/bar/", content=b"foo")
response = client.post("/api/bar", content=b"foo")
assert response.status_code == 200
assert mock.call_count == 3
assert mock.call_args is not None
Expand All @@ -312,6 +322,8 @@ def test_get_startup_data(app: Starlette, mocker: MockerFixture):

data = _get_startup_data(app=app.middleware_stack, app_version="1.2.3", openapi_url=None)
assert len(data["paths"]) == 7
assert {"method": "get", "path": "/api/foo"} in data["paths"]
assert {"method": "get", "path": "/stream"} in data["paths"]
assert data["versions"]["starlette"]
assert data["versions"]["app"] == "1.2.3"
assert data["client"] == "python:starlette"
Loading