diff --git a/samcli/commands/local/lib/api_collector.py b/samcli/commands/local/lib/api_collector.py index cbd198c6b7..be18cea8c8 100644 --- a/samcli/commands/local/lib/api_collector.py +++ b/samcli/commands/local/lib/api_collector.py @@ -1,207 +1,173 @@ """ Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit -APIs in a standardized format +routes in a standardized format """ import logging -from collections import namedtuple +from collections import defaultdict from six import string_types +from samcli.local.apigw.local_apigw_service import Route +from samcli.commands.local.lib.provider import Api + LOG = logging.getLogger(__name__) class ApiCollector(object): - # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. - # This is intentional because it allows us to easily extend this class to support future properties on the API. - # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit - # APIs. - Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) def __init__(self): - # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and - # value is the properties - self.by_resource = {} + # Route properties stored per resource. + self._route_per_resource = defaultdict(list) + + # processed values to be set before creating the api + self._routes = [] + self.binary_media_types_set = set() + self.stage_name = None + self.stage_variables = None def __iter__(self): """ - Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the - LogicalId of the API resource and a list of APIs available in this resource. - + Iterator to iterate through all the routes stored in the collector. In each iteration, this yields the + LogicalId of the route resource and a list of routes available in this resource. Yields ------- str - LogicalID of the AWS::Serverless::Api resource + LogicalID of the AWS::Serverless::Api or AWS::ApiGateway::RestApi resource list samcli.commands.local.lib.provider.Api List of the API available in this resource along with additional configuration like binary media types. """ - for logical_id, _ in self.by_resource.items(): - yield logical_id, self._get_apis_with_config(logical_id) + for logical_id, _ in self._route_per_resource.items(): + yield logical_id, self._get_routes(logical_id) - def add_apis(self, logical_id, apis): + def add_routes(self, logical_id, routes): """ - Stores the given APIs tagged under the given logicalId - + Stores the given routes tagged under the given logicalId Parameters ---------- logical_id : str - LogicalId of the AWS::Serverless::Api resource - - apis : list of samcli.commands.local.lib.provider.Api - List of APIs available in this resource + LogicalId of the AWS::Serverless::Api or AWS::ApiGateway::RestApi resource + routes : list of samcli.commands.local.agiw.local_apigw_service.Route + List of routes available in this resource """ - properties = self._get_properties(logical_id) - properties.apis.extend(apis) + self._get_routes(logical_id).extend(routes) - def add_binary_media_types(self, logical_id, binary_media_types): + def _get_routes(self, logical_id): """ - Stores the binary media type configuration for the API with given logical ID - + Returns the properties of resource with given logical ID. If a resource is not found, then it returns an + empty data. Parameters ---------- logical_id : str - LogicalId of the AWS::Serverless::Api resource - - binary_media_types : list of str - List of binary media types supported by this resource - - """ - properties = self._get_properties(logical_id) - - binary_media_types = binary_media_types or [] - for value in binary_media_types: - normalized_value = self._normalize_binary_media_type(value) - - # If the value is not supported, then just skip it. - if normalized_value: - properties.binary_media_types.add(normalized_value) - else: - LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) - - def add_stage_name(self, logical_id, stage_name): + Logical ID of the resource + Returns + ------- + samcli.commands.local.lib.Routes + Properties object for this resource. """ - Stores the stage name for the API with the given local ID - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource + return self._route_per_resource[logical_id] - stage_name : str - The stage_name string + @property + def routes(self): + return self._routes if self._routes else self.all_routes() - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_name=stage_name) - self._set_properties(logical_id, properties) + @routes.setter + def routes(self, routes): + self._routes = routes - def add_stage_variables(self, logical_id, stage_variables): + def all_routes(self): """ - Stores the stage variables for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_variables : dict - A dictionary containing stage variables. + Gets all the routes within the _route_per_resource + Return + ------- + All the routes within the _route_per_resource """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_variables=stage_variables) - self._set_properties(logical_id, properties) + routes = [] + for logical_id in self._route_per_resource.keys(): + routes.extend(self._get_routes(logical_id)) + return routes - def _get_apis_with_config(self, logical_id): + def get_api(self): """ - Returns the list of APIs in this resource along with other extra configuration such as binary media types, - cors etc. Additional configuration is merged directly into the API data because these properties, although - defined globally, actually apply to each API. + Creates the api using the parts from the ApiCollector. The routes are also deduped so that there is no + duplicate routes with the same function name, path, but different method. - Parameters - ---------- - logical_id : str - Logical ID of the resource to fetch data for + The normalised_routes are the routes that have been processed. By default, this will get all the routes. + However, it can be changed to override the default value of normalised routes such as in SamApiProvider - Returns + Return ------- - list of samcli.commands.local.lib.provider.Api - List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, - then it returns an empty list + An Api object with all the properties """ + api = Api() + api.routes = self.dedupe_function_routes(self.routes) + api.binary_media_types_set = self.binary_media_types_set + api.stage_name = self.stage_name + api.stage_variables = self.stage_variables + return api - properties = self._get_properties(logical_id) + @staticmethod + def dedupe_function_routes(routes): + """ + Remove duplicate routes that have the same function_name and method - # These configs need to be applied to each API - binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable - cors = properties.cors - stage_name = properties.stage_name - stage_variables = properties.stage_variables + route: list(Route) + List of Routes - result = [] - for api in properties.apis: - # Create a copy of the API with updated configuration - updated_api = api._replace(binary_media_types=binary_media, - cors=cors, - stage_name=stage_name, - stage_variables=stage_variables) - result.append(updated_api) + Return + ------- + A list of routes without duplicate routes with the same function_name and method + """ + grouped_routes = {} - return result + for route in routes: + key = "{}-{}".format(route.function_name, route.path) + config = grouped_routes.get(key, None) + methods = route.methods + if config: + methods += config.methods + sorted_methods = sorted(methods) + grouped_routes[key] = Route(function_name=route.function_name, path=route.path, methods=sorted_methods) + return list(grouped_routes.values()) - def _get_properties(self, logical_id): + def add_binary_media_types(self, logical_id, binary_media_types): """ - Returns the properties of resource with given logical ID. If a resource is not found, then it returns an - empty data. - + Stores the binary media type configuration for the API with given logical ID Parameters ---------- - logical_id : str - Logical ID of the resource - Returns - ------- - samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ - - if logical_id not in self.by_resource: - self.by_resource[logical_id] = self.Properties(apis=[], - # Use a set() to be able to easily de-dupe - binary_media_types=set(), - cors=None, - stage_name=None, - stage_variables=None) + logical_id : str + LogicalId of the AWS::Serverless::Api resource - return self.by_resource[logical_id] + api: samcli.commands.local.lib.provider.Api + Instance of the Api which will save all the api configurations - def _set_properties(self, logical_id, properties): + binary_media_types : list of str + List of binary media types supported by this resource """ - Sets the properties of resource with given logical ID. If a resource is not found, it does nothing - Parameters - ---------- - logical_id : str - Logical ID of the resource - properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ + binary_media_types = binary_media_types or [] + for value in binary_media_types: + normalized_value = self.normalize_binary_media_type(value) - if logical_id in self.by_resource: - self.by_resource[logical_id] = properties + # If the value is not supported, then just skip it. + if normalized_value: + self.binary_media_types_set.add(normalized_value) + else: + LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) @staticmethod - def _normalize_binary_media_type(value): + def normalize_binary_media_type(value): """ Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not a string, then this method just returns None - Parameters ---------- value : str Value to be normalized - Returns ------- str or None diff --git a/samcli/commands/local/lib/api_provider.py b/samcli/commands/local/lib/api_provider.py index afc686e166..20d31039f7 100644 --- a/samcli/commands/local/lib/api_provider.py +++ b/samcli/commands/local/lib/api_provider.py @@ -1,13 +1,13 @@ -"""Class that provides Apis from a SAM Template""" +"""Class that provides the Api with a list of routes from a Template""" import logging -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider from samcli.commands.local.lib.api_collector import ApiCollector +from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider from samcli.commands.local.lib.provider import AbstractApiProvider -from samcli.commands.local.lib.sam_base_provider import SamBaseProvider from samcli.commands.local.lib.sam_api_provider import SamApiProvider -from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider +from samcli.commands.local.lib.sam_base_provider import SamBaseProvider LOG = logging.getLogger(__name__) @@ -16,7 +16,7 @@ class ApiProvider(AbstractApiProvider): def __init__(self, template_dict, parameter_overrides=None, cwd=None): """ - Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed + Initialize the class with template data. The template_dict is assumed to be valid, normalized and a dictionary. template_dict should be normalized by running any and all pre-processing before passing to this class. This class does not perform any syntactic validation of the template. @@ -27,7 +27,7 @@ def __init__(self, template_dict, parameter_overrides=None, cwd=None): Parameters ---------- template_dict : dict - SAM Template as a dictionary + Template as a dictionary cwd : str Optional working directory with respect to which we will resolve relative path to Swagger file @@ -39,23 +39,22 @@ def __init__(self, template_dict, parameter_overrides=None, cwd=None): # Store a set of apis self.cwd = cwd - self.apis = self._extract_apis(self.resources) - - LOG.debug("%d APIs found in the template", len(self.apis)) + self.api = self._extract_api(self.resources) + self.routes = self.api.routes + LOG.debug("%d APIs found in the template", len(self.routes)) def get_all(self): """ - Yields all the Lambda functions with Api Events available in the SAM Template. + Yields all the Apis in the current Provider - :yields Api: namedtuple containing the Api information + :yields api: an Api object with routes and properties """ - for api in self.apis: - yield api + yield self.api - def _extract_apis(self, resources): + def _extract_api(self, resources): """ - Extracts all the Apis by running through the one providers. The provider that has the first type matched + Extracts all the routes by running through the one providers. The provider that has the first type matched will be run across all the resources Parameters @@ -64,12 +63,12 @@ def _extract_apis(self, resources): The dictionary containing the different resources within the template Returns --------- - list of Apis extracted from the resources + An Api from the parsed template """ collector = ApiCollector() provider = self.find_api_provider(resources) - apis = provider.extract_resource_api(resources, collector, cwd=self.cwd) - return self.normalize_apis(apis) + provider.extract_resources(resources, collector, cwd=self.cwd) + return collector.get_api() @staticmethod def find_api_provider(resources): diff --git a/samcli/commands/local/lib/cfn_api_provider.py b/samcli/commands/local/lib/cfn_api_provider.py index 0e3919611c..dc1c16848f 100644 --- a/samcli/commands/local/lib/cfn_api_provider.py +++ b/samcli/commands/local/lib/cfn_api_provider.py @@ -1,6 +1,7 @@ """Parses SAM given a template""" import logging +from samcli.commands.local.cli_common.user_exceptions import InvalidSamTemplateException from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider LOG = logging.getLogger(__name__) @@ -8,20 +9,22 @@ class CfnApiProvider(CfnBaseApiProvider): APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" + APIGATEWAY_STAGE = "AWS::ApiGateway::Stage" TYPES = [ - APIGATEWAY_RESTAPI + APIGATEWAY_RESTAPI, + APIGATEWAY_STAGE ] - def extract_resource_api(self, resources, collector, cwd=None): + def extract_resources(self, resources, collector, cwd=None): """ - Extract the Api Object from a given resource and adds it to the ApiCollector. + Extract the Route Object from a given resource and adds it to the RouteCollector. Parameters ---------- resources: dict The dictionary containing the different resources within the template - collector: ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Instance of the API collector that where we will save the API information cwd : str @@ -29,18 +32,17 @@ def extract_resource_api(self, resources, collector, cwd=None): Return ------- - Returns a list of Apis + Returns a list of routes """ for logical_id, resource in resources.items(): resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) if resource_type == CfnApiProvider.APIGATEWAY_RESTAPI: - self._extract_cloud_formation_api(logical_id, resource, collector, cwd) - all_apis = [] - for _, apis in collector: - all_apis.extend(apis) - return all_apis + self._extract_cloud_formation_route(logical_id, resource, collector, cwd=cwd) - def _extract_cloud_formation_api(self, logical_id, api_resource, collector, cwd=None): + if resource_type == CfnApiProvider.APIGATEWAY_STAGE: + self._extract_cloud_formation_stage(resources, resource, collector) + + def _extract_cloud_formation_route(self, logical_id, api_resource, collector, cwd=None): """ Extract APIs from AWS::ApiGateway::RestApi resource by reading and parsing Swagger documents. The result is added to the collector. @@ -66,4 +68,38 @@ def _extract_cloud_formation_api(self, logical_id, api_resource, collector, cwd= LOG.debug("Skipping resource '%s'. Swagger document not found in Body and BodyS3Location", logical_id) return - self.extract_swagger_api(logical_id, body, body_s3_location, binary_media, collector, cwd) + self.extract_swagger_route(logical_id, body, body_s3_location, binary_media, collector, cwd) + + @staticmethod + def _extract_cloud_formation_stage(resources, stage_resource, collector): + """ + Extract the stage from AWS::ApiGateway::Stage resource by reading and adds it to the collector. + Parameters + ---------- + resources: dict + All Resource definition, including its properties + + stage_resource : dict + Stage Resource definition, including its properties + + collector : ApiCollector + Instance of the API collector that where we will save the API information + """ + properties = stage_resource.get("Properties", {}) + stage_name = properties.get("StageName") + stage_variables = properties.get("Variables") + + # Currently, we aren't resolving any Refs or other intrinsic properties that come with it + # A separate pr will need to fully resolve intrinsics + logical_id = properties.get("RestApiId") + if not logical_id: + raise InvalidSamTemplateException("The AWS::ApiGateway::Stage must have a RestApiId property") + + rest_api_resource_type = resources.get(logical_id, {}).get("Type") + if rest_api_resource_type != CfnApiProvider.APIGATEWAY_RESTAPI: + raise InvalidSamTemplateException( + "The AWS::ApiGateway::Stage must have a valid RestApiId that points to RestApi resource {}".format( + logical_id)) + + collector.stage_name = stage_name + collector.stage_variables = stage_variables diff --git a/samcli/commands/local/lib/cfn_base_api_provider.py b/samcli/commands/local/lib/cfn_base_api_provider.py index 79bc6d8f1d..8d0d4c3774 100644 --- a/samcli/commands/local/lib/cfn_base_api_provider.py +++ b/samcli/commands/local/lib/cfn_base_api_provider.py @@ -1,9 +1,8 @@ """Class that parses the CloudFormation Api Template""" - import logging from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.swagger.reader import SamSwaggerReader +from samcli.commands.local.lib.swagger.reader import SwaggerReader LOG = logging.getLogger(__name__) @@ -11,16 +10,16 @@ class CfnBaseApiProvider(object): RESOURCE_TYPE = "Type" - def extract_resource_api(self, resources, collector, cwd=None): + def extract_resources(self, resources, collector, cwd=None): """ - Extract the Api Object from a given resource and adds it to the ApiCollector. + Extract the Route Object from a given resource and adds it to the RouteCollector. Parameters ---------- resources: dict The dictionary containing the different resources within the template - collector: ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Instance of the API collector that where we will save the API information cwd : str @@ -28,12 +27,11 @@ def extract_resource_api(self, resources, collector, cwd=None): Return ------- - Returns a list of Apis + Returns a list of routes """ raise NotImplementedError("not implemented") - @staticmethod - def extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd=None): + def extract_swagger_route(self, logical_id, body, uri, binary_media, collector, cwd=None): """ Parse the Swagger documents and adds it to the ApiCollector. @@ -51,20 +49,21 @@ def extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd=None binary_media: list The link to the binary media - collector: ApiCollector - Instance of the API collector that where we will save the API information + collector: samcli.commands.local.lib.route_collector.RouteCollector + Instance of the Route collector that where we will save the route information cwd : str Optional working directory with respect to which we will resolve relative path to Swagger file """ - reader = SamSwaggerReader(definition_body=body, - definition_uri=uri, - working_dir=cwd) + reader = SwaggerReader(definition_body=body, + definition_uri=uri, + working_dir=cwd) swagger = reader.read() parser = SwaggerParser(swagger) - apis = parser.get_apis() - LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) + routes = parser.get_routes() + LOG.debug("Found '%s' APIs in resource '%s'", len(routes), logical_id) + + collector.add_routes(logical_id, routes) - collector.add_apis(logical_id, apis) collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template diff --git a/samcli/commands/local/lib/local_api_service.py b/samcli/commands/local/lib/local_api_service.py index d456e67a83..441d6c3cbc 100644 --- a/samcli/commands/local/lib/local_api_service.py +++ b/samcli/commands/local/lib/local_api_service.py @@ -2,20 +2,20 @@ Connects the CLI with Local API Gateway service. """ -import os import logging +import os -from samcli.local.apigw.local_apigw_service import LocalApigwService, Route -from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined +from samcli.local.apigw.local_apigw_service import LocalApigwService +from samcli.commands.local.lib.api_provider import ApiProvider LOG = logging.getLogger(__name__) class LocalApiService(object): """ - Implementation of Local API service that is capable of serving APIs defined in a SAM file that invoke a Lambda - function. + Implementation of Local API service that is capable of serving API defined in a configuration file that invoke a + Lambda function. """ def __init__(self, @@ -53,10 +53,8 @@ def start(self): NOTE: This is a blocking call that will not return until the thread is interrupted with SIGINT/SIGTERM """ - routing_list = self._make_routing_list(self.api_provider) - - if not routing_list: - raise NoApisDefined("No APIs available in SAM template") + if not self.api_provider.api.routes: + raise NoApisDefined("No APIs available in template") static_dir_path = self._make_static_dir_path(self.cwd, self.static_dir) @@ -64,7 +62,7 @@ def start(self): # contains the response to the API which is sent out as HTTP response. Only stderr needs to be printed # to the console or a log file. stderr from Docker container contains runtime logs and output of print # statements from the Lambda function - service = LocalApigwService(routing_list=routing_list, + service = LocalApigwService(api=self.api_provider.api, lambda_runner=self.lambda_runner, static_dir=static_dir_path, port=self.port, @@ -74,7 +72,7 @@ def start(self): service.create() # Print out the list of routes that will be mounted - self._print_routes(self.api_provider, self.host, self.port) + self._print_routes(self.api_provider.api.routes, self.host, self.port) LOG.info("You can now browse to the above endpoints to invoke your functions. " "You do not need to restart/reload SAM CLI while working on your functions, " "changes will be reflected instantly/automatically. You only need to restart " @@ -83,30 +81,7 @@ def start(self): service.run() @staticmethod - def _make_routing_list(api_provider): - """ - Returns a list of routes to configure the Local API Service based on the APIs configured in the template. - - Parameters - ---------- - api_provider : samcli.commands.local.lib.api_provider.ApiProvider - - Returns - ------- - list(samcli.local.apigw.service.Route) - List of Routes to pass to the service - """ - - routes = [] - for api in api_provider.get_all(): - route = Route(methods=[api.method], function_name=api.function_name, path=api.path, - binary_types=api.binary_media_types, stage_name=api.stage_name, - stage_variables=api.stage_variables) - routes.append(route) - return routes - - @staticmethod - def _print_routes(api_provider, host, port): + def _print_routes(routes, host, port): """ Helper method to print the APIs that will be mounted. This method is purely for printing purposes. This method takes in a list of Route Configurations and prints out the Routes grouped by path. @@ -116,8 +91,8 @@ def _print_routes(api_provider, host, port): Mounting Product at http://127.0.0.1:3000/path1/bar [GET, POST, DELETE] Mounting Product at http://127.0.0.1:3000/path2/bar [HEAD] - :param samcli.commands.local.lib.provider.AbstractApiProvider api_provider: - API Provider that can return a list of APIs + :param list(Route) routes: + List of routes grouped by the same function_name and path :param string host: Host name where the service is running :param int port: @@ -125,28 +100,15 @@ def _print_routes(api_provider, host, port): :returns list(string): List of lines that were printed to the console. Helps with testing """ - grouped_api_configs = {} - - for api in api_provider.get_all(): - key = "{}-{}".format(api.function_name, api.path) - - config = grouped_api_configs.get(key, {}) - config.setdefault("methods", []) - - config["function_name"] = api.function_name - config["path"] = api.path - config["methods"].append(api.method) - - grouped_api_configs[key] = config print_lines = [] - for _, config in grouped_api_configs.items(): - methods_str = "[{}]".format(', '.join(config["methods"])) + for route in routes: + methods_str = "[{}]".format(', '.join(route.methods)) output = "Mounting {} at http://{}:{}{} {}".format( - config["function_name"], + route.function_name, host, port, - config["path"], + route.path, methods_str) print_lines.append(output) diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 959166e814..94789ba799 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -199,40 +199,32 @@ def get_all(self): raise NotImplementedError("not implemented") -_ApiTuple = namedtuple("Api", [ +class Api(object): + def __init__(self, routes=None): + if routes is None: + routes = [] + self.routes = routes - # String. Path that this API serves. Ex: /foo, /bar/baz - "path", + # Optional Dictionary containing CORS configuration on this path+method If this configuration is set, + # then API server will automatically respond to OPTIONS HTTP method on this path and respond with appropriate + # CORS headers based on configuration. - # String. HTTP Method this API responds with - "method", + self.cors = None + # If this configuration is set, then API server will automatically respond to OPTIONS HTTP method on this + # path and - # String. Name of the Function this API connects to - "function_name", + self.binary_media_types_set = set() - # Optional Dictionary containing CORS configuration on this path+method - # If this configuration is set, then API server will automatically respond to OPTIONS HTTP method on this path and - # respond with appropriate CORS headers based on configuration. - "cors", + self.stage_name = None + self.stage_variables = None - # List(Str). List of the binary media types the API - "binary_media_types", - # The Api stage name - "stage_name", - # The variables for that stage - "stage_variables" -]) -_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, - None, # Stage name is optional with default None - None # Stage variables is optional with default None - ) - - -class Api(_ApiTuple): def __hash__(self): # Other properties are not a part of the hash - return hash(self.path) * hash(self.method) * hash(self.function_name) + return hash(self.routes) * hash(self.cors) * hash(self.binary_media_types_set) + + @property + def binary_media_types(self): + return list(self.binary_media_types_set) Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) @@ -242,13 +234,6 @@ class AbstractApiProvider(object): """ Abstract base class to return APIs and the functions they route to """ - _ANY_HTTP_METHODS = ["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"] def get_all(self): """ @@ -257,43 +242,3 @@ def get_all(self): :yields Api: namedtuple containing the API information """ raise NotImplementedError("not implemented") - - @staticmethod - def normalize_http_methods(http_method): - """ - Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all - supported Http Methods on Api Gateway. - - :param str http_method: Http method - :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) - """ - - if http_method.upper() == 'ANY': - for method in AbstractApiProvider._ANY_HTTP_METHODS: - yield method.upper() - else: - yield http_method.upper() - - @staticmethod - def normalize_apis(apis): - """ - Normalize the APIs to use standard method name - - Parameters - ---------- - apis : list of samcli.commands.local.lib.provider.Api - List of APIs to replace normalize - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of normalized APIs - """ - - result = list() - for api in apis: - for normalized_method in AbstractApiProvider.normalize_http_methods(api.method): - # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple - result.append(api._replace(method=normalized_method)) - - return result diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index f0ec57b823..1710edbf2d 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -2,9 +2,9 @@ import logging -from samcli.commands.local.lib.provider import Api, AbstractApiProvider -from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider +from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +from samcli.local.apigw.local_apigw_service import Route LOG = logging.getLogger(__name__) @@ -23,24 +23,21 @@ class SamApiProvider(CfnBaseApiProvider): _EVENT_TYPE = "Type" IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" - def extract_resource_api(self, resources, collector, cwd=None): + def extract_resources(self, resources, collector, cwd=None): """ - Extract the Api Object from a given resource and adds it to the ApiCollector. + Extract the Route Object from a given resource and adds it to the RouteCollector. Parameters ---------- resources: dict The dictionary containing the different resources within the template - collector: ApiCollector + collector: samcli.commands.local.lib.route_collector.ApiCollector Instance of the API collector that where we will save the API information cwd : str Optional working directory with respect to which we will resolve relative path to Swagger file - Return - ------- - Returns a list of Apis """ # AWS::Serverless::Function is currently included when parsing of Apis because when SamBaseProvider is run on # the template we are creating the implicit apis due to plugins that translate it in the SAM repo, @@ -49,10 +46,11 @@ def extract_resource_api(self, resources, collector, cwd=None): for logical_id, resource in resources.items(): resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) if resource_type == SamApiProvider.SERVERLESS_FUNCTION: - self._extract_apis_from_function(logical_id, resource, collector) + self._extract_routes_from_function(logical_id, resource, collector) if resource_type == SamApiProvider.SERVERLESS_API: - self._extract_from_serverless_api(logical_id, resource, collector, cwd) - return self.merge_apis(collector) + self._extract_from_serverless_api(logical_id, resource, collector, cwd=cwd) + + collector.routes = self.merge_routes(collector) def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=None): """ @@ -67,8 +65,12 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= api_resource : dict Resource definition, including its properties - collector : ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Instance of the API collector that where we will save the API information + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + """ properties = api_resource.get("Properties", {}) @@ -83,13 +85,13 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri", logical_id) return - self.extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd) - collector.add_stage_name(logical_id, stage_name) - collector.add_stage_variables(logical_id, stage_variables) + self.extract_swagger_route(logical_id, body, uri, binary_media, collector, cwd=cwd) + collector.stage_name = stage_name + collector.stage_variables = stage_variables - def _extract_apis_from_function(self, logical_id, function_resource, collector): + def _extract_routes_from_function(self, logical_id, function_resource, collector): """ - Fetches a list of APIs configured for this SAM Function resource. + Fetches a list of routes configured for this SAM Function resource. Parameters ---------- @@ -99,17 +101,17 @@ def _extract_apis_from_function(self, logical_id, function_resource, collector): function_resource : dict Contents of the function resource including its properties - collector : ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Instance of the API collector that where we will save the API information """ resource_properties = function_resource.get("Properties", {}) serverless_function_events = resource_properties.get(self._FUNCTION_EVENT, {}) - self.extract_apis_from_events(logical_id, serverless_function_events, collector) + self.extract_routes_from_events(logical_id, serverless_function_events, collector) - def extract_apis_from_events(self, function_logical_id, serverless_function_events, collector): + def extract_routes_from_events(self, function_logical_id, serverless_function_events, collector): """ - Given an AWS::Serverless::Function Event Dictionary, extract out all 'Api' events and store within the + Given an AWS::Serverless::Function Event Dictionary, extract out all 'route' events and store within the collector Parameters @@ -120,27 +122,27 @@ def extract_apis_from_events(self, function_logical_id, serverless_function_even serverless_function_events : dict Event Dictionary of a AWS::Serverless::Function - collector : ApiCollector - Instance of the API collector that where we will save the API information + collector: samcli.commands.local.lib.route_collector.RouteCollector + Instance of the Route collector that where we will save the route information """ count = 0 for _, event in serverless_function_events.items(): if self._FUNCTION_EVENT_TYPE_API == event.get(self._EVENT_TYPE): - api_resource_id, api = self._convert_event_api(function_logical_id, event.get("Properties")) - collector.add_apis(api_resource_id, [api]) + route_resource_id, route = self._convert_event_route(function_logical_id, event.get("Properties")) + collector.add_routes(route_resource_id, [route]) count += 1 LOG.debug("Found '%d' API Events in Serverless function with name '%s'", count, function_logical_id) @staticmethod - def _convert_event_api(lambda_logical_id, event_properties): + def _convert_event_route(lambda_logical_id, event_properties): """ - Converts a AWS::Serverless::Function's Event Property to an Api configuration usable by the provider. + Converts a AWS::Serverless::Function's Event Property to an Route configuration usable by the provider. :param str lambda_logical_id: Logical Id of the AWS::Serverless::Function :param dict event_properties: Dictionary of the Event's Property - :return tuple: tuple of API resource name and Api namedTuple + :return tuple: tuple of route resource name and route """ path = event_properties.get(SamApiProvider._EVENT_PATH) method = event_properties.get(SamApiProvider._EVENT_METHOD) @@ -159,55 +161,54 @@ def _convert_event_api(lambda_logical_id, event_properties): "It should either be a LogicalId string or a Ref of a Logical Id string" .format(lambda_logical_id)) - return api_resource_id, Api(path=path, method=method, function_name=lambda_logical_id) + return api_resource_id, Route(path=path, methods=[method], function_name=lambda_logical_id) @staticmethod - def merge_apis(collector): + def merge_routes(collector): """ - Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API + Quite often, an API is defined both in Implicit and Explicit Route definitions. In such cases, Implicit API definition wins because that conveys clear intent that the API is backed by a function. This method will - merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined + merge two such list of routes with the right order of precedence. If a Path+Method combination is defined in both the places, only one wins. Parameters ---------- - collector : ApiCollector + collector: samcli.commands.local.lib.route_collector.RouteCollector Collector object that holds all the APIs specified in the template Returns ------- - list of samcli.commands.local.lib.provider.Api - List of APIs obtained by combining both the input lists. + list of samcli.local.apigw.local_apigw_service.Route + List of routes obtained by combining both the input lists. """ - implicit_apis = [] - explicit_apis = [] + implicit_routes = [] + explicit_routes = [] # Store implicit and explicit APIs separately in order to merge them later in the correct order # Implicit APIs are defined on a resource with logicalID ServerlessRestApi for logical_id, apis in collector: if logical_id == SamApiProvider.IMPLICIT_API_RESOURCE_ID: - implicit_apis.extend(apis) + implicit_routes.extend(apis) else: - explicit_apis.extend(apis) + explicit_routes.extend(apis) # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. # If an path+method combo already exists, then overwrite it if and only if this is an implicit API - all_apis = {} + all_routes = {} # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. - all_configs = explicit_apis + implicit_apis + all_configs = explicit_routes + implicit_routes for config in all_configs: # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP - # method on explicit API. - for normalized_method in AbstractApiProvider.normalize_http_methods(config.method): + # method on explicit route. + for normalized_method in config.methods: key = config.path + normalized_method - all_apis[key] = config + all_routes[key] = config - result = set(all_apis.values()) # Assign to a set() to de-dupe + result = set(all_routes.values()) # Assign to a set() to de-dupe LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", - len(explicit_apis), len(implicit_apis), len(result)) - + len(explicit_routes), len(implicit_routes), len(result)) return list(result) diff --git a/samcli/commands/local/lib/swagger/parser.py b/samcli/commands/local/lib/swagger/parser.py index 076161993c..072e71c378 100644 --- a/samcli/commands/local/lib/swagger/parser.py +++ b/samcli/commands/local/lib/swagger/parser.py @@ -2,8 +2,8 @@ import logging -from samcli.commands.local.lib.provider import Api from samcli.commands.local.lib.swagger.integration_uri import LambdaUri, IntegrationType +from samcli.local.apigw.local_apigw_service import Route LOG = logging.getLogger(__name__) @@ -34,7 +34,7 @@ def get_binary_media_types(self): """ return self.swagger.get(self._BINARY_MEDIA_TYPES_EXTENSION_KEY) or [] - def get_apis(self): + def get_routes(self): """ Parses a swagger document and returns a list of APIs configured in the document. @@ -62,15 +62,13 @@ def get_apis(self): Returns ------- - list of samcli.commands.local.lib.provider.Api + list of list of samcli.commands.local.apigw.local_apigw_service.Route List of APIs that are configured in the Swagger document """ result = [] paths_dict = self.swagger.get("paths", {}) - binary_media_types = self.get_binary_media_types() - for full_path, path_config in paths_dict.items(): for method, method_config in path_config.items(): @@ -83,11 +81,8 @@ def get_apis(self): if method.lower() == self._ANY_METHOD_EXTENSION_KEY: # Convert to a more commonly used method notation method = self._ANY_METHOD - - api = Api(path=full_path, method=method, function_name=function_name, cors=None, - binary_media_types=binary_media_types) - result.append(api) - + route = Route(function_name, full_path, methods=[method]) + result.append(route) return result def _get_integration_function_name(self, method_config): diff --git a/samcli/commands/local/lib/swagger/reader.py b/samcli/commands/local/lib/swagger/reader.py index d3235170c6..02c2c1edb7 100644 --- a/samcli/commands/local/lib/swagger/reader.py +++ b/samcli/commands/local/lib/swagger/reader.py @@ -57,7 +57,7 @@ def parse_aws_include_transform(data): return location -class SamSwaggerReader(object): +class SwaggerReader(object): """ Class to read and parse Swagger document from a variety of sources. This class accepts the same data formats as available in Serverless::Api SAM resource diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index 7aef82dc06..c18304064a 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -18,35 +18,63 @@ class Route(object): - - def __init__(self, methods, function_name, path, binary_types=None, stage_name=None, stage_variables=None): + _ANY_HTTP_METHODS = ["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"] + + def __init__(self, function_name, path, methods): """ Creates an ApiGatewayRoute - :param list(str) methods: List of HTTP Methods + :param list(str) methods: http method :param function_name: Name of the Lambda function this API is connected to :param str path: Path off the base url """ - self.methods = methods + self.methods = self.normalize_method(methods) self.function_name = function_name self.path = path - self.binary_types = binary_types or [] - self.stage_name = stage_name - self.stage_variables = stage_variables + + def __eq__(self, other): + return isinstance(other, Route) and \ + sorted(self.methods) == sorted( + other.methods) and self.function_name == other.function_name and self.path == other.path + + def __hash__(self): + route_hash = hash(self.function_name) * hash(self.path) + for method in sorted(self.methods): + route_hash *= hash(method) + return route_hash + + def normalize_method(self, methods): + """ + Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all + supported Http Methods on Api Gateway. + + :param list methods: Http methods + :return list: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) + """ + methods = [method.upper() for method in methods] + if "ANY" in methods: + return self._ANY_HTTP_METHODS + return methods class LocalApigwService(BaseLocalService): _DEFAULT_PORT = 3000 _DEFAULT_HOST = '127.0.0.1' - def __init__(self, routing_list, lambda_runner, static_dir=None, port=None, host=None, stderr=None): + def __init__(self, api, lambda_runner, static_dir=None, port=None, host=None, stderr=None): """ Creates an ApiGatewayService Parameters ---------- - routing_list list(ApiGatewayCallModel) - A list of the Model that represent the service paths to create. + api: Api + an Api object that contains the list of routes and properties lambda_runner samcli.commands.local.lib.local_lambda.LocalLambdaRunner The Lambda runner class capable of invoking the function static_dir str @@ -61,7 +89,7 @@ def __init__(self, routing_list, lambda_runner, static_dir=None, port=None, host Optional stream writer where the stderr from Docker container should be written to """ super(LocalApigwService, self).__init__(lambda_runner.is_debugging(), port=port, host=host) - self.routing_list = routing_list + self.api = api self.lambda_runner = lambda_runner self.static_dir = static_dir self._dict_of_routes = {} @@ -77,12 +105,11 @@ def create(self): static_folder=self.static_dir # Serve static files from this directory ) - for api_gateway_route in self.routing_list: + for api_gateway_route in self.api.routes: path = PathConverter.convert_path_to_flask(api_gateway_route.path) for route_key in self._generate_route_keys(api_gateway_route.methods, path): self._dict_of_routes[route_key] = api_gateway_route - self._app.add_url_rule(path, endpoint=path, view_func=self._request_handler, @@ -144,8 +171,8 @@ def _request_handler(self, **kwargs): route = self._get_current_route(request) try: - event = self._construct_event(request, self.port, route.binary_types, route.stage_name, - route.stage_variables) + event = self._construct_event(request, self.port, self.api.binary_media_types, self.api.stage_name, + self.api.stage_variables) except UnicodeDecodeError: return ServiceErrorResponses.lambda_failure_response() @@ -165,7 +192,7 @@ def _request_handler(self, **kwargs): try: (status_code, headers, body) = self._parse_lambda_output(lambda_response, - route.binary_types, + self.api.binary_media_types, request) except (KeyError, TypeError, ValueError): LOG.error("Function returned an invalid response (must include one of: body, headers, multiValueHeaders or " diff --git a/tests/functional/commands/local/lib/test_local_api_service.py b/tests/functional/commands/local/lib/test_local_api_service.py index a507304bae..23df3e9025 100644 --- a/tests/functional/commands/local/lib/test_local_api_service.py +++ b/tests/functional/commands/local/lib/test_local_api_service.py @@ -10,6 +10,8 @@ import time import logging +from samcli.commands.local.lib.provider import Api +from samcli.local.apigw.local_apigw_service import Route from samcli.commands.local.lib import provider from samcli.commands.local.lib.local_lambda import LocalLambdaRunner from samcli.local.lambdafn.runtime import LambdaRuntime @@ -42,7 +44,7 @@ def setUp(self): self.static_dir = "mystaticdir" self.static_file_name = "myfile.txt" self.static_file_content = "This is a static file" - self._setup_static_file(os.path.join(self.cwd, self.static_dir), # Create static directory with in cwd + self._setup_static_file(os.path.join(self.cwd, self.static_dir), # Create static directory with in cwd self.static_file_name, self.static_file_content) @@ -56,12 +58,14 @@ def setUp(self): self.mock_function_provider.get.return_value = self.function # Setup two APIs pointing to the same function - apis = [ - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors"), - provider.Api(path="/post", method="POST", function_name=self.function_name, cors="cors"), + routes = [ + Route(path="/get", methods=["GET"], function_name=self.function_name), + Route(path="/post", methods=["POST"], function_name=self.function_name), ] + api = Api(routes=routes) + self.api_provider_mock = Mock() - self.api_provider_mock.get_all.return_value = apis + self.api_provider_mock.get_all.return_value = api # Now wire up the Lambda invoker and pass it through the context self.lambda_invoke_context_mock = Mock() @@ -69,7 +73,9 @@ def setUp(self): layer_downloader = LayerDownloader("./", "./") lambda_image = LambdaImage(layer_downloader, False, False) local_runtime = LambdaRuntime(manager, lambda_image) - lambda_runner = LocalLambdaRunner(local_runtime, self.mock_function_provider, self.cwd, env_vars_values=None, + lambda_runner = LocalLambdaRunner(local_runtime, + self.mock_function_provider, + self.cwd, debug_context=None) self.lambda_invoke_context_mock.local_lambda_runner = lambda_runner self.lambda_invoke_context_mock.get_cwd.return_value = self.cwd @@ -77,7 +83,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.code_abs_path) - @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") + @patch("samcli.commands.local.lib.sam_api_provider.SamApiProvider") def test_must_start_service_and_serve_endpoints(self, sam_api_provider_mock): sam_api_provider_mock.return_value = self.api_provider_mock @@ -97,7 +103,7 @@ def test_must_start_service_and_serve_endpoints(self, sam_api_provider_mock): response = requests.get(self.url + '/post') self.assertEquals(response.status_code, 403) # "HTTP GET /post" must not exist - @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") + @patch("samcli.commands.local.lib.sam_api_provider.SamApiProvider") def test_must_serve_static_files(self, sam_api_provider_mock): sam_api_provider_mock.return_value = self.api_provider_mock @@ -123,10 +129,8 @@ def _start_service_thread(service): @staticmethod def _setup_static_file(directory, filename, contents): - if not os.path.isdir(directory): os.mkdir(directory) with open(os.path.join(directory, filename), "w") as fp: fp.write(contents) - diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 321741e0bf..84fbf2da79 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -656,3 +656,31 @@ def test_swagger_stage_variable(self): response_data = response.json() self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'}) + + +class TestStartApiWithCloudFormationStage(StartApiIntegBaseClass): + """ + Test Class centered around the different responses that can happen in Lambda and pass through start-api + """ + template_path = "/testdata/start_api/swagger-rest-api-template.yaml" + + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + + def test_default_stage_name(self): + response = requests.get(self.url + "/echoeventbody") + + self.assertEquals(response.status_code, 200) + + response_data = response.json() + print(response_data) + self.assertEquals(response_data.get("requestContext", {}).get("stage"), "Dev") + + def test_global_stage_variables(self): + response = requests.get(self.url + "/echoeventbody") + + self.assertEquals(response.status_code, 200) + + response_data = response.json() + + self.assertEquals(response_data.get("stageVariables"), {"Stack": "Dev"}) diff --git a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml index 5edeb8717f..5e7be3a95e 100644 --- a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml +++ b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml @@ -13,6 +13,12 @@ Resources: Handler: main.echo_base64_event_body Runtime: python3.6 Type: AWS::Lambda::Function + EchoEventBodyFunction: + Properties: + Code: "." + Handler: main.echo_event_handler + Runtime: python3.6 + Type: AWS::Lambda::Function MyApi: Properties: Body: @@ -35,6 +41,13 @@ Resources: type: aws_proxy uri: Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${Base64ResponseFunction.Arn}/invocations + "/echoeventbody": + get: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EchoEventBodyFunction.Arn}/invocations "/echobase64eventbody": post: x-amazon-apigateway-integration: @@ -61,6 +74,13 @@ Resources: - image/gif StageName: prod Type: AWS::ApiGateway::RestApi + Dev: + Type: AWS::ApiGateway::Stage + Properties: + StageName: Dev + RestApiId: MyApi + Variables: + Stack: Dev MyNonServerlessLambdaFunction: Properties: Code: "." diff --git a/tests/unit/commands/local/lib/swagger/test_parser.py b/tests/unit/commands/local/lib/swagger/test_parser.py index 59db1ea969..827f49be1c 100644 --- a/tests/unit/commands/local/lib/swagger/test_parser.py +++ b/tests/unit/commands/local/lib/swagger/test_parser.py @@ -1,14 +1,14 @@ """ Test the swagger parser """ - -from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.provider import Api - from unittest import TestCase + from mock import patch, Mock from parameterized import parameterized, param +from samcli.commands.local.lib.swagger.parser import SwaggerParser +from samcli.local.apigw.local_apigw_service import Route + class TestSwaggerParser_get_apis(TestCase): @@ -31,8 +31,8 @@ def test_with_one_path_method(self): parser._get_integration_function_name = Mock() parser._get_integration_function_name.return_value = function_name - expected = [Api(path="/path1", method="get", function_name=function_name, cors=None)] - result = parser.get_apis() + expected = [Route(path="/path1", methods=["get"], function_name=function_name)] + result = parser.get_routes() self.assertEquals(expected, result) parser._get_integration_function_name.assert_called_with({ @@ -77,11 +77,11 @@ def test_with_combination_of_paths_methods(self): parser._get_integration_function_name.return_value = function_name expected = { - Api(path="/path1", method="get", function_name=function_name, cors=None), - Api(path="/path1", method="delete", function_name=function_name, cors=None), - Api(path="/path2", method="post", function_name=function_name, cors=None), + Route(path="/path1", methods=["get"], function_name=function_name), + Route(path="/path1", methods=["delete"], function_name=function_name), + Route(path="/path2", methods=["post"], function_name=function_name), } - result = parser.get_apis() + result = parser.get_routes() self.assertEquals(expected, set(result)) @@ -104,8 +104,9 @@ def test_with_any_method(self): parser._get_integration_function_name = Mock() parser._get_integration_function_name.return_value = function_name - expected = [Api(path="/path1", method="ANY", function_name=function_name, cors=None)] - result = parser.get_apis() + expected = [Route(methods=["ANY"], path="/path1", + function_name=function_name)] + result = parser.get_routes() self.assertEquals(expected, result) @@ -128,7 +129,7 @@ def test_does_not_have_function_name(self): parser._get_integration_function_name.return_value = None # Function Name could not be resolved expected = [] - result = parser.get_apis() + result = parser.get_routes() self.assertEquals(expected, result) @@ -146,9 +147,8 @@ def test_does_not_have_function_name(self): }}) ]) def test_invalid_swagger(self, test_case_name, swagger): - parser = SwaggerParser(swagger) - result = parser.get_apis() + result = parser.get_routes() expected = [] self.assertEquals(expected, result) diff --git a/tests/unit/commands/local/lib/swagger/test_reader.py b/tests/unit/commands/local/lib/swagger/test_reader.py index 8112b2f21c..9ecb4d276d 100644 --- a/tests/unit/commands/local/lib/swagger/test_reader.py +++ b/tests/unit/commands/local/lib/swagger/test_reader.py @@ -8,7 +8,7 @@ from parameterized import parameterized, param from mock import Mock, patch -from samcli.commands.local.lib.swagger.reader import parse_aws_include_transform, SamSwaggerReader +from samcli.commands.local.lib.swagger.reader import parse_aws_include_transform, SwaggerReader class TestParseAwsIncludeTransform(TestCase): @@ -57,7 +57,7 @@ class TestSamSwaggerReader_init(TestCase): def test_definition_body_and_uri_required(self): with self.assertRaises(ValueError): - SamSwaggerReader() + SwaggerReader() class TestSamSwaggerReader_read(TestCase): @@ -67,7 +67,7 @@ def test_must_read_first_from_definition_body(self): uri = "./file.txt" expected = {"some": "value"} - reader = SamSwaggerReader(definition_body=body, definition_uri=uri) + reader = SwaggerReader(definition_body=body, definition_uri=uri) reader._download_swagger = Mock() reader._read_from_definition_body = Mock() reader._read_from_definition_body.return_value = expected @@ -82,7 +82,7 @@ def test_read_from_definition_uri(self): uri = "./file.txt" expected = {"some": "value"} - reader = SamSwaggerReader(definition_uri=uri) + reader = SwaggerReader(definition_uri=uri) reader._download_swagger = Mock() reader._download_swagger.return_value = expected @@ -96,7 +96,7 @@ def test_must_use_definition_uri_if_body_does_not_exist(self): uri = "./file.txt" expected = {"some": "value"} - reader = SamSwaggerReader(definition_body=body, definition_uri=uri) + reader = SwaggerReader(definition_body=body, definition_uri=uri) reader._download_swagger = Mock() reader._download_swagger.return_value = expected @@ -119,7 +119,7 @@ def test_must_work_with_include_transform(self, parse_mock): expected = {'k': 'v'} location = "some location" - reader = SamSwaggerReader(definition_body=body) + reader = SwaggerReader(definition_body=body) reader._download_swagger = Mock() reader._download_swagger.return_value = expected parse_mock.return_value = location @@ -132,7 +132,7 @@ def test_must_work_with_include_transform(self, parse_mock): def test_must_get_body_directly(self, parse_mock): body = {'this': 'swagger'} - reader = SamSwaggerReader(definition_body=body) + reader = SwaggerReader(definition_body=body) parse_mock.return_value = None # No location is returned from aws_include parser actual = reader._read_from_definition_body() @@ -151,7 +151,7 @@ def test_must_download_from_s3_for_s3_locations(self, yaml_parse_mock): swagger_str = "some swagger str" expected = "some data" - reader = SamSwaggerReader(definition_uri=location) + reader = SwaggerReader(definition_uri=location) reader._download_from_s3 = Mock() reader._download_from_s3.return_value = swagger_str yaml_parse_mock.return_value = expected @@ -169,7 +169,7 @@ def test_must_skip_non_s3_dictionaries(self, yaml_parse_mock): location = {"some": "value"} - reader = SamSwaggerReader(definition_uri=location) + reader = SwaggerReader(definition_uri=location) reader._download_from_s3 = Mock() actual = reader._download_swagger(location) @@ -193,7 +193,7 @@ def test_must_read_from_local_file(self, yaml_parse_mock): cwd = os.path.dirname(filepath) filename = os.path.basename(filepath) - reader = SamSwaggerReader(definition_uri=filename, working_dir=cwd) + reader = SwaggerReader(definition_uri=filename, working_dir=cwd) actual = reader._download_swagger(filename) self.assertEquals(actual, expected) @@ -211,7 +211,7 @@ def test_must_read_from_local_file_without_working_directory(self, yaml_parse_mo json.dump(data, fp) fp.flush() - reader = SamSwaggerReader(definition_uri=filepath) + reader = SwaggerReader(definition_uri=filepath) actual = reader._download_swagger(filepath) self.assertEquals(actual, expected) @@ -222,7 +222,7 @@ def test_must_return_none_if_file_not_found(self, yaml_parse_mock): expected = "parsed result" yaml_parse_mock.return_value = expected - reader = SamSwaggerReader(definition_uri="somepath") + reader = SwaggerReader(definition_uri="somepath") actual = reader._download_swagger("abcdefgh.txt") self.assertIsNone(actual) @@ -230,7 +230,7 @@ def test_must_return_none_if_file_not_found(self, yaml_parse_mock): def test_with_invalid_location(self): - reader = SamSwaggerReader(definition_uri="something") + reader = SwaggerReader(definition_uri="something") actual = reader._download_swagger({}) self.assertIsNone(actual) @@ -256,7 +256,7 @@ def test_must_download_file_from_s3(self, tempfilemock, botomock): expected = "data from file" fp_mock.read.return_value = expected - actual = SamSwaggerReader._download_from_s3(self.bucket, self.key, self.version) + actual = SwaggerReader._download_from_s3(self.bucket, self.key, self.version) self.assertEquals(actual, expected) s3_mock.download_fileobj.assert_called_with(self.bucket, self.key, fp_mock, @@ -277,7 +277,7 @@ def test_must_fail_on_download_from_s3(self, tempfilemock, botomock): "download_file") with self.assertRaises(Exception) as cm: - SamSwaggerReader._download_from_s3(self.bucket, self.key) + SwaggerReader._download_from_s3(self.bucket, self.key) self.assertIn(cm.exception.__class__, (botocore.exceptions.NoCredentialsError, botocore.exceptions.ClientError)) @@ -294,7 +294,7 @@ def test_must_work_without_object_version_id(self, tempfilemock, botomock): expected = "data from file" fp_mock.read.return_value = expected - actual = SamSwaggerReader._download_from_s3(self.bucket, self.key) + actual = SwaggerReader._download_from_s3(self.bucket, self.key) self.assertEquals(actual, expected) s3_mock.download_fileobj.assert_called_with(self.bucket, self.key, fp_mock, @@ -313,7 +313,7 @@ def test_must_log_on_download_exception(self, tempfilemock, botomock): "download_file") with self.assertRaises(botocore.exceptions.ClientError): - SamSwaggerReader._download_from_s3(self.bucket, self.key) + SwaggerReader._download_from_s3(self.bucket, self.key) fp_mock.read.assert_not_called() @@ -332,7 +332,7 @@ def test_must_parse_valid_dict(self): "Version": self.version } - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (self.bucket, self.key, self.version)) def test_must_parse_dict_without_version(self): @@ -341,19 +341,19 @@ def test_must_parse_dict_without_version(self): "Key": self.key } - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (self.bucket, self.key, None)) def test_must_parse_s3_uri_string(self): location = "s3://{}/{}?versionId={}".format(self.bucket, self.key, self.version) - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (self.bucket, self.key, self.version)) def test_must_parse_s3_uri_string_without_version_id(self): location = "s3://{}/{}".format(self.bucket, self.key) - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (self.bucket, self.key, None)) @parameterized.expand([ @@ -364,5 +364,5 @@ def test_must_parse_s3_uri_string_without_version_id(self): ]) def test_must_parse_invalid_location(self, location): - result = SamSwaggerReader._parse_s3_location(location) + result = SwaggerReader._parse_s3_location(location) self.assertEquals(result, (None, None, None)) diff --git a/tests/unit/commands/local/lib/test_api_provider.py b/tests/unit/commands/local/lib/test_api_provider.py index 50b8d073d4..013405429a 100644 --- a/tests/unit/commands/local/lib/test_api_provider.py +++ b/tests/unit/commands/local/lib/test_api_provider.py @@ -3,6 +3,7 @@ from mock import patch +from samcli.commands.local.lib.provider import Api from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.sam_api_provider import SamApiProvider from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider @@ -10,18 +11,17 @@ class TestApiProvider_init(TestCase): - @patch.object(ApiProvider, "_extract_apis") + @patch.object(ApiProvider, "_extract_api") @patch("samcli.commands.local.lib.api_provider.SamBaseProvider") def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): - extract_api_mock.return_value = {"set", "of", "values"} - + extract_api_mock.return_value = Api(routes={"set", "of", "values"}) template = {"Resources": {"a": "b"}} SamBaseProviderMock.get_template.return_value = template provider = ApiProvider(template) + self.assertEquals(len(provider.routes), 3) + self.assertEquals(provider.routes, set(["set", "of", "values"])) - self.assertEquals(len(provider.apis), 3) - self.assertEquals(provider.apis, set(["set", "of", "values"])) self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) self.assertEquals(provider.resources, {"a": "b"}) diff --git a/tests/unit/commands/local/lib/test_cfn_api_provider.py b/tests/unit/commands/local/lib/test_cfn_api_provider.py index 723951eb11..d4f45171e5 100644 --- a/tests/unit/commands/local/lib/test_cfn_api_provider.py +++ b/tests/unit/commands/local/lib/test_cfn_api_provider.py @@ -1,27 +1,24 @@ import json import tempfile +from collections import OrderedDict from unittest import TestCase from mock import patch from six import assertCountEqual from samcli.commands.local.lib.api_provider import ApiProvider -from samcli.commands.local.lib.provider import Api +from samcli.local.apigw.local_apigw_service import Route from tests.unit.commands.local.lib.test_sam_api_provider import make_swagger -class TestApiProviderWithApiGatewayRestApi(TestCase): +class TestApiProviderWithApiGatewayRestRoute(TestCase): def setUp(self): self.binary_types = ["image/png", "image/jpg"] - self.input_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None) + self.input_routes = [ + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") ] def test_with_no_apis(self): @@ -39,7 +36,7 @@ def test_with_no_apis(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) def test_with_inline_swagger_apis(self): template = { @@ -48,20 +45,21 @@ def test_with_inline_swagger_apis(self): "Api1": { "Type": "AWS::ApiGateway::RestApi", "Properties": { - "Body": make_swagger(self.input_apis) + "Body": make_swagger(self.input_routes) } } } } provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) + assertCountEqual(self, self.input_routes, provider.routes) def test_with_swagger_as_local_file(self): with tempfile.NamedTemporaryFile(mode='w') as fp: filename = fp.name - swagger = make_swagger(self.input_apis) + swagger = make_swagger(self.input_routes) + json.dump(swagger, fp) fp.flush() @@ -78,13 +76,14 @@ def test_with_swagger_as_local_file(self): } provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) + assertCountEqual(self, self.input_routes, provider.routes) def test_body_with_swagger_as_local_file_expect_fail(self): with tempfile.NamedTemporaryFile(mode='w') as fp: filename = fp.name - swagger = make_swagger(self.input_apis) + swagger = make_swagger(self.input_routes) + json.dump(swagger, fp) fp.flush() @@ -101,8 +100,8 @@ def test_body_with_swagger_as_local_file_expect_fail(self): } self.assertRaises(Exception, ApiProvider, template) - @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): + @patch("samcli.commands.local.lib.cfn_base_api_provider.SwaggerReader") + def test_with_swagger_as_both_body_and_uri_called(self, SwaggerReaderMock): body = {"some": "body"} filename = "somefile.txt" @@ -119,26 +118,26 @@ def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): } } - SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) + SwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_routes) cwd = "foo" provider = ApiProvider(template, cwd=cwd) - assertCountEqual(self, self.input_apis, provider.apis) - SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) + assertCountEqual(self, self.input_routes, provider.routes) + SwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) def test_swagger_with_any_method(self): - apis = [ - Api(path="/path", method="any", function_name="SamFunc1", cors=None) + routes = [ + Route(path="/path", methods=["any"], function_name="SamFunc1") ] - expected_apis = [ - Api(path="/path", method="GET", function_name="SamFunc1", cors=None), - Api(path="/path", method="POST", function_name="SamFunc1", cors=None), - Api(path="/path", method="PUT", function_name="SamFunc1", cors=None), - Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None), - Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None), - Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None), - Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None) + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], function_name="SamFunc1") ] template = { @@ -146,14 +145,14 @@ def test_swagger_with_any_method(self): "Api1": { "Type": "AWS::ApiGateway::RestApi", "Properties": { - "Body": make_swagger(apis) + "Body": make_swagger(routes) } } } } provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_with_binary_media_types(self): template = { @@ -162,7 +161,7 @@ def test_with_binary_media_types(self): "Api1": { "Type": "AWS::ApiGateway::RestApi", "Properties": { - "Body": make_swagger(self.input_apis, binary_media_types=self.binary_types) + "Body": make_swagger(self.input_routes, binary_media_types=self.binary_types) } } } @@ -170,26 +169,19 @@ def test_with_binary_media_types(self): expected_binary_types = sorted(self.binary_types) expected_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types) + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") ] provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_apis, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_binary_types) def test_with_binary_media_types_in_swagger_and_on_resource(self): - input_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1"), + input_routes = [ + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), + ] extra_binary_types = ["text/html"] @@ -200,16 +192,201 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): "Type": "AWS::ApiGateway::RestApi", "Properties": { "BinaryMediaTypes": extra_binary_types, - "Body": make_swagger(input_apis, binary_media_types=self.binary_types) + "Body": make_swagger(input_routes, binary_media_types=self.binary_types) } } } } expected_binary_types = sorted(self.binary_types + extra_binary_types) - expected_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=expected_binary_types), + expected_routes = [ + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), ] provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_binary_types) + + +class TestCloudFormationStageValues(TestCase): + def setUp(self): + self.binary_types = ["image/png", "image/jpg"] + self.input_routes = [ + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") + ] + + def test_provider_parse_stage_name(self): + template = { + "Resources": { + "Stage": { + "Type": "AWS::ApiGateway::Stage", + "Properties": { + "StageName": "dev", + "RestApiId": "TestApi" + } + }, + "TestApi": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + } + provider = ApiProvider(template) + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + + self.assertIn(route1, provider.routes) + self.assertEquals(provider.api.stage_name, "dev") + self.assertEquals(provider.api.stage_variables, None) + + def test_provider_stage_variables(self): + template = { + "Resources": { + "Stage": { + "Type": "AWS::ApiGateway::Stage", + "Properties": { + "StageName": "dev", + "Variables": { + "vis": "data", + "random": "test", + "foo": "bar" + }, + "RestApiId": "TestApi" + } + }, + "TestApi": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + } + provider = ApiProvider(template) + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + self.assertIn(route1, provider.routes) + self.assertEquals(provider.api.stage_name, "dev") + self.assertEquals(provider.api.stage_variables, { + "vis": "data", + "random": "test", + "foo": "bar" + }) + + def test_multi_stage_get_all(self): + resources = OrderedDict({ + "ProductionApi": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + }, + "/anotherpath": { + "post": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + }) + resources["StageDev"] = { + "Type": "AWS::ApiGateway::Stage", + "Properties": { + "StageName": "dev", + "Variables": { + "vis": "data", + "random": "test", + "foo": "bar" + }, + "RestApiId": "ProductionApi" + } + } + resources["StageProd"] = { + "Type": "AWS::ApiGateway::Stage", + "Properties": { + "StageName": "Production", + "Variables": { + "vis": "prod data", + "random": "test", + "foo": "bar" + }, + "RestApiId": "ProductionApi" + }, + } + template = {"Resources": resources} + provider = ApiProvider(template) + + result = [f for f in provider.get_all()] + routes = result[0].routes + + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + route2 = Route(path='/anotherpath', methods=['POST'], function_name='NoApiEventFunction') + self.assertEquals(len(routes), 2) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + + self.assertEquals(provider.api.stage_name, "Production") + self.assertEquals(provider.api.stage_variables, { + "vis": "prod data", + "random": "test", + "foo": "bar" + }) diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index cfa35af954..f43f93713e 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -6,10 +6,11 @@ from mock import Mock, patch -from samcli.commands.local.lib import provider +from samcli.commands.local.lib.provider import Api +from samcli.commands.local.lib.api_collector import ApiCollector +from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined from samcli.commands.local.lib.local_api_service import LocalApiService -from samcli.commands.local.lib.provider import Api from samcli.local.apigw.local_apigw_service import Route @@ -38,9 +39,7 @@ def setUp(self): @patch("samcli.commands.local.lib.local_api_service.ApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") - @patch.object(LocalApiService, "_make_routing_list") def test_must_start_service(self, - make_routing_list_mock, log_routes_mock, make_static_dir_mock, SamApiProviderMock, @@ -48,7 +47,6 @@ def test_must_start_service(self, routing_list = [1, 2, 3] # something static_dir_path = "/foo/bar" - make_routing_list_mock.return_value = routing_list make_static_dir_mock.return_value = static_dir_path SamApiProviderMock.return_value = self.api_provider_mock @@ -56,6 +54,7 @@ def test_must_start_service(self, # Now start the service local_service = LocalApiService(self.lambda_invoke_context_mock, self.port, self.host, self.static_dir) + local_service.api_provider.api.routes = routing_list local_service.start() # Make sure the right methods are called @@ -63,10 +62,9 @@ def test_must_start_service(self, cwd=self.cwd, parameter_overrides=self.lambda_invoke_context_mock.parameter_overrides) - make_routing_list_mock.assert_called_with(self.api_provider_mock) - log_routes_mock.assert_called_with(self.api_provider_mock, self.host, self.port) + log_routes_mock.assert_called_with(routing_list, self.host, self.port) make_static_dir_mock.assert_called_with(self.cwd, self.static_dir) - ApiGwServiceMock.assert_called_with(routing_list=routing_list, + ApiGwServiceMock.assert_called_with(api=self.api_provider_mock.api, lambda_runner=self.lambda_runner_mock, static_dir=static_dir_path, port=self.port, @@ -80,72 +78,47 @@ def test_must_start_service(self, @patch("samcli.commands.local.lib.local_api_service.ApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") - @patch.object(LocalApiService, "_make_routing_list") + @patch.object(ApiProvider, "_extract_api") def test_must_raise_if_route_not_available(self, - make_routing_list_mock, + extract_api, log_routes_mock, make_static_dir_mock, SamApiProviderMock, ApiGwServiceMock): routing_list = [] # Empty - - make_routing_list_mock.return_value = routing_list - + api = Api() + extract_api.return_value = api + SamApiProviderMock.extract_api.return_value = api SamApiProviderMock.return_value = self.api_provider_mock ApiGwServiceMock.return_value = self.apigw_service # Now start the service local_service = LocalApiService(self.lambda_invoke_context_mock, self.port, self.host, self.static_dir) - + local_service.api_provider.api.routes = routing_list with self.assertRaises(NoApisDefined): local_service.start() -class TestLocalApiService_make_routing_list(TestCase): - - def test_must_return_routing_list_from_apis(self): - api_provider = Mock() - apis = [ - Api(path="/1", method="GET1", function_name="name1", cors="CORS1"), - Api(path="/2", method="GET2", function_name="name2", cors="CORS2"), - Api(path="/3", method="GET3", function_name="name3", cors="CORS3"), - ] - expected = [ - Route(path="/1", methods=["GET1"], function_name="name1"), - Route(path="/2", methods=["GET2"], function_name="name2"), - Route(path="/3", methods=["GET3"], function_name="name3") - ] - - api_provider.get_all.return_value = apis - - result = LocalApiService._make_routing_list(api_provider) - self.assertEquals(len(result), len(expected)) - for index, r in enumerate(result): - self.assertEquals(r.__dict__, expected[index].__dict__) - - class TestLocalApiService_print_routes(TestCase): def test_must_print_routes(self): host = "host" port = 123 - api_provider = Mock() apis = [ - Api(path="/1", method="GET", function_name="name1", cors="CORS1"), - Api(path="/1", method="POST", function_name="name1", cors="CORS1"), - Api(path="/1", method="DELETE", function_name="othername1", cors="CORS1"), - Api(path="/2", method="GET2", function_name="name2", cors="CORS2"), - Api(path="/3", method="GET3", function_name="name3", cors="CORS3"), + Route(path="/1", methods=["GET"], function_name="name1"), + Route(path="/1", methods=["POST"], function_name="name1"), + Route(path="/1", methods=["DELETE"], function_name="othername1"), + Route(path="/2", methods=["GET2"], function_name="name2"), + Route(path="/3", methods=["GET3"], function_name="name3"), ] - api_provider.get_all.return_value = apis - + apis = ApiCollector.dedupe_function_routes(apis) expected = {"Mounting name1 at http://host:123/1 [GET, POST]", "Mounting othername1 at http://host:123/1 [DELETE]", "Mounting name2 at http://host:123/2 [GET2]", "Mounting name3 at http://host:123/3 [GET3]"} - actual = LocalApiService._print_routes(api_provider, host, port) + actual = LocalApiService._print_routes(apis, host, port) self.assertEquals(expected, set(actual)) @@ -181,39 +154,3 @@ def test_must_return_none_if_path_not_exists(self, os_mock): result = LocalApiService._make_static_dir_path(cwd, static_dir) self.assertIsNone(result) - - -class TestRoutingList(TestCase): - - def setUp(self): - self.function_name = "routingTest" - apis = [ - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors"), - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors", stage_name="Dev"), - provider.Api(path="/post", method="POST", function_name=self.function_name, cors="cors", stage_name="Prod"), - provider.Api(path="/get", method="GET", function_name=self.function_name, cors="cors", - stage_variables={"test": "data"}), - provider.Api(path="/post", method="POST", function_name=self.function_name, cors="cors", stage_name="Prod", - stage_variables={"data": "more data"}), - ] - self.api_provider_mock = Mock() - self.api_provider_mock.get_all.return_value = apis - - def test_make_routing_list(self): - routing_list = LocalApiService._make_routing_list(self.api_provider_mock) - - expected_routes = [ - Route(function_name=self.function_name, methods=['GET'], path='/get', stage_name=None, - stage_variables=None), - Route(function_name=self.function_name, methods=['GET'], path='/get', stage_name='Dev', - stage_variables=None), - Route(function_name=self.function_name, methods=['POST'], path='/post', stage_name='Prod', - stage_variables=None), - Route(function_name=self.function_name, methods=['GET'], path='/get', stage_name=None, - stage_variables={'test': 'data'}), - Route(function_name=self.function_name, methods=['POST'], path='/post', stage_name='Prod', - stage_variables={'data': 'more data'}), - ] - self.assertEquals(len(routing_list), len(expected_routes)) - for index, r in enumerate(routing_list): - self.assertEquals(r.__dict__, expected_routes[index].__dict__) diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index fa5f342e49..a210b95eb3 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -1,15 +1,14 @@ -import tempfile import json - +import tempfile +from collections import OrderedDict from unittest import TestCase + from mock import patch from nose_parameterized import parameterized - from six import assertCountEqual -from samcli.commands.local.lib.api_provider import ApiProvider, SamApiProvider -from samcli.commands.local.lib.provider import Api -from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.local.apigw.local_apigw_service import Route class TestSamApiProviderWithImplicitApis(TestCase): @@ -26,7 +25,7 @@ def test_provider_with_no_resource_properties(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) @parameterized.expand([("GET"), ("get")]) def test_provider_has_correct_api(self, method): @@ -55,9 +54,8 @@ def test_provider_has_correct_api(self, method): provider = ApiProvider(template) - self.assertEquals(len(provider.apis), 1) - self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", cors=None, - stage_name="Prod")) + self.assertEquals(len(provider.routes), 1) + self.assertEquals(list(provider.routes)[0], Route(path="/path", methods=["GET"], function_name="SamFunc1")) def test_provider_creates_api_for_all_events(self): template = { @@ -92,12 +90,10 @@ def test_provider_creates_api_for_all_events(self): provider = ApiProvider(template) - api_event1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") - api_event2 = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") + api = Route(path="/path", methods=["GET", "POST"], function_name="SamFunc1") - self.assertIn(api_event1, provider.apis) - self.assertIn(api_event2, provider.apis) - self.assertEquals(len(provider.apis), 2) + self.assertIn(api, provider.routes) + self.assertEquals(len(provider.routes), 1) def test_provider_has_correct_template(self): template = { @@ -142,11 +138,11 @@ def test_provider_has_correct_template(self): provider = ApiProvider(template) - api1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") - api2 = Api(path="/path", method="POST", function_name="SamFunc2", cors=None, stage_name="Prod") + api1 = Route(path="/path", methods=["GET"], function_name="SamFunc1") + api2 = Route(path="/path", methods=["POST"], function_name="SamFunc2") - self.assertIn(api1, provider.apis) - self.assertIn(api2, provider.apis) + self.assertIn(api1, provider.routes) + self.assertIn(api2, provider.routes) def test_provider_with_no_api_events(self): template = { @@ -173,7 +169,7 @@ def test_provider_with_no_api_events(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) def test_provider_with_no_serverless_function(self): template = { @@ -192,7 +188,7 @@ def test_provider_with_no_serverless_function(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) def test_provider_get_all(self): template = { @@ -238,21 +234,22 @@ def test_provider_get_all(self): provider = ApiProvider(template) result = [f for f in provider.get_all()] + routes = result[0].routes + route1 = Route(path="/path", methods=["GET"], function_name="SamFunc1") + route2 = Route(path="/path", methods=["POST"], function_name="SamFunc2") - api1 = Api(path="/path", method="GET", function_name="SamFunc1", stage_name="Prod") - api2 = Api(path="/path", method="POST", function_name="SamFunc2", stage_name="Prod") - - self.assertIn(api1, result) - self.assertIn(api2, result) + self.assertIn(route1, routes) + self.assertIn(route2, routes) - def test_provider_get_all_with_no_apis(self): + def test_provider_get_all_with_no_routes(self): template = {} provider = ApiProvider(template) result = [f for f in provider.get_all()] + routes = result[0].routes - self.assertEquals(result, []) + self.assertEquals(routes, []) @parameterized.expand([("ANY"), ("any")]) def test_provider_with_any_method(self, method): @@ -281,22 +278,16 @@ def test_provider_with_any_method(self, method): provider = ApiProvider(template) - api_get = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") - api_post = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") - api_put = Api(path="/path", method="PUT", function_name="SamFunc1", cors=None, stage_name="Prod") - api_delete = Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None, stage_name="Prod") - api_patch = Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None, stage_name="Prod") - api_head = Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None, stage_name="Prod") - api_options = Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None, stage_name="Prod") - - self.assertEquals(len(provider.apis), 7) - self.assertIn(api_get, provider.apis) - self.assertIn(api_post, provider.apis) - self.assertIn(api_put, provider.apis) - self.assertIn(api_delete, provider.apis) - self.assertIn(api_patch, provider.apis) - self.assertIn(api_head, provider.apis) - self.assertIn(api_options, provider.apis) + api1 = Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], function_name="SamFunc1") + + self.assertEquals(len(provider.routes), 1) + self.assertIn(api1, provider.routes) def test_provider_must_support_binary_media_types(self): template = { @@ -334,10 +325,10 @@ def test_provider_must_support_binary_media_types(self): provider = ApiProvider(template) - self.assertEquals(len(provider.apis), 1) - self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", - binary_media_types=["image/gif", "image/png"], cors=None, - stage_name="Prod")) + self.assertEquals(len(provider.routes), 1) + self.assertEquals(list(provider.routes)[0], Route(path="/path", methods=["GET"], function_name="SamFunc1")) + assertCountEqual(self, provider.api.binary_media_types, ["image/gif", "image/png"]) + self.assertEquals(provider.api.stage_name, "Prod") def test_provider_must_support_binary_media_types_with_any_method(self): template = { @@ -374,49 +365,34 @@ def test_provider_must_support_binary_media_types_with_any_method(self): binary = ["image/gif", "image/png", "text/html"] - expected_apis = [ - Api(path="/path", method="GET", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="POST", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="PUT", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="DELETE", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="HEAD", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod"), - Api(path="/path", method="PATCH", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod") + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], function_name="SamFunc1") ] provider = ApiProvider(template) - assertCountEqual(self, provider.apis, expected_apis) - - def test_convert_event_api_with_invalid_event_properties(self): - properties = { - "Path": "/foo", - "Method": "get", - "RestApiId": { - # This is not supported. Only Ref is supported - "Fn::Sub": "foo" - } - } - - with self.assertRaises(InvalidSamDocumentException): - SamApiProvider._convert_event_api("logicalId", properties) + assertCountEqual(self, provider.routes, expected_routes) + assertCountEqual(self, provider.api.binary_media_types, binary) class TestSamApiProviderWithExplicitApis(TestCase): def setUp(self): self.binary_types = ["image/png", "image/jpg"] - self.input_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod"), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod"), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, stage_name="Prod") + self.stage_name = "Prod" + self.input_routes = [ + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["PUT", "GET"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") ] - def test_with_no_apis(self): + def test_with_no_routes(self): template = { "Resources": { @@ -431,9 +407,9 @@ def test_with_no_apis(self): provider = ApiProvider(template) - self.assertEquals(provider.apis, []) + self.assertEquals(provider.routes, []) - def test_with_inline_swagger_apis(self): + def test_with_inline_swagger_routes(self): template = { "Resources": { @@ -441,20 +417,20 @@ def test_with_inline_swagger_apis(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(self.input_apis) + "DefinitionBody": make_swagger(self.input_routes) } } } } provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) + assertCountEqual(self, self.input_routes, provider.routes) def test_with_swagger_as_local_file(self): with tempfile.NamedTemporaryFile(mode='w') as fp: filename = fp.name - swagger = make_swagger(self.input_apis) + swagger = make_swagger(self.input_routes) json.dump(swagger, fp) fp.flush() @@ -472,10 +448,10 @@ def test_with_swagger_as_local_file(self): } provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) + assertCountEqual(self, self.input_routes, provider.routes) - @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): + @patch("samcli.commands.local.lib.cfn_base_api_provider.SwaggerReader") + def test_with_swagger_as_both_body_and_uri_called(self, SwaggerReaderMock): body = {"some": "body"} filename = "somefile.txt" @@ -493,26 +469,27 @@ def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): } } - SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) + SwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_routes) cwd = "foo" provider = ApiProvider(template, cwd=cwd) - assertCountEqual(self, self.input_apis, provider.apis) - SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) + assertCountEqual(self, self.input_routes, provider.routes) + SwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) def test_swagger_with_any_method(self): - apis = [ - Api(path="/path", method="any", function_name="SamFunc1", cors=None) + routes = [ + Route(path="/path", methods=["any"], function_name="SamFunc1") ] - expected_apis = [ - Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="PUT", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None, stage_name="Prod"), - Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None, stage_name="Prod") + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], + function_name="SamFunc1") ] template = { @@ -521,14 +498,14 @@ def test_swagger_with_any_method(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(apis) + "DefinitionBody": make_swagger(routes) } } } } provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_with_binary_media_types(self): template = { @@ -538,34 +515,26 @@ def test_with_binary_media_types(self): "Type": "AWS::Serverless::Api", "Properties": { "StageName": "Prod", - "DefinitionBody": make_swagger(self.input_apis, binary_media_types=self.binary_types) + "DefinitionBody": make_swagger(self.input_routes, binary_media_types=self.binary_types) } } } } expected_binary_types = sorted(self.binary_types) - expected_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod"), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod"), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod"), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types, stage_name="Prod") + expected_routes = [ + Route(path="/path1", methods=["GET", "POST"], function_name="SamFunc1"), + Route(path="/path2", methods=["GET", "PUT"], function_name="SamFunc1"), + Route(path="/path3", methods=["DELETE"], function_name="SamFunc1") ] provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_binary_types) def test_with_binary_media_types_in_swagger_and_on_resource(self): - input_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1", stage_name="Prod"), + input_routes = [ + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), ] extra_binary_types = ["text/html"] @@ -577,32 +546,33 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): "Properties": { "BinaryMediaTypes": extra_binary_types, "StageName": "Prod", - "DefinitionBody": make_swagger(input_apis, binary_media_types=self.binary_types) + "DefinitionBody": make_swagger(input_routes, binary_media_types=self.binary_types) } } } } expected_binary_types = sorted(self.binary_types + extra_binary_types) - expected_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=expected_binary_types, - stage_name="Prod"), + expected_routes = [ + Route(path="/path", methods=["OPTIONS"], function_name="SamFunc1"), ] provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_binary_types) class TestSamApiProviderWithExplicitAndImplicitApis(TestCase): def setUp(self): - self.explicit_apis = [ - Api(path="/path1", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod") + self.stage_name = "Prod" + self.explicit_routes = [ + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction") ] - self.swagger = make_swagger(self.explicit_apis) + self.swagger = make_swagger(self.explicit_routes) self.template = { "Resources": { @@ -655,22 +625,22 @@ def test_must_union_implicit_and_explicit(self): self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events - expected_apis = [ + expected_routes = [ # From Explicit APIs - Api(path="/path1", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs - Api(path="/path1", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path2", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path3", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") + Route(path="/path1", methods=["POST"], function_name="ImplicitFunc"), + Route(path="/path2", methods=["POST"], function_name="ImplicitFunc"), + Route(path="/path3", methods=["POST"], function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_must_prefer_implicit_api_over_explicit(self): - implicit_apis = { + implicit_routes = { "Event1": { "Type": "Api", "Properties": { @@ -690,24 +660,24 @@ def test_must_prefer_implicit_api_over_explicit(self): } self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_apis + self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes - expected_apis = [ - Api(path="/path1", method="GET", function_name="ImplicitFunc", stage_name="Prod"), + expected_routes = [ + Route(path="/path1", methods=["GET"], function_name="ImplicitFunc"), # Comes from Implicit - Api(path="/path2", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path2", method="POST", function_name="ImplicitFunc", stage_name="Prod"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["POST"], function_name="ImplicitFunc"), # Comes from implicit - Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_must_prefer_implicit_with_any_method(self): - implicit_apis = { + implicit_routes = { "Event1": { "Type": "Api", "Properties": { @@ -718,30 +688,31 @@ def test_must_prefer_implicit_with_any_method(self): } } - explicit_apis = [ + explicit_routes = [ # Explicit should be over masked completely by implicit, because of "ANY" - Api(path="/path", method="GET", function_name="explicitfunction", cors=None), - Api(path="/path", method="DELETE", function_name="explicitfunction", cors=None), + Route(path="/path", methods=["GET"], function_name="explicitfunction"), + Route(path="/path", methods=["DELETE"], function_name="explicitfunction"), ] - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_apis) - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_apis - - expected_apis = [ - Api(path="/path", method="GET", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="PUT", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="DELETE", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="HEAD", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="OPTIONS", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="PATCH", function_name="ImplicitFunc", cors=None, stage_name="Prod") + self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_routes) + self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], + function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_with_any_method_on_both(self): - implicit_apis = { + implicit_routes = { "Event1": { "Type": "Api", "Properties": { @@ -760,30 +731,32 @@ def test_with_any_method_on_both(self): } } - explicit_apis = [ + explicit_routes = [ # Explicit should be over masked completely by implicit, because of "ANY" - Api(path="/path", method="ANY", function_name="explicitfunction", cors=None), - Api(path="/path2", method="POST", function_name="explicitfunction", cors=None), + Route(path="/path", methods=["ANY"], function_name="explicitfunction"), + Route(path="/path2", methods=["POST"], function_name="explicitfunction"), ] - self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_apis) - self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_apis - - expected_apis = [ - Api(path="/path", method="GET", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="PUT", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="DELETE", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="HEAD", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="OPTIONS", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path", method="PATCH", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - - Api(path="/path2", method="GET", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/path2", method="POST", function_name="explicitfunction", cors=None, stage_name="Prod") + self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = make_swagger(explicit_routes) + self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = implicit_routes + + expected_routes = [ + Route(path="/path", methods=["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"], + function_name="ImplicitFunc"), + + Route(path="/path2", methods=["GET"], + function_name="ImplicitFunc"), + Route(path="/path2", methods=["POST"], function_name="explicitfunction") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) def test_must_add_explicit_api_when_ref_with_rest_api_id(self): events = { @@ -809,20 +782,20 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): self.template["Resources"]["Api1"]["Properties"]["DefinitionBody"] = self.swagger self.template["Resources"]["ImplicitFunc"]["Properties"]["Events"] = events - expected_apis = [ + expected_routes = [ # From Explicit APIs - Api(path="/path1", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs - Api(path="/newpath1", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod"), - Api(path="/newpath2", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") + Route(path="/newpath1", methods=["POST"], function_name="ImplicitFunc"), + Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) - def test_both_apis_must_get_binary_media_types(self): + def test_both_routes_must_get_binary_media_types(self): events = { "Event1": { "Type": "Api", @@ -855,27 +828,20 @@ def test_both_apis_must_get_binary_media_types(self): # Because of Globals, binary types will be concatenated on the explicit API expected_explicit_binary_types = ["explicit/type1", "explicit/type2", "image/gif", "image/png"] - expected_implicit_binary_types = ["image/gif", "image/png"] - expected_apis = [ + expected_routes = [ # From Explicit APIs - Api(path="/path1", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # From Implicit APIs - Api(path="/newpath1", method="POST", function_name="ImplicitFunc", - binary_media_types=expected_implicit_binary_types, - stage_name="Prod"), - Api(path="/newpath2", method="POST", function_name="ImplicitFunc", - binary_media_types=expected_implicit_binary_types, - stage_name="Prod") + Route(path="/newpath1", methods=["POST"], function_name="ImplicitFunc"), + Route(path="/newpath2", methods=["POST"], function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_explicit_binary_types) def test_binary_media_types_with_rest_api_id_reference(self): events = { @@ -911,31 +877,25 @@ def test_binary_media_types_with_rest_api_id_reference(self): # Because of Globals, binary types will be concatenated on the explicit API expected_explicit_binary_types = ["explicit/type1", "explicit/type2", "image/gif", "image/png"] - expected_implicit_binary_types = ["image/gif", "image/png"] + # expected_implicit_binary_types = ["image/gif", "image/png"] - expected_apis = [ + expected_routes = [ # From Explicit APIs - Api(path="/path1", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), - Api(path="/path2", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), - Api(path="/path3", method="GET", function_name="explicitfunction", - binary_media_types=expected_explicit_binary_types, stage_name="Prod"), + Route(path="/path1", methods=["GET"], function_name="explicitfunction"), + Route(path="/path2", methods=["GET"], function_name="explicitfunction"), + Route(path="/path3", methods=["GET"], function_name="explicitfunction"), # Because of the RestApiId, Implicit APIs will also get the binary media types inherited from # the corresponding Explicit API - Api(path="/connected-to-explicit-path", method="POST", function_name="ImplicitFunc", - binary_media_types=expected_explicit_binary_types, - stage_name="Prod"), + Route(path="/connected-to-explicit-path", methods=["POST"], function_name="ImplicitFunc"), # This is still just a true implicit API because it does not have RestApiId property - Api(path="/true-implicit-path", method="POST", function_name="ImplicitFunc", - binary_media_types=expected_implicit_binary_types, - stage_name="Prod") + Route(path="/true-implicit-path", methods=["POST"], function_name="ImplicitFunc") ] provider = ApiProvider(self.template) - assertCountEqual(self, expected_apis, provider.apis) + assertCountEqual(self, expected_routes, provider.routes) + assertCountEqual(self, provider.api.binary_media_types, expected_explicit_binary_types) class TestSamStageValues(TestCase): @@ -971,11 +931,11 @@ def test_provider_parse_stage_name(self): } } provider = ApiProvider(template) - api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], - stage_name='dev', - stage_variables=None) + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') - self.assertIn(api1, provider.apis) + self.assertIn(route1, provider.routes) + self.assertEquals(provider.api.stage_name, "dev") + self.assertEquals(provider.api.stage_variables, None) def test_provider_stage_variables(self): template = { @@ -1013,125 +973,120 @@ def test_provider_stage_variables(self): } } provider = ApiProvider(template) - api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], - stage_name='dev', - stage_variables={ - "vis": "data", - "random": "test", - "foo": "bar" - }) + route1 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') - self.assertIn(api1, provider.apis) + self.assertIn(route1, provider.routes) + self.assertEquals(provider.api.stage_name, "dev") + self.assertEquals(provider.api.stage_variables, { + "vis": "data", + "random": "test", + "foo": "bar" + }) def test_multi_stage_get_all(self): - template = { - "Resources": { - "TestApi": { - "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "dev", - "Variables": { - "vis": "data", - "random": "test", - "foo": "bar" - }, - "DefinitionBody": { - "paths": { - "/path2": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } + template = OrderedDict({ + "Resources": {} + }) + template["Resources"]["TestApi"] = { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "dev", + "Variables": { + "vis": "data", + "random": "test", + "foo": "bar" + }, + "DefinitionBody": { + "paths": { + "/path2": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, } } } + } + } + } + + template["Resources"]["ProductionApi"] = { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "Production", + "Variables": { + "vis": "prod data", + "random": "test", + "foo": "bar" }, - "ProductionApi": { - "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "Production", - "Variables": { - "vis": "prod data", - "random": "test", - "foo": "bar" + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } }, - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } + "/anotherpath": { + "post": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, }, - "/anotherpath": { - "post": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - } } + } } } } + provider = ApiProvider(template) result = [f for f in provider.get_all()] + routes = result[0].routes + + route1 = Route(path='/path2', methods=['GET'], function_name='NoApiEventFunction') + route2 = Route(path='/path', methods=['GET'], function_name='NoApiEventFunction') + route3 = Route(path='/anotherpath', methods=['POST'], function_name='NoApiEventFunction') + self.assertEquals(len(routes), 3) + self.assertIn(route1, routes) + self.assertIn(route2, routes) + self.assertIn(route3, routes) + + self.assertEquals(provider.api.stage_name, "Production") + self.assertEquals(provider.api.stage_variables, { + "vis": "prod data", + "random": "test", + "foo": "bar" + }) - api1 = Api(path='/path2', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], - stage_name='dev', - stage_variables={ - "vis": "data", - "random": "test", - "foo": "bar" - }) - api2 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], - stage_name='Production', stage_variables={'vis': 'prod data', 'random': 'test', 'foo': 'bar'}) - api3 = Api(path='/anotherpath', method='POST', function_name='NoApiEventFunction', cors=None, - binary_media_types=[], - stage_name='Production', - stage_variables={ - "vis": "prod data", - "random": "test", - "foo": "bar" - }) - self.assertEquals(len(result), 3) - self.assertIn(api1, result) - self.assertIn(api2, result) - self.assertIn(api3, result) - - -def make_swagger(apis, binary_media_types=None): + +def make_swagger(routes, binary_media_types=None): """ Given a list of API configurations named tuples, returns a Swagger document Parameters ---------- - apis : list of samcli.commands.local.lib.provider.Api + routes : list of samcli.commands.local.agiw.local_agiw_service.Route binary_media_types : list of str Returns @@ -1145,7 +1100,7 @@ def make_swagger(apis, binary_media_types=None): } } - for api in apis: + for api in routes: swagger["paths"].setdefault(api.path, {}) integration = { @@ -1156,12 +1111,11 @@ def make_swagger(apis, binary_media_types=None): api.function_name) # NOQA } } + for method in api.methods: + if method.lower() == "any": + method = "x-amazon-apigateway-any-method" - method = api.method - if method.lower() == "any": - method = "x-amazon-apigateway-any-method" - - swagger["paths"][api.path][method] = integration + swagger["paths"][api.path][method] = integration if binary_media_types: swagger["x-amazon-apigateway-binary-media-types"] = binary_media_types diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index ba2d6316b5..9bbf52cc62 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -1,3 +1,4 @@ +import copy from unittest import TestCase from mock import Mock, patch, ANY import json @@ -6,6 +7,7 @@ from parameterized import parameterized, param from werkzeug.datastructures import Headers +from samcli.commands.local.lib.provider import Api from samcli.local.apigw.local_apigw_service import LocalApigwService, Route from samcli.local.lambdafn.exceptions import FunctionNotFound @@ -14,14 +16,15 @@ class TestApiGatewayService(TestCase): def setUp(self): self.function_name = Mock() - self.api_gateway_route = Route(['GET'], self.function_name, '/') + self.api_gateway_route = Route(methods=['GET'], function_name=self.function_name, path='/') self.list_of_routes = [self.api_gateway_route] self.lambda_runner = Mock() self.lambda_runner.is_debugging.return_value = False self.stderr = Mock() - self.service = LocalApigwService(self.list_of_routes, + self.api = Api(routes=self.list_of_routes) + self.service = LocalApigwService(self.api, self.lambda_runner, port=3000, host='127.0.0.1', @@ -102,14 +105,15 @@ def test_request_handler_returns_make_response(self): def test_create_creates_dict_of_routes(self): function_name_1 = Mock() function_name_2 = Mock() - api_gateway_route_1 = Route(['GET'], function_name_1, '/') - api_gateway_route_2 = Route(['POST'], function_name_2, '/') + api_gateway_route_1 = Route(methods=["GET"], function_name=function_name_1, path='/') + api_gateway_route_2 = Route(methods=["POST"], function_name=function_name_2, path='/') list_of_routes = [api_gateway_route_1, api_gateway_route_2] lambda_runner = Mock() - service = LocalApigwService(list_of_routes, lambda_runner) + api = Api(routes=list_of_routes) + service = LocalApigwService(api, lambda_runner) service.create() @@ -135,16 +139,16 @@ def test_create_creates_flask_app_with_url_rules(self, flask): def test_initalize_creates_default_values(self): self.assertEquals(self.service.port, 3000) self.assertEquals(self.service.host, '127.0.0.1') - self.assertEquals(self.service.routing_list, self.list_of_routes) + self.assertEquals(self.service.api.routes, self.list_of_routes) self.assertIsNone(self.service.static_dir) self.assertEquals(self.service.lambda_runner, self.lambda_runner) def test_initalize_with_values(self): lambda_runner = Mock() - local_service = LocalApigwService([], lambda_runner, static_dir='dir/static', port=5000, host='129.0.0.0') + local_service = LocalApigwService(Api(), lambda_runner, static_dir='dir/static', port=5000, host='129.0.0.0') self.assertEquals(local_service.port, 5000) self.assertEquals(local_service.host, '129.0.0.0') - self.assertEquals(local_service.routing_list, []) + self.assertEquals(local_service.api.routes, []) self.assertEquals(local_service.static_dir, 'dir/static') self.assertEquals(local_service.lambda_runner, lambda_runner) @@ -250,19 +254,12 @@ class TestApiGatewayModel(TestCase): def setUp(self): self.function_name = "name" - self.stage_name = "Dev" - self.stage_variables = { - "test": "sample" - } - self.api_gateway = Route(['POST'], self.function_name, '/', stage_name=self.stage_name, - stage_variables=self.stage_variables) + self.api_gateway = Route(function_name=self.function_name, methods=["Post"], path="/") def test_class_initialization(self): self.assertEquals(self.api_gateway.methods, ['POST']) self.assertEquals(self.api_gateway.function_name, self.function_name) self.assertEquals(self.api_gateway.path, '/') - self.assertEqual(self.api_gateway.stage_name, "Dev") - self.assertEqual(self.api_gateway.stage_variables, {"test": "sample"}) class TestLambdaHeaderDictionaryMerge(TestCase): @@ -488,7 +485,7 @@ def setUp(self): '"Custom User Agent String", "caller": null, "cognitoAuthenticationType": null, "sourceIp": ' \ '"190.0.0.0", "user": null}, "accountId": "123456789012"}, "headers": {"Content-Type": ' \ '"application/json", "X-Test": "Value", "X-Forwarded-Port": "3000", "X-Forwarded-Proto": "http"}, ' \ - '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], '\ + '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], ' \ '"X-Forwarded-Port": ["3000"], "X-Forwarded-Proto": ["http"]}, ' \ '"stageVariables": null, "path": "path", "pathParameters": {"path": "params"}, ' \ '"isBase64Encoded": false}' @@ -590,3 +587,60 @@ def test_should_base64_encode_returns_true(self, test_case_name, binary_types, m ]) def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype): self.assertFalse(LocalApigwService._should_base64_encode(binary_types, mimetype)) + + +class TestRouteEqualsHash(TestCase): + + def test_route_in_list(self): + route = Route(function_name="test", path="/test", methods=["POST"]) + routes = [route] + self.assertIn(route, routes) + + def test_route_method_order_equals(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + route2 = Route(function_name="test", path="/test", methods=["GET", "POST"]) + self.assertEquals(route1, route2) + + def test_route_hash(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + dic = {route1: "test"} + self.assertEquals(dic[route1], "test") + + def test_route_object_equals(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + route2 = type('obj', (object,), {'function_name': 'test', "path": "/test", "methods": ["GET", "POST"]}) + + self.assertNotEqual(route1, route2) + + def test_route_function_name_equals(self): + route1 = Route(function_name="test1", path="/test", methods=["GET", "POST"]) + route2 = Route(function_name="test2", path="/test", methods=["GET", "POST"]) + self.assertNotEqual(route1, route2) + + def test_route_different_path_equals(self): + route1 = Route(function_name="test", path="/test1", methods=["GET", "POST"]) + route2 = Route(function_name="test", path="/test2", methods=["GET", "POST"]) + self.assertNotEqual(route1, route2) + + def test_same_object_equals(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + self.assertEquals(route1, copy.deepcopy(route1)) + + def test_route_function_name_hash(self): + route1 = Route(function_name="test1", path="/test", methods=["GET", "POST"]) + route2 = Route(function_name="test2", path="/test", methods=["GET", "POST"]) + self.assertNotEqual(route1.__hash__(), route2.__hash__()) + + def test_route_different_path_hash(self): + route1 = Route(function_name="test", path="/test1", methods=["GET", "POST"]) + route2 = Route(function_name="test", path="/test2", methods=["GET", "POST"]) + self.assertNotEqual(route1.__hash__(), route2.__hash__()) + + def test_same_object_hash(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + self.assertEquals(route1.__hash__(), copy.deepcopy(route1).__hash__()) + + def test_route_method_order_hash(self): + route1 = Route(function_name="test", path="/test", methods=["POST", "GET"]) + route2 = Route(function_name="test", path="/test", methods=["GET", "POST"]) + self.assertEquals(route1.__hash__(), route2.__hash__())