From b87c45b6af509f260539634a7ac7a1d338984dc5 Mon Sep 17 00:00:00 2001 From: jcarlyl Date: Fri, 28 Apr 2017 09:07:29 -0700 Subject: [PATCH] Reread the cors spec and the apigateway documentation and updated. I misunderstood the spec slightly, this update corrects that misunderstanding and increases code coverage. --- chalice/app.py | 21 ++++++++++++++------- chalice/app.pyi | 2 ++ chalice/deploy/swagger.py | 19 ++++++++----------- tests/unit/deploy/test_swagger.py | 23 ++++++++++++++++++----- tests/unit/test_local.py | 28 +++++++++++++++++++++++++++- 5 files changed, 69 insertions(+), 24 deletions(-) diff --git a/chalice/app.py b/chalice/app.py index 73549790b6..3b4f6f3dee 100644 --- a/chalice/app.py +++ b/chalice/app.py @@ -111,9 +111,9 @@ def __init__(self, allow_origin='*', allow_headers=None, self.allow_origin = allow_origin if allow_headers is None: - allow_headers = list(self._REQUIRED_HEADERS) + allow_headers = set(self._REQUIRED_HEADERS) else: - allow_headers.extend(self._REQUIRED_HEADERS) + allow_headers = set(allow_headers + self._REQUIRED_HEADERS) self._allow_headers = allow_headers if expose_headers is None: @@ -125,12 +125,12 @@ def __init__(self, allow_origin='*', allow_headers=None, @property def allow_headers(self): - return ','.join(self._allow_headers) + return ','.join(sorted(self._allow_headers)) def get_access_control_headers(self): headers = { 'Access-Control-Allow-Origin': self.allow_origin, - 'Access-Control-Allow-Headers': ','.join(self._allow_headers), + 'Access-Control-Allow-Headers': self.allow_headers } if self._expose_headers: headers.update({ @@ -140,9 +140,9 @@ def get_access_control_headers(self): headers.update({ 'Access-Control-Max-Age': str(self._max_age) }) - if self._allow_credentials is not None: + if self._allow_credentials is True: headers.update({ - 'Access-Control-Allow-Credentials': self._allow_credentials + 'Access-Control-Allow-Credentials': 'true' }) return headers @@ -215,6 +215,11 @@ def __init__(self, view_function, view_name, path, methods, #: e.g, '/foo/{bar}/{baz}/qux -> ['bar', 'baz'] self.view_args = self._parse_view_args() self.content_types = content_types + # cors is passed as either a boolean or a CORSConfig object. If it is a + # boolean it needs to be replaced with a real CORSConfig object to + # pass the typechecker. None in this context will not inject any cors + # headers, otherwise the CORSConfig object will determine which + # headers are injected. if cors is True: cors = CORSConfig() elif cors is False: @@ -409,4 +414,6 @@ def _cors_enabled_for_route(self, route_entry): return route_entry.cors is not None def _add_cors_headers(self, response, cors): - response.headers.update(cors.get_access_control_headers()) + for name, value in cors.get_access_control_headers().items(): + if name not in response.headers: + response.headers[name] = value diff --git a/chalice/app.pyi b/chalice/app.pyi index 456ae963dd..e3003e7e63 100644 --- a/chalice/app.pyi +++ b/chalice/app.pyi @@ -17,6 +17,8 @@ ALL_ERRORS = ... # type: List[ChaliceViewError] class CORSConfig: allow_origin = ... # type: str allow_headers = ... # type: str + get_access_control_headers = ... # type: Callable[..., Dict[str, str]] + class Request: query_params = ... # type: Dict[str, str] diff --git a/chalice/deploy/swagger.py b/chalice/deploy/swagger.py index 55113ef9c9..ccbe14ebb8 100644 --- a/chalice/deploy/swagger.py +++ b/chalice/deploy/swagger.py @@ -144,14 +144,15 @@ def _add_preflight_request(self, view, swagger_for_path): cors = view.cors methods = view.methods + ['OPTIONS'] allowed_methods = ','.join(methods) + response_params = { - "method.response.header.Access-Control-Allow-Methods": ( - "'%s'" % allowed_methods), - "method.response.header.Access-Control-Allow-Origin": ( - "'%s'" % cors.allow_origin), - "method.response.header.Access-Control-Allow-Headers": ( - "'%s'" % cors.allow_headers) + 'Access-Control-Allow-Methods': '%s' % allowed_methods } + response_params.update(cors.get_access_control_headers()) + + headers = {k: {'type': 'string'} for k, _ in response_params.items()} + response_params = {'method.response.header.%s' % k: "'%s'" % v for k, v + in response_params.items()} options_request = { "consumes": ["application/json"], @@ -160,11 +161,7 @@ def _add_preflight_request(self, view, swagger_for_path): "200": { "description": "200 response", "schema": {"$ref": "#/definitions/Empty"}, - "headers": { - "Access-Control-Allow-Origin": {"type": "string"}, - "Access-Control-Allow-Methods": {"type": "string"}, - "Access-Control-Allow-Headers": {"type": "string"}, - } + "headers": headers } }, "x-amazon-apigateway-integration": { diff --git a/tests/unit/deploy/test_swagger.py b/tests/unit/deploy/test_swagger.py index de5adae3a7..8ffb94cfcb 100644 --- a/tests/unit/deploy/test_swagger.py +++ b/tests/unit/deploy/test_swagger.py @@ -74,7 +74,10 @@ def multiple_methods(): def test_can_add_preflight_cors(sample_app, swagger_gen): @sample_app.route('/cors', methods=['GET', 'POST'], cors=CORSConfig( allow_origin='http://foo.com', - allow_headers=['X-Special-Header'])) + allow_headers=['X-ZZ-Top', 'X-Special-Header'], + expose_headers=['X-Exposed', 'X-Special'], + max_age=600, + allow_credentials=True)) def cors_request(): pass @@ -88,10 +91,17 @@ def cors_request(): 'method.response.header.Access-Control-Allow-Methods': ( "'GET,POST,OPTIONS'"), 'method.response.header.Access-Control-Allow-Headers': ( - "'X-Special-Header,Content-Type,X-Amz-Date,Authorization," - "X-Api-Key,X-Amz-Security-Token'"), + "'Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token," + "X-Api-Key,X-Special-Header,X-ZZ-Top'"), 'method.response.header.Access-Control-Allow-Origin': ( "'http://foo.com'"), + 'method.response.header.Access-Control-Expose-Headers': ( + "'X-Exposed,X-Special'"), + 'method.response.header.Access-Control-Max-Age': ( + "'600'"), + 'method.response.header.Access-Control-Allow-Credentials': ( + "'true'"), + } assert options == { 'consumes': ['application/json'], @@ -106,6 +116,9 @@ def cors_request(): 'Access-Control-Allow-Origin': {'type': 'string'}, 'Access-Control-Allow-Methods': {'type': 'string'}, 'Access-Control-Allow-Headers': {'type': 'string'}, + 'Access-Control-Expose-Headers': {'type': 'string'}, + 'Access-Control-Max-Age': {'type': 'string'}, + 'Access-Control-Allow-Credentials': {'type': 'string'}, } } }, @@ -140,8 +153,8 @@ def cors_request(): 'method.response.header.Access-Control-Allow-Methods': ( "'GET,POST,OPTIONS'"), 'method.response.header.Access-Control-Allow-Headers': ( - "'Content-Type,X-Amz-Date,Authorization," - "X-Api-Key,X-Amz-Security-Token'"), + "'Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token," + "X-Api-Key'"), 'method.response.header.Access-Control-Allow-Origin': "'*'", } assert options == { diff --git a/tests/unit/test_local.py b/tests/unit/test_local.py index 42a7a6d9bb..841a23d1c6 100644 --- a/tests/unit/test_local.py +++ b/tests/unit/test_local.py @@ -1,4 +1,4 @@ -from chalice import local, BadRequestError +from chalice import local, BadRequestError, CORSConfig import json import decimal import pytest @@ -42,6 +42,16 @@ def put(): def cors(): return {'cors': True} + @demo.route('/custom_cors', methods=['GET', 'PUT'], cors=CORSConfig( + allow_origin='https://foo.bar', + allow_headers=['Header-A', 'Header-B'], + expose_headers=['Header-A', 'Header-B'], + max_age=600, + allow_credentials=True + )) + def custom_cors(): + return {'cors': True} + @demo.route('/options', methods=['OPTIONS']) def options(): return {'options': True} @@ -129,6 +139,22 @@ def test_will_respond_with_cors_enabled(handler): assert b'Access-Control-Allow-Origin: *' in response_lines +def test_will_respond_with_custom_cors_enabled(handler): + headers = {'content-type': 'application/json', 'origin': 'null'} + set_current_request(handler, method='GET', path='/custom_cors', + headers=headers) + handler.do_GET() + response = handler.wfile.getvalue().splitlines() + print(response) + assert b'Access-Control-Allow-Origin: https://foo.bar' in response + assert (b'Access-Control-Allow-Headers: Authorization,Content-Type,' + b'Header-A,Header-B,X-Amz-Date,X-Amz-Security-Token,' + b'X-Api-Key') in response + assert b'Access-Control-Expose-Headers: Header-A,Header-B' in response + assert b'Access-Control-Max-Age: 600' in response + assert b'Access-Control-Allow-Credentials: true' in response + + def test_can_preflight_request(handler): headers = {'content-type': 'application/json', 'origin': 'null'} set_current_request(handler, method='OPTIONS', path='/cors',