diff --git a/samcli/commands/local/lib/api_collector.py b/samcli/commands/local/lib/api_collector.py index be18cea8c8..4c4c1abe8c 100644 --- a/samcli/commands/local/lib/api_collector.py +++ b/samcli/commands/local/lib/api_collector.py @@ -25,6 +25,7 @@ def __init__(self): self.binary_media_types_set = set() self.stage_name = None self.stage_variables = None + self.cors = None def __iter__(self): """ @@ -103,12 +104,40 @@ def get_api(self): An Api object with all the properties """ api = Api() - api.routes = self.dedupe_function_routes(self.routes) + routes = self.dedupe_function_routes(self.routes) + routes = self.normalize_cors_methods(routes, self.cors) + api.routes = routes api.binary_media_types_set = self.binary_media_types_set api.stage_name = self.stage_name api.stage_variables = self.stage_variables + api.cors = self.cors return api + @staticmethod + def normalize_cors_methods(routes, cors): + """ + Adds OPTIONS method to all the route methods if cors exists + + Parameters + ----------- + routes: list(samcli.local.apigw.local_apigw_service.Route) + List of Routes + + cors: samcli.commands.local.lib.provider.Cors + the cors object for the api + + Return + ------- + A list of routes without duplicate routes with the same function_name and method + """ + + def add_options_to_route(route): + if "OPTIONS" not in route.methods: + route.methods.append("OPTIONS") + return route + + return routes if not cors else [add_options_to_route(route) for route in routes] + @staticmethod def dedupe_function_routes(routes): """ diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 94789ba799..a61891dbfc 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -227,7 +227,41 @@ def binary_media_types(self): return list(self.binary_media_types_set) -Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) +_CorsTuple = namedtuple("Cors", ["allow_origin", "allow_methods", "allow_headers", "max_age"]) + + +_CorsTuple.__new__.__defaults__ = (None, # Allow Origin defaults to None + None, # Allow Methods is optional and defaults to empty + None, # Allow Headers is optional and defaults to empty + None # MaxAge is optional and defaults to empty + ) + + +class Cors(_CorsTuple): + + @staticmethod + def cors_to_headers(cors): + """ + Convert CORS object to headers dictionary + Parameters + ---------- + cors list(samcli.commands.local.lib.provider.Cors) + CORS configuration objcet + Returns + ------- + Dictionary with CORS headers + """ + if not cors: + return {} + headers = { + 'Access-Control-Allow-Origin': cors.allow_origin, + 'Access-Control-Allow-Methods': cors.allow_methods, + 'Access-Control-Allow-Headers': cors.allow_headers, + 'Access-Control-Max-Age': cors.max_age + } + # Filters out items in the headers dictionary that isn't empty. + # This is required because the flask Headers dict will send an invalid 'None' string + return {h_key: h_value for h_key, h_value in headers.items() if h_value is not None} class AbstractApiProvider(object): diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 1710edbf2d..9fda35a934 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -2,6 +2,9 @@ import logging +from six import string_types + +from samcli.commands.local.lib.provider import Cors from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.local.apigw.local_apigw_service import Route @@ -77,9 +80,9 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= body = properties.get("DefinitionBody") uri = properties.get("DefinitionUri") binary_media = properties.get("BinaryMediaTypes", []) + cors = self.extract_cors(properties.get("Cors", {})) stage_name = properties.get("StageName") stage_variables = properties.get("Variables") - if not body and not uri: # Swagger is not found anywhere. LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri", @@ -88,6 +91,65 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= self.extract_swagger_route(logical_id, body, uri, binary_media, collector, cwd=cwd) collector.stage_name = stage_name collector.stage_variables = stage_variables + collector.cors = cors + + def extract_cors(self, cors_prop): + """ + Extract Cors property from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result + is added to the Api. + + Parameters + ---------- + cors_prop : dict + Resource properties for Cors + """ + cors = None + if cors_prop and isinstance(cors_prop, dict): + allow_methods = cors_prop.get("AllowMethods", ','.join(sorted(Route.ANY_HTTP_METHODS))) + allow_methods = self.normalize_cors_allow_methods(allow_methods) + cors = Cors( + allow_origin=cors_prop.get("AllowOrigin"), + allow_methods=allow_methods, + allow_headers=cors_prop.get("AllowHeaders"), + max_age=cors_prop.get("MaxAge") + ) + elif cors_prop and isinstance(cors_prop, string_types): + cors = Cors( + allow_origin=cors_prop, + allow_methods=','.join(sorted(Route.ANY_HTTP_METHODS)), + allow_headers=None, + max_age=None + ) + return cors + + @staticmethod + def normalize_cors_allow_methods(allow_methods): + """ + Normalize cors AllowMethods and Options to the methods if it's missing. + + Parameters + ---------- + allow_methods : str + The allow_methods string provided in the query + + Return + ------- + A string with normalized route + """ + if allow_methods == "*": + return ','.join(sorted(Route.ANY_HTTP_METHODS)) + methods = allow_methods.split(",") + normalized_methods = [] + for method in methods: + normalized_method = method.strip().upper() + if normalized_method not in Route.ANY_HTTP_METHODS: + raise InvalidSamDocumentException("The method {} is not a valid CORS method".format(normalized_method)) + normalized_methods.append(normalized_method) + + if "OPTIONS" not in normalized_methods: + normalized_methods.append("OPTIONS") + + return ','.join(sorted(normalized_methods)) def _extract_routes_from_function(self, logical_id, function_resource, collector): """ @@ -96,7 +158,7 @@ def _extract_routes_from_function(self, logical_id, function_resource, collector Parameters ---------- logical_id : str - Logical ID of the resource + Logical ID of the resourc function_resource : dict Contents of the function resource including its properties diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index f3eb02dc78..37425e2520 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -7,6 +7,7 @@ from flask import Flask, request from werkzeug.datastructures import Headers +from samcli.commands.local.lib.provider import Cors from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser from samcli.lib.utils.stream_writer import StreamWriter from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -170,6 +171,12 @@ def _request_handler(self, **kwargs): """ route = self._get_current_route(request) + cors_headers = Cors.cors_to_headers(self.api.cors) + + method, _ = self.get_request_methods_endpoints(request) + if method == 'OPTIONS': + headers = Headers(cors_headers) + return self.service_response('', headers, 200) try: event = self._construct_event(request, self.port, self.api.binary_media_types, self.api.stage_name, @@ -209,8 +216,7 @@ def _get_current_route(self, flask_request): :param request flask_request: Flask Request :return: Route matching the endpoint and method of the request """ - endpoint = flask_request.endpoint - method = flask_request.method + method, endpoint = self.get_request_methods_endpoints(flask_request) route_key = self._route_key(method, endpoint) route = self._dict_of_routes.get(route_key, None) @@ -223,6 +229,16 @@ def _get_current_route(self, flask_request): return route + def get_request_methods_endpoints(self, flask_request): + """ + Separated out for testing requests in request handler + :param request flask_request: Flask Request + :return: the request's endpoint and method + """ + endpoint = flask_request.endpoint + method = flask_request.method + return method, endpoint + # Consider moving this out to its own class. Logic is started to get dense and looks messy @jfuss @staticmethod def _parse_lambda_output(lambda_output, binary_types, flask_request): @@ -451,6 +467,8 @@ def _event_headers(flask_request, port): Request from Flask int port Forwarded Port + cors_headers dict + Dict of the Cors properties Returns dict (str: str), dict (str: list of str) ------- @@ -471,7 +489,6 @@ def _event_headers(flask_request, port): headers_dict["X-Forwarded-Port"] = str(port) multi_value_headers_dict["X-Forwarded-Port"] = [str(port)] - return headers_dict, multi_value_headers_dict @staticmethod diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index b32c341901..63763e4282 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -2,6 +2,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from time import time +from samcli.local.apigw.local_apigw_service import Route from .start_api_integ_base import StartApiIntegBaseClass @@ -664,6 +665,67 @@ def test_swagger_stage_variable(self): self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'}) +class TestServiceCorsSwaggerRequests(StartApiIntegBaseClass): + """ + Test to check that the correct headers are being added with Cors with swagger code + """ + template_path = "/testdata/start_api/swagger-template.yaml" + binary_data_file = "testdata/start_api/binarydata.gif" + + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + + def test_cors_swagger_options(self): + """ + This tests that the Cors are added to option requests in the swagger template + """ + response = requests.options(self.url + '/echobase64eventbody') + + self.assertEquals(response.status_code, 200) + + self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") + self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), "origin, x-requested-with") + self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), "GET,OPTIONS") + self.assertEquals(response.headers.get("Access-Control-Max-Age"), '510') + + +class TestServiceCorsGlobalRequests(StartApiIntegBaseClass): + """ + Test to check that the correct headers are being added with Cors with the global property + """ + template_path = "/testdata/start_api/template.yaml" + + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + + def test_cors_global(self): + """ + This tests that the Cors are added to options requests when the global property is set + """ + response = requests.options(self.url + '/echobase64eventbody') + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), "*") + self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None) + self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), + ','.join(sorted(Route.ANY_HTTP_METHODS))) + self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) + + def test_cors_global_get(self): + """ + This tests that the Cors are added to post requests when the global property is set + """ + response = requests.get(self.url + "/onlysetstatuscode") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.content.decode('utf-8'), "no data") + self.assertEquals(response.headers.get("Content-Type"), "application/json") + self.assertEquals(response.headers.get("Access-Control-Allow-Origin"), None) + self.assertEquals(response.headers.get("Access-Control-Allow-Headers"), None) + self.assertEquals(response.headers.get("Access-Control-Allow-Methods"), None) + self.assertEquals(response.headers.get("Access-Control-Max-Age"), None) + + class TestStartApiWithCloudFormationStage(StartApiIntegBaseClass): """ Test Class centered around the different responses that can happen in Lambda and pass through start-api diff --git a/tests/integration/testdata/start_api/swagger-template.yaml b/tests/integration/testdata/start_api/swagger-template.yaml index 9f987c0d8c..cff33b1f43 100644 --- a/tests/integration/testdata/start_api/swagger-template.yaml +++ b/tests/integration/testdata/start_api/swagger-template.yaml @@ -1,4 +1,4 @@ -AWSTemplateFormatVersion : '2010-09-09' +AWSTemplateFormatVersion: '2010-09-09' Transform: AWS::Serverless-2016-10-31 Globals: @@ -14,6 +14,11 @@ Resources: StageName: dev Variables: VarName: varValue + Cors: + AllowOrigin: "*" + AllowMethods: "GET" + AllowHeaders: "origin, x-requested-with" + MaxAge: 510 DefinitionBody: swagger: "2.0" info: diff --git a/tests/integration/testdata/start_api/template.yaml b/tests/integration/testdata/start_api/template.yaml index 763009613a..ec0b65978c 100644 --- a/tests/integration/testdata/start_api/template.yaml +++ b/tests/integration/testdata/start_api/template.yaml @@ -9,6 +9,7 @@ Globals: - image~1png Variables: VarName: varValue + Cors: "*" Resources: HelloWorldFunction: Type: AWS::Serverless::Function diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 03fde51e4c..3e582ebac4 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -7,21 +7,15 @@ from nose_parameterized import parameterized from six import assertCountEqual +from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.commands.local.lib.provider import Cors from samcli.local.apigw.local_apigw_service import Route class TestSamApiProviderWithImplicitApis(TestCase): - def test_provider_with_no_resource_properties(self): - template = { - "Resources": { - - "SamFunc1": { - "Type": "AWS::Lambda::Function" - } - } - } + template = {"Resources": {"SamFunc1": {"Type": "AWS::Lambda::Function"}}} provider = ApiProvider(template) @@ -31,7 +25,6 @@ def test_provider_with_no_resource_properties(self): def test_provider_has_correct_api(self, method): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -41,13 +34,10 @@ def test_provider_has_correct_api(self, method): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": method - } + "Properties": {"Path": "/path", "Method": method}, } - } - } + }, + }, } } } @@ -55,12 +45,14 @@ def test_provider_has_correct_api(self, method): provider = ApiProvider(template) self.assertEquals(len(provider.routes), 1) - self.assertEquals(list(provider.routes)[0], Route(path="/path", methods=["GET"], function_name="SamFunc1")) + self.assertEquals( + list(provider.routes)[0], + Route(path="/path", methods=["GET"], function_name="SamFunc1"), + ) def test_provider_creates_api_for_all_events(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -70,20 +62,14 @@ def test_provider_creates_api_for_all_events(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "GET" - } + "Properties": {"Path": "/path", "Method": "GET"}, }, "Event2": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "POST" - } - } - } - } + "Properties": {"Path": "/path", "Method": "POST"}, + }, + }, + }, } } } @@ -98,7 +84,6 @@ def test_provider_creates_api_for_all_events(self): def test_provider_has_correct_template(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -108,13 +93,10 @@ def test_provider_has_correct_template(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "GET" - } + "Properties": {"Path": "/path", "Method": "GET"}, } - } - } + }, + }, }, "SamFunc2": { "Type": "AWS::Serverless::Function", @@ -125,14 +107,11 @@ def test_provider_has_correct_template(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "POST" - } + "Properties": {"Path": "/path", "Method": "POST"}, } - } - } - } + }, + }, + }, } } @@ -147,7 +126,6 @@ def test_provider_has_correct_template(self): def test_provider_with_no_api_events(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -157,12 +135,10 @@ def test_provider_with_no_api_events(self): "Events": { "Event1": { "Type": "S3", - "Properties": { - "Property1": "value" - } + "Properties": {"Property1": "value"}, } - } - } + }, + }, } } } @@ -174,14 +150,13 @@ def test_provider_with_no_api_events(self): def test_provider_with_no_serverless_function(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Lambda::Function", "Properties": { "CodeUri": "/usr/foo/bar", "Runtime": "nodejs4.3", - "Handler": "index.handler" - } + "Handler": "index.handler", + }, } } } @@ -193,7 +168,6 @@ def test_provider_with_no_serverless_function(self): def test_provider_get_all(self): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -203,13 +177,10 @@ def test_provider_get_all(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "GET" - } + "Properties": {"Path": "/path", "Method": "GET"}, } - } - } + }, + }, }, "SamFunc2": { "Type": "AWS::Serverless::Function", @@ -220,14 +191,11 @@ def test_provider_get_all(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "POST" - } + "Properties": {"Path": "/path", "Method": "POST"}, } - } - } - } + }, + }, + }, } } @@ -255,7 +223,6 @@ def test_provider_get_all_with_no_routes(self): def test_provider_with_any_method(self, method): template = { "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -265,26 +232,21 @@ def test_provider_with_any_method(self, method): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": method - } + "Properties": {"Path": "/path", "Method": method}, } - } - } + }, + }, } } } provider = ApiProvider(template) - api1 = Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], function_name="SamFunc1") + api1 = Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="SamFunc1", + ) self.assertEquals(len(provider.routes), 1) self.assertIn(api1, provider.routes) @@ -297,12 +259,11 @@ def test_provider_must_support_binary_media_types(self): "image~1gif", "image~1png", "image~1png", # Duplicates must be ignored - {"Ref": "SomeParameter"} # Refs are ignored as well + {"Ref": "SomeParameter"}, # Refs are ignored as well ] } }, "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -312,37 +273,32 @@ def test_provider_must_support_binary_media_types(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "get" - } + "Properties": {"Path": "/path", "Method": "get"}, } - } - } + }, + }, } - } + }, } provider = ApiProvider(template) self.assertEquals(len(provider.routes), 1) - self.assertEquals(list(provider.routes)[0], Route(path="/path", methods=["GET"], function_name="SamFunc1")) - assertCountEqual(self, provider.api.binary_media_types, ["image/gif", "image/png"]) + self.assertEquals( + list(provider.routes)[0], + Route(path="/path", methods=["GET"], function_name="SamFunc1"), + ) + assertCountEqual( + self, provider.api.binary_media_types, ["image/gif", "image/png"] + ) self.assertEquals(provider.api.stage_name, "Prod") def test_provider_must_support_binary_media_types_with_any_method(self): template = { "Globals": { - "Api": { - "BinaryMediaTypes": [ - "image~1gif", - "image~1png", - "text/html" - ] - } + "Api": {"BinaryMediaTypes": ["image~1gif", "image~1png", "text/html"]} }, "Resources": { - "SamFunc1": { "Type": "AWS::Serverless::Function", "Properties": { @@ -352,27 +308,22 @@ def test_provider_must_support_binary_media_types_with_any_method(self): "Events": { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path", - "Method": "any" - } + "Properties": {"Path": "/path", "Method": "any"}, } - } - } + }, + }, } - } + }, } binary = ["image/gif", "image/png", "text/html"] expected_routes = [ - Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], function_name="SamFunc1") + Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="SamFunc1", + ) ] provider = ApiProvider(template) @@ -382,25 +333,21 @@ def test_provider_must_support_binary_media_types_with_any_method(self): class TestSamApiProviderWithExplicitApis(TestCase): - def setUp(self): self.binary_types = ["image/png", "image/jpg"] self.stage_name = "Prod" self.input_routes = [ Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), - Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1"), ] def test_with_no_routes(self): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "Prod" - } + "Properties": {"StageName": "Prod"}, } } } @@ -412,13 +359,12 @@ def test_with_no_routes(self): def test_with_inline_swagger_routes(self): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(self.input_routes) - } + "DefinitionBody": make_swagger(self.input_routes), + }, } } } @@ -427,7 +373,7 @@ def test_with_inline_swagger_routes(self): assertCountEqual(self, self.input_routes, provider.routes) def test_with_swagger_as_local_file(self): - with tempfile.NamedTemporaryFile(mode='w', delete=False) as fp: + with tempfile.NamedTemporaryFile(mode="w", delete=False) as fp: filename = fp.name swagger = make_swagger(self.input_routes) @@ -436,13 +382,9 @@ def test_with_swagger_as_local_file(self): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "Prod", - "DefinitionUri": filename - } + "Properties": {"StageName": "Prod", "DefinitionUri": filename}, } } } @@ -457,39 +399,37 @@ def test_with_swagger_as_both_body_and_uri_called(self, SwaggerReaderMock): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", "DefinitionUri": filename, - "DefinitionBody": body - } + "DefinitionBody": body, + }, } } } - SwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_routes) + SwaggerReaderMock.return_value.read.return_value = make_swagger( + self.input_routes + ) cwd = "foo" provider = ApiProvider(template, cwd=cwd) assertCountEqual(self, self.input_routes, provider.routes) - SwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) + SwaggerReaderMock.assert_called_with( + definition_body=body, definition_uri=filename, working_dir=cwd + ) def test_swagger_with_any_method(self): - routes = [ - Route(path="/path", methods=["any"], function_name="SamFunc1") - ] + routes = [Route(path="/path", methods=["any"], function_name="SamFunc1")] expected_routes = [ - Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], - function_name="SamFunc1") + Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="SamFunc1", + ) ] template = { @@ -498,8 +438,8 @@ def test_swagger_with_any_method(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(routes) - } + "DefinitionBody": make_swagger(routes), + }, } } } @@ -510,13 +450,14 @@ def test_swagger_with_any_method(self): def test_with_binary_media_types(self): template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(self.input_routes, binary_media_types=self.binary_types) - } + "DefinitionBody": make_swagger( + self.input_routes, binary_media_types=self.binary_types + ), + }, } } } @@ -525,7 +466,7 @@ def test_with_binary_media_types(self): expected_routes = [ Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), Route(path="/path2", methods=["GET", "PUT"], function_name="SamFunc1"), - Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1"), ] provider = ApiProvider(template) @@ -534,27 +475,28 @@ def test_with_binary_media_types(self): def test_with_binary_media_types_in_swagger_and_on_resource(self): input_routes = [ - Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1") ] extra_binary_types = ["text/html"] template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", "Properties": { "BinaryMediaTypes": extra_binary_types, "StageName": "Prod", - "DefinitionBody": make_swagger(input_routes, binary_media_types=self.binary_types) - } + "DefinitionBody": make_swagger( + input_routes, binary_media_types=self.binary_types + ), + }, } } } expected_binary_types = sorted(self.binary_types + extra_binary_types) expected_routes = [ - Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1") ] provider = ApiProvider(template) @@ -563,35 +505,30 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): class TestSamApiProviderWithExplicitAndImplicitApis(TestCase): - def setUp(self): self.stage_name = "Prod" self.explicit_routes = [ Route(path="/path1", methods=["GET"], function_name="explicitfunction"), Route(path="/path2", methods=["GET"], function_name="explicitfunction"), - Route(path="/path3", methods=["GET"], function_name="explicitfunction") + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), ] self.swagger = make_swagger(self.explicit_routes) self.template = { "Resources": { - "Api1": { "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "Prod", - } + "Properties": {"StageName": "Prod"}, }, - "ImplicitFunc": { "Type": "AWS::Serverless::Function", "Properties": { "CodeUri": "/usr/foo/bar", "Runtime": "nodejs4.3", - "Handler": "index.handler" - } - } + "Handler": "index.handler", + }, + }, } } @@ -599,30 +536,21 @@ def test_must_union_implicit_and_explicit(self): events = { "Event1": { "Type": "Api", - "Properties": { - "Path": "/path1", - "Method": "POST" - } + "Properties": {"Path": "/path1", "Method": "POST"}, }, - "Event2": { "Type": "Api", - "Properties": { - "Path": "/path2", - "Method": "POST" - } + "Properties": {"Path": "/path2", "Method": "POST"}, }, - "Event3": { "Type": "Api", - "Properties": { - "Path": "/path3", - "Method": "POST" - } - } + "Properties": {"Path": "/path3", "Method": "POST"}, + }, } - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events expected_routes = [ @@ -633,7 +561,7 @@ def test_must_union_implicit_and_explicit(self): # From Implicit APIs Route(path="/path1", methods=["POST"], function_name="ImplicitFunc"), Route(path="/path2", methods=["POST"], function_name="ImplicitFunc"), - Route(path="/path3", methods=["POST"], function_name="ImplicitFunc") + Route(path="/path3", methods=["POST"], function_name="ImplicitFunc"), ] provider = ApiProvider(self.template) @@ -646,30 +574,28 @@ def test_must_prefer_implicit_api_over_explicit(self): "Properties": { # This API is duplicated between implicit & explicit "Path": "/path1", - "Method": "get" - } + "Method": "get", + }, }, - "Event2": { "Type": "Api", - "Properties": { - "Path": "/path2", - "Method": "POST" - } - } + "Properties": {"Path": "/path2", "Method": "POST"}, + }, } - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger + self.template["Resources"]["ImplicitFunc"]["Properties"][ + "Events" + ] = implicit_routes expected_routes = [ Route(path="/path1", methods=["GET"], function_name="ImplicitFunc"), # Comes from Implicit - Route(path="/path2", methods=["GET"], function_name="explicitfunction"), Route(path="/path2", methods=["POST"], function_name="ImplicitFunc"), # Comes from implicit - Route(path="/path3", methods=["GET"], function_name="explicitfunction"), ] @@ -683,8 +609,8 @@ def test_must_prefer_implicit_with_any_method(self): "Properties": { # This API is duplicated between implicit & explicit "Path": "/path", - "Method": "ANY" - } + "Method": "ANY", + }, } } @@ -694,18 +620,19 @@ def test_must_prefer_implicit_with_any_method(self): Route(path="/path", methods=["DELETE"], function_name="explicitfunction"), ] - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_routes) - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = make_swagger(explicit_routes) + self.template["Resources"]["ImplicitFunc"]["Properties"][ + "Events" + ] = implicit_routes expected_routes = [ - Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], - function_name="ImplicitFunc") + Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="ImplicitFunc", + ) ] provider = ApiProvider(self.template) @@ -718,17 +645,17 @@ def test_with_any_method_on_both(self): "Properties": { # This API is duplicated between implicit & explicit "Path": "/path", - "Method": "ANY" - } + "Method": "ANY", + }, }, "Event2": { "Type": "Api", "Properties": { # This API is duplicated between implicit & explicit "Path": "/path2", - "Method": "GET" - } - } + "Method": "GET", + }, + }, } explicit_routes = [ @@ -737,22 +664,21 @@ def test_with_any_method_on_both(self): Route(path="/path2", methods=["POST"], function_name="explicitfunction"), ] - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_routes) - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = make_swagger(explicit_routes) + self.template["Resources"]["ImplicitFunc"]["Properties"][ + "Events" + ] = implicit_routes expected_routes = [ - Route(path="/path", methods=["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"], - function_name="ImplicitFunc"), - - Route(path="/path2", methods=["GET"], - function_name="ImplicitFunc"), - Route(path="/path2", methods=["POST"], function_name="explicitfunction") + Route( + path="/path", + methods=["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"], + function_name="ImplicitFunc", + ), + Route(path="/path2", methods=["GET"], function_name="ImplicitFunc"), + Route(path="/path2", methods=["POST"], function_name="explicitfunction"), ] provider = ApiProvider(self.template) @@ -765,21 +691,24 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): "Properties": { "Path": "/newpath1", "Method": "POST", - "RestApiId": "Api1" # This path must get added to this API - } + "RestApiId": "Api1", # This path must get added to this API + }, }, - "Event2": { "Type": "Api", "Properties": { "Path": "/newpath2", "Method": "POST", - "RestApiId": {"Ref": "Api1"} # This path must get added to this API - } - } + "RestApiId": { + "Ref": "Api1" + }, # This path must get added to this API + }, + }, } - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events expected_routes = [ @@ -789,7 +718,7 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs Route(path="/newpath1", methods=["POST"], function_name="ImplicitFunc"), - Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc") + Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc"), ] provider = ApiProvider(self.template) @@ -799,35 +728,36 @@ def test_both_routes_must_get_binary_media_types(self): events = { "Event1": { "Type": "Api", - "Properties": { - "Path": "/newpath1", - "Method": "POST" - } + "Properties": {"Path": "/newpath1", "Method": "POST"}, }, - "Event2": { "Type": "Api", - "Properties": { - "Path": "/newpath2", - "Method": "POST" - } - } + "Properties": {"Path": "/newpath2", "Method": "POST"}, + }, } # Binary type for implicit self.template["Globals"] = { - "Api": { - "BinaryMediaTypes": ["image~1gif", "image~1png"] - } + "Api": {"BinaryMediaTypes": ["image~1gif", "image~1png"]} } self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger # Binary type for explicit - self.template["Resources"]["Api1"]["Properties"]["BinaryMediaTypes"] = ["explicit/type1", "explicit/type2"] + self.template["Resources"]["Api1"]["Properties"]["BinaryMediaTypes"] = [ + "explicit/type1", + "explicit/type2", + ] # Because of Globals, binary types will be concatenated on the explicit API - expected_explicit_binary_types = ["explicit/type1", "explicit/type2", "image/gif", "image/png"] + expected_explicit_binary_types = [ + "explicit/type1", + "explicit/type2", + "image/gif", + "image/png", + ] expected_routes = [ # From Explicit APIs @@ -836,12 +766,14 @@ def test_both_routes_must_get_binary_media_types(self): Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs Route(path="/newpath1", methods=["POST"], function_name="ImplicitFunc"), - Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc") + Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc"), ] provider = ApiProvider(self.template) assertCountEqual(self, expected_routes, provider.routes) - assertCountEqual(self, provider.api.binary_media_types, expected_explicit_binary_types) + assertCountEqual( + self, provider.api.binary_media_types, expected_explicit_binary_types + ) def test_binary_media_types_with_rest_api_id_reference(self): events = { @@ -850,33 +782,37 @@ def test_binary_media_types_with_rest_api_id_reference(self): "Properties": { "Path": "/connected-to-explicit-path", "Method": "POST", - "RestApiId": "Api1" - } + "RestApiId": "Api1", + }, }, - "Event2": { "Type": "Api", - "Properties": { - "Path": "/true-implicit-path", - "Method": "POST" - } - } + "Properties": {"Path": "/true-implicit-path", "Method": "POST"}, + }, } # Binary type for implicit self.template["Globals"] = { - "Api": { - "BinaryMediaTypes": ["image~1gif", "image~1png"] - } + "Api": {"BinaryMediaTypes": ["image~1gif", "image~1png"]} } self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger + self.template["Resources"]["Api1"]["Properties"][ + "DefinitionBody" + ] = self.swagger # Binary type for explicit - self.template["Resources"]["Api1"]["Properties"]["BinaryMediaTypes"] = ["explicit/type1", "explicit/type2"] + self.template["Resources"]["Api1"]["Properties"]["BinaryMediaTypes"] = [ + "explicit/type1", + "explicit/type2", + ] # Because of Globals, binary types will be concatenated on the explicit API - expected_explicit_binary_types = ["explicit/type1", "explicit/type2", "image/gif", "image/png"] + expected_explicit_binary_types = [ + "explicit/type1", + "explicit/type2", + "image/gif", + "image/png", + ] # expected_implicit_binary_types = ["image/gif", "image/png"] expected_routes = [ @@ -884,26 +820,32 @@ def test_binary_media_types_with_rest_api_id_reference(self): Route(path="/path1", methods=["GET"], function_name="explicitfunction"), Route(path="/path2", methods=["GET"], function_name="explicitfunction"), Route(path="/path3", methods=["GET"], function_name="explicitfunction"), - # Because of the RestApiId, Implicit APIs will also get the binary media types inherited from # the corresponding Explicit API - Route(path="/connected-to-explicit-path", methods=["POST"], function_name="ImplicitFunc"), - + Route( + path="/connected-to-explicit-path", + methods=["POST"], + function_name="ImplicitFunc", + ), # This is still just a true implicit API because it does not have RestApiId property - Route(path="/true-implicit-path", methods=["POST"], function_name="ImplicitFunc") + Route( + path="/true-implicit-path", + methods=["POST"], + function_name="ImplicitFunc", + ), ] provider = ApiProvider(self.template) assertCountEqual(self, expected_routes, provider.routes) - assertCountEqual(self, provider.api.binary_media_types, expected_explicit_binary_types) + assertCountEqual( + self, provider.api.binary_media_types, expected_explicit_binary_types + ) class TestSamStageValues(TestCase): - def test_provider_parse_stage_name(self): template = { "Resources": { - "TestApi": { "Type": "AWS::Serverless::Api", "Properties": { @@ -917,21 +859,22 @@ def test_provider_parse_stage_name(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } } - } - } - } + }, + }, } } } provider = ApiProvider(template) - route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + route1 = Route( + path="/path", methods=["GET"], function_name="NoApiEventFunction" + ) self.assertIn(route1, provider.routes) self.assertEquals(provider.api.stage_name, "dev") @@ -940,16 +883,11 @@ def test_provider_parse_stage_name(self): def test_provider_stage_variables(self): template = { "Resources": { - "TestApi": { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "dev", - "Variables": { - "vis": "data", - "random": "test", - "foo": "bar" - }, + "Variables": {"vis": "data", "random": "test", "foo": "bar"}, "DefinitionBody": { "paths": { "/path": { @@ -959,43 +897,37 @@ def test_provider_stage_variables(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } } - } - } - } + }, + }, } } } provider = ApiProvider(template) - route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + route1 = Route( + path="/path", methods=["GET"], function_name="NoApiEventFunction" + ) self.assertIn(route1, provider.routes) self.assertEquals(provider.api.stage_name, "dev") - self.assertEquals(provider.api.stage_variables, { - "vis": "data", - "random": "test", - "foo": "bar" - }) + self.assertEquals( + provider.api.stage_variables, + {"vis": "data", "random": "test", "foo": "bar"}, + ) def test_multi_stage_get_all(self): - template = OrderedDict({ - "Resources": {} - }) + template = OrderedDict({"Resources": {}}) template["Resources"]["TestApi"] = { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "dev", - "Variables": { - "vis": "data", - "random": "test", - "foo": "bar" - }, + "Variables": {"vis": "data", "random": "test", "foo": "bar"}, "DefinitionBody": { "paths": { "/path2": { @@ -1005,26 +937,22 @@ def test_multi_stage_get_all(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } } } - } - } + }, + }, } template["Resources"]["ProductionApi"] = { "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Production", - "Variables": { - "vis": "prod data", - "random": "test", - "foo": "bar" - }, + "Variables": {"vis": "prod data", "random": "test", "foo": "bar"}, "DefinitionBody": { "paths": { "/path": { @@ -1034,10 +962,10 @@ def test_multi_stage_get_all(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } }, "/anotherpath": { @@ -1047,16 +975,15 @@ def test_multi_stage_get_all(self): "type": "aws_proxy", "uri": { "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", + "/functions/${NoApiEventFunction.Arn}/invocations" }, "responses": {}, - }, + } } - } - + }, } - } - } + }, + }, } provider = ApiProvider(template) @@ -1064,20 +991,403 @@ def test_multi_stage_get_all(self): result = [f for f in provider.get_all()] routes = result[0].routes - route1 = Route(path='/path2', methods=['GET'], function_name='NoApiEventFunction') - route2 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') - route3 = Route(path='/anotherpath', methods=['POST'], function_name='NoApiEventFunction') + route1 = Route( + path="/path2", methods=["GET"], function_name="NoApiEventFunction" + ) + route2 = Route( + path="/path", methods=["GET"], function_name="NoApiEventFunction" + ) + route3 = Route( + path="/anotherpath", methods=["POST"], function_name="NoApiEventFunction" + ) self.assertEquals(len(routes), 3) self.assertIn(route1, routes) self.assertIn(route2, routes) self.assertIn(route3, routes) self.assertEquals(provider.api.stage_name, "Production") - self.assertEquals(provider.api.stage_variables, { - "vis": "prod data", - "random": "test", - "foo": "bar" - }) + self.assertEquals( + provider.api.stage_variables, + {"vis": "prod data", "random": "test", "foo": "bar"}, + ) + + +class TestSamCors(TestCase): + def test_provider_parse_cors_string(self): + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": "*", + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + } + }, + }, + } + } + } + + provider = ApiProvider(template) + + routes = provider.routes + cors = Cors( + allow_origin="*", + allow_methods=",".join( + sorted(["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"]) + ), + ) + route1 = Route( + path="/path2", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) + route2 = Route( + path="/path", methods=["GET", "OPTIONS"], function_name="NoApiEventFunction" + ) + + self.assertEquals(len(routes), 2) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + self.assertEquals(provider.api.cors, cors) + + def test_provider_parse_cors_dict(self): + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": { + "AllowMethods": "POST, GET", + "AllowOrigin": "*", + "AllowHeaders": "Upgrade-Insecure-Requests", + "MaxAge": 600, + }, + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + "/path": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + } + }, + }, + } + } + } + + provider = ApiProvider(template) + + routes = provider.routes + cors = Cors( + allow_origin="*", + allow_methods=",".join(sorted(["POST", "GET", "OPTIONS"])), + allow_headers="Upgrade-Insecure-Requests", + max_age=600, + ) + route1 = Route( + path="/path2", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) + route2 = Route( + path="/path", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) + + self.assertEquals(len(routes), 2) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + self.assertEquals(provider.api.cors, cors) + + def test_provider_parse_cors_dict_star_allow(self): + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": { + "AllowMethods": "*", + "AllowOrigin": "*", + "AllowHeaders": "Upgrade-Insecure-Requests", + "MaxAge": 600, + }, + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + "/path": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + } + }, + }, + } + } + } + + provider = ApiProvider(template) + + routes = provider.routes + cors = Cors( + allow_origin="*", + allow_methods=",".join(sorted(Route.ANY_HTTP_METHODS)), + allow_headers="Upgrade-Insecure-Requests", + max_age=600, + ) + route1 = Route( + path="/path2", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) + route2 = Route( + path="/path", + methods=["POST", "OPTIONS"], + function_name="NoApiEventFunction", + ) + + self.assertEquals(len(routes), 2) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + self.assertEquals(provider.api.cors, cors) + + def test_invalid_cors_dict_allow_methods(self): + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": { + "AllowMethods": "GET, INVALID_METHOD", + "AllowOrigin": "*", + "AllowHeaders": "Upgrade-Insecure-Requests", + "MaxAge": 600, + }, + "DefinitionBody": { + "paths": { + "/path2": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + "/path": { + "post": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + } + }, + }, + } + } + } + with self.assertRaises( + InvalidSamDocumentException, + msg="ApiProvider should fail for Invalid Cors Allow method", + ): + ApiProvider(template) + + def test_default_cors_dict_prop(self): + template = { + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "Cors": {"AllowOrigin": "www.domain.com"}, + "DefinitionBody": { + "paths": { + "/path2": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + } + } + }, + }, + } + } + } + + provider = ApiProvider(template) + + routes = provider.routes + cors = Cors( + allow_origin="www.domain.com", + allow_methods=",".join(sorted(Route.ANY_HTTP_METHODS)), + ) + route1 = Route( + path="/path2", + methods=["GET", "OPTIONS"], + function_name="NoApiEventFunction", + ) + self.assertEquals(len(routes), 1) + self.assertIn(route1, routes) + self.assertEquals(provider.api.cors, cors) + + def test_global_cors(self): + template = { + "Globals": { + "Api": { + "Cors": { + "AllowMethods": "GET", + "AllowOrigin": "*", + "AllowHeaders": "Upgrade-Insecure-Requests", + "MaxAge": 600, + } + } + }, + "Resources": { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Prod", + "DefinitionBody": { + "paths": { + "/path2": { + "get": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations" + }, + "responses": {}, + } + } + }, + } + }, + }, + } + }, + } + + provider = ApiProvider(template) + + routes = provider.routes + cors = Cors( + allow_origin="*", + allow_headers="Upgrade-Insecure-Requests", + allow_methods=",".join(["GET", "OPTIONS"]), + max_age=600, + ) + route1 = Route( + path="/path2", + methods=["GET", "OPTIONS"], + function_name="NoApiEventFunction", + ) + route2 = Route( + path="/path", methods=["GET", "OPTIONS"], function_name="NoApiEventFunction" + ) + + self.assertEquals(len(routes), 2) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + self.assertEquals(provider.api.cors, cors) def make_swagger(routes, binary_media_types=None): @@ -1095,10 +1405,7 @@ def make_swagger(routes, binary_media_types=None): Swagger document """ - swagger = { - "paths": { - } - } + swagger = {"paths": {}} for api in routes: swagger["paths"].setdefault(api.path, {}) @@ -1107,8 +1414,9 @@ def make_swagger(routes, binary_media_types=None): "x-amazon-apigateway-integration": { "type": "aws_proxy", "uri": "arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1" - ":123456789012:function:{}/invocations".format( - api.function_name) # NOQA + ":123456789012:function:{}/invocations".format( + api.function_name + ), # NOQA } } for method in api.methods: diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index 9bbf52cc62..6ff3ee025d 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -1,13 +1,14 @@ +import base64 import copy -from unittest import TestCase -from mock import Mock, patch, ANY import json -import base64 +from unittest import TestCase +from mock import Mock, patch, ANY, MagicMock from parameterized import parameterized, param from werkzeug.datastructures import Headers from samcli.commands.local.lib.provider import Api +from samcli.commands.local.lib.provider import Cors from samcli.local.apigw.local_apigw_service import LocalApigwService, Route from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -30,21 +31,25 @@ def setUp(self): host='127.0.0.1', stderr=self.stderr) - def test_request_must_invoke_lambda(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_must_invoke_lambda(self, request_mock): make_response_mock = Mock() self.service.service_response = make_response_mock - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] self.service._construct_event = Mock() parse_output_mock = Mock() - parse_output_mock.return_value = ("status_code", "headers", "body") + parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") self.service._parse_lambda_output = parse_output_mock service_response_mock = Mock() service_response_mock.return_value = make_response_mock self.service.service_response = service_response_mock + request_mock.return_value = ('test', 'test') + result = self.service._request_handler() self.assertEquals(result, make_response_mock) @@ -53,16 +58,19 @@ def test_request_must_invoke_lambda(self): stdout=ANY, stderr=self.stderr) + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.LambdaOutputParser') - def test_request_handler_returns_process_stdout_when_making_response(self, lambda_output_parser_mock): + def test_request_handler_returns_process_stdout_when_making_response(self, lambda_output_parser_mock, request_mock): make_response_mock = Mock() - + request_mock.return_value = ('test', 'test') self.service.service_response = make_response_mock - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] + self.service._construct_event = Mock() parse_output_mock = Mock() - parse_output_mock.return_value = ("status_code", "headers", "body") + parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") self.service._parse_lambda_output = parse_output_mock lambda_logs = "logs" @@ -83,21 +91,24 @@ def test_request_handler_returns_process_stdout_when_making_response(self, lambd # Make sure the logs are written to stderr self.stderr.write.assert_called_with(lambda_logs) - def test_request_handler_returns_make_response(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_handler_returns_make_response(self, request_mock): make_response_mock = Mock() self.service.service_response = make_response_mock - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() self.service._construct_event = Mock() + self.service._get_current_route.methods = [] parse_output_mock = Mock() - parse_output_mock.return_value = ("status_code", "headers", "body") + parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") self.service._parse_lambda_output = parse_output_mock service_response_mock = Mock() service_response_mock.return_value = make_response_mock self.service.service_response = service_response_mock + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, make_response_mock) @@ -152,30 +163,37 @@ def test_initalize_with_values(self): self.assertEquals(local_service.static_dir, 'dir/static') self.assertEquals(local_service.lambda_runner, lambda_runner) + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handles_error_when_invoke_cant_find_function(self, service_error_responses_patch): + def test_request_handles_error_when_invoke_cant_find_function(self, service_error_responses_patch, request_mock): not_found_response_mock = Mock() self.service._construct_event = Mock() - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] + service_error_responses_patch.lambda_not_found_response.return_value = not_found_response_mock self.lambda_runner.invoke.side_effect = FunctionNotFound() - + request_mock.return_value = ('test', 'test') response = self.service._request_handler() self.assertEquals(response, not_found_response_mock) - def test_request_throws_when_invoke_fails(self): + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_request_throws_when_invoke_fails(self, request_mock): self.lambda_runner.invoke.side_effect = Exception() self.service._construct_event = Mock() self.service._get_current_route = Mock() + request_mock.return_value = ('test', 'test') with self.assertRaises(Exception): self.service._request_handler() + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, service_error_responses_patch): + def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, service_error_responses_patch, + request_mock): parse_output_mock = Mock() parse_output_mock.side_effect = KeyError() self.service._parse_lambda_output = parse_output_mock @@ -185,8 +203,10 @@ def test_request_handler_errors_when_parse_lambda_output_raises_keyerror(self, s service_error_responses_patch.lambda_failure_response.return_value = failure_response_mock self.service._construct_event = Mock() - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, failure_response_mock) @@ -200,16 +220,20 @@ def test_request_handler_errors_when_get_current_route_fails(self, service_error with self.assertRaises(KeyError): self.service._request_handler() + @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch('samcli.local.apigw.local_apigw_service.ServiceErrorResponses') - def test_request_handler_errors_when_unable_to_read_binary_data(self, service_error_responses_patch): + def test_request_handler_errors_when_unable_to_read_binary_data(self, service_error_responses_patch, request_mock): _construct_event = Mock() _construct_event.side_effect = UnicodeDecodeError("utf8", b"obj", 1, 2, "reason") - self.service._get_current_route = Mock() + self.service._get_current_route = MagicMock() + self.service._get_current_route.methods = [] + self.service._construct_event = _construct_event failure_mock = Mock() service_error_responses_patch.lambda_failure_response.return_value = failure_mock + request_mock.return_value = ('test', 'test') result = self.service._request_handler() self.assertEquals(result, failure_mock) @@ -589,6 +613,24 @@ def test_should_base64_encode_returns_false(self, test_case_name, binary_types, self.assertFalse(LocalApigwService._should_base64_encode(binary_types, mimetype)) +class TestServiceCorsToHeaders(TestCase): + def test_basic_conversion(self): + cors = Cors(allow_origin="*", allow_methods=','.join(["POST", "OPTIONS"]), allow_headers="UPGRADE-HEADER", + max_age=6) + headers = Cors.cors_to_headers(cors) + + self.assertEquals(headers, {'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Methods': 'POST,OPTIONS', + 'Access-Control-Allow-Headers': 'UPGRADE-HEADER', 'Access-Control-Max-Age': 6}) + + def test_empty_elements(self): + cors = Cors(allow_origin="www.domain.com", allow_methods=','.join(["GET", "POST", "OPTIONS"])) + headers = Cors.cors_to_headers(cors) + + self.assertEquals(headers, + {'Access-Control-Allow-Origin': 'www.domain.com', + 'Access-Control-Allow-Methods': 'GET,POST,OPTIONS'}) + + class TestRouteEqualsHash(TestCase): def test_route_in_list(self):