Skip to content

Commit

Permalink
feat: OpenAPI plugins send CSRF request header (#3754)
Browse files Browse the repository at this point in the history
* feat: Swagger sends CSRF request header

* feat: RapiDoc sends CSRF request header

* test: Add tests for Swagger & RapiDoc with CSRF

* test: csrf config with httponly cookie
  • Loading branch information
floxay authored Sep 22, 2024
1 parent 88c313c commit 5f01bb9
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 1 deletion.
45 changes: 44 additions & 1 deletion litestar/openapi/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from litestar.serialization import encode_json, get_serializer

if TYPE_CHECKING:
from litestar.config.csrf import CSRFConfig
from litestar.connection import Request
from litestar.router import Router

Expand All @@ -30,6 +31,11 @@
_default_style = "<style>body { margin: 0; padding: 0 }</style>"


def _get_cookie_value_or_undefined(cookie_name: str) -> str:
"""Javascript code as a string to get the value of a cookie by name or undefined."""
return f"document.cookie.split('; ').find((row) => row.startsWith('{cookie_name}='))?.split('=')[1];"


class OpenAPIRenderPlugin(ABC):
"""Base class for OpenAPI UI render plugins."""

Expand Down Expand Up @@ -221,6 +227,25 @@ def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes:
A rendered html string.
"""

def create_request_interceptor(csrf_config: CSRFConfig) -> str:
if csrf_config.cookie_httponly:
return ""

return f"""
<script>
window.addEventListener('DOMContentLoaded', (event) => {{
const rapidocEl = document.getElementsByTagName("rapi-doc")[0];
rapidocEl.addEventListener('before-try', (e) => {{
const csrf_token = {_get_cookie_value_or_undefined(csrf_config.cookie_name)};
if (csrf_token !== undefined) {{
e.detail.request.headers.append('{csrf_config.header_name}', csrf_token);
}}
}});
}});
</script>"""

head = f"""
<head>
<title>{openapi_schema["info"]["title"]}</title>
Expand All @@ -235,6 +260,7 @@ def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes:
body = f"""
<body>
<rapi-doc spec-url="{self.get_openapi_json_route(request)}" />
{create_request_interceptor(request.app.csrf_config) if request.app.csrf_config else ""}
</body>
"""

Expand Down Expand Up @@ -520,6 +546,21 @@ def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes:
A rendered html string.
"""

def create_request_interceptor(csrf_config: CSRFConfig) -> bytes:
if csrf_config.cookie_httponly:
return b""

return f"""
requestInterceptor: (request) => {{
const csrf_token = {_get_cookie_value_or_undefined(csrf_config.cookie_name)};
if (csrf_token !== undefined) {{
request.headers['{csrf_config.header_name}'] = csrf_token;
}}
return request;
}},""".encode()

head = f"""
<head>
<title>{openapi_schema["info"]["title"]}</title>
Expand Down Expand Up @@ -550,7 +591,9 @@ def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes:
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
],""",
create_request_interceptor(request.app.csrf_config) if request.app.csrf_config else b"",
b"""
})
ui.initOAuth(""",
encode_json(self.init_oauth),
Expand Down
62 changes: 62 additions & 0 deletions tests/unit/test_openapi/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from litestar import Litestar
from litestar.config.csrf import CSRFConfig
from litestar.openapi.config import OpenAPIConfig
from litestar.openapi.plugins import RapidocRenderPlugin, SwaggerRenderPlugin
from litestar.testing import TestClient

rapidoc_fragment = ".addEventListener('before-try',"
swagger_fragment = "requestInterceptor:"


def test_rapidoc_csrf() -> None:
app = Litestar(
csrf_config=CSRFConfig(secret="litestar"),
openapi_config=OpenAPIConfig(
title="Litestar Example",
version="0.0.1",
render_plugins=[RapidocRenderPlugin()],
),
)

with TestClient(app=app) as client:
resp = client.get("/schema/rapidoc")
assert resp.status_code == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
assert rapidoc_fragment in resp.text


def test_swagger_ui_csrf() -> None:
app = Litestar(
csrf_config=CSRFConfig(secret="litestar"),
openapi_config=OpenAPIConfig(
title="Litestar Example",
version="0.0.1",
render_plugins=[SwaggerRenderPlugin()],
),
)

with TestClient(app=app) as client:
resp = client.get("/schema/swagger")
assert resp.status_code == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
assert swagger_fragment in resp.text


def test_plugins_csrf_httponly() -> None:
app = Litestar(
csrf_config=CSRFConfig(secret="litestar", cookie_httponly=True),
openapi_config=OpenAPIConfig(
title="Litestar Example",
version="0.0.1",
render_plugins=[RapidocRenderPlugin(), SwaggerRenderPlugin()],
),
)

with TestClient(app=app) as client:
resp = client.get("/schema/rapidoc")
assert resp.status_code == 200
assert rapidoc_fragment not in resp.text

resp = client.get("/schema/swagger")
assert resp.status_code == 200
assert swagger_fragment not in resp.text

0 comments on commit 5f01bb9

Please sign in to comment.