Skip to content

Commit

Permalink
fix(BA-82): Add request & response policy middleware for web security (
Browse files Browse the repository at this point in the history
…#2937)

Backported-from: main (25.1)
Backported-to: 24.09
Backport-of: 2937
  • Loading branch information
HyeockJinKim committed Feb 3, 2025
1 parent e62e170 commit ac692ff
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 1 deletion.
1 change: 1 addition & 0 deletions changes/2937.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add reject middleware for web security
80 changes: 80 additions & 0 deletions src/ai/backend/web/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Callable, Iterable, Self

from aiohttp import web
from aiohttp.typedefs import Handler


@web.middleware
async def security_policy_middleware(request: web.Request, handler: Handler) -> web.StreamResponse:
security_policy: SecurityPolicy = request.app["security_policy"]
security_policy.check_request_policies(request)
response = await handler(request)
return security_policy.apply_response_policies(response)


class SecurityPolicy:
_request_policies: Iterable[Callable[[web.Request], None]]
_response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]]

def __init__(
self,
request_policies: Iterable[Callable[[web.Request], None]],
response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]],
) -> None:
self._request_policies = request_policies
self._response_policies = response_policies

@classmethod
def default_policy(cls) -> Self:
request_policies = [reject_metadata_local_link_policy, reject_access_for_unsafe_file_policy]
response_policies = [add_self_content_security_policy, set_content_type_nosniff_policy]
return cls(request_policies, response_policies)

def check_request_policies(self, request: web.Request) -> None:
for policy in self._request_policies:
policy(request)

def apply_response_policies(self, response: web.StreamResponse) -> web.StreamResponse:
for policy in self._response_policies:
response = policy(response)
return response


def reject_metadata_local_link_policy(request: web.Request) -> None:
metadata_local_link_map = {
"metadata.google.internal": True,
"169.254.169.254": True,
"100.100.100.200": True,
"alibaba.zaproxy.org": True,
"metadata.oraclecloud.com": True,
}
if metadata_local_link_map.get(request.host):
raise web.HTTPForbidden()


def reject_access_for_unsafe_file_policy(request: web.Request) -> None:
unsafe_file_map = {
"._darcs": True,
".bzr": True,
".hg": True,
"BitKeeper": True,
".bak": True,
".log": True,
".git": True,
".svn": True,
}
file_name = request.path.split("/")[-1]
if unsafe_file_map.get(file_name):
raise web.HTTPForbidden()


def add_self_content_security_policy(response: web.StreamResponse) -> web.StreamResponse:
response.headers["Content-Security-Policy"] = (
"default-src 'self'; frame-ancestors 'none'; form-action 'self';"
)
return response


def set_content_type_nosniff_policy(response: web.StreamResponse) -> web.StreamResponse:
response.headers["X-Content-Type-Options"] = "nosniff"
return response
6 changes: 5 additions & 1 deletion src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ai.backend.common.web.session import setup as setup_session
from ai.backend.common.web.session.redis_storage import RedisStorage
from ai.backend.logging import BraceStyleAdapter, Logger, LogLevel
from ai.backend.web.security import SecurityPolicy, security_policy_middleware

from . import __version__, user_agent
from .auth import fill_forwarding_hdrs_to_api_session, get_client_ip
Expand Down Expand Up @@ -603,8 +604,11 @@ async def server_main(
args: Tuple[Any, ...],
) -> AsyncIterator[None]:
config = args[0]
app = web.Application(middlewares=[decrypt_payload, track_active_handlers])
app = web.Application(
middlewares=[decrypt_payload, track_active_handlers, security_policy_middleware]
)
app["config"] = config
app["security_policy"] = SecurityPolicy.default_policy()
j2env = jinja2.Environment(
extensions=[
"ai.backend.web.template.TOMLField",
Expand Down
114 changes: 114 additions & 0 deletions tests/webserver/test_security_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import pytest
from aiohttp import web
from aiohttp.test_utils import make_mocked_request
from aiohttp.typedefs import Handler

from ai.backend.web.security import (
SecurityPolicy,
add_self_content_security_policy,
reject_access_for_unsafe_file_policy,
reject_metadata_local_link_policy,
security_policy_middleware,
set_content_type_nosniff_policy,
)


@pytest.fixture
def default_app():
app = web.Application(middlewares=[security_policy_middleware])
app["security_policy"] = SecurityPolicy.default_policy()
return app


@pytest.fixture
async def async_handler() -> Handler:
async def handler(request):
return web.Response()

return handler


async def test_default_security_policy_reject_metadata_local_link(
default_app, async_handler
) -> None:
request = make_mocked_request("GET", "/", headers={"Host": "169.254.169.254"}, app=default_app)
with pytest.raises(web.HTTPForbidden):
await security_policy_middleware(request, async_handler)


async def test_default_security_policy_response(default_app, async_handler) -> None:
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=default_app)
response = await security_policy_middleware(request, async_handler)
assert (
response.headers["Content-Security-Policy"]
== "default-src 'self'; frame-ancestors 'none'; form-action 'self';"
)
assert response.headers["X-Content-Type-Options"] == "nosniff"


@pytest.mark.parametrize(
"meta_local_link",
[
"metadata.google.internal",
"169.254.169.254",
"100.100.100.200",
"alibaba.zaproxy.org",
"metadata.oraclecloud.com",
],
)
async def test_reject_metadata_local_link_policy(async_handler, meta_local_link) -> None:
test_app = web.Application()
test_app["security_policy"] = SecurityPolicy(
request_policies=[reject_metadata_local_link_policy], response_policies=[]
)
request = make_mocked_request("GET", "/", headers={"Host": meta_local_link}, app=test_app)
with pytest.raises(web.HTTPForbidden):
await security_policy_middleware(request, async_handler)


@pytest.mark.parametrize(
"url_suffix",
[
"._darcs",
".bzr",
".hg",
"BitKeeper",
".bak",
".log",
".git",
".svn",
],
)
async def test_reject_access_for_unsafe_file_policy(async_handler, url_suffix) -> None:
test_app = web.Application()
test_app["security_policy"] = SecurityPolicy(
request_policies=[reject_access_for_unsafe_file_policy], response_policies=[]
)
request = make_mocked_request(
"GET", f"/{url_suffix}", headers={"Host": "localhost"}, app=test_app
)
with pytest.raises(web.HTTPForbidden):
await security_policy_middleware(request, async_handler)


async def test_add_self_content_security_policy(async_handler) -> None:
test_app = web.Application()
test_app["security_policy"] = SecurityPolicy(
request_policies=[], response_policies=[add_self_content_security_policy]
)
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=test_app)
response = await security_policy_middleware(request, async_handler)
assert (
response.headers["Content-Security-Policy"]
== "default-src 'self'; frame-ancestors 'none'; form-action 'self';"
)


async def test_set_content_type_nosniff_policy(async_handler) -> None:
test_app = web.Application()
test_app["security_policy"] = SecurityPolicy(
request_policies=[], response_policies=[set_content_type_nosniff_policy]
)
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=test_app)
response = await security_policy_middleware(request, async_handler)
assert response.headers["X-Content-Type-Options"] == "nosniff"

0 comments on commit ac692ff

Please sign in to comment.