diff --git a/samcli/commands/local/lib/local_api_service.py b/samcli/commands/local/lib/local_api_service.py index df15675191..a15585c03d 100644 --- a/samcli/commands/local/lib/local_api_service.py +++ b/samcli/commands/local/lib/local_api_service.py @@ -100,7 +100,7 @@ def _make_routing_list(api_provider): routes = [] for api in api_provider.get_all(): route = Route(methods=[api.method], function_name=api.function_name, path=api.path, - binary_types=api.binary_media_types) + binary_types=api.binary_media_types, cors=api.cors) routes.append(route) return routes diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index c8265c81e6..324b577cdd 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -229,7 +229,7 @@ def __hash__(self): return hash(self.path) * hash(self.method) * hash(self.function_name) -Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) +Cors = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"]) class ApiProvider(object): diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 19b0559a3b..fb4584ebdb 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -6,7 +6,7 @@ from six import string_types from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.provider import ApiProvider, Api +from samcli.commands.local.lib.provider import ApiProvider, Api, Cors from samcli.commands.local.lib.sam_base_provider import SamBaseProvider from samcli.commands.local.lib.swagger.reader import SamSwaggerReader from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException @@ -127,6 +127,7 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): body = properties.get("DefinitionBody") uri = properties.get("DefinitionUri") binary_media = properties.get("BinaryMediaTypes", []) + cors = properties.get("Cors", {}) if not body and not uri: # Swagger is not found anywhere. @@ -146,6 +147,13 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template + collector.add_cors(logical_id, Cors( + allow_origin=cors.get("AllowOrigin"), + allow_methods=cors.get("AllowMethods"), + allow_headers=cors.get("AllowHeaders"), + max_age=cors.get("MaxAge") + )) + @staticmethod def _merge_apis(collector): """ @@ -387,6 +395,23 @@ def add_binary_media_types(self, logical_id, binary_media_types): else: LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) + def add_cors(self, logical_id, cors): + """ + Stores the CORS configuration for the API with given logical ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + apis : samcli.commands.local.lib.provider.Cors + CORS configuration for the given resource + """ + old_properties = self._get_properties(logical_id) + self.by_resource[logical_id] = self.Properties(apis=old_properties.apis, + binary_media_types=old_properties.binary_media_types, + cors=cors) + def _get_apis_with_config(self, logical_id): """ Returns the list of APIs in this resource along with other extra configuration such as binary media types, diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index fb71e0ed6a..2afb165b7f 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -18,18 +18,20 @@ class Route(object): - def __init__(self, methods, function_name, path, binary_types=None): + def __init__(self, methods, function_name, path, binary_types=None, cors=None): """ Creates an ApiGatewayRoute :param list(str) methods: List of HTTP Methods :param function_name: Name of the Lambda function this API is connected to :param str path: Path off the base url + :param Cors cors: Cors configuration for the route """ self.methods = methods self.function_name = function_name self.path = path self.binary_types = binary_types or [] + self.cors = cors class LocalApigwService(BaseLocalService): @@ -43,7 +45,7 @@ def __init__(self, routing_list, lambda_runner, static_dir=None, port=None, host Parameters ---------- - routing_list list(ApiGatewayCallModel) + routing_list list(Route) A list of the Model that represent the service paths to create. lambda_runner samcli.commands.local.lib.local_lambda.LocalLambdaRunner The Lambda runner class capable of invoking the function @@ -81,10 +83,18 @@ def create(self): path): self._dict_of_routes[route_key] = api_gateway_route + methods = api_gateway_route.methods.copy() + if api_gateway_route.cors is not None: + methods.append('OPTIONS') + + route_key = self._route_key('OPTIONS', api_gateway_route.path) + if route_key not in self._dict_of_routes: + self._dict_of_routes[route_key] = api_gateway_route + self._app.add_url_rule(path, endpoint=path, view_func=self._request_handler, - methods=api_gateway_route.methods, + methods=methods, provide_automatic_options=False) self._construct_error_handling() @@ -141,6 +151,9 @@ def _request_handler(self, **kwargs): """ route = self._get_current_route(request) + if request.method == 'OPTIONS': + return self.service_response('', LocalApigwService._cors_to_headers(route.cors), 200) + try: event = self._construct_event(request, self.port, route.binary_types) except UnicodeDecodeError: @@ -358,3 +371,31 @@ def _should_base64_encode(binary_types, request_mimetype): """ return request_mimetype in binary_types or "*/*" in binary_types + + + @staticmethod + def _cors_to_headers(cors): + """ + Convert CORS object to headers dictionary + + Parameters + ---------- + cors list(samcli.commands.local.lib.provider.Cors) + CORS configuration objcet + + Returns + ------- + Dictionary with CORS headers + + """ + headers = {} + if cors.allow_origin is not None: + headers['Access-Control-Allow-Origin'] = cors.allow_origin[1:-1] + if cors.allow_methods is not None: + headers['Access-Control-Allow-Methods'] = cors.allow_methods[1:-1] + if cors.allow_headers is not None: + headers['Access-Control-Allow-Headers'] = cors.allow_headers[1:-1] + if cors.max_age is not None: + headers['Access-Control-Max-Age'] = cors.max_age[1:-1] + + return headers