-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
…#2937) Backported-from: main (25.1) Backported-to: 24.09 Backport-of: 2937
- Loading branch information
1 parent
e62e170
commit ac692ff
Showing
4 changed files
with
200 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add reject middleware for web security |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Callable, Iterable, Self | ||
|
||
from aiohttp import web | ||
from aiohttp.typedefs import Handler | ||
|
||
|
||
@web.middleware | ||
async def security_policy_middleware(request: web.Request, handler: Handler) -> web.StreamResponse: | ||
security_policy: SecurityPolicy = request.app["security_policy"] | ||
security_policy.check_request_policies(request) | ||
response = await handler(request) | ||
return security_policy.apply_response_policies(response) | ||
|
||
|
||
class SecurityPolicy: | ||
_request_policies: Iterable[Callable[[web.Request], None]] | ||
_response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]] | ||
|
||
def __init__( | ||
self, | ||
request_policies: Iterable[Callable[[web.Request], None]], | ||
response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]], | ||
) -> None: | ||
self._request_policies = request_policies | ||
self._response_policies = response_policies | ||
|
||
@classmethod | ||
def default_policy(cls) -> Self: | ||
request_policies = [reject_metadata_local_link_policy, reject_access_for_unsafe_file_policy] | ||
response_policies = [add_self_content_security_policy, set_content_type_nosniff_policy] | ||
return cls(request_policies, response_policies) | ||
|
||
def check_request_policies(self, request: web.Request) -> None: | ||
for policy in self._request_policies: | ||
policy(request) | ||
|
||
def apply_response_policies(self, response: web.StreamResponse) -> web.StreamResponse: | ||
for policy in self._response_policies: | ||
response = policy(response) | ||
return response | ||
|
||
|
||
def reject_metadata_local_link_policy(request: web.Request) -> None: | ||
metadata_local_link_map = { | ||
"metadata.google.internal": True, | ||
"169.254.169.254": True, | ||
"100.100.100.200": True, | ||
"alibaba.zaproxy.org": True, | ||
"metadata.oraclecloud.com": True, | ||
} | ||
if metadata_local_link_map.get(request.host): | ||
raise web.HTTPForbidden() | ||
|
||
|
||
def reject_access_for_unsafe_file_policy(request: web.Request) -> None: | ||
unsafe_file_map = { | ||
"._darcs": True, | ||
".bzr": True, | ||
".hg": True, | ||
"BitKeeper": True, | ||
".bak": True, | ||
".log": True, | ||
".git": True, | ||
".svn": True, | ||
} | ||
file_name = request.path.split("/")[-1] | ||
if unsafe_file_map.get(file_name): | ||
raise web.HTTPForbidden() | ||
|
||
|
||
def add_self_content_security_policy(response: web.StreamResponse) -> web.StreamResponse: | ||
response.headers["Content-Security-Policy"] = ( | ||
"default-src 'self'; frame-ancestors 'none'; form-action 'self';" | ||
) | ||
return response | ||
|
||
|
||
def set_content_type_nosniff_policy(response: web.StreamResponse) -> web.StreamResponse: | ||
response.headers["X-Content-Type-Options"] = "nosniff" | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import pytest | ||
from aiohttp import web | ||
from aiohttp.test_utils import make_mocked_request | ||
from aiohttp.typedefs import Handler | ||
|
||
from ai.backend.web.security import ( | ||
SecurityPolicy, | ||
add_self_content_security_policy, | ||
reject_access_for_unsafe_file_policy, | ||
reject_metadata_local_link_policy, | ||
security_policy_middleware, | ||
set_content_type_nosniff_policy, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def default_app(): | ||
app = web.Application(middlewares=[security_policy_middleware]) | ||
app["security_policy"] = SecurityPolicy.default_policy() | ||
return app | ||
|
||
|
||
@pytest.fixture | ||
async def async_handler() -> Handler: | ||
async def handler(request): | ||
return web.Response() | ||
|
||
return handler | ||
|
||
|
||
async def test_default_security_policy_reject_metadata_local_link( | ||
default_app, async_handler | ||
) -> None: | ||
request = make_mocked_request("GET", "/", headers={"Host": "169.254.169.254"}, app=default_app) | ||
with pytest.raises(web.HTTPForbidden): | ||
await security_policy_middleware(request, async_handler) | ||
|
||
|
||
async def test_default_security_policy_response(default_app, async_handler) -> None: | ||
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=default_app) | ||
response = await security_policy_middleware(request, async_handler) | ||
assert ( | ||
response.headers["Content-Security-Policy"] | ||
== "default-src 'self'; frame-ancestors 'none'; form-action 'self';" | ||
) | ||
assert response.headers["X-Content-Type-Options"] == "nosniff" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"meta_local_link", | ||
[ | ||
"metadata.google.internal", | ||
"169.254.169.254", | ||
"100.100.100.200", | ||
"alibaba.zaproxy.org", | ||
"metadata.oraclecloud.com", | ||
], | ||
) | ||
async def test_reject_metadata_local_link_policy(async_handler, meta_local_link) -> None: | ||
test_app = web.Application() | ||
test_app["security_policy"] = SecurityPolicy( | ||
request_policies=[reject_metadata_local_link_policy], response_policies=[] | ||
) | ||
request = make_mocked_request("GET", "/", headers={"Host": meta_local_link}, app=test_app) | ||
with pytest.raises(web.HTTPForbidden): | ||
await security_policy_middleware(request, async_handler) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"url_suffix", | ||
[ | ||
"._darcs", | ||
".bzr", | ||
".hg", | ||
"BitKeeper", | ||
".bak", | ||
".log", | ||
".git", | ||
".svn", | ||
], | ||
) | ||
async def test_reject_access_for_unsafe_file_policy(async_handler, url_suffix) -> None: | ||
test_app = web.Application() | ||
test_app["security_policy"] = SecurityPolicy( | ||
request_policies=[reject_access_for_unsafe_file_policy], response_policies=[] | ||
) | ||
request = make_mocked_request( | ||
"GET", f"/{url_suffix}", headers={"Host": "localhost"}, app=test_app | ||
) | ||
with pytest.raises(web.HTTPForbidden): | ||
await security_policy_middleware(request, async_handler) | ||
|
||
|
||
async def test_add_self_content_security_policy(async_handler) -> None: | ||
test_app = web.Application() | ||
test_app["security_policy"] = SecurityPolicy( | ||
request_policies=[], response_policies=[add_self_content_security_policy] | ||
) | ||
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=test_app) | ||
response = await security_policy_middleware(request, async_handler) | ||
assert ( | ||
response.headers["Content-Security-Policy"] | ||
== "default-src 'self'; frame-ancestors 'none'; form-action 'self';" | ||
) | ||
|
||
|
||
async def test_set_content_type_nosniff_policy(async_handler) -> None: | ||
test_app = web.Application() | ||
test_app["security_policy"] = SecurityPolicy( | ||
request_policies=[], response_policies=[set_content_type_nosniff_policy] | ||
) | ||
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=test_app) | ||
response = await security_policy_middleware(request, async_handler) | ||
assert response.headers["X-Content-Type-Options"] == "nosniff" |