From a25cc2d129a45794098b43465a94e7b68637b6de Mon Sep 17 00:00:00 2001 From: Sanath Kumar Ramesh Date: Mon, 22 Jul 2019 14:24:40 -0700 Subject: [PATCH] Revert "feat(start-api): CloudFormation AWS::ApiGateway::RestApi support (#1238)" This reverts commit 7cb50b0174f85684e0e0e8ab15e056a309b6504b. --- samcli/commands/local/lib/api_collector.py | 215 -------- samcli/commands/local/lib/api_provider.py | 94 ---- samcli/commands/local/lib/cfn_api_provider.py | 69 --- .../local/lib/cfn_base_api_provider.py | 70 --- .../commands/local/lib/local_api_service.py | 22 +- samcli/commands/local/lib/provider.py | 57 +-- samcli/commands/local/lib/sam_api_provider.py | 462 +++++++++++++++--- .../commands/local/lib/sam_base_provider.py | 3 +- .../local/start_api/test_start_api.py | 106 ---- .../start_api/swagger-rest-api-template.yaml | 69 --- .../commands/local/lib/test_api_provider.py | 207 -------- .../local/lib/test_cfn_api_provider.py | 215 -------- .../local/lib/test_local_api_service.py | 4 +- .../local/lib/test_sam_api_provider.py | 84 ++-- 14 files changed, 459 insertions(+), 1218 deletions(-) delete mode 100644 samcli/commands/local/lib/api_collector.py delete mode 100644 samcli/commands/local/lib/api_provider.py delete mode 100644 samcli/commands/local/lib/cfn_api_provider.py delete mode 100644 samcli/commands/local/lib/cfn_base_api_provider.py delete mode 100644 tests/integration/testdata/start_api/swagger-rest-api-template.yaml delete mode 100644 tests/unit/commands/local/lib/test_api_provider.py delete mode 100644 tests/unit/commands/local/lib/test_cfn_api_provider.py diff --git a/samcli/commands/local/lib/api_collector.py b/samcli/commands/local/lib/api_collector.py deleted file mode 100644 index cbd198c6b7..0000000000 --- a/samcli/commands/local/lib/api_collector.py +++ /dev/null @@ -1,215 +0,0 @@ -""" -Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit -APIs in a standardized format -""" - -import logging -from collections import namedtuple - -from six import string_types - -LOG = logging.getLogger(__name__) - - -class ApiCollector(object): - # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. - # This is intentional because it allows us to easily extend this class to support future properties on the API. - # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit - # APIs. - Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) - - def __init__(self): - # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and - # value is the properties - self.by_resource = {} - - def __iter__(self): - """ - Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the - LogicalId of the API resource and a list of APIs available in this resource. - - Yields - ------- - str - LogicalID of the AWS::Serverless::Api resource - list samcli.commands.local.lib.provider.Api - List of the API available in this resource along with additional configuration like binary media types. - """ - - for logical_id, _ in self.by_resource.items(): - yield logical_id, self._get_apis_with_config(logical_id) - - def add_apis(self, logical_id, apis): - """ - Stores the given APIs tagged under the given logicalId - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - apis : list of samcli.commands.local.lib.provider.Api - List of APIs available in this resource - """ - properties = self._get_properties(logical_id) - properties.apis.extend(apis) - - def add_binary_media_types(self, logical_id, binary_media_types): - """ - Stores the binary media type configuration for the API with given logical ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - binary_media_types : list of str - List of binary media types supported by this resource - - """ - properties = self._get_properties(logical_id) - - binary_media_types = binary_media_types or [] - for value in binary_media_types: - normalized_value = self._normalize_binary_media_type(value) - - # If the value is not supported, then just skip it. - if normalized_value: - properties.binary_media_types.add(normalized_value) - else: - LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) - - def add_stage_name(self, logical_id, stage_name): - """ - Stores the stage name for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_name : str - The stage_name string - - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_name=stage_name) - self._set_properties(logical_id, properties) - - def add_stage_variables(self, logical_id, stage_variables): - """ - Stores the stage variables for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_variables : dict - A dictionary containing stage variables. - - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_variables=stage_variables) - self._set_properties(logical_id, properties) - - def _get_apis_with_config(self, logical_id): - """ - Returns the list of APIs in this resource along with other extra configuration such as binary media types, - cors etc. Additional configuration is merged directly into the API data because these properties, although - defined globally, actually apply to each API. - - Parameters - ---------- - logical_id : str - Logical ID of the resource to fetch data for - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, - then it returns an empty list - """ - - properties = self._get_properties(logical_id) - - # These configs need to be applied to each API - binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable - cors = properties.cors - stage_name = properties.stage_name - stage_variables = properties.stage_variables - - result = [] - for api in properties.apis: - # Create a copy of the API with updated configuration - updated_api = api._replace(binary_media_types=binary_media, - cors=cors, - stage_name=stage_name, - stage_variables=stage_variables) - result.append(updated_api) - - return result - - def _get_properties(self, logical_id): - """ - Returns the properties of resource with given logical ID. If a resource is not found, then it returns an - empty data. - - Parameters - ---------- - logical_id : str - Logical ID of the resource - - Returns - ------- - samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ - - if logical_id not in self.by_resource: - self.by_resource[logical_id] = self.Properties(apis=[], - # Use a set() to be able to easily de-dupe - binary_media_types=set(), - cors=None, - stage_name=None, - stage_variables=None) - - return self.by_resource[logical_id] - - def _set_properties(self, logical_id, properties): - """ - Sets the properties of resource with given logical ID. If a resource is not found, it does nothing - - Parameters - ---------- - logical_id : str - Logical ID of the resource - properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ - - if logical_id in self.by_resource: - self.by_resource[logical_id] = properties - - @staticmethod - def _normalize_binary_media_type(value): - """ - Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not - a string, then this method just returns None - - Parameters - ---------- - value : str - Value to be normalized - - Returns - ------- - str or None - Normalized value. If the input was not a string, then None is returned - """ - - if not isinstance(value, string_types): - # It is possible that user specified a dict value for one of the binary media types. We just skip them - return None - - return value.replace("~1", "/") diff --git a/samcli/commands/local/lib/api_provider.py b/samcli/commands/local/lib/api_provider.py deleted file mode 100644 index afc686e166..0000000000 --- a/samcli/commands/local/lib/api_provider.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Class that provides Apis from a SAM Template""" - -import logging - -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider -from samcli.commands.local.lib.api_collector import ApiCollector -from samcli.commands.local.lib.provider import AbstractApiProvider -from samcli.commands.local.lib.sam_base_provider import SamBaseProvider -from samcli.commands.local.lib.sam_api_provider import SamApiProvider -from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider - -LOG = logging.getLogger(__name__) - - -class ApiProvider(AbstractApiProvider): - - def __init__(self, template_dict, parameter_overrides=None, cwd=None): - """ - Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed - to be valid, normalized and a dictionary. template_dict should be normalized by running any and all - pre-processing before passing to this class. - This class does not perform any syntactic validation of the template. - - After the class is initialized, changes to ``template_dict`` will not be reflected in here. - You will need to explicitly update the class with new template, if necessary. - - Parameters - ---------- - template_dict : dict - SAM Template as a dictionary - - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - """ - self.template_dict = SamBaseProvider.get_template(template_dict, parameter_overrides) - self.resources = self.template_dict.get("Resources", {}) - - LOG.debug("%d resources found in the template", len(self.resources)) - - # Store a set of apis - self.cwd = cwd - self.apis = self._extract_apis(self.resources) - - LOG.debug("%d APIs found in the template", len(self.apis)) - - def get_all(self): - """ - Yields all the Lambda functions with Api Events available in the SAM Template. - - :yields Api: namedtuple containing the Api information - """ - - for api in self.apis: - yield api - - def _extract_apis(self, resources): - """ - Extracts all the Apis by running through the one providers. The provider that has the first type matched - will be run across all the resources - - Parameters - ---------- - resources: dict - The dictionary containing the different resources within the template - Returns - --------- - list of Apis extracted from the resources - """ - collector = ApiCollector() - provider = self.find_api_provider(resources) - apis = provider.extract_resource_api(resources, collector, cwd=self.cwd) - return self.normalize_apis(apis) - - @staticmethod - def find_api_provider(resources): - """ - Finds the ApiProvider given the first api type of the resource - - Parameters - ----------- - resources: dict - The dictionary containing the different resources within the template - - Return - ---------- - Instance of the ApiProvider that will be run on the template with a default of SamApiProvider - """ - for _, resource in resources.items(): - if resource.get(CfnBaseApiProvider.RESOURCE_TYPE) in SamApiProvider.TYPES: - return SamApiProvider() - elif resource.get(CfnBaseApiProvider.RESOURCE_TYPE) in CfnApiProvider.TYPES: - return CfnApiProvider() - - return SamApiProvider() diff --git a/samcli/commands/local/lib/cfn_api_provider.py b/samcli/commands/local/lib/cfn_api_provider.py deleted file mode 100644 index 0e3919611c..0000000000 --- a/samcli/commands/local/lib/cfn_api_provider.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Parses SAM given a template""" -import logging - -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider - -LOG = logging.getLogger(__name__) - - -class CfnApiProvider(CfnBaseApiProvider): - APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" - TYPES = [ - APIGATEWAY_RESTAPI - ] - - def extract_resource_api(self, resources, collector, cwd=None): - """ - Extract the Api Object from a given resource and adds it to the ApiCollector. - - Parameters - ---------- - resources: dict - The dictionary containing the different resources within the template - - collector: ApiCollector - Instance of the API collector that where we will save the API information - - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - - Return - ------- - Returns a list of Apis - """ - for logical_id, resource in resources.items(): - resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) - if resource_type == CfnApiProvider.APIGATEWAY_RESTAPI: - self._extract_cloud_formation_api(logical_id, resource, collector, cwd) - all_apis = [] - for _, apis in collector: - all_apis.extend(apis) - return all_apis - - def _extract_cloud_formation_api(self, logical_id, api_resource, collector, cwd=None): - """ - Extract APIs from AWS::ApiGateway::RestApi resource by reading and parsing Swagger documents. The result is - added to the collector. - - Parameters - ---------- - logical_id : str - Logical ID of the resource - - api_resource : dict - Resource definition, including its properties - - collector : ApiCollector - Instance of the API collector that where we will save the API information - """ - properties = api_resource.get("Properties", {}) - body = properties.get("Body") - body_s3_location = properties.get("BodyS3Location") - binary_media = properties.get("BinaryMediaTypes", []) - - if not body and not body_s3_location: - # Swagger is not found anywhere. - LOG.debug("Skipping resource '%s'. Swagger document not found in Body and BodyS3Location", - logical_id) - return - self.extract_swagger_api(logical_id, body, body_s3_location, binary_media, collector, cwd) diff --git a/samcli/commands/local/lib/cfn_base_api_provider.py b/samcli/commands/local/lib/cfn_base_api_provider.py deleted file mode 100644 index 79bc6d8f1d..0000000000 --- a/samcli/commands/local/lib/cfn_base_api_provider.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Class that parses the CloudFormation Api Template""" - -import logging - -from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.swagger.reader import SamSwaggerReader - -LOG = logging.getLogger(__name__) - - -class CfnBaseApiProvider(object): - RESOURCE_TYPE = "Type" - - def extract_resource_api(self, resources, collector, cwd=None): - """ - Extract the Api Object from a given resource and adds it to the ApiCollector. - - Parameters - ---------- - resources: dict - The dictionary containing the different resources within the template - - collector: ApiCollector - Instance of the API collector that where we will save the API information - - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - - Return - ------- - Returns a list of Apis - """ - raise NotImplementedError("not implemented") - - @staticmethod - def extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd=None): - """ - Parse the Swagger documents and adds it to the ApiCollector. - - Parameters - ---------- - logical_id : str - Logical ID of the resource - - body : dict - The body of the RestApi - - uri : str or dict - The url to location of the RestApi - - binary_media: list - The link to the binary media - - collector: ApiCollector - Instance of the API collector that where we will save the API information - - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - """ - reader = SamSwaggerReader(definition_body=body, - definition_uri=uri, - working_dir=cwd) - swagger = reader.read() - parser = SwaggerParser(swagger) - apis = parser.get_apis() - LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) - - collector.add_apis(logical_id, apis) - collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger - collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template diff --git a/samcli/commands/local/lib/local_api_service.py b/samcli/commands/local/lib/local_api_service.py index d456e67a83..d0ebbbd975 100644 --- a/samcli/commands/local/lib/local_api_service.py +++ b/samcli/commands/local/lib/local_api_service.py @@ -6,7 +6,7 @@ import logging from samcli.local.apigw.local_apigw_service import LocalApigwService, Route -from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.commands.local.lib.sam_api_provider import SamApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined LOG = logging.getLogger(__name__) @@ -38,9 +38,9 @@ def __init__(self, self.static_dir = static_dir self.cwd = lambda_invoke_context.get_cwd() - self.api_provider = ApiProvider(lambda_invoke_context.template, - parameter_overrides=lambda_invoke_context.parameter_overrides, - cwd=self.cwd) + self.api_provider = SamApiProvider(lambda_invoke_context.template, + parameter_overrides=lambda_invoke_context.parameter_overrides, + cwd=self.cwd) self.lambda_runner = lambda_invoke_context.local_lambda_runner self.stderr_stream = lambda_invoke_context.stderr @@ -89,7 +89,7 @@ def _make_routing_list(api_provider): Parameters ---------- - api_provider : samcli.commands.local.lib.api_provider.ApiProvider + api_provider : samcli.commands.local.lib.sam_api_provider.SamApiProvider Returns ------- @@ -116,14 +116,10 @@ def _print_routes(api_provider, host, port): Mounting Product at http://127.0.0.1:3000/path1/bar [GET, POST, DELETE] Mounting Product at http://127.0.0.1:3000/path2/bar [HEAD] - :param samcli.commands.local.lib.provider.AbstractApiProvider api_provider: - API Provider that can return a list of APIs - :param string host: - Host name where the service is running - :param int port: - Port number where the service is running - :returns list(string): - List of lines that were printed to the console. Helps with testing + :param samcli.commands.local.lib.provider.ApiProvider api_provider: API Provider that can return a list of APIs + :param string host: Host name where the service is running + :param int port: Port number where the service is running + :returns list(string): List of lines that were printed to the console. Helps with testing """ grouped_api_configs = {} diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 959166e814..eead981089 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -222,10 +222,10 @@ def get_all(self): # The variables for that stage "stage_variables" ]) -_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, - None, # Stage name is optional with default None - None # Stage variables is optional with default None +_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None + [], # binary_media_types is optional and defaults to empty, + None, # Stage name is optional with default None + None # Stage variables is optional with default None ) @@ -238,17 +238,10 @@ def __hash__(self): Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) -class AbstractApiProvider(object): +class ApiProvider(object): """ Abstract base class to return APIs and the functions they route to """ - _ANY_HTTP_METHODS = ["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"] def get_all(self): """ @@ -257,43 +250,3 @@ def get_all(self): :yields Api: namedtuple containing the API information """ raise NotImplementedError("not implemented") - - @staticmethod - def normalize_http_methods(http_method): - """ - Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all - supported Http Methods on Api Gateway. - - :param str http_method: Http method - :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) - """ - - if http_method.upper() == 'ANY': - for method in AbstractApiProvider._ANY_HTTP_METHODS: - yield method.upper() - else: - yield http_method.upper() - - @staticmethod - def normalize_apis(apis): - """ - Normalize the APIs to use standard method name - - Parameters - ---------- - apis : list of samcli.commands.local.lib.provider.Api - List of APIs to replace normalize - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of normalized APIs - """ - - result = list() - for api in apis: - for normalized_method in AbstractApiProvider.normalize_http_methods(api.method): - # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple - result.append(api._replace(method=normalized_method)) - - return result diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index f0ec57b823..84336a8d2d 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -1,60 +1,111 @@ -"""Parses SAM given the template""" +"""Class that provides Apis from a SAM Template""" import logging +from collections import namedtuple -from samcli.commands.local.lib.provider import Api, AbstractApiProvider +from six import string_types + +from samcli.commands.local.lib.swagger.parser import SwaggerParser +from samcli.commands.local.lib.provider import ApiProvider, Api +from samcli.commands.local.lib.sam_base_provider import SamBaseProvider +from samcli.commands.local.lib.swagger.reader import SamSwaggerReader from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider LOG = logging.getLogger(__name__) -class SamApiProvider(CfnBaseApiProvider): - SERVERLESS_FUNCTION = "AWS::Serverless::Function" - SERVERLESS_API = "AWS::Serverless::Api" - TYPES = [ - SERVERLESS_FUNCTION, - SERVERLESS_API - ] +class SamApiProvider(ApiProvider): + _IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" + _SERVERLESS_FUNCTION = "AWS::Serverless::Function" + _SERVERLESS_API = "AWS::Serverless::Api" + _TYPE = "Type" + _FUNCTION_EVENT_TYPE_API = "Api" _FUNCTION_EVENT = "Events" _EVENT_PATH = "Path" _EVENT_METHOD = "Method" - _EVENT_TYPE = "Type" - IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" - def extract_resource_api(self, resources, collector, cwd=None): + _ANY_HTTP_METHODS = ["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"] + + def __init__(self, template_dict, parameter_overrides=None, cwd=None): """ - Extract the Api Object from a given resource and adds it to the ApiCollector. + Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed + to be valid, normalized and a dictionary. template_dict should be normalized by running any and all + pre-processing before passing to this class. + This class does not perform any syntactic validation of the template. + + After the class is initialized, changes to ``template_dict`` will not be reflected in here. + You will need to explicitly update the class with new template, if necessary. Parameters ---------- - resources: dict - The dictionary containing the different resources within the template - - collector: ApiCollector - Instance of the API collector that where we will save the API information - + template_dict : dict + SAM Template as a dictionary cwd : str Optional working directory with respect to which we will resolve relative path to Swagger file + """ - Return - ------- - Returns a list of Apis + self.template_dict = SamBaseProvider.get_template(template_dict, parameter_overrides) + self.resources = self.template_dict.get("Resources", {}) + + LOG.debug("%d resources found in the template", len(self.resources)) + + # Store a set of apis + self.cwd = cwd + self.apis = self._extract_apis(self.resources) + + LOG.debug("%d APIs found in the template", len(self.apis)) + + def get_all(self): """ - # AWS::Serverless::Function is currently included when parsing of Apis because when SamBaseProvider is run on - # the template we are creating the implicit apis due to plugins that translate it in the SAM repo, - # which we later merge with the explicit ones in SamApiProvider.merge_apis. This requires the code to be - # parsed here and in InvokeContext. + Yields all the Lambda functions with Api Events available in the SAM Template. + + :yields Api: namedtuple containing the Api information + """ + + for api in self.apis: + yield api + + def _extract_apis(self, resources): + """ + Extract all Implicit Apis (Apis defined through Serverless Function with an Api Event + + :param dict resources: Dictionary of SAM/CloudFormation resources + :return: List of nametuple Api + """ + + # Some properties like BinaryMediaTypes, Cors are set once on the resource but need to be applied to each API. + # For Implicit APIs, which are defined on the Function resource, these properties + # are defined on a AWS::Serverless::Api resource with logical ID "ServerlessRestApi". Therefore, no matter + # if it is an implicit API or an explicit API, there is a corresponding resource of type AWS::Serverless::Api + # that contains these additional configurations. + # + # We use this assumption in the following loop to collect information from resources of type + # AWS::Serverless::Api. We also extract API from Serverless::Function resource and add them to the + # corresponding Serverless::Api resource. This is all done using the ``collector``. + + collector = ApiCollector() + for logical_id, resource in resources.items(): - resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) - if resource_type == SamApiProvider.SERVERLESS_FUNCTION: + + resource_type = resource.get(SamApiProvider._TYPE) + + if resource_type == SamApiProvider._SERVERLESS_FUNCTION: self._extract_apis_from_function(logical_id, resource, collector) - if resource_type == SamApiProvider.SERVERLESS_API: - self._extract_from_serverless_api(logical_id, resource, collector, cwd) - return self.merge_apis(collector) - def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=None): + if resource_type == SamApiProvider._SERVERLESS_API: + self._extract_from_serverless_api(logical_id, resource, collector) + + apis = SamApiProvider._merge_apis(collector) + return self._normalize_apis(apis) + + def _extract_from_serverless_api(self, logical_id, api_resource, collector): """ Extract APIs from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result is added to the collector. @@ -83,11 +134,99 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri", logical_id) return - self.extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd) + + reader = SamSwaggerReader(definition_body=body, + definition_uri=uri, + working_dir=self.cwd) + swagger = reader.read() + parser = SwaggerParser(swagger) + apis = parser.get_apis() + LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) + + collector.add_apis(logical_id, apis) + collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger + collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template + collector.add_stage_name(logical_id, stage_name) collector.add_stage_variables(logical_id, stage_variables) - def _extract_apis_from_function(self, logical_id, function_resource, collector): + @staticmethod + def _merge_apis(collector): + """ + Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API + definition wins because that conveys clear intent that the API is backed by a function. This method will + merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined + in both the places, only one wins. + + Parameters + ---------- + collector : ApiCollector + Collector object that holds all the APIs specified in the template + + Returns + ------- + list of samcli.commands.local.lib.provider.Api + List of APIs obtained by combining both the input lists. + """ + + implicit_apis = [] + explicit_apis = [] + + # Store implicit and explicit APIs separately in order to merge them later in the correct order + # Implicit APIs are defined on a resource with logicalID ServerlessRestApi + for logical_id, apis in collector: + if logical_id == SamApiProvider._IMPLICIT_API_RESOURCE_ID: + implicit_apis.extend(apis) + else: + explicit_apis.extend(apis) + + # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. + # If an path+method combo already exists, then overwrite it if and only if this is an implicit API + all_apis = {} + + # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already + # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. + all_configs = explicit_apis + implicit_apis + + for config in all_configs: + # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP + # method on explicit API. + for normalized_method in SamApiProvider._normalize_http_methods(config.method): + key = config.path + normalized_method + all_apis[key] = config + + result = set(all_apis.values()) # Assign to a set() to de-dupe + LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", + len(explicit_apis), len(implicit_apis), len(result)) + + return list(result) + + @staticmethod + def _normalize_apis(apis): + """ + Normalize the APIs to use standard method name + + Parameters + ---------- + apis : list of samcli.commands.local.lib.provider.Api + List of APIs to replace normalize + + Returns + ------- + list of samcli.commands.local.lib.provider.Api + List of normalized APIs + """ + + result = list() + for api in apis: + for normalized_method in SamApiProvider._normalize_http_methods(api.method): + # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple + result.append(api._replace(method=normalized_method)) + + return result + + @staticmethod + def _extract_apis_from_function(logical_id, function_resource, collector): """ Fetches a list of APIs configured for this SAM Function resource. @@ -104,10 +243,11 @@ def _extract_apis_from_function(self, logical_id, function_resource, collector): """ resource_properties = function_resource.get("Properties", {}) - serverless_function_events = resource_properties.get(self._FUNCTION_EVENT, {}) - self.extract_apis_from_events(logical_id, serverless_function_events, collector) + serverless_function_events = resource_properties.get(SamApiProvider._FUNCTION_EVENT, {}) + SamApiProvider._extract_apis_from_events(logical_id, serverless_function_events, collector) - def extract_apis_from_events(self, function_logical_id, serverless_function_events, collector): + @staticmethod + def _extract_apis_from_events(function_logical_id, serverless_function_events, collector): """ Given an AWS::Serverless::Function Event Dictionary, extract out all 'Api' events and store within the collector @@ -126,8 +266,8 @@ def extract_apis_from_events(self, function_logical_id, serverless_function_even count = 0 for _, event in serverless_function_events.items(): - if self._FUNCTION_EVENT_TYPE_API == event.get(self._EVENT_TYPE): - api_resource_id, api = self._convert_event_api(function_logical_id, event.get("Properties")) + if SamApiProvider._FUNCTION_EVENT_TYPE_API == event.get(SamApiProvider._TYPE): + api_resource_id, api = SamApiProvider._convert_event_api(function_logical_id, event.get("Properties")) collector.add_apis(api_resource_id, [api]) count += 1 @@ -148,7 +288,7 @@ def _convert_event_api(lambda_logical_id, event_properties): # An API Event, can have RestApiId property which designates the resource that owns this API. If omitted, # the API is owned by Implicit API resource. This could either be a direct resource logical ID or a # "Ref" of the logicalID - api_resource_id = event_properties.get("RestApiId", SamApiProvider.IMPLICIT_API_RESOURCE_ID) + api_resource_id = event_properties.get("RestApiId", SamApiProvider._IMPLICIT_API_RESOURCE_ID) if isinstance(api_resource_id, dict) and "Ref" in api_resource_id: api_resource_id = api_resource_id["Ref"] @@ -162,52 +302,226 @@ def _convert_event_api(lambda_logical_id, event_properties): return api_resource_id, Api(path=path, method=method, function_name=lambda_logical_id) @staticmethod - def merge_apis(collector): + def _normalize_http_methods(http_method): """ - Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API - definition wins because that conveys clear intent that the API is backed by a function. This method will - merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined - in both the places, only one wins. + Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all + supported Http Methods on Api Gateway. + + :param str http_method: Http method + :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) + """ + + if http_method.upper() == 'ANY': + for method in SamApiProvider._ANY_HTTP_METHODS: + yield method.upper() + else: + yield http_method.upper() + + +class ApiCollector(object): + """ + Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit + APIs in a standardized format + """ + + # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. + # This is intentional because it allows us to easily extend this class to support future properties on the API. + # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit + # APIs. + Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) + + def __init__(self): + # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and + # value is the properties + self.by_resource = {} + + def __iter__(self): + """ + Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the + LogicalId of the API resource and a list of APIs available in this resource. + + Yields + ------- + str + LogicalID of the AWS::Serverless::Api resource + list samcli.commands.local.lib.provider.Api + List of the API available in this resource along with additional configuration like binary media types. + """ + + for logical_id, _ in self.by_resource.items(): + yield logical_id, self._get_apis_with_config(logical_id) + + def add_apis(self, logical_id, apis): + """ + Stores the given APIs tagged under the given logicalId Parameters ---------- - collector : ApiCollector - Collector object that holds all the APIs specified in the template + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + apis : list of samcli.commands.local.lib.provider.Api + List of APIs available in this resource + """ + properties = self._get_properties(logical_id) + properties.apis.extend(apis) + + def add_binary_media_types(self, logical_id, binary_media_types): + """ + Stores the binary media type configuration for the API with given logical ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + binary_media_types : list of str + List of binary media types supported by this resource + + """ + properties = self._get_properties(logical_id) + + binary_media_types = binary_media_types or [] + for value in binary_media_types: + normalized_value = self._normalize_binary_media_type(value) + + # If the value is not supported, then just skip it. + if normalized_value: + properties.binary_media_types.add(normalized_value) + else: + LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) + + def add_stage_name(self, logical_id, stage_name): + """ + Stores the stage name for the API with the given local ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + stage_name : str + The stage_name string + + """ + properties = self._get_properties(logical_id) + properties = properties._replace(stage_name=stage_name) + self._set_properties(logical_id, properties) + + def add_stage_variables(self, logical_id, stage_variables): + """ + Stores the stage variables for the API with the given local ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + stage_variables : dict + A dictionary containing stage variables. + + """ + properties = self._get_properties(logical_id) + properties = properties._replace(stage_variables=stage_variables) + self._set_properties(logical_id, properties) + + def _get_apis_with_config(self, logical_id): + """ + Returns the list of APIs in this resource along with other extra configuration such as binary media types, + cors etc. Additional configuration is merged directly into the API data because these properties, although + defined globally, actually apply to each API. + + Parameters + ---------- + logical_id : str + Logical ID of the resource to fetch data for Returns ------- list of samcli.commands.local.lib.provider.Api - List of APIs obtained by combining both the input lists. + List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, + then it returns an empty list """ - implicit_apis = [] - explicit_apis = [] + properties = self._get_properties(logical_id) - # Store implicit and explicit APIs separately in order to merge them later in the correct order - # Implicit APIs are defined on a resource with logicalID ServerlessRestApi - for logical_id, apis in collector: - if logical_id == SamApiProvider.IMPLICIT_API_RESOURCE_ID: - implicit_apis.extend(apis) - else: - explicit_apis.extend(apis) + # These configs need to be applied to each API + binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable + cors = properties.cors + stage_name = properties.stage_name + stage_variables = properties.stage_variables - # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. - # If an path+method combo already exists, then overwrite it if and only if this is an implicit API - all_apis = {} + result = [] + for api in properties.apis: + # Create a copy of the API with updated configuration + updated_api = api._replace(binary_media_types=binary_media, + cors=cors, + stage_name=stage_name, + stage_variables=stage_variables) + result.append(updated_api) - # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already - # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. - all_configs = explicit_apis + implicit_apis + return result - for config in all_configs: - # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP - # method on explicit API. - for normalized_method in AbstractApiProvider.normalize_http_methods(config.method): - key = config.path + normalized_method - all_apis[key] = config + def _get_properties(self, logical_id): + """ + Returns the properties of resource with given logical ID. If a resource is not found, then it returns an + empty data. - result = set(all_apis.values()) # Assign to a set() to de-dupe - LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", - len(explicit_apis), len(implicit_apis), len(result)) + Parameters + ---------- + logical_id : str + Logical ID of the resource - return list(result) + Returns + ------- + samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties + Properties object for this resource. + """ + + if logical_id not in self.by_resource: + self.by_resource[logical_id] = self.Properties(apis=[], + # Use a set() to be able to easily de-dupe + binary_media_types=set(), + cors=None, + stage_name=None, + stage_variables=None) + + return self.by_resource[logical_id] + + def _set_properties(self, logical_id, properties): + """ + Sets the properties of resource with given logical ID. If a resource is not found, it does nothing + + Parameters + ---------- + logical_id : str + Logical ID of the resource + properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties + Properties object for this resource. + """ + + if logical_id in self.by_resource: + self.by_resource[logical_id] = properties + + @staticmethod + def _normalize_binary_media_type(value): + """ + Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not + a string, then this method just returns None + + Parameters + ---------- + value : str + Value to be normalized + + Returns + ------- + str or None + Normalized value. If the input was not a string, then None is returned + """ + + if not isinstance(value, string_types): + # It is possible that user specified a dict value for one of the binary media types. We just skip them + return None + + return value.replace("~1", "/") diff --git a/samcli/commands/local/lib/sam_base_provider.py b/samcli/commands/local/lib/sam_base_provider.py index 861e1fd47a..bbf4d6381b 100644 --- a/samcli/commands/local/lib/sam_base_provider.py +++ b/samcli/commands/local/lib/sam_base_provider.py @@ -10,6 +10,7 @@ from samcli.lib.samlib.wrapper import SamTranslatorWrapper from samcli.lib.samlib.resource_metadata_normalizer import ResourceMetadataNormalizer + LOG = logging.getLogger(__name__) @@ -88,7 +89,7 @@ def _resolve_parameters(template_dict, parameter_overrides): supported_intrinsics = {action.intrinsic_name: action() for action in SamBaseProvider._SUPPORTED_INTRINSICS} # Intrinsics resolver will mutate the original template - return IntrinsicsResolver(parameters=parameter_values, supported_intrinsics=supported_intrinsics) \ + return IntrinsicsResolver(parameters=parameter_values, supported_intrinsics=supported_intrinsics)\ .resolve_parameter_refs(template_dict) @staticmethod diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 321741e0bf..700491260d 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -288,112 +288,6 @@ def test_binary_response(self): self.assertEquals(response.content, expected) -class TestStartApiWithSwaggerRestApis(StartApiIntegBaseClass): - template_path = "/testdata/start_api/swagger-rest-api-template.yaml" - binary_data_file = "testdata/start_api/binarydata.gif" - - def setUp(self): - self.url = "http://127.0.0.1:{}".format(self.port) - - def test_get_call_with_path_setup_with_any_swagger(self): - """ - Get Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.get(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_post_call_with_path_setup_with_any_swagger(self): - """ - Post Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.post(self.url + "/anyandall", json={}) - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_put_call_with_path_setup_with_any_swagger(self): - """ - Put Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.put(self.url + "/anyandall", json={}) - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_head_call_with_path_setup_with_any_swagger(self): - """ - Head Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.head(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - - def test_delete_call_with_path_setup_with_any_swagger(self): - """ - Delete Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.delete(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_options_call_with_path_setup_with_any_swagger(self): - """ - Options Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.options(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - - def test_patch_call_with_path_setup_with_any_swagger(self): - """ - Patch Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.patch(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_function_not_defined_in_template(self): - response = requests.get(self.url + "/nofunctionfound") - - self.assertEquals(response.status_code, 502) - self.assertEquals(response.json(), {"message": "No function defined for resource method"}) - - def test_lambda_function_resource_is_reachable(self): - response = requests.get(self.url + "/nonserverlessfunction") - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_binary_request(self): - """ - This tests that the service can accept and invoke a lambda when given binary data in a request - """ - input_data = self.get_binary_data(self.binary_data_file) - response = requests.post(self.url + '/echobase64eventbody', - headers={"Content-Type": "image/gif"}, - data=input_data) - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.headers.get("Content-Type"), "image/gif") - self.assertEquals(response.content, input_data) - - def test_binary_response(self): - """ - Binary data is returned correctly - """ - expected = self.get_binary_data(self.binary_data_file) - - response = requests.get(self.url + '/base64response') - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.headers.get("Content-Type"), "image/gif") - self.assertEquals(response.content, expected) - - class TestServiceResponses(StartApiIntegBaseClass): """ Test Class centered around the different responses that can happen in Lambda and pass through start-api diff --git a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml deleted file mode 100644 index 5edeb8717f..0000000000 --- a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml +++ /dev/null @@ -1,69 +0,0 @@ -AWSTemplateFormatVersion: '2010-09-09' - -Resources: - Base64ResponseFunction: - Properties: - Code: "." - Handler: main.base64_response - Runtime: python3.6 - Type: AWS::Lambda::Function - EchoBase64EventBodyFunction: - Properties: - Code: "." - Handler: main.echo_base64_event_body - Runtime: python3.6 - Type: AWS::Lambda::Function - MyApi: - Properties: - Body: - info: - title: - Ref: AWS::StackName - paths: - "/anyandall": - x-amazon-apigateway-any-method: - x-amazon-apigateway-integration: - httpMethod: POST - responses: {} - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyNonServerlessLambdaFunction.Arn}/invocations - "/base64response": - get: - x-amazon-apigateway-integration: - httpMethod: POST - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${Base64ResponseFunction.Arn}/invocations - "/echobase64eventbody": - post: - x-amazon-apigateway-integration: - httpMethod: POST - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EchoBase64EventBodyFunction.Arn}/invocations - "/nofunctionfound": - get: - x-amazon-apigateway-integration: - httpMethod: POST - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${WhatFunction.Arn}/invocations - "/nonserverlessfunction": - get: - x-amazon-apigateway-integration: - httpMethod: POST - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyNonServerlessLambdaFunction.Arn}/invocations - swagger: '2.0' - x-amazon-apigateway-binary-media-types: - - image/gif - StageName: prod - Type: AWS::ApiGateway::RestApi - MyNonServerlessLambdaFunction: - Properties: - Code: "." - Handler: main.handler - Runtime: python3.6 - Type: AWS::Lambda::Function diff --git a/tests/unit/commands/local/lib/test_api_provider.py b/tests/unit/commands/local/lib/test_api_provider.py deleted file mode 100644 index 50b8d073d4..0000000000 --- a/tests/unit/commands/local/lib/test_api_provider.py +++ /dev/null @@ -1,207 +0,0 @@ -from collections import OrderedDict -from unittest import TestCase - -from mock import patch - -from samcli.commands.local.lib.api_provider import ApiProvider -from samcli.commands.local.lib.sam_api_provider import SamApiProvider -from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider - - -class TestApiProvider_init(TestCase): - - @patch.object(ApiProvider, "_extract_apis") - @patch("samcli.commands.local.lib.api_provider.SamBaseProvider") - def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): - extract_api_mock.return_value = {"set", "of", "values"} - - template = {"Resources": {"a": "b"}} - SamBaseProviderMock.get_template.return_value = template - - provider = ApiProvider(template) - - self.assertEquals(len(provider.apis), 3) - self.assertEquals(provider.apis, set(["set", "of", "values"])) - self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) - self.assertEquals(provider.resources, {"a": "b"}) - - -class TestApiProviderSelection(TestCase): - def test_default_provider(self): - resources = { - "TestApi": { - "Type": "AWS::UNKNOWN_TYPE", - "Properties": { - "StageName": "dev", - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - self.assertTrue(isinstance(provider, SamApiProvider)) - - def test_api_provider_sam_api(self): - resources = { - "TestApi": { - "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "dev", - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - self.assertTrue(isinstance(provider, SamApiProvider)) - - def test_api_provider_sam_function(self): - resources = { - "TestApi": { - "Type": "AWS::Serverless::Function", - "Properties": { - "StageName": "dev", - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - - self.assertTrue(isinstance(provider, SamApiProvider)) - - def test_api_provider_cloud_formation(self): - resources = { - "TestApi": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "StageName": "dev", - "Body": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - self.assertTrue(isinstance(provider, CfnApiProvider)) - - def test_multiple_api_provider_cloud_formation(self): - resources = OrderedDict() - resources["TestApi"] = { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "StageName": "dev", - "Body": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - resources["OtherApi"] = { - "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "dev", - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - self.assertTrue(isinstance(provider, CfnApiProvider)) diff --git a/tests/unit/commands/local/lib/test_cfn_api_provider.py b/tests/unit/commands/local/lib/test_cfn_api_provider.py deleted file mode 100644 index 723951eb11..0000000000 --- a/tests/unit/commands/local/lib/test_cfn_api_provider.py +++ /dev/null @@ -1,215 +0,0 @@ -import json -import tempfile -from unittest import TestCase - -from mock import patch -from six import assertCountEqual - -from samcli.commands.local.lib.api_provider import ApiProvider -from samcli.commands.local.lib.provider import Api -from tests.unit.commands.local.lib.test_sam_api_provider import make_swagger - - -class TestApiProviderWithApiGatewayRestApi(TestCase): - - def setUp(self): - self.binary_types = ["image/png", "image/jpg"] - self.input_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None) - ] - - def test_with_no_apis(self): - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - }, - - } - } - } - - provider = ApiProvider(template) - - self.assertEquals(provider.apis, []) - - def test_with_inline_swagger_apis(self): - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "Body": make_swagger(self.input_apis) - } - } - } - } - - provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) - - def test_with_swagger_as_local_file(self): - with tempfile.NamedTemporaryFile(mode='w') as fp: - filename = fp.name - - swagger = make_swagger(self.input_apis) - json.dump(swagger, fp) - fp.flush() - - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "BodyS3Location": filename - } - } - } - } - - provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) - - def test_body_with_swagger_as_local_file_expect_fail(self): - with tempfile.NamedTemporaryFile(mode='w') as fp: - filename = fp.name - - swagger = make_swagger(self.input_apis) - json.dump(swagger, fp) - fp.flush() - - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "Body": filename - } - } - } - } - self.assertRaises(Exception, ApiProvider, template) - - @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): - body = {"some": "body"} - filename = "somefile.txt" - - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "BodyS3Location": filename, - "Body": body - } - } - } - } - - SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) - - cwd = "foo" - provider = ApiProvider(template, cwd=cwd) - assertCountEqual(self, self.input_apis, provider.apis) - SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) - - def test_swagger_with_any_method(self): - apis = [ - Api(path="/path", method="any", function_name="SamFunc1", cors=None) - ] - - expected_apis = [ - Api(path="/path", method="GET", function_name="SamFunc1", cors=None), - Api(path="/path", method="POST", function_name="SamFunc1", cors=None), - Api(path="/path", method="PUT", function_name="SamFunc1", cors=None), - Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None), - Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None), - Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None), - Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None) - ] - - template = { - "Resources": { - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "Body": make_swagger(apis) - } - } - } - } - - provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) - - def test_with_binary_media_types(self): - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "Body": make_swagger(self.input_apis, binary_media_types=self.binary_types) - } - } - } - } - - expected_binary_types = sorted(self.binary_types) - expected_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types) - ] - - provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) - - def test_with_binary_media_types_in_swagger_and_on_resource(self): - input_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1"), - ] - extra_binary_types = ["text/html"] - - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "BinaryMediaTypes": extra_binary_types, - "Body": make_swagger(input_apis, binary_media_types=self.binary_types) - } - } - } - } - - expected_binary_types = sorted(self.binary_types + extra_binary_types) - expected_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=expected_binary_types), - ] - - provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index cfa35af954..3cc5d2c4c3 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -35,7 +35,7 @@ def setUp(self): self.lambda_invoke_context_mock.stderr = self.stderr_mock @patch("samcli.commands.local.lib.local_api_service.LocalApigwService") - @patch("samcli.commands.local.lib.local_api_service.ApiProvider") + @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") @patch.object(LocalApiService, "_make_routing_list") @@ -77,7 +77,7 @@ def test_must_start_service(self, self.apigw_service.run.assert_called_with() @patch("samcli.commands.local.lib.local_api_service.LocalApigwService") - @patch("samcli.commands.local.lib.local_api_service.ApiProvider") + @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") @patch.object(LocalApiService, "_make_routing_list") diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index fa5f342e49..3ac01956be 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -7,11 +7,29 @@ from six import assertCountEqual -from samcli.commands.local.lib.api_provider import ApiProvider, SamApiProvider +from samcli.commands.local.lib.sam_api_provider import SamApiProvider from samcli.commands.local.lib.provider import Api from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +class TestSamApiProvider_init(TestCase): + + @patch.object(SamApiProvider, "_extract_apis") + @patch("samcli.commands.local.lib.sam_api_provider.SamBaseProvider") + def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): + extract_api_mock.return_value = {"set", "of", "values"} + + template = {"Resources": {"a": "b"}} + SamBaseProviderMock.get_template.return_value = template + + provider = SamApiProvider(template) + + self.assertEquals(len(provider.apis), 3) + self.assertEquals(provider.apis, set(["set", "of", "values"])) + self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) + self.assertEquals(provider.resources, {"a": "b"}) + + class TestSamApiProviderWithImplicitApis(TestCase): def test_provider_with_no_resource_properties(self): @@ -24,8 +42,9 @@ def test_provider_with_no_resource_properties(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) + self.assertEquals(len(provider.apis), 0) self.assertEquals(provider.apis, []) @parameterized.expand([("GET"), ("get")]) @@ -53,7 +72,7 @@ def test_provider_has_correct_api(self, method): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) self.assertEquals(len(provider.apis), 1) self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", cors=None, @@ -90,7 +109,7 @@ def test_provider_creates_api_for_all_events(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api_event1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api_event2 = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") @@ -140,7 +159,7 @@ def test_provider_has_correct_template(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api2 = Api(path="/path", method="POST", function_name="SamFunc2", cors=None, stage_name="Prod") @@ -171,7 +190,7 @@ def test_provider_with_no_api_events(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) self.assertEquals(provider.apis, []) @@ -190,7 +209,7 @@ def test_provider_with_no_serverless_function(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) self.assertEquals(provider.apis, []) @@ -235,7 +254,7 @@ def test_provider_get_all(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) result = [f for f in provider.get_all()] @@ -248,7 +267,7 @@ def test_provider_get_all(self): def test_provider_get_all_with_no_apis(self): template = {} - provider = ApiProvider(template) + provider = SamApiProvider(template) result = [f for f in provider.get_all()] @@ -279,7 +298,7 @@ def test_provider_with_any_method(self, method): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api_get = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api_post = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") @@ -332,7 +351,7 @@ def test_provider_must_support_binary_media_types(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) self.assertEquals(len(provider.apis), 1) self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", @@ -384,7 +403,7 @@ def test_provider_must_support_binary_media_types_with_any_method(self): Api(path="/path", method="PATCH", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod") ] - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, provider.apis, expected_apis) @@ -429,8 +448,9 @@ def test_with_no_apis(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) + self.assertEquals(len(provider.apis), 0) self.assertEquals(provider.apis, []) def test_with_inline_swagger_apis(self): @@ -447,7 +467,7 @@ def test_with_inline_swagger_apis(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, self.input_apis, provider.apis) def test_with_swagger_as_local_file(self): @@ -471,11 +491,11 @@ def test_with_swagger_as_local_file(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, self.input_apis, provider.apis) - @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): + @patch("samcli.commands.local.lib.sam_api_provider.SamSwaggerReader") + def test_with_swagger_as_both_body_and_uri(self, SamSwaggerReaderMock): body = {"some": "body"} filename = "somefile.txt" @@ -496,7 +516,7 @@ def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) cwd = "foo" - provider = ApiProvider(template, cwd=cwd) + provider = SamApiProvider(template, cwd=cwd) assertCountEqual(self, self.input_apis, provider.apis) SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) @@ -527,7 +547,7 @@ def test_swagger_with_any_method(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) def test_with_binary_media_types(self): @@ -560,7 +580,7 @@ def test_with_binary_media_types(self): binary_media_types=expected_binary_types, stage_name="Prod") ] - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) def test_with_binary_media_types_in_swagger_and_on_resource(self): @@ -589,7 +609,7 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): stage_name="Prod"), ] - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) @@ -666,7 +686,7 @@ def test_must_union_implicit_and_explicit(self): Api(path="/path3", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_prefer_implicit_api_over_explicit(self): @@ -703,7 +723,7 @@ def test_must_prefer_implicit_api_over_explicit(self): Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_prefer_implicit_with_any_method(self): @@ -737,7 +757,7 @@ def test_must_prefer_implicit_with_any_method(self): Api(path="/path", method="PATCH", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_with_any_method_on_both(self): @@ -782,7 +802,8 @@ def test_with_any_method_on_both(self): Api(path="/path2", method="POST", function_name="explicitfunction", cors=None, stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) + print(provider.apis) assertCountEqual(self, expected_apis, provider.apis) def test_must_add_explicit_api_when_ref_with_rest_api_id(self): @@ -819,7 +840,7 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): Api(path="/newpath2", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_both_apis_must_get_binary_media_types(self): @@ -874,7 +895,7 @@ def test_both_apis_must_get_binary_media_types(self): stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_binary_media_types_with_rest_api_id_reference(self): @@ -934,7 +955,7 @@ def test_binary_media_types_with_rest_api_id_reference(self): stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) @@ -970,7 +991,7 @@ def test_provider_parse_stage_name(self): } } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], stage_name='dev', stage_variables=None) @@ -1012,7 +1033,7 @@ def test_provider_stage_variables(self): } } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], stage_name='dev', stage_variables={ @@ -1098,7 +1119,8 @@ def test_multi_stage_get_all(self): } } } - provider = ApiProvider(template) + + provider = SamApiProvider(template) result = [f for f in provider.get_all()]