Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decorator #899

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 39 additions & 59 deletions src/corsheaders/conf.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,57 @@
from __future__ import annotations

from typing import cast
from typing import List
import re
from dataclasses import dataclass
from typing import Any
from typing import Pattern
from typing import Sequence
from typing import Tuple
from typing import Union

from django.conf import settings
from django.conf import settings as _django_settings

from corsheaders.defaults import default_headers
from corsheaders.defaults import default_methods

# Kept here for backwards compatibility


@dataclass
class Settings:
CORS_ALLOW_HEADERS: Sequence[str] = default_headers
CORS_ALLOW_METHODS: Sequence[str] = default_methods
CORS_ALLOW_CREDENTIALS: bool = False
CORS_ALLOW_PRIVATE_NETWORK: bool = False
CORS_PREFLIGHT_MAX_AGE: int = 86400
CORS_ALLOW_ALL_ORIGINS: bool = False
CORS_ALLOWED_ORIGINS: list[str] | tuple[str] = () # type: ignore
CORS_ALLOWED_ORIGIN_REGEXES: Sequence[str | Pattern[str]] = ()
CORS_EXPOSE_HEADERS: Sequence[str] = ()
CORS_URLS_REGEX: str | Pattern[str] = re.compile(r"^.*$")


_RENAMED_SETTINGS = {
# New name -> Old name
"CORS_ALLOW_ALL_ORIGINS": "CORS_ORIGIN_ALLOW_ALL",
"CORS_ALLOWED_ORIGINS": "CORS_ORIGIN_WHITELIST",
"CORS_ALLOWED_ORIGIN_REGEXES": "CORS_ORIGIN_REGEX_WHITELIST",
}


class DjangoConfig(Settings):
"""
Shadow Django's settings with a little logic
"""

@property
def CORS_ALLOW_HEADERS(self) -> Sequence[str]:
return getattr(settings, "CORS_ALLOW_HEADERS", default_headers)

@property
def CORS_ALLOW_METHODS(self) -> Sequence[str]:
return getattr(settings, "CORS_ALLOW_METHODS", default_methods)

@property
def CORS_ALLOW_CREDENTIALS(self) -> bool:
return getattr(settings, "CORS_ALLOW_CREDENTIALS", False)

@property
def CORS_ALLOW_PRIVATE_NETWORK(self) -> bool:
return getattr(settings, "CORS_ALLOW_PRIVATE_NETWORK", False)
A version of Settings that prefers to read from Django's settings.

@property
def CORS_PREFLIGHT_MAX_AGE(self) -> int:
return getattr(settings, "CORS_PREFLIGHT_MAX_AGE", 86400)

@property
def CORS_ALLOW_ALL_ORIGINS(self) -> bool:
return getattr(
settings,
"CORS_ALLOW_ALL_ORIGINS",
getattr(settings, "CORS_ORIGIN_ALLOW_ALL", False),
)

@property
def CORS_ALLOWED_ORIGINS(self) -> list[str] | tuple[str]:
value = getattr(
settings,
"CORS_ALLOWED_ORIGINS",
getattr(settings, "CORS_ORIGIN_WHITELIST", ()),
)
return cast(Union[List[str], Tuple[str]], value)

@property
def CORS_ALLOWED_ORIGIN_REGEXES(self) -> Sequence[str | Pattern[str]]:
return getattr(
settings,
"CORS_ALLOWED_ORIGIN_REGEXES",
getattr(settings, "CORS_ORIGIN_REGEX_WHITELIST", ()),
)

@property
def CORS_EXPOSE_HEADERS(self) -> Sequence[str]:
return getattr(settings, "CORS_EXPOSE_HEADERS", ())
Falls back to its own values if the setting is not configured
in Django.
"""

@property
def CORS_URLS_REGEX(self) -> str | Pattern[str]:
return getattr(settings, "CORS_URLS_REGEX", r"^.*$")
def __getattribute__(self, name: str) -> Any:
default = object.__getattribute__(self, name)
if name in _RENAMED_SETTINGS:
# Renamed settings are used if the new setting
# is not configured in Django,
old_name = _RENAMED_SETTINGS[name]
default = getattr(_django_settings, old_name, default)
return getattr(_django_settings, name, default)


conf = Settings()
conf = DjangoConfig()
46 changes: 46 additions & 0 deletions src/corsheaders/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

import asyncio
import functools
from typing import Any
from typing import Callable
from typing import cast
from typing import TypeVar

from django.http import HttpRequest
from django.http.response import HttpResponseBase

from corsheaders.conf import conf as _conf
from corsheaders.conf import Settings
from corsheaders.middleware import CorsMiddleware

F = TypeVar("F", bound=Callable[..., HttpResponseBase])


def cors(func: F | None = None, *, conf: Settings = _conf) -> F | Callable[[F], F]:
if func is None:
return cast(Callable[[F], F], functools.partial(cors, conf=conf))

assert callable(func)

if asyncio.iscoroutinefunction(func):

async def inner(
_request: HttpRequest, *args: Any, **kwargs: Any
) -> HttpResponseBase:
async def get_response(request: HttpRequest) -> HttpResponseBase:
return await func(request, *args, **kwargs)

return await CorsMiddleware(get_response, conf=conf)(_request)

else:

def inner(_request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
def get_response(request: HttpRequest) -> HttpResponseBase:
return func(request, *args, **kwargs)

return CorsMiddleware(get_response, conf=conf)(_request)

wrapper = functools.wraps(func)(inner)
wrapper._skip_cors_middleware = True # type: ignore [attr-defined]
return cast(F, wrapper)
83 changes: 55 additions & 28 deletions src/corsheaders/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import re
from typing import Any
from typing import Awaitable
from typing import Callable
from urllib.parse import SplitResult
Expand All @@ -13,6 +14,7 @@
from django.utils.cache import patch_vary_headers

from corsheaders.conf import conf
from corsheaders.conf import Settings
from corsheaders.signals import check_request_enabled

ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
Expand All @@ -35,8 +37,10 @@ def __init__(
Callable[[HttpRequest], HttpResponseBase]
| Callable[[HttpRequest], Awaitable[HttpResponseBase]]
),
conf: Settings = conf,
) -> None:
self.get_response = get_response
self.conf = conf
if asyncio.iscoroutinefunction(self.get_response):
# Mark the class as async-capable, but do the actual switch
# inside __call__ to avoid swapping out dunder methods
Expand All @@ -51,22 +55,40 @@ def __call__(
) -> HttpResponseBase | Awaitable[HttpResponseBase]:
if self._is_coroutine:
return self.__acall__(request)
response: HttpResponseBase | None = self.check_preflight(request)
if response is None:
result = self.get_response(request)
assert isinstance(result, HttpResponseBase)
response = result
self.add_response_headers(request, response)
return response
result = self.get_response(request)
assert isinstance(result, HttpResponseBase)
response = result
if getattr(response, "_cors_processing_done", False):
return response
else:
# Request wasn't processed (e.g. because of a 404)
return self.add_response_headers(
request, self.check_preflight(request) or response
)

async def __acall__(self, request: HttpRequest) -> HttpResponseBase:
response = self.check_preflight(request)
if response is None:
result = self.get_response(request)
assert not isinstance(result, HttpResponseBase)
response = await result
self.add_response_headers(request, response)
return response
result = self.get_response(request)
assert not isinstance(result, HttpResponseBase)
response = await result
if getattr(response, "_cors_processing_done", False):
return response
else:
# View wasn't processed (e.g. because of a 404)
return self.add_response_headers(
request, self.check_preflight(request) or response
)

def process_view(
self,
request: HttpRequest,
callback: Callable[[HttpRequest], HttpResponseBase],
callback_args: Any,
callback_kwargs: Any,
) -> HttpResponseBase | None:
if getattr(callback, "_skip_cors_middleware", False):
# View is decorated and will add CORS headers itself
return None
return self.check_preflight(request)

def check_preflight(self, request: HttpRequest) -> HttpResponseBase | None:
"""
Expand All @@ -87,6 +109,7 @@ def add_response_headers(
"""
Add the respective CORS headers
"""
response._cors_processing_done = True
enabled = getattr(request, "_cors_enabled", None)
if enabled is None:
enabled = self.is_enabled(request)
Expand All @@ -105,34 +128,38 @@ def add_response_headers(
except ValueError:
return response

if conf.CORS_ALLOW_CREDENTIALS:
if self.conf.CORS_ALLOW_CREDENTIALS:
response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"

if (
not conf.CORS_ALLOW_ALL_ORIGINS
not self.conf.CORS_ALLOW_ALL_ORIGINS
and not self.origin_found_in_white_lists(origin, url)
and not self.check_signal(request)
):
return response

if conf.CORS_ALLOW_ALL_ORIGINS and not conf.CORS_ALLOW_CREDENTIALS:
if self.conf.CORS_ALLOW_ALL_ORIGINS and not self.conf.CORS_ALLOW_CREDENTIALS:
response[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
else:
response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin

if len(conf.CORS_EXPOSE_HEADERS):
if len(self.conf.CORS_EXPOSE_HEADERS):
response[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
conf.CORS_EXPOSE_HEADERS
self.conf.CORS_EXPOSE_HEADERS
)

if request.method == "OPTIONS":
response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(conf.CORS_ALLOW_HEADERS)
response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(conf.CORS_ALLOW_METHODS)
if conf.CORS_PREFLIGHT_MAX_AGE:
response[ACCESS_CONTROL_MAX_AGE] = str(conf.CORS_PREFLIGHT_MAX_AGE)
response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
self.conf.CORS_ALLOW_HEADERS
)
response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
self.conf.CORS_ALLOW_METHODS
)
if self.conf.CORS_PREFLIGHT_MAX_AGE:
response[ACCESS_CONTROL_MAX_AGE] = str(self.conf.CORS_PREFLIGHT_MAX_AGE)

if (
conf.CORS_ALLOW_PRIVATE_NETWORK
self.conf.CORS_ALLOW_PRIVATE_NETWORK
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
):
response[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
Expand All @@ -141,28 +168,28 @@ def add_response_headers(

def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool:
return (
(origin == "null" and origin in conf.CORS_ALLOWED_ORIGINS)
(origin == "null" and origin in self.conf.CORS_ALLOWED_ORIGINS)
or self._url_in_whitelist(url)
or self.regex_domain_match(origin)
)

def regex_domain_match(self, origin: str) -> bool:
return any(
re.match(domain_pattern, origin)
for domain_pattern in conf.CORS_ALLOWED_ORIGIN_REGEXES
for domain_pattern in self.conf.CORS_ALLOWED_ORIGIN_REGEXES
)

def is_enabled(self, request: HttpRequest) -> bool:
return bool(
re.match(conf.CORS_URLS_REGEX, request.path_info)
re.match(self.conf.CORS_URLS_REGEX, request.path_info)
) or self.check_signal(request)

def check_signal(self, request: HttpRequest) -> bool:
signal_responses = check_request_enabled.send(sender=None, request=request)
return any(return_value for function, return_value in signal_responses)

def _url_in_whitelist(self, url: SplitResult) -> bool:
origins = [urlsplit(o) for o in conf.CORS_ALLOWED_ORIGINS]
origins = [urlsplit(o) for o in self.conf.CORS_ALLOWED_ORIGINS]
return any(
origin.scheme == url.scheme and origin.netloc == url.netloc
for origin in origins
Expand Down
5 changes: 5 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@


class ConfTests(SimpleTestCase):
@override_settings(SECRET_KEY="foo")
def test_other_setting(self):
# Only proxy settings that are defined in the Settings class.
self.assertRaises(AttributeError, getattr, conf, "SECRET_KEY")

@override_settings(CORS_ALLOW_HEADERS=["foo"])
def test_can_override(self):
assert conf.CORS_ALLOW_HEADERS == ["foo"]
Expand Down
Loading