Skip to content

Commit

Permalink
Fix CORS support for exceptions.
Browse files Browse the repository at this point in the history
  • Loading branch information
almet committed Mar 8, 2013
1 parent 26606e4 commit 21fde22
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 57 deletions.
5 changes: 3 additions & 2 deletions cornice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from cornice.service import Service # NOQA
from cornice.pyramidhook import (
wrap_request,
register_service_views
register_service_views,
handle_exceptions
)

logger = logging.getLogger('cornice')
Expand All @@ -35,5 +36,5 @@ def includeme(config):
config.add_directive('add_cornice_service', register_service_views)
config.add_subscriber(add_renderer_globals, BeforeRender)
config.add_subscriber(wrap_request, NewRequest)
config.add_tween('cornice.pyramidhook.tween_factory')
config.add_renderer('simplejson', util.json_renderer)
config.add_view(handle_exceptions, context=Exception)
54 changes: 26 additions & 28 deletions cornice/cors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import fnmatch
import functools


CORS_PARAMETERS = ('cors_headers', 'cors_enabled', 'cors_origins',
Expand Down Expand Up @@ -78,19 +79,14 @@ def _get_method(request):
return method


def get_cors_validator(service):
"""Create a cornice validator to handle CORS-related verifications.
Checks, if an "Origin" header is present, that the origin is authorized
(and issue an error if not)
"""
def ensure_origin(service, request, response=None):
"""Ensure that the origin header is set and allowed."""
response = response or request.response

def _cors_validator(request):
response = request.response
# Don't check this twice.
if not request.info.get('cors_checked', False):
method = _get_method(request)

# If we have an "Origin" header, check it's authorized and add the
# response headers accordingly.
origin = request.headers.get('Origin')
if origin:
if not any([fnmatch.fnmatchcase(origin, o)
Expand All @@ -99,30 +95,32 @@ def _cors_validator(request):
'%s not allowed' % origin)
else:
response.headers['Access-Control-Allow-Origin'] = origin
return _cors_validator
request.info['cors_checked'] = True
return response


def get_cors_validator(service):
return functools.partial(ensure_origin, service)


def get_cors_filter(service):
"""Create a cornice filter to handle CORS-related post-request
things.
def apply_cors_post_request(service, request, response):
"""Handles CORS-related post-request things.
Add some response headers, such as the Expose-Headers and the
Allow-Credentials ones.
"""
response = ensure_origin(service, request, response)
method = _get_method(request)

def _cors_filter(response, request):
method = _get_method(request)
if (service.cors_support_credentials(method) and
not 'Access-Control-Allow-Credentials' in response.headers):
response.headers['Access-Control-Allow-Credentials'] = 'true'

if (service.cors_support_credentials(method) and
not 'Access-Control-Allow-Credentials' in response.headers):
response.headers['Access-Control-Allow-Credentials'] = 'true'

if request.method is not 'OPTIONS':
# Which headers are exposed?
supported_headers = service.cors_supported_headers
if supported_headers:
response.headers['Access-Control-Expose-Headers'] = (
', '.join(supported_headers))
if request.method is not 'OPTIONS':
# Which headers are exposed?
supported_headers = service.cors_supported_headers
if supported_headers:
response.headers['Access-Control-Expose-Headers'] = (
', '.join(supported_headers))

return response
return _cors_filter
return response
59 changes: 35 additions & 24 deletions cornice/pyramidhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from cornice.service import decorate_view
from cornice.errors import Errors
from cornice.util import is_string, to_list
from cornice.cors import (get_cors_filter, get_cors_validator,
get_cors_preflight_view, CORS_PARAMETERS)
from cornice.cors import (
get_cors_validator,
get_cors_preflight_view,
apply_cors_post_request,
CORS_PARAMETERS
)


def match_accept_header(func, info, request):
Expand Down Expand Up @@ -82,33 +86,42 @@ def _fallback_view(request):
return _fallback_view


def tween_factory(handler, registry):
"""Wraps the default WSGI workflow to provide cornice utilities"""
def cornice_tween(request):
response = handler(request)
if request.matched_route is not None:
# do some sanity checking on the response using filters
services = request.registry.get('cornice_services', {})
pattern = request.matched_route.pattern
service = services.get(pattern, None)
if service is not None:
kwargs, ob = getattr(request, "cornice_args", ({}, None))
for _filter in kwargs.get('filters', []):
if is_string(_filter) and ob is not None:
_filter = getattr(ob, _filter)
try:
response = _filter(response, request)
except TypeError:
response = _filter(response)
return response
return cornice_tween
def apply_filters(request, response):
if request.matched_route is not None:
# do some sanity checking on the response using filters
services = request.registry.get('cornice_services', {})
pattern = request.matched_route.pattern
service = services.get(pattern, None)
if service is not None:
kwargs, ob = getattr(request, "cornice_args", ({}, None))
for _filter in kwargs.get('filters', []):
if is_string(_filter) and ob is not None:
_filter = getattr(ob, _filter)
try:
response = _filter(response, request)
except TypeError:
response = _filter(response)

if service.cors_enabled:
apply_cors_post_request(service, request, response)

return response


def handle_exceptions(exc, request):
# At this stage, the checks done by the validators had been removed because
# a new response started (the exception), so we need to do that again.
request.info['cors_checked'] = False
return apply_filters(request, exc)


def wrap_request(event):
"""Adds a "validated" dict, a custom "errors" object and an "info" dict to
the request object if they don't already exists
"""
request = event.request
request.add_response_callback(apply_filters)

if not hasattr(request, 'validated'):
setattr(request, 'validated', {})

Expand Down Expand Up @@ -139,7 +152,6 @@ def register_service_views(config, service):
# register the fallback view, which takes care of returning good error
# messages to the user-agent
cors_validator = get_cors_validator(service)
cors_filter = get_cors_filter(service)

for method, view, args in service.definitions:

Expand All @@ -148,7 +160,6 @@ def register_service_views(config, service):

if service.cors_enabled:
args['validators'].insert(0, cors_validator)
args['filters'].append(cors_filter)

decorated_view = decorate_view(view, dict(args), method)

Expand Down
3 changes: 3 additions & 0 deletions cornice/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ def wrapper(request):

# check for errors and return them if any
if len(request.errors) > 0:
# We already checked for CORS, but since the response is created
# again, we want to do that again before returning the response.
request.info['cors_checked'] = False
return args['error_handler'](request.errors)

# We can't apply filters at this level, since "response" may not have
Expand Down
5 changes: 2 additions & 3 deletions cornice/tests/test_cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ def is_bacon_good(request):

@bacon.get(validators=is_bacon_good)
def get_some_bacon(request):
# if you got here, the only kind of bacon existing is 'good'.
# Okay, you there. Bear in mind, the only kind of bacon existing is 'good'.
if request.matchdict['type'] != 'good':
raise NotFound('Not. Found.')

return "yay"


Expand Down Expand Up @@ -208,7 +207,7 @@ def test_preflight_headers_arent_case_sensitive(self):
'Access-Control-Request-Method': 'GET',
'Access-Control-Request-Headers': 'x-my-header', })

def test_400_return_CORS_headers(self):
def test_400_returns_CORS_headers(self):
resp = self.app.get('/bacon/not', status=400,
headers={'Origin': 'notmyidea.org'})
self.assertIn('Access-Control-Allow-Origin', resp.headers)
Expand Down

0 comments on commit 21fde22

Please sign in to comment.