Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion samcli/commands/local/lib/local_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _make_routing_list(api_provider):
routes = []
for api in api_provider.get_all():
route = Route(methods=[api.method], function_name=api.function_name, path=api.path,
binary_types=api.binary_media_types)
binary_types=api.binary_media_types, cors=api.cors)
routes.append(route)

return routes
Expand Down
2 changes: 1 addition & 1 deletion samcli/commands/local/lib/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __hash__(self):
return hash(self.path) * hash(self.method) * hash(self.function_name)


Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"])
Cors = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"])


class ApiProvider(object):
Expand Down
27 changes: 26 additions & 1 deletion samcli/commands/local/lib/sam_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from six import string_types

from samcli.commands.local.lib.swagger.parser import SwaggerParser
from samcli.commands.local.lib.provider import ApiProvider, Api
from samcli.commands.local.lib.provider import ApiProvider, Api, Cors
from samcli.commands.local.lib.sam_base_provider import SamBaseProvider
from samcli.commands.local.lib.swagger.reader import SamSwaggerReader
from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException
Expand Down Expand Up @@ -127,6 +127,7 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector):
body = properties.get("DefinitionBody")
uri = properties.get("DefinitionUri")
binary_media = properties.get("BinaryMediaTypes", [])
cors = properties.get("Cors", {})

if not body and not uri:
# Swagger is not found anywhere.
Expand All @@ -146,6 +147,13 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector):
collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger
collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template

collector.add_cors(logical_id, Cors(
allow_origin=cors.get("AllowOrigin"),
allow_methods=cors.get("AllowMethods"),
allow_headers=cors.get("AllowHeaders"),
max_age=cors.get("MaxAge")
))

@staticmethod
def _merge_apis(collector):
"""
Expand Down Expand Up @@ -387,6 +395,23 @@ def add_binary_media_types(self, logical_id, binary_media_types):
else:
LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id)

def add_cors(self, logical_id, cors):
"""
Stores the CORS configuration for the API with given logical ID

Parameters
----------
logical_id : str
LogicalId of the AWS::Serverless::Api resource

apis : samcli.commands.local.lib.provider.Cors
CORS configuration for the given resource
"""
old_properties = self._get_properties(logical_id)
self.by_resource[logical_id] = self.Properties(apis=old_properties.apis,
binary_media_types=old_properties.binary_media_types,
cors=cors)

def _get_apis_with_config(self, logical_id):
"""
Returns the list of APIs in this resource along with other extra configuration such as binary media types,
Expand Down
47 changes: 44 additions & 3 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@

class Route(object):

def __init__(self, methods, function_name, path, binary_types=None):
def __init__(self, methods, function_name, path, binary_types=None, cors=None):
"""
Creates an ApiGatewayRoute

:param list(str) methods: List of HTTP Methods
:param function_name: Name of the Lambda function this API is connected to
:param str path: Path off the base url
:param Cors cors: Cors configuration for the route
"""
self.methods = methods
self.function_name = function_name
self.path = path
self.binary_types = binary_types or []
self.cors = cors


class LocalApigwService(BaseLocalService):
Expand All @@ -43,7 +45,7 @@ def __init__(self, routing_list, lambda_runner, static_dir=None, port=None, host

Parameters
----------
routing_list list(ApiGatewayCallModel)
routing_list list(Route)
A list of the Model that represent the service paths to create.
lambda_runner samcli.commands.local.lib.local_lambda.LocalLambdaRunner
The Lambda runner class capable of invoking the function
Expand Down Expand Up @@ -81,10 +83,18 @@ def create(self):
path):
self._dict_of_routes[route_key] = api_gateway_route

methods = api_gateway_route.methods.copy()
if api_gateway_route.cors is not None:
methods.append('OPTIONS')

route_key = self._route_key('OPTIONS', api_gateway_route.path)
if route_key not in self._dict_of_routes:
self._dict_of_routes[route_key] = api_gateway_route

self._app.add_url_rule(path,
endpoint=path,
view_func=self._request_handler,
methods=api_gateway_route.methods,
methods=methods,
provide_automatic_options=False)

self._construct_error_handling()
Expand Down Expand Up @@ -141,6 +151,9 @@ def _request_handler(self, **kwargs):
"""
route = self._get_current_route(request)

if request.method == 'OPTIONS':
return self.service_response('', LocalApigwService._cors_to_headers(route.cors), 200)

try:
event = self._construct_event(request, self.port, route.binary_types)
except UnicodeDecodeError:
Expand Down Expand Up @@ -358,3 +371,31 @@ def _should_base64_encode(binary_types, request_mimetype):

"""
return request_mimetype in binary_types or "*/*" in binary_types


@staticmethod
def _cors_to_headers(cors):
"""
Convert CORS object to headers dictionary

Parameters
----------
cors list(samcli.commands.local.lib.provider.Cors)
CORS configuration objcet

Returns
-------
Dictionary with CORS headers

"""
headers = {}
if cors.allow_origin is not None:
headers['Access-Control-Allow-Origin'] = cors.allow_origin[1:-1]
if cors.allow_methods is not None:
headers['Access-Control-Allow-Methods'] = cors.allow_methods[1:-1]
if cors.allow_headers is not None:
headers['Access-Control-Allow-Headers'] = cors.allow_headers[1:-1]
if cors.max_age is not None:
headers['Access-Control-Max-Age'] = cors.max_age[1:-1]

return headers