From 7cb50b0174f85684e0e0e8ab15e056a309b6504b Mon Sep 17 00:00:00 2001 From: Vikranth Srivatsa <51216482+viksrivat@users.noreply.github.com> Date: Thu, 11 Jul 2019 06:26:03 -0700 Subject: [PATCH 1/7] feat(start-api): CloudFormation AWS::ApiGateway::RestApi support (#1238) --- samcli/commands/local/lib/api_collector.py | 215 ++++++++ samcli/commands/local/lib/api_provider.py | 94 ++++ samcli/commands/local/lib/cfn_api_provider.py | 69 +++ .../local/lib/cfn_base_api_provider.py | 70 +++ .../commands/local/lib/local_api_service.py | 22 +- samcli/commands/local/lib/provider.py | 57 ++- samcli/commands/local/lib/sam_api_provider.py | 462 +++--------------- .../commands/local/lib/sam_base_provider.py | 3 +- .../local/start_api/test_start_api.py | 106 ++++ .../start_api/swagger-rest-api-template.yaml | 69 +++ .../commands/local/lib/test_api_provider.py | 207 ++++++++ .../local/lib/test_cfn_api_provider.py | 215 ++++++++ .../local/lib/test_local_api_service.py | 4 +- .../local/lib/test_sam_api_provider.py | 84 ++-- 14 files changed, 1218 insertions(+), 459 deletions(-) create mode 100644 samcli/commands/local/lib/api_collector.py create mode 100644 samcli/commands/local/lib/api_provider.py create mode 100644 samcli/commands/local/lib/cfn_api_provider.py create mode 100644 samcli/commands/local/lib/cfn_base_api_provider.py create mode 100644 tests/integration/testdata/start_api/swagger-rest-api-template.yaml create mode 100644 tests/unit/commands/local/lib/test_api_provider.py create mode 100644 tests/unit/commands/local/lib/test_cfn_api_provider.py diff --git a/samcli/commands/local/lib/api_collector.py b/samcli/commands/local/lib/api_collector.py new file mode 100644 index 0000000000..cbd198c6b7 --- /dev/null +++ b/samcli/commands/local/lib/api_collector.py @@ -0,0 +1,215 @@ +""" +Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit +APIs in a standardized format +""" + +import logging +from collections import namedtuple + +from six import string_types + +LOG = logging.getLogger(__name__) + + +class ApiCollector(object): + # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. + # This is intentional because it allows us to easily extend this class to support future properties on the API. + # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit + # APIs. + Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) + + def __init__(self): + # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and + # value is the properties + self.by_resource = {} + + def __iter__(self): + """ + Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the + LogicalId of the API resource and a list of APIs available in this resource. + + Yields + ------- + str + LogicalID of the AWS::Serverless::Api resource + list samcli.commands.local.lib.provider.Api + List of the API available in this resource along with additional configuration like binary media types. + """ + + for logical_id, _ in self.by_resource.items(): + yield logical_id, self._get_apis_with_config(logical_id) + + def add_apis(self, logical_id, apis): + """ + Stores the given APIs tagged under the given logicalId + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + apis : list of samcli.commands.local.lib.provider.Api + List of APIs available in this resource + """ + properties = self._get_properties(logical_id) + properties.apis.extend(apis) + + def add_binary_media_types(self, logical_id, binary_media_types): + """ + Stores the binary media type configuration for the API with given logical ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + binary_media_types : list of str + List of binary media types supported by this resource + + """ + properties = self._get_properties(logical_id) + + binary_media_types = binary_media_types or [] + for value in binary_media_types: + normalized_value = self._normalize_binary_media_type(value) + + # If the value is not supported, then just skip it. + if normalized_value: + properties.binary_media_types.add(normalized_value) + else: + LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) + + def add_stage_name(self, logical_id, stage_name): + """ + Stores the stage name for the API with the given local ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + stage_name : str + The stage_name string + + """ + properties = self._get_properties(logical_id) + properties = properties._replace(stage_name=stage_name) + self._set_properties(logical_id, properties) + + def add_stage_variables(self, logical_id, stage_variables): + """ + Stores the stage variables for the API with the given local ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + stage_variables : dict + A dictionary containing stage variables. + + """ + properties = self._get_properties(logical_id) + properties = properties._replace(stage_variables=stage_variables) + self._set_properties(logical_id, properties) + + def _get_apis_with_config(self, logical_id): + """ + Returns the list of APIs in this resource along with other extra configuration such as binary media types, + cors etc. Additional configuration is merged directly into the API data because these properties, although + defined globally, actually apply to each API. + + Parameters + ---------- + logical_id : str + Logical ID of the resource to fetch data for + + Returns + ------- + list of samcli.commands.local.lib.provider.Api + List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, + then it returns an empty list + """ + + properties = self._get_properties(logical_id) + + # These configs need to be applied to each API + binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable + cors = properties.cors + stage_name = properties.stage_name + stage_variables = properties.stage_variables + + result = [] + for api in properties.apis: + # Create a copy of the API with updated configuration + updated_api = api._replace(binary_media_types=binary_media, + cors=cors, + stage_name=stage_name, + stage_variables=stage_variables) + result.append(updated_api) + + return result + + def _get_properties(self, logical_id): + """ + Returns the properties of resource with given logical ID. If a resource is not found, then it returns an + empty data. + + Parameters + ---------- + logical_id : str + Logical ID of the resource + + Returns + ------- + samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties + Properties object for this resource. + """ + + if logical_id not in self.by_resource: + self.by_resource[logical_id] = self.Properties(apis=[], + # Use a set() to be able to easily de-dupe + binary_media_types=set(), + cors=None, + stage_name=None, + stage_variables=None) + + return self.by_resource[logical_id] + + def _set_properties(self, logical_id, properties): + """ + Sets the properties of resource with given logical ID. If a resource is not found, it does nothing + + Parameters + ---------- + logical_id : str + Logical ID of the resource + properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties + Properties object for this resource. + """ + + if logical_id in self.by_resource: + self.by_resource[logical_id] = properties + + @staticmethod + def _normalize_binary_media_type(value): + """ + Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not + a string, then this method just returns None + + Parameters + ---------- + value : str + Value to be normalized + + Returns + ------- + str or None + Normalized value. If the input was not a string, then None is returned + """ + + if not isinstance(value, string_types): + # It is possible that user specified a dict value for one of the binary media types. We just skip them + return None + + return value.replace("~1", "/") diff --git a/samcli/commands/local/lib/api_provider.py b/samcli/commands/local/lib/api_provider.py new file mode 100644 index 0000000000..afc686e166 --- /dev/null +++ b/samcli/commands/local/lib/api_provider.py @@ -0,0 +1,94 @@ +"""Class that provides Apis from a SAM Template""" + +import logging + +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider +from samcli.commands.local.lib.api_collector import ApiCollector +from samcli.commands.local.lib.provider import AbstractApiProvider +from samcli.commands.local.lib.sam_base_provider import SamBaseProvider +from samcli.commands.local.lib.sam_api_provider import SamApiProvider +from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider + +LOG = logging.getLogger(__name__) + + +class ApiProvider(AbstractApiProvider): + + def __init__(self, template_dict, parameter_overrides=None, cwd=None): + """ + Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed + to be valid, normalized and a dictionary. template_dict should be normalized by running any and all + pre-processing before passing to this class. + This class does not perform any syntactic validation of the template. + + After the class is initialized, changes to ``template_dict`` will not be reflected in here. + You will need to explicitly update the class with new template, if necessary. + + Parameters + ---------- + template_dict : dict + SAM Template as a dictionary + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + """ + self.template_dict = SamBaseProvider.get_template(template_dict, parameter_overrides) + self.resources = self.template_dict.get("Resources", {}) + + LOG.debug("%d resources found in the template", len(self.resources)) + + # Store a set of apis + self.cwd = cwd + self.apis = self._extract_apis(self.resources) + + LOG.debug("%d APIs found in the template", len(self.apis)) + + def get_all(self): + """ + Yields all the Lambda functions with Api Events available in the SAM Template. + + :yields Api: namedtuple containing the Api information + """ + + for api in self.apis: + yield api + + def _extract_apis(self, resources): + """ + Extracts all the Apis by running through the one providers. The provider that has the first type matched + will be run across all the resources + + Parameters + ---------- + resources: dict + The dictionary containing the different resources within the template + Returns + --------- + list of Apis extracted from the resources + """ + collector = ApiCollector() + provider = self.find_api_provider(resources) + apis = provider.extract_resource_api(resources, collector, cwd=self.cwd) + return self.normalize_apis(apis) + + @staticmethod + def find_api_provider(resources): + """ + Finds the ApiProvider given the first api type of the resource + + Parameters + ----------- + resources: dict + The dictionary containing the different resources within the template + + Return + ---------- + Instance of the ApiProvider that will be run on the template with a default of SamApiProvider + """ + for _, resource in resources.items(): + if resource.get(CfnBaseApiProvider.RESOURCE_TYPE) in SamApiProvider.TYPES: + return SamApiProvider() + elif resource.get(CfnBaseApiProvider.RESOURCE_TYPE) in CfnApiProvider.TYPES: + return CfnApiProvider() + + return SamApiProvider() diff --git a/samcli/commands/local/lib/cfn_api_provider.py b/samcli/commands/local/lib/cfn_api_provider.py new file mode 100644 index 0000000000..0e3919611c --- /dev/null +++ b/samcli/commands/local/lib/cfn_api_provider.py @@ -0,0 +1,69 @@ +"""Parses SAM given a template""" +import logging + +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider + +LOG = logging.getLogger(__name__) + + +class CfnApiProvider(CfnBaseApiProvider): + APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" + TYPES = [ + APIGATEWAY_RESTAPI + ] + + def extract_resource_api(self, resources, collector, cwd=None): + """ + Extract the Api Object from a given resource and adds it to the ApiCollector. + + Parameters + ---------- + resources: dict + The dictionary containing the different resources within the template + + collector: ApiCollector + Instance of the API collector that where we will save the API information + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + + Return + ------- + Returns a list of Apis + """ + for logical_id, resource in resources.items(): + resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) + if resource_type == CfnApiProvider.APIGATEWAY_RESTAPI: + self._extract_cloud_formation_api(logical_id, resource, collector, cwd) + all_apis = [] + for _, apis in collector: + all_apis.extend(apis) + return all_apis + + def _extract_cloud_formation_api(self, logical_id, api_resource, collector, cwd=None): + """ + Extract APIs from AWS::ApiGateway::RestApi resource by reading and parsing Swagger documents. The result is + added to the collector. + + Parameters + ---------- + logical_id : str + Logical ID of the resource + + api_resource : dict + Resource definition, including its properties + + collector : ApiCollector + Instance of the API collector that where we will save the API information + """ + properties = api_resource.get("Properties", {}) + body = properties.get("Body") + body_s3_location = properties.get("BodyS3Location") + binary_media = properties.get("BinaryMediaTypes", []) + + if not body and not body_s3_location: + # Swagger is not found anywhere. + LOG.debug("Skipping resource '%s'. Swagger document not found in Body and BodyS3Location", + logical_id) + return + self.extract_swagger_api(logical_id, body, body_s3_location, binary_media, collector, cwd) diff --git a/samcli/commands/local/lib/cfn_base_api_provider.py b/samcli/commands/local/lib/cfn_base_api_provider.py new file mode 100644 index 0000000000..79bc6d8f1d --- /dev/null +++ b/samcli/commands/local/lib/cfn_base_api_provider.py @@ -0,0 +1,70 @@ +"""Class that parses the CloudFormation Api Template""" + +import logging + +from samcli.commands.local.lib.swagger.parser import SwaggerParser +from samcli.commands.local.lib.swagger.reader import SamSwaggerReader + +LOG = logging.getLogger(__name__) + + +class CfnBaseApiProvider(object): + RESOURCE_TYPE = "Type" + + def extract_resource_api(self, resources, collector, cwd=None): + """ + Extract the Api Object from a given resource and adds it to the ApiCollector. + + Parameters + ---------- + resources: dict + The dictionary containing the different resources within the template + + collector: ApiCollector + Instance of the API collector that where we will save the API information + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + + Return + ------- + Returns a list of Apis + """ + raise NotImplementedError("not implemented") + + @staticmethod + def extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd=None): + """ + Parse the Swagger documents and adds it to the ApiCollector. + + Parameters + ---------- + logical_id : str + Logical ID of the resource + + body : dict + The body of the RestApi + + uri : str or dict + The url to location of the RestApi + + binary_media: list + The link to the binary media + + collector: ApiCollector + Instance of the API collector that where we will save the API information + + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file + """ + reader = SamSwaggerReader(definition_body=body, + definition_uri=uri, + working_dir=cwd) + swagger = reader.read() + parser = SwaggerParser(swagger) + apis = parser.get_apis() + LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) + + collector.add_apis(logical_id, apis) + collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger + collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template diff --git a/samcli/commands/local/lib/local_api_service.py b/samcli/commands/local/lib/local_api_service.py index d0ebbbd975..d456e67a83 100644 --- a/samcli/commands/local/lib/local_api_service.py +++ b/samcli/commands/local/lib/local_api_service.py @@ -6,7 +6,7 @@ import logging from samcli.local.apigw.local_apigw_service import LocalApigwService, Route -from samcli.commands.local.lib.sam_api_provider import SamApiProvider +from samcli.commands.local.lib.api_provider import ApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined LOG = logging.getLogger(__name__) @@ -38,9 +38,9 @@ def __init__(self, self.static_dir = static_dir self.cwd = lambda_invoke_context.get_cwd() - self.api_provider = SamApiProvider(lambda_invoke_context.template, - parameter_overrides=lambda_invoke_context.parameter_overrides, - cwd=self.cwd) + self.api_provider = ApiProvider(lambda_invoke_context.template, + parameter_overrides=lambda_invoke_context.parameter_overrides, + cwd=self.cwd) self.lambda_runner = lambda_invoke_context.local_lambda_runner self.stderr_stream = lambda_invoke_context.stderr @@ -89,7 +89,7 @@ def _make_routing_list(api_provider): Parameters ---------- - api_provider : samcli.commands.local.lib.sam_api_provider.SamApiProvider + api_provider : samcli.commands.local.lib.api_provider.ApiProvider Returns ------- @@ -116,10 +116,14 @@ def _print_routes(api_provider, host, port): Mounting Product at http://127.0.0.1:3000/path1/bar [GET, POST, DELETE] Mounting Product at http://127.0.0.1:3000/path2/bar [HEAD] - :param samcli.commands.local.lib.provider.ApiProvider api_provider: API Provider that can return a list of APIs - :param string host: Host name where the service is running - :param int port: Port number where the service is running - :returns list(string): List of lines that were printed to the console. Helps with testing + :param samcli.commands.local.lib.provider.AbstractApiProvider api_provider: + API Provider that can return a list of APIs + :param string host: + Host name where the service is running + :param int port: + Port number where the service is running + :returns list(string): + List of lines that were printed to the console. Helps with testing """ grouped_api_configs = {} diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index eead981089..959166e814 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -222,10 +222,10 @@ def get_all(self): # The variables for that stage "stage_variables" ]) -_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, - None, # Stage name is optional with default None - None # Stage variables is optional with default None +_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None + [], # binary_media_types is optional and defaults to empty, + None, # Stage name is optional with default None + None # Stage variables is optional with default None ) @@ -238,10 +238,17 @@ def __hash__(self): Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) -class ApiProvider(object): +class AbstractApiProvider(object): """ Abstract base class to return APIs and the functions they route to """ + _ANY_HTTP_METHODS = ["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"] def get_all(self): """ @@ -250,3 +257,43 @@ def get_all(self): :yields Api: namedtuple containing the API information """ raise NotImplementedError("not implemented") + + @staticmethod + def normalize_http_methods(http_method): + """ + Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all + supported Http Methods on Api Gateway. + + :param str http_method: Http method + :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) + """ + + if http_method.upper() == 'ANY': + for method in AbstractApiProvider._ANY_HTTP_METHODS: + yield method.upper() + else: + yield http_method.upper() + + @staticmethod + def normalize_apis(apis): + """ + Normalize the APIs to use standard method name + + Parameters + ---------- + apis : list of samcli.commands.local.lib.provider.Api + List of APIs to replace normalize + + Returns + ------- + list of samcli.commands.local.lib.provider.Api + List of normalized APIs + """ + + result = list() + for api in apis: + for normalized_method in AbstractApiProvider.normalize_http_methods(api.method): + # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple + result.append(api._replace(method=normalized_method)) + + return result diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index 84336a8d2d..f0ec57b823 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -1,111 +1,60 @@ -"""Class that provides Apis from a SAM Template""" +"""Parses SAM given the template""" import logging -from collections import namedtuple -from six import string_types - -from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.provider import ApiProvider, Api -from samcli.commands.local.lib.sam_base_provider import SamBaseProvider -from samcli.commands.local.lib.swagger.reader import SamSwaggerReader +from samcli.commands.local.lib.provider import Api, AbstractApiProvider from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider LOG = logging.getLogger(__name__) -class SamApiProvider(ApiProvider): - _IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" - _SERVERLESS_FUNCTION = "AWS::Serverless::Function" - _SERVERLESS_API = "AWS::Serverless::Api" - _TYPE = "Type" - +class SamApiProvider(CfnBaseApiProvider): + SERVERLESS_FUNCTION = "AWS::Serverless::Function" + SERVERLESS_API = "AWS::Serverless::Api" + TYPES = [ + SERVERLESS_FUNCTION, + SERVERLESS_API + ] _FUNCTION_EVENT_TYPE_API = "Api" _FUNCTION_EVENT = "Events" _EVENT_PATH = "Path" _EVENT_METHOD = "Method" + _EVENT_TYPE = "Type" + IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" - _ANY_HTTP_METHODS = ["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"] - - def __init__(self, template_dict, parameter_overrides=None, cwd=None): + def extract_resource_api(self, resources, collector, cwd=None): """ - Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed - to be valid, normalized and a dictionary. template_dict should be normalized by running any and all - pre-processing before passing to this class. - This class does not perform any syntactic validation of the template. - - After the class is initialized, changes to ``template_dict`` will not be reflected in here. - You will need to explicitly update the class with new template, if necessary. + Extract the Api Object from a given resource and adds it to the ApiCollector. Parameters ---------- - template_dict : dict - SAM Template as a dictionary - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - """ - - self.template_dict = SamBaseProvider.get_template(template_dict, parameter_overrides) - self.resources = self.template_dict.get("Resources", {}) - - LOG.debug("%d resources found in the template", len(self.resources)) - - # Store a set of apis - self.cwd = cwd - self.apis = self._extract_apis(self.resources) - - LOG.debug("%d APIs found in the template", len(self.apis)) - - def get_all(self): - """ - Yields all the Lambda functions with Api Events available in the SAM Template. + resources: dict + The dictionary containing the different resources within the template - :yields Api: namedtuple containing the Api information - """ - - for api in self.apis: - yield api + collector: ApiCollector + Instance of the API collector that where we will save the API information - def _extract_apis(self, resources): - """ - Extract all Implicit Apis (Apis defined through Serverless Function with an Api Event + cwd : str + Optional working directory with respect to which we will resolve relative path to Swagger file - :param dict resources: Dictionary of SAM/CloudFormation resources - :return: List of nametuple Api + Return + ------- + Returns a list of Apis """ - - # Some properties like BinaryMediaTypes, Cors are set once on the resource but need to be applied to each API. - # For Implicit APIs, which are defined on the Function resource, these properties - # are defined on a AWS::Serverless::Api resource with logical ID "ServerlessRestApi". Therefore, no matter - # if it is an implicit API or an explicit API, there is a corresponding resource of type AWS::Serverless::Api - # that contains these additional configurations. - # - # We use this assumption in the following loop to collect information from resources of type - # AWS::Serverless::Api. We also extract API from Serverless::Function resource and add them to the - # corresponding Serverless::Api resource. This is all done using the ``collector``. - - collector = ApiCollector() - + # AWS::Serverless::Function is currently included when parsing of Apis because when SamBaseProvider is run on + # the template we are creating the implicit apis due to plugins that translate it in the SAM repo, + # which we later merge with the explicit ones in SamApiProvider.merge_apis. This requires the code to be + # parsed here and in InvokeContext. for logical_id, resource in resources.items(): - - resource_type = resource.get(SamApiProvider._TYPE) - - if resource_type == SamApiProvider._SERVERLESS_FUNCTION: + resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) + if resource_type == SamApiProvider.SERVERLESS_FUNCTION: self._extract_apis_from_function(logical_id, resource, collector) + if resource_type == SamApiProvider.SERVERLESS_API: + self._extract_from_serverless_api(logical_id, resource, collector, cwd) + return self.merge_apis(collector) - if resource_type == SamApiProvider._SERVERLESS_API: - self._extract_from_serverless_api(logical_id, resource, collector) - - apis = SamApiProvider._merge_apis(collector) - return self._normalize_apis(apis) - - def _extract_from_serverless_api(self, logical_id, api_resource, collector): + def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=None): """ Extract APIs from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result is added to the collector. @@ -134,99 +83,11 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector): LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri", logical_id) return - - reader = SamSwaggerReader(definition_body=body, - definition_uri=uri, - working_dir=self.cwd) - swagger = reader.read() - parser = SwaggerParser(swagger) - apis = parser.get_apis() - LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) - - collector.add_apis(logical_id, apis) - collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger - collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template - + self.extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd) collector.add_stage_name(logical_id, stage_name) collector.add_stage_variables(logical_id, stage_variables) - @staticmethod - def _merge_apis(collector): - """ - Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API - definition wins because that conveys clear intent that the API is backed by a function. This method will - merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined - in both the places, only one wins. - - Parameters - ---------- - collector : ApiCollector - Collector object that holds all the APIs specified in the template - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of APIs obtained by combining both the input lists. - """ - - implicit_apis = [] - explicit_apis = [] - - # Store implicit and explicit APIs separately in order to merge them later in the correct order - # Implicit APIs are defined on a resource with logicalID ServerlessRestApi - for logical_id, apis in collector: - if logical_id == SamApiProvider._IMPLICIT_API_RESOURCE_ID: - implicit_apis.extend(apis) - else: - explicit_apis.extend(apis) - - # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. - # If an path+method combo already exists, then overwrite it if and only if this is an implicit API - all_apis = {} - - # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already - # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. - all_configs = explicit_apis + implicit_apis - - for config in all_configs: - # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP - # method on explicit API. - for normalized_method in SamApiProvider._normalize_http_methods(config.method): - key = config.path + normalized_method - all_apis[key] = config - - result = set(all_apis.values()) # Assign to a set() to de-dupe - LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", - len(explicit_apis), len(implicit_apis), len(result)) - - return list(result) - - @staticmethod - def _normalize_apis(apis): - """ - Normalize the APIs to use standard method name - - Parameters - ---------- - apis : list of samcli.commands.local.lib.provider.Api - List of APIs to replace normalize - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of normalized APIs - """ - - result = list() - for api in apis: - for normalized_method in SamApiProvider._normalize_http_methods(api.method): - # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple - result.append(api._replace(method=normalized_method)) - - return result - - @staticmethod - def _extract_apis_from_function(logical_id, function_resource, collector): + def _extract_apis_from_function(self, logical_id, function_resource, collector): """ Fetches a list of APIs configured for this SAM Function resource. @@ -243,11 +104,10 @@ def _extract_apis_from_function(logical_id, function_resource, collector): """ resource_properties = function_resource.get("Properties", {}) - serverless_function_events = resource_properties.get(SamApiProvider._FUNCTION_EVENT, {}) - SamApiProvider._extract_apis_from_events(logical_id, serverless_function_events, collector) + serverless_function_events = resource_properties.get(self._FUNCTION_EVENT, {}) + self.extract_apis_from_events(logical_id, serverless_function_events, collector) - @staticmethod - def _extract_apis_from_events(function_logical_id, serverless_function_events, collector): + def extract_apis_from_events(self, function_logical_id, serverless_function_events, collector): """ Given an AWS::Serverless::Function Event Dictionary, extract out all 'Api' events and store within the collector @@ -266,8 +126,8 @@ def _extract_apis_from_events(function_logical_id, serverless_function_events, c count = 0 for _, event in serverless_function_events.items(): - if SamApiProvider._FUNCTION_EVENT_TYPE_API == event.get(SamApiProvider._TYPE): - api_resource_id, api = SamApiProvider._convert_event_api(function_logical_id, event.get("Properties")) + if self._FUNCTION_EVENT_TYPE_API == event.get(self._EVENT_TYPE): + api_resource_id, api = self._convert_event_api(function_logical_id, event.get("Properties")) collector.add_apis(api_resource_id, [api]) count += 1 @@ -288,7 +148,7 @@ def _convert_event_api(lambda_logical_id, event_properties): # An API Event, can have RestApiId property which designates the resource that owns this API. If omitted, # the API is owned by Implicit API resource. This could either be a direct resource logical ID or a # "Ref" of the logicalID - api_resource_id = event_properties.get("RestApiId", SamApiProvider._IMPLICIT_API_RESOURCE_ID) + api_resource_id = event_properties.get("RestApiId", SamApiProvider.IMPLICIT_API_RESOURCE_ID) if isinstance(api_resource_id, dict) and "Ref" in api_resource_id: api_resource_id = api_resource_id["Ref"] @@ -302,226 +162,52 @@ def _convert_event_api(lambda_logical_id, event_properties): return api_resource_id, Api(path=path, method=method, function_name=lambda_logical_id) @staticmethod - def _normalize_http_methods(http_method): - """ - Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all - supported Http Methods on Api Gateway. - - :param str http_method: Http method - :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) - """ - - if http_method.upper() == 'ANY': - for method in SamApiProvider._ANY_HTTP_METHODS: - yield method.upper() - else: - yield http_method.upper() - - -class ApiCollector(object): - """ - Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit - APIs in a standardized format - """ - - # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. - # This is intentional because it allows us to easily extend this class to support future properties on the API. - # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit - # APIs. - Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) - - def __init__(self): - # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and - # value is the properties - self.by_resource = {} - - def __iter__(self): - """ - Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the - LogicalId of the API resource and a list of APIs available in this resource. - - Yields - ------- - str - LogicalID of the AWS::Serverless::Api resource - list samcli.commands.local.lib.provider.Api - List of the API available in this resource along with additional configuration like binary media types. - """ - - for logical_id, _ in self.by_resource.items(): - yield logical_id, self._get_apis_with_config(logical_id) - - def add_apis(self, logical_id, apis): - """ - Stores the given APIs tagged under the given logicalId - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - apis : list of samcli.commands.local.lib.provider.Api - List of APIs available in this resource - """ - properties = self._get_properties(logical_id) - properties.apis.extend(apis) - - def add_binary_media_types(self, logical_id, binary_media_types): - """ - Stores the binary media type configuration for the API with given logical ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - binary_media_types : list of str - List of binary media types supported by this resource - - """ - properties = self._get_properties(logical_id) - - binary_media_types = binary_media_types or [] - for value in binary_media_types: - normalized_value = self._normalize_binary_media_type(value) - - # If the value is not supported, then just skip it. - if normalized_value: - properties.binary_media_types.add(normalized_value) - else: - LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) - - def add_stage_name(self, logical_id, stage_name): - """ - Stores the stage name for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_name : str - The stage_name string - - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_name=stage_name) - self._set_properties(logical_id, properties) - - def add_stage_variables(self, logical_id, stage_variables): - """ - Stores the stage variables for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_variables : dict - A dictionary containing stage variables. - - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_variables=stage_variables) - self._set_properties(logical_id, properties) - - def _get_apis_with_config(self, logical_id): + def merge_apis(collector): """ - Returns the list of APIs in this resource along with other extra configuration such as binary media types, - cors etc. Additional configuration is merged directly into the API data because these properties, although - defined globally, actually apply to each API. + Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API + definition wins because that conveys clear intent that the API is backed by a function. This method will + merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined + in both the places, only one wins. Parameters ---------- - logical_id : str - Logical ID of the resource to fetch data for + collector : ApiCollector + Collector object that holds all the APIs specified in the template Returns ------- list of samcli.commands.local.lib.provider.Api - List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, - then it returns an empty list - """ - - properties = self._get_properties(logical_id) - - # These configs need to be applied to each API - binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable - cors = properties.cors - stage_name = properties.stage_name - stage_variables = properties.stage_variables - - result = [] - for api in properties.apis: - # Create a copy of the API with updated configuration - updated_api = api._replace(binary_media_types=binary_media, - cors=cors, - stage_name=stage_name, - stage_variables=stage_variables) - result.append(updated_api) - - return result - - def _get_properties(self, logical_id): - """ - Returns the properties of resource with given logical ID. If a resource is not found, then it returns an - empty data. - - Parameters - ---------- - logical_id : str - Logical ID of the resource - - Returns - ------- - samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ - - if logical_id not in self.by_resource: - self.by_resource[logical_id] = self.Properties(apis=[], - # Use a set() to be able to easily de-dupe - binary_media_types=set(), - cors=None, - stage_name=None, - stage_variables=None) - - return self.by_resource[logical_id] - - def _set_properties(self, logical_id, properties): + List of APIs obtained by combining both the input lists. """ - Sets the properties of resource with given logical ID. If a resource is not found, it does nothing - Parameters - ---------- - logical_id : str - Logical ID of the resource - properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ + implicit_apis = [] + explicit_apis = [] - if logical_id in self.by_resource: - self.by_resource[logical_id] = properties + # Store implicit and explicit APIs separately in order to merge them later in the correct order + # Implicit APIs are defined on a resource with logicalID ServerlessRestApi + for logical_id, apis in collector: + if logical_id == SamApiProvider.IMPLICIT_API_RESOURCE_ID: + implicit_apis.extend(apis) + else: + explicit_apis.extend(apis) - @staticmethod - def _normalize_binary_media_type(value): - """ - Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not - a string, then this method just returns None + # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. + # If an path+method combo already exists, then overwrite it if and only if this is an implicit API + all_apis = {} - Parameters - ---------- - value : str - Value to be normalized + # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already + # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. + all_configs = explicit_apis + implicit_apis - Returns - ------- - str or None - Normalized value. If the input was not a string, then None is returned - """ + for config in all_configs: + # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP + # method on explicit API. + for normalized_method in AbstractApiProvider.normalize_http_methods(config.method): + key = config.path + normalized_method + all_apis[key] = config - if not isinstance(value, string_types): - # It is possible that user specified a dict value for one of the binary media types. We just skip them - return None + result = set(all_apis.values()) # Assign to a set() to de-dupe + LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", + len(explicit_apis), len(implicit_apis), len(result)) - return value.replace("~1", "/") + return list(result) diff --git a/samcli/commands/local/lib/sam_base_provider.py b/samcli/commands/local/lib/sam_base_provider.py index bbf4d6381b..861e1fd47a 100644 --- a/samcli/commands/local/lib/sam_base_provider.py +++ b/samcli/commands/local/lib/sam_base_provider.py @@ -10,7 +10,6 @@ from samcli.lib.samlib.wrapper import SamTranslatorWrapper from samcli.lib.samlib.resource_metadata_normalizer import ResourceMetadataNormalizer - LOG = logging.getLogger(__name__) @@ -89,7 +88,7 @@ def _resolve_parameters(template_dict, parameter_overrides): supported_intrinsics = {action.intrinsic_name: action() for action in SamBaseProvider._SUPPORTED_INTRINSICS} # Intrinsics resolver will mutate the original template - return IntrinsicsResolver(parameters=parameter_values, supported_intrinsics=supported_intrinsics)\ + return IntrinsicsResolver(parameters=parameter_values, supported_intrinsics=supported_intrinsics) \ .resolve_parameter_refs(template_dict) @staticmethod diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 700491260d..321741e0bf 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -288,6 +288,112 @@ def test_binary_response(self): self.assertEquals(response.content, expected) +class TestStartApiWithSwaggerRestApis(StartApiIntegBaseClass): + template_path = "/testdata/start_api/swagger-rest-api-template.yaml" + binary_data_file = "testdata/start_api/binarydata.gif" + + def setUp(self): + self.url = "http://127.0.0.1:{}".format(self.port) + + def test_get_call_with_path_setup_with_any_swagger(self): + """ + Get Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.get(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_post_call_with_path_setup_with_any_swagger(self): + """ + Post Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.post(self.url + "/anyandall", json={}) + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_put_call_with_path_setup_with_any_swagger(self): + """ + Put Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.put(self.url + "/anyandall", json={}) + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_head_call_with_path_setup_with_any_swagger(self): + """ + Head Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.head(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + + def test_delete_call_with_path_setup_with_any_swagger(self): + """ + Delete Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.delete(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_options_call_with_path_setup_with_any_swagger(self): + """ + Options Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.options(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + + def test_patch_call_with_path_setup_with_any_swagger(self): + """ + Patch Request to a path that was defined as ANY in SAM through Swagger + """ + response = requests.patch(self.url + "/anyandall") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_function_not_defined_in_template(self): + response = requests.get(self.url + "/nofunctionfound") + + self.assertEquals(response.status_code, 502) + self.assertEquals(response.json(), {"message": "No function defined for resource method"}) + + def test_lambda_function_resource_is_reachable(self): + response = requests.get(self.url + "/nonserverlessfunction") + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.json(), {'hello': 'world'}) + + def test_binary_request(self): + """ + This tests that the service can accept and invoke a lambda when given binary data in a request + """ + input_data = self.get_binary_data(self.binary_data_file) + response = requests.post(self.url + '/echobase64eventbody', + headers={"Content-Type": "image/gif"}, + data=input_data) + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Content-Type"), "image/gif") + self.assertEquals(response.content, input_data) + + def test_binary_response(self): + """ + Binary data is returned correctly + """ + expected = self.get_binary_data(self.binary_data_file) + + response = requests.get(self.url + '/base64response') + + self.assertEquals(response.status_code, 200) + self.assertEquals(response.headers.get("Content-Type"), "image/gif") + self.assertEquals(response.content, expected) + + class TestServiceResponses(StartApiIntegBaseClass): """ Test Class centered around the different responses that can happen in Lambda and pass through start-api diff --git a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml new file mode 100644 index 0000000000..5edeb8717f --- /dev/null +++ b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml @@ -0,0 +1,69 @@ +AWSTemplateFormatVersion: '2010-09-09' + +Resources: + Base64ResponseFunction: + Properties: + Code: "." + Handler: main.base64_response + Runtime: python3.6 + Type: AWS::Lambda::Function + EchoBase64EventBodyFunction: + Properties: + Code: "." + Handler: main.echo_base64_event_body + Runtime: python3.6 + Type: AWS::Lambda::Function + MyApi: + Properties: + Body: + info: + title: + Ref: AWS::StackName + paths: + "/anyandall": + x-amazon-apigateway-any-method: + x-amazon-apigateway-integration: + httpMethod: POST + responses: {} + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyNonServerlessLambdaFunction.Arn}/invocations + "/base64response": + get: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${Base64ResponseFunction.Arn}/invocations + "/echobase64eventbody": + post: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EchoBase64EventBodyFunction.Arn}/invocations + "/nofunctionfound": + get: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${WhatFunction.Arn}/invocations + "/nonserverlessfunction": + get: + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyNonServerlessLambdaFunction.Arn}/invocations + swagger: '2.0' + x-amazon-apigateway-binary-media-types: + - image/gif + StageName: prod + Type: AWS::ApiGateway::RestApi + MyNonServerlessLambdaFunction: + Properties: + Code: "." + Handler: main.handler + Runtime: python3.6 + Type: AWS::Lambda::Function diff --git a/tests/unit/commands/local/lib/test_api_provider.py b/tests/unit/commands/local/lib/test_api_provider.py new file mode 100644 index 0000000000..50b8d073d4 --- /dev/null +++ b/tests/unit/commands/local/lib/test_api_provider.py @@ -0,0 +1,207 @@ +from collections import OrderedDict +from unittest import TestCase + +from mock import patch + +from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.commands.local.lib.sam_api_provider import SamApiProvider +from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider + + +class TestApiProvider_init(TestCase): + + @patch.object(ApiProvider, "_extract_apis") + @patch("samcli.commands.local.lib.api_provider.SamBaseProvider") + def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): + extract_api_mock.return_value = {"set", "of", "values"} + + template = {"Resources": {"a": "b"}} + SamBaseProviderMock.get_template.return_value = template + + provider = ApiProvider(template) + + self.assertEquals(len(provider.apis), 3) + self.assertEquals(provider.apis, set(["set", "of", "values"])) + self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) + self.assertEquals(provider.resources, {"a": "b"}) + + +class TestApiProviderSelection(TestCase): + def test_default_provider(self): + resources = { + "TestApi": { + "Type": "AWS::UNKNOWN_TYPE", + "Properties": { + "StageName": "dev", + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + self.assertTrue(isinstance(provider, SamApiProvider)) + + def test_api_provider_sam_api(self): + resources = { + "TestApi": { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "dev", + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + self.assertTrue(isinstance(provider, SamApiProvider)) + + def test_api_provider_sam_function(self): + resources = { + "TestApi": { + "Type": "AWS::Serverless::Function", + "Properties": { + "StageName": "dev", + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + + self.assertTrue(isinstance(provider, SamApiProvider)) + + def test_api_provider_cloud_formation(self): + resources = { + "TestApi": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "StageName": "dev", + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + self.assertTrue(isinstance(provider, CfnApiProvider)) + + def test_multiple_api_provider_cloud_formation(self): + resources = OrderedDict() + resources["TestApi"] = { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "StageName": "dev", + "Body": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + resources["OtherApi"] = { + "Type": "AWS::Serverless::Api", + "Properties": { + "StageName": "dev", + "DefinitionBody": { + "paths": { + "/path": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" + "/functions/${NoApiEventFunction.Arn}/invocations", + }, + "responses": {}, + }, + } + } + + } + } + } + } + + provider = ApiProvider.find_api_provider(resources) + self.assertTrue(isinstance(provider, CfnApiProvider)) diff --git a/tests/unit/commands/local/lib/test_cfn_api_provider.py b/tests/unit/commands/local/lib/test_cfn_api_provider.py new file mode 100644 index 0000000000..723951eb11 --- /dev/null +++ b/tests/unit/commands/local/lib/test_cfn_api_provider.py @@ -0,0 +1,215 @@ +import json +import tempfile +from unittest import TestCase + +from mock import patch +from six import assertCountEqual + +from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.commands.local.lib.provider import Api +from tests.unit.commands.local.lib.test_sam_api_provider import make_swagger + + +class TestApiProviderWithApiGatewayRestApi(TestCase): + + def setUp(self): + self.binary_types = ["image/png", "image/jpg"] + self.input_apis = [ + Api(path="/path1", method="GET", function_name="SamFunc1", cors=None), + Api(path="/path1", method="POST", function_name="SamFunc1", cors=None), + + Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None), + Api(path="/path2", method="GET", function_name="SamFunc1", cors=None), + + Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None) + ] + + def test_with_no_apis(self): + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + }, + + } + } + } + + provider = ApiProvider(template) + + self.assertEquals(provider.apis, []) + + def test_with_inline_swagger_apis(self): + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": make_swagger(self.input_apis) + } + } + } + } + + provider = ApiProvider(template) + assertCountEqual(self, self.input_apis, provider.apis) + + def test_with_swagger_as_local_file(self): + with tempfile.NamedTemporaryFile(mode='w') as fp: + filename = fp.name + + swagger = make_swagger(self.input_apis) + json.dump(swagger, fp) + fp.flush() + + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "BodyS3Location": filename + } + } + } + } + + provider = ApiProvider(template) + assertCountEqual(self, self.input_apis, provider.apis) + + def test_body_with_swagger_as_local_file_expect_fail(self): + with tempfile.NamedTemporaryFile(mode='w') as fp: + filename = fp.name + + swagger = make_swagger(self.input_apis) + json.dump(swagger, fp) + fp.flush() + + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": filename + } + } + } + } + self.assertRaises(Exception, ApiProvider, template) + + @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") + def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): + body = {"some": "body"} + filename = "somefile.txt" + + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "BodyS3Location": filename, + "Body": body + } + } + } + } + + SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) + + cwd = "foo" + provider = ApiProvider(template, cwd=cwd) + assertCountEqual(self, self.input_apis, provider.apis) + SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) + + def test_swagger_with_any_method(self): + apis = [ + Api(path="/path", method="any", function_name="SamFunc1", cors=None) + ] + + expected_apis = [ + Api(path="/path", method="GET", function_name="SamFunc1", cors=None), + Api(path="/path", method="POST", function_name="SamFunc1", cors=None), + Api(path="/path", method="PUT", function_name="SamFunc1", cors=None), + Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None), + Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None), + Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None), + Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None) + ] + + template = { + "Resources": { + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": make_swagger(apis) + } + } + } + } + + provider = ApiProvider(template) + assertCountEqual(self, expected_apis, provider.apis) + + def test_with_binary_media_types(self): + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "Body": make_swagger(self.input_apis, binary_media_types=self.binary_types) + } + } + } + } + + expected_binary_types = sorted(self.binary_types) + expected_apis = [ + Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types), + Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types), + + Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types), + Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types), + + Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, + binary_media_types=expected_binary_types) + ] + + provider = ApiProvider(template) + assertCountEqual(self, expected_apis, provider.apis) + + def test_with_binary_media_types_in_swagger_and_on_resource(self): + input_apis = [ + Api(path="/path", method="OPTIONS", function_name="SamFunc1"), + ] + extra_binary_types = ["text/html"] + + template = { + "Resources": { + + "Api1": { + "Type": "AWS::ApiGateway::RestApi", + "Properties": { + "BinaryMediaTypes": extra_binary_types, + "Body": make_swagger(input_apis, binary_media_types=self.binary_types) + } + } + } + } + + expected_binary_types = sorted(self.binary_types + extra_binary_types) + expected_apis = [ + Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=expected_binary_types), + ] + + provider = ApiProvider(template) + assertCountEqual(self, expected_apis, provider.apis) diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index 3cc5d2c4c3..cfa35af954 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -35,7 +35,7 @@ def setUp(self): self.lambda_invoke_context_mock.stderr = self.stderr_mock @patch("samcli.commands.local.lib.local_api_service.LocalApigwService") - @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") + @patch("samcli.commands.local.lib.local_api_service.ApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") @patch.object(LocalApiService, "_make_routing_list") @@ -77,7 +77,7 @@ def test_must_start_service(self, self.apigw_service.run.assert_called_with() @patch("samcli.commands.local.lib.local_api_service.LocalApigwService") - @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") + @patch("samcli.commands.local.lib.local_api_service.ApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") @patch.object(LocalApiService, "_make_routing_list") diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 3ac01956be..fa5f342e49 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -7,29 +7,11 @@ from six import assertCountEqual -from samcli.commands.local.lib.sam_api_provider import SamApiProvider +from samcli.commands.local.lib.api_provider import ApiProvider, SamApiProvider from samcli.commands.local.lib.provider import Api from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException -class TestSamApiProvider_init(TestCase): - - @patch.object(SamApiProvider, "_extract_apis") - @patch("samcli.commands.local.lib.sam_api_provider.SamBaseProvider") - def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): - extract_api_mock.return_value = {"set", "of", "values"} - - template = {"Resources": {"a": "b"}} - SamBaseProviderMock.get_template.return_value = template - - provider = SamApiProvider(template) - - self.assertEquals(len(provider.apis), 3) - self.assertEquals(provider.apis, set(["set", "of", "values"])) - self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) - self.assertEquals(provider.resources, {"a": "b"}) - - class TestSamApiProviderWithImplicitApis(TestCase): def test_provider_with_no_resource_properties(self): @@ -42,9 +24,8 @@ def test_provider_with_no_resource_properties(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) - self.assertEquals(len(provider.apis), 0) self.assertEquals(provider.apis, []) @parameterized.expand([("GET"), ("get")]) @@ -72,7 +53,7 @@ def test_provider_has_correct_api(self, method): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) self.assertEquals(len(provider.apis), 1) self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", cors=None, @@ -109,7 +90,7 @@ def test_provider_creates_api_for_all_events(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api_event1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api_event2 = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") @@ -159,7 +140,7 @@ def test_provider_has_correct_template(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api2 = Api(path="/path", method="POST", function_name="SamFunc2", cors=None, stage_name="Prod") @@ -190,7 +171,7 @@ def test_provider_with_no_api_events(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) self.assertEquals(provider.apis, []) @@ -209,7 +190,7 @@ def test_provider_with_no_serverless_function(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) self.assertEquals(provider.apis, []) @@ -254,7 +235,7 @@ def test_provider_get_all(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] @@ -267,7 +248,7 @@ def test_provider_get_all(self): def test_provider_get_all_with_no_apis(self): template = {} - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] @@ -298,7 +279,7 @@ def test_provider_with_any_method(self, method): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api_get = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api_post = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") @@ -351,7 +332,7 @@ def test_provider_must_support_binary_media_types(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) self.assertEquals(len(provider.apis), 1) self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", @@ -403,7 +384,7 @@ def test_provider_must_support_binary_media_types_with_any_method(self): Api(path="/path", method="PATCH", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod") ] - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, provider.apis, expected_apis) @@ -448,9 +429,8 @@ def test_with_no_apis(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) - self.assertEquals(len(provider.apis), 0) self.assertEquals(provider.apis, []) def test_with_inline_swagger_apis(self): @@ -467,7 +447,7 @@ def test_with_inline_swagger_apis(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, self.input_apis, provider.apis) def test_with_swagger_as_local_file(self): @@ -491,11 +471,11 @@ def test_with_swagger_as_local_file(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, self.input_apis, provider.apis) - @patch("samcli.commands.local.lib.sam_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri(self, SamSwaggerReaderMock): + @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") + def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): body = {"some": "body"} filename = "somefile.txt" @@ -516,7 +496,7 @@ def test_with_swagger_as_both_body_and_uri(self, SamSwaggerReaderMock): SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) cwd = "foo" - provider = SamApiProvider(template, cwd=cwd) + provider = ApiProvider(template, cwd=cwd) assertCountEqual(self, self.input_apis, provider.apis) SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) @@ -547,7 +527,7 @@ def test_swagger_with_any_method(self): } } - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) def test_with_binary_media_types(self): @@ -580,7 +560,7 @@ def test_with_binary_media_types(self): binary_media_types=expected_binary_types, stage_name="Prod") ] - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) def test_with_binary_media_types_in_swagger_and_on_resource(self): @@ -609,7 +589,7 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): stage_name="Prod"), ] - provider = SamApiProvider(template) + provider = ApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) @@ -686,7 +666,7 @@ def test_must_union_implicit_and_explicit(self): Api(path="/path3", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_prefer_implicit_api_over_explicit(self): @@ -723,7 +703,7 @@ def test_must_prefer_implicit_api_over_explicit(self): Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_prefer_implicit_with_any_method(self): @@ -757,7 +737,7 @@ def test_must_prefer_implicit_with_any_method(self): Api(path="/path", method="PATCH", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_with_any_method_on_both(self): @@ -802,8 +782,7 @@ def test_with_any_method_on_both(self): Api(path="/path2", method="POST", function_name="explicitfunction", cors=None, stage_name="Prod") ] - provider = SamApiProvider(self.template) - print(provider.apis) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_add_explicit_api_when_ref_with_rest_api_id(self): @@ -840,7 +819,7 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): Api(path="/newpath2", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_both_apis_must_get_binary_media_types(self): @@ -895,7 +874,7 @@ def test_both_apis_must_get_binary_media_types(self): stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_binary_media_types_with_rest_api_id_reference(self): @@ -955,7 +934,7 @@ def test_binary_media_types_with_rest_api_id_reference(self): stage_name="Prod") ] - provider = SamApiProvider(self.template) + provider = ApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) @@ -991,7 +970,7 @@ def test_provider_parse_stage_name(self): } } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], stage_name='dev', stage_variables=None) @@ -1033,7 +1012,7 @@ def test_provider_stage_variables(self): } } } - provider = SamApiProvider(template) + provider = ApiProvider(template) api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], stage_name='dev', stage_variables={ @@ -1119,8 +1098,7 @@ def test_multi_stage_get_all(self): } } } - - provider = SamApiProvider(template) + provider = ApiProvider(template) result = [f for f in provider.get_all()] From 9306b9ca39ec546cee2e46a50a14eb33c112801b Mon Sep 17 00:00:00 2001 From: Sanath Kumar Ramesh Date: Mon, 22 Jul 2019 14:30:29 -0700 Subject: [PATCH 2/7] Revert "feat(start-api): CloudFormation AWS::ApiGateway::RestApi support (#1238)" (#1282) This reverts commit 7cb50b0174f85684e0e0e8ab15e056a309b6504b. --- samcli/commands/local/lib/api_collector.py | 215 -------- samcli/commands/local/lib/api_provider.py | 94 ---- samcli/commands/local/lib/cfn_api_provider.py | 69 --- .../local/lib/cfn_base_api_provider.py | 70 --- .../commands/local/lib/local_api_service.py | 22 +- samcli/commands/local/lib/provider.py | 57 +-- samcli/commands/local/lib/sam_api_provider.py | 462 +++++++++++++++--- .../commands/local/lib/sam_base_provider.py | 3 +- .../local/start_api/test_start_api.py | 106 ---- .../start_api/swagger-rest-api-template.yaml | 69 --- .../commands/local/lib/test_api_provider.py | 207 -------- .../local/lib/test_cfn_api_provider.py | 215 -------- .../local/lib/test_local_api_service.py | 4 +- .../local/lib/test_sam_api_provider.py | 84 ++-- 14 files changed, 459 insertions(+), 1218 deletions(-) delete mode 100644 samcli/commands/local/lib/api_collector.py delete mode 100644 samcli/commands/local/lib/api_provider.py delete mode 100644 samcli/commands/local/lib/cfn_api_provider.py delete mode 100644 samcli/commands/local/lib/cfn_base_api_provider.py delete mode 100644 tests/integration/testdata/start_api/swagger-rest-api-template.yaml delete mode 100644 tests/unit/commands/local/lib/test_api_provider.py delete mode 100644 tests/unit/commands/local/lib/test_cfn_api_provider.py diff --git a/samcli/commands/local/lib/api_collector.py b/samcli/commands/local/lib/api_collector.py deleted file mode 100644 index cbd198c6b7..0000000000 --- a/samcli/commands/local/lib/api_collector.py +++ /dev/null @@ -1,215 +0,0 @@ -""" -Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit -APIs in a standardized format -""" - -import logging -from collections import namedtuple - -from six import string_types - -LOG = logging.getLogger(__name__) - - -class ApiCollector(object): - # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. - # This is intentional because it allows us to easily extend this class to support future properties on the API. - # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit - # APIs. - Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) - - def __init__(self): - # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and - # value is the properties - self.by_resource = {} - - def __iter__(self): - """ - Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the - LogicalId of the API resource and a list of APIs available in this resource. - - Yields - ------- - str - LogicalID of the AWS::Serverless::Api resource - list samcli.commands.local.lib.provider.Api - List of the API available in this resource along with additional configuration like binary media types. - """ - - for logical_id, _ in self.by_resource.items(): - yield logical_id, self._get_apis_with_config(logical_id) - - def add_apis(self, logical_id, apis): - """ - Stores the given APIs tagged under the given logicalId - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - apis : list of samcli.commands.local.lib.provider.Api - List of APIs available in this resource - """ - properties = self._get_properties(logical_id) - properties.apis.extend(apis) - - def add_binary_media_types(self, logical_id, binary_media_types): - """ - Stores the binary media type configuration for the API with given logical ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - binary_media_types : list of str - List of binary media types supported by this resource - - """ - properties = self._get_properties(logical_id) - - binary_media_types = binary_media_types or [] - for value in binary_media_types: - normalized_value = self._normalize_binary_media_type(value) - - # If the value is not supported, then just skip it. - if normalized_value: - properties.binary_media_types.add(normalized_value) - else: - LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) - - def add_stage_name(self, logical_id, stage_name): - """ - Stores the stage name for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_name : str - The stage_name string - - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_name=stage_name) - self._set_properties(logical_id, properties) - - def add_stage_variables(self, logical_id, stage_variables): - """ - Stores the stage variables for the API with the given local ID - - Parameters - ---------- - logical_id : str - LogicalId of the AWS::Serverless::Api resource - - stage_variables : dict - A dictionary containing stage variables. - - """ - properties = self._get_properties(logical_id) - properties = properties._replace(stage_variables=stage_variables) - self._set_properties(logical_id, properties) - - def _get_apis_with_config(self, logical_id): - """ - Returns the list of APIs in this resource along with other extra configuration such as binary media types, - cors etc. Additional configuration is merged directly into the API data because these properties, although - defined globally, actually apply to each API. - - Parameters - ---------- - logical_id : str - Logical ID of the resource to fetch data for - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, - then it returns an empty list - """ - - properties = self._get_properties(logical_id) - - # These configs need to be applied to each API - binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable - cors = properties.cors - stage_name = properties.stage_name - stage_variables = properties.stage_variables - - result = [] - for api in properties.apis: - # Create a copy of the API with updated configuration - updated_api = api._replace(binary_media_types=binary_media, - cors=cors, - stage_name=stage_name, - stage_variables=stage_variables) - result.append(updated_api) - - return result - - def _get_properties(self, logical_id): - """ - Returns the properties of resource with given logical ID. If a resource is not found, then it returns an - empty data. - - Parameters - ---------- - logical_id : str - Logical ID of the resource - - Returns - ------- - samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ - - if logical_id not in self.by_resource: - self.by_resource[logical_id] = self.Properties(apis=[], - # Use a set() to be able to easily de-dupe - binary_media_types=set(), - cors=None, - stage_name=None, - stage_variables=None) - - return self.by_resource[logical_id] - - def _set_properties(self, logical_id, properties): - """ - Sets the properties of resource with given logical ID. If a resource is not found, it does nothing - - Parameters - ---------- - logical_id : str - Logical ID of the resource - properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties - Properties object for this resource. - """ - - if logical_id in self.by_resource: - self.by_resource[logical_id] = properties - - @staticmethod - def _normalize_binary_media_type(value): - """ - Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not - a string, then this method just returns None - - Parameters - ---------- - value : str - Value to be normalized - - Returns - ------- - str or None - Normalized value. If the input was not a string, then None is returned - """ - - if not isinstance(value, string_types): - # It is possible that user specified a dict value for one of the binary media types. We just skip them - return None - - return value.replace("~1", "/") diff --git a/samcli/commands/local/lib/api_provider.py b/samcli/commands/local/lib/api_provider.py deleted file mode 100644 index afc686e166..0000000000 --- a/samcli/commands/local/lib/api_provider.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Class that provides Apis from a SAM Template""" - -import logging - -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider -from samcli.commands.local.lib.api_collector import ApiCollector -from samcli.commands.local.lib.provider import AbstractApiProvider -from samcli.commands.local.lib.sam_base_provider import SamBaseProvider -from samcli.commands.local.lib.sam_api_provider import SamApiProvider -from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider - -LOG = logging.getLogger(__name__) - - -class ApiProvider(AbstractApiProvider): - - def __init__(self, template_dict, parameter_overrides=None, cwd=None): - """ - Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed - to be valid, normalized and a dictionary. template_dict should be normalized by running any and all - pre-processing before passing to this class. - This class does not perform any syntactic validation of the template. - - After the class is initialized, changes to ``template_dict`` will not be reflected in here. - You will need to explicitly update the class with new template, if necessary. - - Parameters - ---------- - template_dict : dict - SAM Template as a dictionary - - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - """ - self.template_dict = SamBaseProvider.get_template(template_dict, parameter_overrides) - self.resources = self.template_dict.get("Resources", {}) - - LOG.debug("%d resources found in the template", len(self.resources)) - - # Store a set of apis - self.cwd = cwd - self.apis = self._extract_apis(self.resources) - - LOG.debug("%d APIs found in the template", len(self.apis)) - - def get_all(self): - """ - Yields all the Lambda functions with Api Events available in the SAM Template. - - :yields Api: namedtuple containing the Api information - """ - - for api in self.apis: - yield api - - def _extract_apis(self, resources): - """ - Extracts all the Apis by running through the one providers. The provider that has the first type matched - will be run across all the resources - - Parameters - ---------- - resources: dict - The dictionary containing the different resources within the template - Returns - --------- - list of Apis extracted from the resources - """ - collector = ApiCollector() - provider = self.find_api_provider(resources) - apis = provider.extract_resource_api(resources, collector, cwd=self.cwd) - return self.normalize_apis(apis) - - @staticmethod - def find_api_provider(resources): - """ - Finds the ApiProvider given the first api type of the resource - - Parameters - ----------- - resources: dict - The dictionary containing the different resources within the template - - Return - ---------- - Instance of the ApiProvider that will be run on the template with a default of SamApiProvider - """ - for _, resource in resources.items(): - if resource.get(CfnBaseApiProvider.RESOURCE_TYPE) in SamApiProvider.TYPES: - return SamApiProvider() - elif resource.get(CfnBaseApiProvider.RESOURCE_TYPE) in CfnApiProvider.TYPES: - return CfnApiProvider() - - return SamApiProvider() diff --git a/samcli/commands/local/lib/cfn_api_provider.py b/samcli/commands/local/lib/cfn_api_provider.py deleted file mode 100644 index 0e3919611c..0000000000 --- a/samcli/commands/local/lib/cfn_api_provider.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Parses SAM given a template""" -import logging - -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider - -LOG = logging.getLogger(__name__) - - -class CfnApiProvider(CfnBaseApiProvider): - APIGATEWAY_RESTAPI = "AWS::ApiGateway::RestApi" - TYPES = [ - APIGATEWAY_RESTAPI - ] - - def extract_resource_api(self, resources, collector, cwd=None): - """ - Extract the Api Object from a given resource and adds it to the ApiCollector. - - Parameters - ---------- - resources: dict - The dictionary containing the different resources within the template - - collector: ApiCollector - Instance of the API collector that where we will save the API information - - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - - Return - ------- - Returns a list of Apis - """ - for logical_id, resource in resources.items(): - resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) - if resource_type == CfnApiProvider.APIGATEWAY_RESTAPI: - self._extract_cloud_formation_api(logical_id, resource, collector, cwd) - all_apis = [] - for _, apis in collector: - all_apis.extend(apis) - return all_apis - - def _extract_cloud_formation_api(self, logical_id, api_resource, collector, cwd=None): - """ - Extract APIs from AWS::ApiGateway::RestApi resource by reading and parsing Swagger documents. The result is - added to the collector. - - Parameters - ---------- - logical_id : str - Logical ID of the resource - - api_resource : dict - Resource definition, including its properties - - collector : ApiCollector - Instance of the API collector that where we will save the API information - """ - properties = api_resource.get("Properties", {}) - body = properties.get("Body") - body_s3_location = properties.get("BodyS3Location") - binary_media = properties.get("BinaryMediaTypes", []) - - if not body and not body_s3_location: - # Swagger is not found anywhere. - LOG.debug("Skipping resource '%s'. Swagger document not found in Body and BodyS3Location", - logical_id) - return - self.extract_swagger_api(logical_id, body, body_s3_location, binary_media, collector, cwd) diff --git a/samcli/commands/local/lib/cfn_base_api_provider.py b/samcli/commands/local/lib/cfn_base_api_provider.py deleted file mode 100644 index 79bc6d8f1d..0000000000 --- a/samcli/commands/local/lib/cfn_base_api_provider.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Class that parses the CloudFormation Api Template""" - -import logging - -from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.commands.local.lib.swagger.reader import SamSwaggerReader - -LOG = logging.getLogger(__name__) - - -class CfnBaseApiProvider(object): - RESOURCE_TYPE = "Type" - - def extract_resource_api(self, resources, collector, cwd=None): - """ - Extract the Api Object from a given resource and adds it to the ApiCollector. - - Parameters - ---------- - resources: dict - The dictionary containing the different resources within the template - - collector: ApiCollector - Instance of the API collector that where we will save the API information - - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - - Return - ------- - Returns a list of Apis - """ - raise NotImplementedError("not implemented") - - @staticmethod - def extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd=None): - """ - Parse the Swagger documents and adds it to the ApiCollector. - - Parameters - ---------- - logical_id : str - Logical ID of the resource - - body : dict - The body of the RestApi - - uri : str or dict - The url to location of the RestApi - - binary_media: list - The link to the binary media - - collector: ApiCollector - Instance of the API collector that where we will save the API information - - cwd : str - Optional working directory with respect to which we will resolve relative path to Swagger file - """ - reader = SamSwaggerReader(definition_body=body, - definition_uri=uri, - working_dir=cwd) - swagger = reader.read() - parser = SwaggerParser(swagger) - apis = parser.get_apis() - LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) - - collector.add_apis(logical_id, apis) - collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger - collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template diff --git a/samcli/commands/local/lib/local_api_service.py b/samcli/commands/local/lib/local_api_service.py index d456e67a83..d0ebbbd975 100644 --- a/samcli/commands/local/lib/local_api_service.py +++ b/samcli/commands/local/lib/local_api_service.py @@ -6,7 +6,7 @@ import logging from samcli.local.apigw.local_apigw_service import LocalApigwService, Route -from samcli.commands.local.lib.api_provider import ApiProvider +from samcli.commands.local.lib.sam_api_provider import SamApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined LOG = logging.getLogger(__name__) @@ -38,9 +38,9 @@ def __init__(self, self.static_dir = static_dir self.cwd = lambda_invoke_context.get_cwd() - self.api_provider = ApiProvider(lambda_invoke_context.template, - parameter_overrides=lambda_invoke_context.parameter_overrides, - cwd=self.cwd) + self.api_provider = SamApiProvider(lambda_invoke_context.template, + parameter_overrides=lambda_invoke_context.parameter_overrides, + cwd=self.cwd) self.lambda_runner = lambda_invoke_context.local_lambda_runner self.stderr_stream = lambda_invoke_context.stderr @@ -89,7 +89,7 @@ def _make_routing_list(api_provider): Parameters ---------- - api_provider : samcli.commands.local.lib.api_provider.ApiProvider + api_provider : samcli.commands.local.lib.sam_api_provider.SamApiProvider Returns ------- @@ -116,14 +116,10 @@ def _print_routes(api_provider, host, port): Mounting Product at http://127.0.0.1:3000/path1/bar [GET, POST, DELETE] Mounting Product at http://127.0.0.1:3000/path2/bar [HEAD] - :param samcli.commands.local.lib.provider.AbstractApiProvider api_provider: - API Provider that can return a list of APIs - :param string host: - Host name where the service is running - :param int port: - Port number where the service is running - :returns list(string): - List of lines that were printed to the console. Helps with testing + :param samcli.commands.local.lib.provider.ApiProvider api_provider: API Provider that can return a list of APIs + :param string host: Host name where the service is running + :param int port: Port number where the service is running + :returns list(string): List of lines that were printed to the console. Helps with testing """ grouped_api_configs = {} diff --git a/samcli/commands/local/lib/provider.py b/samcli/commands/local/lib/provider.py index 959166e814..eead981089 100644 --- a/samcli/commands/local/lib/provider.py +++ b/samcli/commands/local/lib/provider.py @@ -222,10 +222,10 @@ def get_all(self): # The variables for that stage "stage_variables" ]) -_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None - [], # binary_media_types is optional and defaults to empty, - None, # Stage name is optional with default None - None # Stage variables is optional with default None +_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None + [], # binary_media_types is optional and defaults to empty, + None, # Stage name is optional with default None + None # Stage variables is optional with default None ) @@ -238,17 +238,10 @@ def __hash__(self): Cors = namedtuple("Cors", ["AllowOrigin", "AllowMethods", "AllowHeaders"]) -class AbstractApiProvider(object): +class ApiProvider(object): """ Abstract base class to return APIs and the functions they route to """ - _ANY_HTTP_METHODS = ["GET", - "DELETE", - "PUT", - "POST", - "HEAD", - "OPTIONS", - "PATCH"] def get_all(self): """ @@ -257,43 +250,3 @@ def get_all(self): :yields Api: namedtuple containing the API information """ raise NotImplementedError("not implemented") - - @staticmethod - def normalize_http_methods(http_method): - """ - Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all - supported Http Methods on Api Gateway. - - :param str http_method: Http method - :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) - """ - - if http_method.upper() == 'ANY': - for method in AbstractApiProvider._ANY_HTTP_METHODS: - yield method.upper() - else: - yield http_method.upper() - - @staticmethod - def normalize_apis(apis): - """ - Normalize the APIs to use standard method name - - Parameters - ---------- - apis : list of samcli.commands.local.lib.provider.Api - List of APIs to replace normalize - - Returns - ------- - list of samcli.commands.local.lib.provider.Api - List of normalized APIs - """ - - result = list() - for api in apis: - for normalized_method in AbstractApiProvider.normalize_http_methods(api.method): - # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple - result.append(api._replace(method=normalized_method)) - - return result diff --git a/samcli/commands/local/lib/sam_api_provider.py b/samcli/commands/local/lib/sam_api_provider.py index f0ec57b823..84336a8d2d 100644 --- a/samcli/commands/local/lib/sam_api_provider.py +++ b/samcli/commands/local/lib/sam_api_provider.py @@ -1,60 +1,111 @@ -"""Parses SAM given the template""" +"""Class that provides Apis from a SAM Template""" import logging +from collections import namedtuple -from samcli.commands.local.lib.provider import Api, AbstractApiProvider +from six import string_types + +from samcli.commands.local.lib.swagger.parser import SwaggerParser +from samcli.commands.local.lib.provider import ApiProvider, Api +from samcli.commands.local.lib.sam_base_provider import SamBaseProvider +from samcli.commands.local.lib.swagger.reader import SamSwaggerReader from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException -from samcli.commands.local.lib.cfn_base_api_provider import CfnBaseApiProvider LOG = logging.getLogger(__name__) -class SamApiProvider(CfnBaseApiProvider): - SERVERLESS_FUNCTION = "AWS::Serverless::Function" - SERVERLESS_API = "AWS::Serverless::Api" - TYPES = [ - SERVERLESS_FUNCTION, - SERVERLESS_API - ] +class SamApiProvider(ApiProvider): + _IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" + _SERVERLESS_FUNCTION = "AWS::Serverless::Function" + _SERVERLESS_API = "AWS::Serverless::Api" + _TYPE = "Type" + _FUNCTION_EVENT_TYPE_API = "Api" _FUNCTION_EVENT = "Events" _EVENT_PATH = "Path" _EVENT_METHOD = "Method" - _EVENT_TYPE = "Type" - IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" - def extract_resource_api(self, resources, collector, cwd=None): + _ANY_HTTP_METHODS = ["GET", + "DELETE", + "PUT", + "POST", + "HEAD", + "OPTIONS", + "PATCH"] + + def __init__(self, template_dict, parameter_overrides=None, cwd=None): """ - Extract the Api Object from a given resource and adds it to the ApiCollector. + Initialize the class with SAM template data. The template_dict (SAM Templated) is assumed + to be valid, normalized and a dictionary. template_dict should be normalized by running any and all + pre-processing before passing to this class. + This class does not perform any syntactic validation of the template. + + After the class is initialized, changes to ``template_dict`` will not be reflected in here. + You will need to explicitly update the class with new template, if necessary. Parameters ---------- - resources: dict - The dictionary containing the different resources within the template - - collector: ApiCollector - Instance of the API collector that where we will save the API information - + template_dict : dict + SAM Template as a dictionary cwd : str Optional working directory with respect to which we will resolve relative path to Swagger file + """ - Return - ------- - Returns a list of Apis + self.template_dict = SamBaseProvider.get_template(template_dict, parameter_overrides) + self.resources = self.template_dict.get("Resources", {}) + + LOG.debug("%d resources found in the template", len(self.resources)) + + # Store a set of apis + self.cwd = cwd + self.apis = self._extract_apis(self.resources) + + LOG.debug("%d APIs found in the template", len(self.apis)) + + def get_all(self): """ - # AWS::Serverless::Function is currently included when parsing of Apis because when SamBaseProvider is run on - # the template we are creating the implicit apis due to plugins that translate it in the SAM repo, - # which we later merge with the explicit ones in SamApiProvider.merge_apis. This requires the code to be - # parsed here and in InvokeContext. + Yields all the Lambda functions with Api Events available in the SAM Template. + + :yields Api: namedtuple containing the Api information + """ + + for api in self.apis: + yield api + + def _extract_apis(self, resources): + """ + Extract all Implicit Apis (Apis defined through Serverless Function with an Api Event + + :param dict resources: Dictionary of SAM/CloudFormation resources + :return: List of nametuple Api + """ + + # Some properties like BinaryMediaTypes, Cors are set once on the resource but need to be applied to each API. + # For Implicit APIs, which are defined on the Function resource, these properties + # are defined on a AWS::Serverless::Api resource with logical ID "ServerlessRestApi". Therefore, no matter + # if it is an implicit API or an explicit API, there is a corresponding resource of type AWS::Serverless::Api + # that contains these additional configurations. + # + # We use this assumption in the following loop to collect information from resources of type + # AWS::Serverless::Api. We also extract API from Serverless::Function resource and add them to the + # corresponding Serverless::Api resource. This is all done using the ``collector``. + + collector = ApiCollector() + for logical_id, resource in resources.items(): - resource_type = resource.get(CfnBaseApiProvider.RESOURCE_TYPE) - if resource_type == SamApiProvider.SERVERLESS_FUNCTION: + + resource_type = resource.get(SamApiProvider._TYPE) + + if resource_type == SamApiProvider._SERVERLESS_FUNCTION: self._extract_apis_from_function(logical_id, resource, collector) - if resource_type == SamApiProvider.SERVERLESS_API: - self._extract_from_serverless_api(logical_id, resource, collector, cwd) - return self.merge_apis(collector) - def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd=None): + if resource_type == SamApiProvider._SERVERLESS_API: + self._extract_from_serverless_api(logical_id, resource, collector) + + apis = SamApiProvider._merge_apis(collector) + return self._normalize_apis(apis) + + def _extract_from_serverless_api(self, logical_id, api_resource, collector): """ Extract APIs from AWS::Serverless::Api resource by reading and parsing Swagger documents. The result is added to the collector. @@ -83,11 +134,99 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector, cwd= LOG.debug("Skipping resource '%s'. Swagger document not found in DefinitionBody and DefinitionUri", logical_id) return - self.extract_swagger_api(logical_id, body, uri, binary_media, collector, cwd) + + reader = SamSwaggerReader(definition_body=body, + definition_uri=uri, + working_dir=self.cwd) + swagger = reader.read() + parser = SwaggerParser(swagger) + apis = parser.get_apis() + LOG.debug("Found '%s' APIs in resource '%s'", len(apis), logical_id) + + collector.add_apis(logical_id, apis) + collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger + collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template + collector.add_stage_name(logical_id, stage_name) collector.add_stage_variables(logical_id, stage_variables) - def _extract_apis_from_function(self, logical_id, function_resource, collector): + @staticmethod + def _merge_apis(collector): + """ + Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API + definition wins because that conveys clear intent that the API is backed by a function. This method will + merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined + in both the places, only one wins. + + Parameters + ---------- + collector : ApiCollector + Collector object that holds all the APIs specified in the template + + Returns + ------- + list of samcli.commands.local.lib.provider.Api + List of APIs obtained by combining both the input lists. + """ + + implicit_apis = [] + explicit_apis = [] + + # Store implicit and explicit APIs separately in order to merge them later in the correct order + # Implicit APIs are defined on a resource with logicalID ServerlessRestApi + for logical_id, apis in collector: + if logical_id == SamApiProvider._IMPLICIT_API_RESOURCE_ID: + implicit_apis.extend(apis) + else: + explicit_apis.extend(apis) + + # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. + # If an path+method combo already exists, then overwrite it if and only if this is an implicit API + all_apis = {} + + # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already + # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. + all_configs = explicit_apis + implicit_apis + + for config in all_configs: + # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP + # method on explicit API. + for normalized_method in SamApiProvider._normalize_http_methods(config.method): + key = config.path + normalized_method + all_apis[key] = config + + result = set(all_apis.values()) # Assign to a set() to de-dupe + LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", + len(explicit_apis), len(implicit_apis), len(result)) + + return list(result) + + @staticmethod + def _normalize_apis(apis): + """ + Normalize the APIs to use standard method name + + Parameters + ---------- + apis : list of samcli.commands.local.lib.provider.Api + List of APIs to replace normalize + + Returns + ------- + list of samcli.commands.local.lib.provider.Api + List of normalized APIs + """ + + result = list() + for api in apis: + for normalized_method in SamApiProvider._normalize_http_methods(api.method): + # _replace returns a copy of the namedtuple. This is the official way of creating copies of namedtuple + result.append(api._replace(method=normalized_method)) + + return result + + @staticmethod + def _extract_apis_from_function(logical_id, function_resource, collector): """ Fetches a list of APIs configured for this SAM Function resource. @@ -104,10 +243,11 @@ def _extract_apis_from_function(self, logical_id, function_resource, collector): """ resource_properties = function_resource.get("Properties", {}) - serverless_function_events = resource_properties.get(self._FUNCTION_EVENT, {}) - self.extract_apis_from_events(logical_id, serverless_function_events, collector) + serverless_function_events = resource_properties.get(SamApiProvider._FUNCTION_EVENT, {}) + SamApiProvider._extract_apis_from_events(logical_id, serverless_function_events, collector) - def extract_apis_from_events(self, function_logical_id, serverless_function_events, collector): + @staticmethod + def _extract_apis_from_events(function_logical_id, serverless_function_events, collector): """ Given an AWS::Serverless::Function Event Dictionary, extract out all 'Api' events and store within the collector @@ -126,8 +266,8 @@ def extract_apis_from_events(self, function_logical_id, serverless_function_even count = 0 for _, event in serverless_function_events.items(): - if self._FUNCTION_EVENT_TYPE_API == event.get(self._EVENT_TYPE): - api_resource_id, api = self._convert_event_api(function_logical_id, event.get("Properties")) + if SamApiProvider._FUNCTION_EVENT_TYPE_API == event.get(SamApiProvider._TYPE): + api_resource_id, api = SamApiProvider._convert_event_api(function_logical_id, event.get("Properties")) collector.add_apis(api_resource_id, [api]) count += 1 @@ -148,7 +288,7 @@ def _convert_event_api(lambda_logical_id, event_properties): # An API Event, can have RestApiId property which designates the resource that owns this API. If omitted, # the API is owned by Implicit API resource. This could either be a direct resource logical ID or a # "Ref" of the logicalID - api_resource_id = event_properties.get("RestApiId", SamApiProvider.IMPLICIT_API_RESOURCE_ID) + api_resource_id = event_properties.get("RestApiId", SamApiProvider._IMPLICIT_API_RESOURCE_ID) if isinstance(api_resource_id, dict) and "Ref" in api_resource_id: api_resource_id = api_resource_id["Ref"] @@ -162,52 +302,226 @@ def _convert_event_api(lambda_logical_id, event_properties): return api_resource_id, Api(path=path, method=method, function_name=lambda_logical_id) @staticmethod - def merge_apis(collector): + def _normalize_http_methods(http_method): """ - Quite often, an API is defined both in Implicit and Explicit API definitions. In such cases, Implicit API - definition wins because that conveys clear intent that the API is backed by a function. This method will - merge two such list of Apis with the right order of precedence. If a Path+Method combination is defined - in both the places, only one wins. + Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all + supported Http Methods on Api Gateway. + + :param str http_method: Http method + :yield str: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) + """ + + if http_method.upper() == 'ANY': + for method in SamApiProvider._ANY_HTTP_METHODS: + yield method.upper() + else: + yield http_method.upper() + + +class ApiCollector(object): + """ + Class to store the API configurations in the SAM Template. This class helps store both implicit and explicit + APIs in a standardized format + """ + + # Properties of each API. The structure is quite similar to the properties of AWS::Serverless::Api resource. + # This is intentional because it allows us to easily extend this class to support future properties on the API. + # We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit + # APIs. + Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"]) + + def __init__(self): + # API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and + # value is the properties + self.by_resource = {} + + def __iter__(self): + """ + Iterator to iterate through all the APIs stored in the collector. In each iteration, this yields the + LogicalId of the API resource and a list of APIs available in this resource. + + Yields + ------- + str + LogicalID of the AWS::Serverless::Api resource + list samcli.commands.local.lib.provider.Api + List of the API available in this resource along with additional configuration like binary media types. + """ + + for logical_id, _ in self.by_resource.items(): + yield logical_id, self._get_apis_with_config(logical_id) + + def add_apis(self, logical_id, apis): + """ + Stores the given APIs tagged under the given logicalId Parameters ---------- - collector : ApiCollector - Collector object that holds all the APIs specified in the template + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + apis : list of samcli.commands.local.lib.provider.Api + List of APIs available in this resource + """ + properties = self._get_properties(logical_id) + properties.apis.extend(apis) + + def add_binary_media_types(self, logical_id, binary_media_types): + """ + Stores the binary media type configuration for the API with given logical ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + binary_media_types : list of str + List of binary media types supported by this resource + + """ + properties = self._get_properties(logical_id) + + binary_media_types = binary_media_types or [] + for value in binary_media_types: + normalized_value = self._normalize_binary_media_type(value) + + # If the value is not supported, then just skip it. + if normalized_value: + properties.binary_media_types.add(normalized_value) + else: + LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id) + + def add_stage_name(self, logical_id, stage_name): + """ + Stores the stage name for the API with the given local ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + stage_name : str + The stage_name string + + """ + properties = self._get_properties(logical_id) + properties = properties._replace(stage_name=stage_name) + self._set_properties(logical_id, properties) + + def add_stage_variables(self, logical_id, stage_variables): + """ + Stores the stage variables for the API with the given local ID + + Parameters + ---------- + logical_id : str + LogicalId of the AWS::Serverless::Api resource + + stage_variables : dict + A dictionary containing stage variables. + + """ + properties = self._get_properties(logical_id) + properties = properties._replace(stage_variables=stage_variables) + self._set_properties(logical_id, properties) + + def _get_apis_with_config(self, logical_id): + """ + Returns the list of APIs in this resource along with other extra configuration such as binary media types, + cors etc. Additional configuration is merged directly into the API data because these properties, although + defined globally, actually apply to each API. + + Parameters + ---------- + logical_id : str + Logical ID of the resource to fetch data for Returns ------- list of samcli.commands.local.lib.provider.Api - List of APIs obtained by combining both the input lists. + List of APIs with additional configurations for the resource with given logicalId. If there are no APIs, + then it returns an empty list """ - implicit_apis = [] - explicit_apis = [] + properties = self._get_properties(logical_id) - # Store implicit and explicit APIs separately in order to merge them later in the correct order - # Implicit APIs are defined on a resource with logicalID ServerlessRestApi - for logical_id, apis in collector: - if logical_id == SamApiProvider.IMPLICIT_API_RESOURCE_ID: - implicit_apis.extend(apis) - else: - explicit_apis.extend(apis) + # These configs need to be applied to each API + binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable + cors = properties.cors + stage_name = properties.stage_name + stage_variables = properties.stage_variables - # We will use "path+method" combination as key to this dictionary and store the Api config for this combination. - # If an path+method combo already exists, then overwrite it if and only if this is an implicit API - all_apis = {} + result = [] + for api in properties.apis: + # Create a copy of the API with updated configuration + updated_api = api._replace(binary_media_types=binary_media, + cors=cors, + stage_name=stage_name, + stage_variables=stage_variables) + result.append(updated_api) - # By adding implicit APIs to the end of the list, they will be iterated last. If a configuration was already - # written by explicit API, it will be overriden by implicit API, just by virtue of order of iteration. - all_configs = explicit_apis + implicit_apis + return result - for config in all_configs: - # Normalize the methods before de-duping to allow an ANY method in implicit API to override a regular HTTP - # method on explicit API. - for normalized_method in AbstractApiProvider.normalize_http_methods(config.method): - key = config.path + normalized_method - all_apis[key] = config + def _get_properties(self, logical_id): + """ + Returns the properties of resource with given logical ID. If a resource is not found, then it returns an + empty data. - result = set(all_apis.values()) # Assign to a set() to de-dupe - LOG.debug("Removed duplicates from '%d' Explicit APIs and '%d' Implicit APIs to produce '%d' APIs", - len(explicit_apis), len(implicit_apis), len(result)) + Parameters + ---------- + logical_id : str + Logical ID of the resource - return list(result) + Returns + ------- + samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties + Properties object for this resource. + """ + + if logical_id not in self.by_resource: + self.by_resource[logical_id] = self.Properties(apis=[], + # Use a set() to be able to easily de-dupe + binary_media_types=set(), + cors=None, + stage_name=None, + stage_variables=None) + + return self.by_resource[logical_id] + + def _set_properties(self, logical_id, properties): + """ + Sets the properties of resource with given logical ID. If a resource is not found, it does nothing + + Parameters + ---------- + logical_id : str + Logical ID of the resource + properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties + Properties object for this resource. + """ + + if logical_id in self.by_resource: + self.by_resource[logical_id] = properties + + @staticmethod + def _normalize_binary_media_type(value): + """ + Converts binary media types values to the canonical format. Ex: image~1gif -> image/gif. If the value is not + a string, then this method just returns None + + Parameters + ---------- + value : str + Value to be normalized + + Returns + ------- + str or None + Normalized value. If the input was not a string, then None is returned + """ + + if not isinstance(value, string_types): + # It is possible that user specified a dict value for one of the binary media types. We just skip them + return None + + return value.replace("~1", "/") diff --git a/samcli/commands/local/lib/sam_base_provider.py b/samcli/commands/local/lib/sam_base_provider.py index 861e1fd47a..bbf4d6381b 100644 --- a/samcli/commands/local/lib/sam_base_provider.py +++ b/samcli/commands/local/lib/sam_base_provider.py @@ -10,6 +10,7 @@ from samcli.lib.samlib.wrapper import SamTranslatorWrapper from samcli.lib.samlib.resource_metadata_normalizer import ResourceMetadataNormalizer + LOG = logging.getLogger(__name__) @@ -88,7 +89,7 @@ def _resolve_parameters(template_dict, parameter_overrides): supported_intrinsics = {action.intrinsic_name: action() for action in SamBaseProvider._SUPPORTED_INTRINSICS} # Intrinsics resolver will mutate the original template - return IntrinsicsResolver(parameters=parameter_values, supported_intrinsics=supported_intrinsics) \ + return IntrinsicsResolver(parameters=parameter_values, supported_intrinsics=supported_intrinsics)\ .resolve_parameter_refs(template_dict) @staticmethod diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 321741e0bf..700491260d 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -288,112 +288,6 @@ def test_binary_response(self): self.assertEquals(response.content, expected) -class TestStartApiWithSwaggerRestApis(StartApiIntegBaseClass): - template_path = "/testdata/start_api/swagger-rest-api-template.yaml" - binary_data_file = "testdata/start_api/binarydata.gif" - - def setUp(self): - self.url = "http://127.0.0.1:{}".format(self.port) - - def test_get_call_with_path_setup_with_any_swagger(self): - """ - Get Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.get(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_post_call_with_path_setup_with_any_swagger(self): - """ - Post Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.post(self.url + "/anyandall", json={}) - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_put_call_with_path_setup_with_any_swagger(self): - """ - Put Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.put(self.url + "/anyandall", json={}) - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_head_call_with_path_setup_with_any_swagger(self): - """ - Head Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.head(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - - def test_delete_call_with_path_setup_with_any_swagger(self): - """ - Delete Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.delete(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_options_call_with_path_setup_with_any_swagger(self): - """ - Options Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.options(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - - def test_patch_call_with_path_setup_with_any_swagger(self): - """ - Patch Request to a path that was defined as ANY in SAM through Swagger - """ - response = requests.patch(self.url + "/anyandall") - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_function_not_defined_in_template(self): - response = requests.get(self.url + "/nofunctionfound") - - self.assertEquals(response.status_code, 502) - self.assertEquals(response.json(), {"message": "No function defined for resource method"}) - - def test_lambda_function_resource_is_reachable(self): - response = requests.get(self.url + "/nonserverlessfunction") - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.json(), {'hello': 'world'}) - - def test_binary_request(self): - """ - This tests that the service can accept and invoke a lambda when given binary data in a request - """ - input_data = self.get_binary_data(self.binary_data_file) - response = requests.post(self.url + '/echobase64eventbody', - headers={"Content-Type": "image/gif"}, - data=input_data) - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.headers.get("Content-Type"), "image/gif") - self.assertEquals(response.content, input_data) - - def test_binary_response(self): - """ - Binary data is returned correctly - """ - expected = self.get_binary_data(self.binary_data_file) - - response = requests.get(self.url + '/base64response') - - self.assertEquals(response.status_code, 200) - self.assertEquals(response.headers.get("Content-Type"), "image/gif") - self.assertEquals(response.content, expected) - - class TestServiceResponses(StartApiIntegBaseClass): """ Test Class centered around the different responses that can happen in Lambda and pass through start-api diff --git a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml b/tests/integration/testdata/start_api/swagger-rest-api-template.yaml deleted file mode 100644 index 5edeb8717f..0000000000 --- a/tests/integration/testdata/start_api/swagger-rest-api-template.yaml +++ /dev/null @@ -1,69 +0,0 @@ -AWSTemplateFormatVersion: '2010-09-09' - -Resources: - Base64ResponseFunction: - Properties: - Code: "." - Handler: main.base64_response - Runtime: python3.6 - Type: AWS::Lambda::Function - EchoBase64EventBodyFunction: - Properties: - Code: "." - Handler: main.echo_base64_event_body - Runtime: python3.6 - Type: AWS::Lambda::Function - MyApi: - Properties: - Body: - info: - title: - Ref: AWS::StackName - paths: - "/anyandall": - x-amazon-apigateway-any-method: - x-amazon-apigateway-integration: - httpMethod: POST - responses: {} - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyNonServerlessLambdaFunction.Arn}/invocations - "/base64response": - get: - x-amazon-apigateway-integration: - httpMethod: POST - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${Base64ResponseFunction.Arn}/invocations - "/echobase64eventbody": - post: - x-amazon-apigateway-integration: - httpMethod: POST - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EchoBase64EventBodyFunction.Arn}/invocations - "/nofunctionfound": - get: - x-amazon-apigateway-integration: - httpMethod: POST - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${WhatFunction.Arn}/invocations - "/nonserverlessfunction": - get: - x-amazon-apigateway-integration: - httpMethod: POST - type: aws_proxy - uri: - Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyNonServerlessLambdaFunction.Arn}/invocations - swagger: '2.0' - x-amazon-apigateway-binary-media-types: - - image/gif - StageName: prod - Type: AWS::ApiGateway::RestApi - MyNonServerlessLambdaFunction: - Properties: - Code: "." - Handler: main.handler - Runtime: python3.6 - Type: AWS::Lambda::Function diff --git a/tests/unit/commands/local/lib/test_api_provider.py b/tests/unit/commands/local/lib/test_api_provider.py deleted file mode 100644 index 50b8d073d4..0000000000 --- a/tests/unit/commands/local/lib/test_api_provider.py +++ /dev/null @@ -1,207 +0,0 @@ -from collections import OrderedDict -from unittest import TestCase - -from mock import patch - -from samcli.commands.local.lib.api_provider import ApiProvider -from samcli.commands.local.lib.sam_api_provider import SamApiProvider -from samcli.commands.local.lib.cfn_api_provider import CfnApiProvider - - -class TestApiProvider_init(TestCase): - - @patch.object(ApiProvider, "_extract_apis") - @patch("samcli.commands.local.lib.api_provider.SamBaseProvider") - def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): - extract_api_mock.return_value = {"set", "of", "values"} - - template = {"Resources": {"a": "b"}} - SamBaseProviderMock.get_template.return_value = template - - provider = ApiProvider(template) - - self.assertEquals(len(provider.apis), 3) - self.assertEquals(provider.apis, set(["set", "of", "values"])) - self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) - self.assertEquals(provider.resources, {"a": "b"}) - - -class TestApiProviderSelection(TestCase): - def test_default_provider(self): - resources = { - "TestApi": { - "Type": "AWS::UNKNOWN_TYPE", - "Properties": { - "StageName": "dev", - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - self.assertTrue(isinstance(provider, SamApiProvider)) - - def test_api_provider_sam_api(self): - resources = { - "TestApi": { - "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "dev", - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - self.assertTrue(isinstance(provider, SamApiProvider)) - - def test_api_provider_sam_function(self): - resources = { - "TestApi": { - "Type": "AWS::Serverless::Function", - "Properties": { - "StageName": "dev", - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - - self.assertTrue(isinstance(provider, SamApiProvider)) - - def test_api_provider_cloud_formation(self): - resources = { - "TestApi": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "StageName": "dev", - "Body": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - self.assertTrue(isinstance(provider, CfnApiProvider)) - - def test_multiple_api_provider_cloud_formation(self): - resources = OrderedDict() - resources["TestApi"] = { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "StageName": "dev", - "Body": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - resources["OtherApi"] = { - "Type": "AWS::Serverless::Api", - "Properties": { - "StageName": "dev", - "DefinitionBody": { - "paths": { - "/path": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31" - "/functions/${NoApiEventFunction.Arn}/invocations", - }, - "responses": {}, - }, - } - } - - } - } - } - } - - provider = ApiProvider.find_api_provider(resources) - self.assertTrue(isinstance(provider, CfnApiProvider)) diff --git a/tests/unit/commands/local/lib/test_cfn_api_provider.py b/tests/unit/commands/local/lib/test_cfn_api_provider.py deleted file mode 100644 index 723951eb11..0000000000 --- a/tests/unit/commands/local/lib/test_cfn_api_provider.py +++ /dev/null @@ -1,215 +0,0 @@ -import json -import tempfile -from unittest import TestCase - -from mock import patch -from six import assertCountEqual - -from samcli.commands.local.lib.api_provider import ApiProvider -from samcli.commands.local.lib.provider import Api -from tests.unit.commands.local.lib.test_sam_api_provider import make_swagger - - -class TestApiProviderWithApiGatewayRestApi(TestCase): - - def setUp(self): - self.binary_types = ["image/png", "image/jpg"] - self.input_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None) - ] - - def test_with_no_apis(self): - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - }, - - } - } - } - - provider = ApiProvider(template) - - self.assertEquals(provider.apis, []) - - def test_with_inline_swagger_apis(self): - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "Body": make_swagger(self.input_apis) - } - } - } - } - - provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) - - def test_with_swagger_as_local_file(self): - with tempfile.NamedTemporaryFile(mode='w') as fp: - filename = fp.name - - swagger = make_swagger(self.input_apis) - json.dump(swagger, fp) - fp.flush() - - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "BodyS3Location": filename - } - } - } - } - - provider = ApiProvider(template) - assertCountEqual(self, self.input_apis, provider.apis) - - def test_body_with_swagger_as_local_file_expect_fail(self): - with tempfile.NamedTemporaryFile(mode='w') as fp: - filename = fp.name - - swagger = make_swagger(self.input_apis) - json.dump(swagger, fp) - fp.flush() - - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "Body": filename - } - } - } - } - self.assertRaises(Exception, ApiProvider, template) - - @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): - body = {"some": "body"} - filename = "somefile.txt" - - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "BodyS3Location": filename, - "Body": body - } - } - } - } - - SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) - - cwd = "foo" - provider = ApiProvider(template, cwd=cwd) - assertCountEqual(self, self.input_apis, provider.apis) - SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) - - def test_swagger_with_any_method(self): - apis = [ - Api(path="/path", method="any", function_name="SamFunc1", cors=None) - ] - - expected_apis = [ - Api(path="/path", method="GET", function_name="SamFunc1", cors=None), - Api(path="/path", method="POST", function_name="SamFunc1", cors=None), - Api(path="/path", method="PUT", function_name="SamFunc1", cors=None), - Api(path="/path", method="DELETE", function_name="SamFunc1", cors=None), - Api(path="/path", method="HEAD", function_name="SamFunc1", cors=None), - Api(path="/path", method="OPTIONS", function_name="SamFunc1", cors=None), - Api(path="/path", method="PATCH", function_name="SamFunc1", cors=None) - ] - - template = { - "Resources": { - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "Body": make_swagger(apis) - } - } - } - } - - provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) - - def test_with_binary_media_types(self): - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "Body": make_swagger(self.input_apis, binary_media_types=self.binary_types) - } - } - } - } - - expected_binary_types = sorted(self.binary_types) - expected_apis = [ - Api(path="/path1", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - Api(path="/path1", method="POST", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - - Api(path="/path2", method="PUT", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - Api(path="/path2", method="GET", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types), - - Api(path="/path3", method="DELETE", function_name="SamFunc1", cors=None, - binary_media_types=expected_binary_types) - ] - - provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) - - def test_with_binary_media_types_in_swagger_and_on_resource(self): - input_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1"), - ] - extra_binary_types = ["text/html"] - - template = { - "Resources": { - - "Api1": { - "Type": "AWS::ApiGateway::RestApi", - "Properties": { - "BinaryMediaTypes": extra_binary_types, - "Body": make_swagger(input_apis, binary_media_types=self.binary_types) - } - } - } - } - - expected_binary_types = sorted(self.binary_types + extra_binary_types) - expected_apis = [ - Api(path="/path", method="OPTIONS", function_name="SamFunc1", binary_media_types=expected_binary_types), - ] - - provider = ApiProvider(template) - assertCountEqual(self, expected_apis, provider.apis) diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index cfa35af954..3cc5d2c4c3 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -35,7 +35,7 @@ def setUp(self): self.lambda_invoke_context_mock.stderr = self.stderr_mock @patch("samcli.commands.local.lib.local_api_service.LocalApigwService") - @patch("samcli.commands.local.lib.local_api_service.ApiProvider") + @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") @patch.object(LocalApiService, "_make_routing_list") @@ -77,7 +77,7 @@ def test_must_start_service(self, self.apigw_service.run.assert_called_with() @patch("samcli.commands.local.lib.local_api_service.LocalApigwService") - @patch("samcli.commands.local.lib.local_api_service.ApiProvider") + @patch("samcli.commands.local.lib.local_api_service.SamApiProvider") @patch.object(LocalApiService, "_make_static_dir_path") @patch.object(LocalApiService, "_print_routes") @patch.object(LocalApiService, "_make_routing_list") diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index fa5f342e49..3ac01956be 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -7,11 +7,29 @@ from six import assertCountEqual -from samcli.commands.local.lib.api_provider import ApiProvider, SamApiProvider +from samcli.commands.local.lib.sam_api_provider import SamApiProvider from samcli.commands.local.lib.provider import Api from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException +class TestSamApiProvider_init(TestCase): + + @patch.object(SamApiProvider, "_extract_apis") + @patch("samcli.commands.local.lib.sam_api_provider.SamBaseProvider") + def test_provider_with_valid_template(self, SamBaseProviderMock, extract_api_mock): + extract_api_mock.return_value = {"set", "of", "values"} + + template = {"Resources": {"a": "b"}} + SamBaseProviderMock.get_template.return_value = template + + provider = SamApiProvider(template) + + self.assertEquals(len(provider.apis), 3) + self.assertEquals(provider.apis, set(["set", "of", "values"])) + self.assertEquals(provider.template_dict, {"Resources": {"a": "b"}}) + self.assertEquals(provider.resources, {"a": "b"}) + + class TestSamApiProviderWithImplicitApis(TestCase): def test_provider_with_no_resource_properties(self): @@ -24,8 +42,9 @@ def test_provider_with_no_resource_properties(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) + self.assertEquals(len(provider.apis), 0) self.assertEquals(provider.apis, []) @parameterized.expand([("GET"), ("get")]) @@ -53,7 +72,7 @@ def test_provider_has_correct_api(self, method): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) self.assertEquals(len(provider.apis), 1) self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", cors=None, @@ -90,7 +109,7 @@ def test_provider_creates_api_for_all_events(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api_event1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api_event2 = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") @@ -140,7 +159,7 @@ def test_provider_has_correct_template(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api1 = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api2 = Api(path="/path", method="POST", function_name="SamFunc2", cors=None, stage_name="Prod") @@ -171,7 +190,7 @@ def test_provider_with_no_api_events(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) self.assertEquals(provider.apis, []) @@ -190,7 +209,7 @@ def test_provider_with_no_serverless_function(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) self.assertEquals(provider.apis, []) @@ -235,7 +254,7 @@ def test_provider_get_all(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) result = [f for f in provider.get_all()] @@ -248,7 +267,7 @@ def test_provider_get_all(self): def test_provider_get_all_with_no_apis(self): template = {} - provider = ApiProvider(template) + provider = SamApiProvider(template) result = [f for f in provider.get_all()] @@ -279,7 +298,7 @@ def test_provider_with_any_method(self, method): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api_get = Api(path="/path", method="GET", function_name="SamFunc1", cors=None, stage_name="Prod") api_post = Api(path="/path", method="POST", function_name="SamFunc1", cors=None, stage_name="Prod") @@ -332,7 +351,7 @@ def test_provider_must_support_binary_media_types(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) self.assertEquals(len(provider.apis), 1) self.assertEquals(list(provider.apis)[0], Api(path="/path", method="GET", function_name="SamFunc1", @@ -384,7 +403,7 @@ def test_provider_must_support_binary_media_types_with_any_method(self): Api(path="/path", method="PATCH", function_name="SamFunc1", binary_media_types=binary, stage_name="Prod") ] - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, provider.apis, expected_apis) @@ -429,8 +448,9 @@ def test_with_no_apis(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) + self.assertEquals(len(provider.apis), 0) self.assertEquals(provider.apis, []) def test_with_inline_swagger_apis(self): @@ -447,7 +467,7 @@ def test_with_inline_swagger_apis(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, self.input_apis, provider.apis) def test_with_swagger_as_local_file(self): @@ -471,11 +491,11 @@ def test_with_swagger_as_local_file(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, self.input_apis, provider.apis) - @patch("samcli.commands.local.lib.cfn_base_api_provider.SamSwaggerReader") - def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): + @patch("samcli.commands.local.lib.sam_api_provider.SamSwaggerReader") + def test_with_swagger_as_both_body_and_uri(self, SamSwaggerReaderMock): body = {"some": "body"} filename = "somefile.txt" @@ -496,7 +516,7 @@ def test_with_swagger_as_both_body_and_uri_called(self, SamSwaggerReaderMock): SamSwaggerReaderMock.return_value.read.return_value = make_swagger(self.input_apis) cwd = "foo" - provider = ApiProvider(template, cwd=cwd) + provider = SamApiProvider(template, cwd=cwd) assertCountEqual(self, self.input_apis, provider.apis) SamSwaggerReaderMock.assert_called_with(definition_body=body, definition_uri=filename, working_dir=cwd) @@ -527,7 +547,7 @@ def test_swagger_with_any_method(self): } } - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) def test_with_binary_media_types(self): @@ -560,7 +580,7 @@ def test_with_binary_media_types(self): binary_media_types=expected_binary_types, stage_name="Prod") ] - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) def test_with_binary_media_types_in_swagger_and_on_resource(self): @@ -589,7 +609,7 @@ def test_with_binary_media_types_in_swagger_and_on_resource(self): stage_name="Prod"), ] - provider = ApiProvider(template) + provider = SamApiProvider(template) assertCountEqual(self, expected_apis, provider.apis) @@ -666,7 +686,7 @@ def test_must_union_implicit_and_explicit(self): Api(path="/path3", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_prefer_implicit_api_over_explicit(self): @@ -703,7 +723,7 @@ def test_must_prefer_implicit_api_over_explicit(self): Api(path="/path3", method="GET", function_name="explicitfunction", cors=None, stage_name="Prod"), ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_must_prefer_implicit_with_any_method(self): @@ -737,7 +757,7 @@ def test_must_prefer_implicit_with_any_method(self): Api(path="/path", method="PATCH", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_with_any_method_on_both(self): @@ -782,7 +802,8 @@ def test_with_any_method_on_both(self): Api(path="/path2", method="POST", function_name="explicitfunction", cors=None, stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) + print(provider.apis) assertCountEqual(self, expected_apis, provider.apis) def test_must_add_explicit_api_when_ref_with_rest_api_id(self): @@ -819,7 +840,7 @@ def test_must_add_explicit_api_when_ref_with_rest_api_id(self): Api(path="/newpath2", method="POST", function_name="ImplicitFunc", cors=None, stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_both_apis_must_get_binary_media_types(self): @@ -874,7 +895,7 @@ def test_both_apis_must_get_binary_media_types(self): stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) def test_binary_media_types_with_rest_api_id_reference(self): @@ -934,7 +955,7 @@ def test_binary_media_types_with_rest_api_id_reference(self): stage_name="Prod") ] - provider = ApiProvider(self.template) + provider = SamApiProvider(self.template) assertCountEqual(self, expected_apis, provider.apis) @@ -970,7 +991,7 @@ def test_provider_parse_stage_name(self): } } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], stage_name='dev', stage_variables=None) @@ -1012,7 +1033,7 @@ def test_provider_stage_variables(self): } } } - provider = ApiProvider(template) + provider = SamApiProvider(template) api1 = Api(path='/path', method='GET', function_name='NoApiEventFunction', cors=None, binary_media_types=[], stage_name='dev', stage_variables={ @@ -1098,7 +1119,8 @@ def test_multi_stage_get_all(self): } } } - provider = ApiProvider(template) + + provider = SamApiProvider(template) result = [f for f in provider.get_all()] From dbd534a44edd10a4be1bb4d29101c45c1d18f9e7 Mon Sep 17 00:00:00 2001 From: Sanath Kumar Ramesh Date: Wed, 24 Jul 2019 16:39:00 -0700 Subject: [PATCH 3/7] feat: Telemetry Implementation (#1287) --- .gitignore | 1 + DESIGN.md | 12 + Makefile | 6 + samcli/cli/context.py | 60 +++++ samcli/cli/global_config.py | 205 +++++++++++++++ samcli/cli/main.py | 35 ++- samcli/commands/build/command.py | 2 + samcli/commands/deploy/__init__.py | 2 + samcli/commands/init/__init__.py | 3 + samcli/commands/local/cli_common/options.py | 1 + .../local/generate_event/event_generation.py | 2 + samcli/commands/local/invoke/cli.py | 4 +- samcli/commands/local/start_api/cli.py | 3 + samcli/commands/local/start_lambda/cli.py | 2 + samcli/commands/logs/command.py | 2 + samcli/commands/package/__init__.py | 2 + samcli/commands/publish/command.py | 2 + samcli/commands/validate/validate.py | 2 + samcli/lib/telemetry/__init__.py | 0 samcli/lib/telemetry/metrics.py | 132 ++++++++++ samcli/lib/telemetry/telemetry.py | 124 ++++++++++ samcli/settings/__init__.py | 24 ++ tests/conftest.py | 6 + tests/functional/commands/cli/__init__.py | 0 .../commands/cli/test_global_config.py | 152 ++++++++++++ tests/functional/commands/cli/test_main.py | 34 +++ tests/integration/telemetry/__init__.py | 0 tests/integration/telemetry/integ_base.py | 196 +++++++++++++++ .../telemetry/test_installed_metric.py | 117 +++++++++ tests/integration/telemetry/test_prompt.py | 53 ++++ tests/unit/cli/test_context.py | 36 +++ tests/unit/cli/test_global_config.py | 118 +++++++++ tests/unit/cli/test_main.py | 56 ++++- .../generate_event/test_event_generation.py | 9 + tests/unit/lib/telemetry/test_metrics.py | 234 ++++++++++++++++++ tests/unit/lib/telemetry/test_telemetry.py | 153 ++++++++++++ 36 files changed, 1776 insertions(+), 14 deletions(-) create mode 100644 samcli/cli/global_config.py create mode 100644 samcli/lib/telemetry/__init__.py create mode 100644 samcli/lib/telemetry/metrics.py create mode 100644 samcli/lib/telemetry/telemetry.py create mode 100644 samcli/settings/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/functional/commands/cli/__init__.py create mode 100644 tests/functional/commands/cli/test_global_config.py create mode 100644 tests/functional/commands/cli/test_main.py create mode 100644 tests/integration/telemetry/__init__.py create mode 100644 tests/integration/telemetry/integ_base.py create mode 100644 tests/integration/telemetry/test_installed_metric.py create mode 100644 tests/integration/telemetry/test_prompt.py create mode 100644 tests/unit/cli/test_global_config.py create mode 100644 tests/unit/lib/telemetry/test_metrics.py create mode 100644 tests/unit/lib/telemetry/test_telemetry.py diff --git a/.gitignore b/.gitignore index 93b3ed0745..77c7fe858c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ .idea/**/tasks.xml .idea/dictionaries .idea +.vscode # Sensitive or high-churn files: .idea/**/dataSources/ diff --git a/DESIGN.md b/DESIGN.md index 4355c23d91..5ea8c23f6f 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -65,3 +65,15 @@ also forces commands implementations to be modular, reusable, and highly customizable. When RC files are implemented, new commands can be added or existing commands can be removed, with simple a configuration in the RC file. + +Internal Environment Variables +============================== + +SAM CLI uses the following internal, undocumented, environment variables +for development purposes. They should *not* be used by customers: + +- `__SAM_CLI_APP_DIR`: Path to application directory to be used in place + of `~/.aws-sam` directory. + +- `__SAM_CLI_TELEMETRY_ENDPOINT_URL`: HTTP Endpoint where the Telemetry + metrics will be published to diff --git a/Makefile b/Makefile index 80fafaaed7..26ada874f1 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,7 @@ +# Default value for environment variable. Can be overridden by setting the +# environment variable. +SAM_CLI_TELEMETRY ?= 0 + init: SAM_CLI_DEV=1 pip install -e '.[dev]' @@ -8,10 +12,12 @@ test: integ-test: # Integration tests don't need code coverage + @echo Telemetry Status: $(SAM_CLI_TELEMETRY) SAM_CLI_DEV=1 pytest tests/integration func-test: # Verify function test coverage only for `samcli.local` package + @echo Telemetry Status: $(SAM_CLI_TELEMETRY) pytest --cov samcli.local --cov samcli.commands.local --cov-report term-missing tests/functional flake: diff --git a/samcli/cli/context.py b/samcli/cli/context.py index 4c029ae17a..2a801bf774 100644 --- a/samcli/cli/context.py +++ b/samcli/cli/context.py @@ -2,8 +2,10 @@ Context information passed to each CLI command """ +import uuid import logging import boto3 +import click class Context(object): @@ -26,6 +28,7 @@ def __init__(self): self._debug = False self._aws_region = None self._aws_profile = None + self._session_id = str(uuid.uuid4()) @property def debug(self): @@ -68,6 +71,63 @@ def profile(self, value): self._aws_profile = value self._refresh_session() + @property + def session_id(self): + """ + Returns the ID of this command session. This is a randomly generated UUIDv4 which will not change until the + command terminates. + """ + return self._session_id + + @property + def command_path(self): + """ + Returns the full path of the command as invoked ex: "sam local generate-event s3 put". Wrapper to + https://click.palletsprojects.com/en/7.x/api/#click.Context.command_path + + Returns + ------- + str + Full path of the command invoked + """ + + # Uses Click's Core Context. Note, this is different from this class, also confusingly named `Context`. + # Click's Core Context object is the one that contains command path information. + click_core_ctx = click.get_current_context() + if click_core_ctx: + return click_core_ctx.command_path + + @staticmethod + def get_current_context(): + """ + Get the current Context object from Click's context stacks. This method is safe to run within the + actual command's handler that has a ``@pass_context`` annotation. Outside of the handler, you run + the risk of creating a new Context object which is entirely different from the Context object used by your + command. + .. code: + @pass_context + def my_command_handler(ctx): + # You will get the right context from within the command handler. This will also work from any + # downstream method invoked as part of the handler. + this_context = Context.get_current_context() + assert ctx == this_context + Returns + ------- + samcli.cli.context.Context + Instance of this object, if we are running in a Click command. None otherwise. + """ + + # Click has the concept of Context stacks. Think of them as linked list containing custom objects that are + # automatically accessible at different levels. We start from the Core Click context and discover the + # SAM CLI command-specific Context object which contains values for global options used by all commands. + # + # https://click.palletsprojects.com/en/7.x/complex/#ensuring-object-creation + # + + click_core_ctx = click.get_current_context() + if click_core_ctx: + return click_core_ctx.find_object(Context) or click_core_ctx.ensure_object(Context) + def _refresh_session(self): """ Update boto3's default session by creating a new session based on values set in the context. Some properties of diff --git a/samcli/cli/global_config.py b/samcli/cli/global_config.py new file mode 100644 index 0000000000..ef563f3c8a --- /dev/null +++ b/samcli/cli/global_config.py @@ -0,0 +1,205 @@ +""" +Provides global configuration helpers. +""" + +import json +import logging +import uuid +import os + +import click + +try: + from pathlib import Path +except ImportError: # pragma: no cover + from pathlib2 import Path # pragma: no cover + +LOG = logging.getLogger(__name__) + +CONFIG_FILENAME = "metadata.json" +INSTALLATION_ID_KEY = "installationId" +TELEMETRY_ENABLED_KEY = "telemetryEnabled" + + +class GlobalConfig(object): + """ + Contains helper methods for global configuration files and values. Handles + configuration file creation, updates, and fetching in a platform-neutral way. + + Generally uses '~/.aws-sam/' or 'C:\\Users\\\\AppData\\Roaming\\AWS SAM' as + the base directory, depending on platform. + """ + + def __init__(self, config_dir=None, installation_id=None, telemetry_enabled=None): + """ + Initializes the class, with options provided to assist with testing. + + :param config_dir: Optional, overrides the default config directory path. + :param installation_id: Optional, will use this installation id rather than checking config values. + """ + self._config_dir = config_dir + self._installation_id = installation_id + self._telemetry_enabled = telemetry_enabled + + @property + def config_dir(self): + if not self._config_dir: + # Internal Environment variable to customize SAM CLI App Dir. Currently used only by integ tests. + app_dir = os.getenv("__SAM_CLI_APP_DIR") + self._config_dir = Path(app_dir) if app_dir else Path(click.get_app_dir('AWS SAM', force_posix=True)) + + return Path(self._config_dir) + + @property + def installation_id(self): + """ + Returns the installation UUID for this AWS SAM CLI installation. If the + installation id has not yet been set, it will be set before returning. + + Examples + -------- + + >>> gc = GlobalConfig() + >>> gc.installation_id + "7b7d4db7-2f54-45ba-bf2f-a2cbc9e74a34" + + >>> gc = GlobalConfig() + >>> gc.installation_id + None + + Returns + ------- + A string containing the installation UUID, or None in case of an error. + """ + if self._installation_id: + return self._installation_id + try: + self._installation_id = self._get_or_set_uuid(INSTALLATION_ID_KEY) + return self._installation_id + except (ValueError, IOError): + return None + + @property + def telemetry_enabled(self): + """ + Check if telemetry is enabled for this installation. Default value of + False. It first tries to get value from SAM_CLI_TELEMETRY environment variable. If its not set, + then it fetches the value from config file. + + To enable telemetry, set SAM_CLI_TELEMETRY environment variable equal to integer 1 or string '1'. + All other values including words like 'True', 'true', 'false', 'False', 'abcd' etc will disable Telemetry + + Examples + -------- + + >>> gc = GlobalConfig() + >>> gc.telemetry_enabled + True + + Returns + ------- + Boolean flag value. True if telemetry is enabled for this installation, + False otherwise. + """ + if self._telemetry_enabled is not None: + return self._telemetry_enabled + + # If environment variable is set, its value takes precedence over the value from config file. + env_name = "SAM_CLI_TELEMETRY" + if env_name in os.environ: + return os.getenv(env_name) in ('1', 1) + + try: + self._telemetry_enabled = self._get_value(TELEMETRY_ENABLED_KEY) + return self._telemetry_enabled + except (ValueError, IOError) as ex: + LOG.debug("Error when retrieving telemetry_enabled flag", exc_info=ex) + return False + + @telemetry_enabled.setter + def telemetry_enabled(self, value): + """ + Sets the telemetry_enabled flag to the provided boolean value. + + Examples + -------- + >>> gc = GlobalConfig() + >>> gc.telemetry_enabled + False + >>> gc.telemetry_enabled = True + >>> gc.telemetry_enabled + True + + Raises + ------ + IOError + If there are errors opening or writing to the global config file. + + JSONDecodeError + If the config file exists, and is not valid JSON. + """ + self._set_value("telemetryEnabled", value) + self._telemetry_enabled = value + + def _get_value(self, key): + cfg_path = self._get_config_file_path(CONFIG_FILENAME) + if not cfg_path.exists(): + return None + with open(str(cfg_path)) as fp: + body = fp.read() + json_body = json.loads(body) + return json_body.get(key) + + def _set_value(self, key, value): + cfg_path = self._get_config_file_path(CONFIG_FILENAME) + if not cfg_path.exists(): + return self._set_json_cfg(cfg_path, key, value) + with open(str(cfg_path)) as fp: + body = fp.read() + try: + json_body = json.loads(body) + except ValueError as ex: + LOG.debug("Failed to decode JSON in {cfg_path}", exc_info=ex) + raise ex + return self._set_json_cfg(cfg_path, key, value, json_body) + + def _create_dir(self): + self.config_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + + def _get_config_file_path(self, filename): + self._create_dir() + filepath = self.config_dir.joinpath(filename) + return filepath + + def _get_or_set_uuid(self, key): + """ + Special logic method for when we want a UUID to always be present, this + method behaves as a getter with side effects. Essentially, if the value + is not present, we will set it with a generated UUID. + + If we have multiple such values in the future, a possible refactor is + to just be _get_or_set_value, where we also take a default value as a + parameter. + """ + cfg_value = self._get_value(key) + if cfg_value is not None: + return cfg_value + return self._set_value(key, str(uuid.uuid4())) + + def _set_json_cfg(self, filepath, key, value, json_body=None): + """ + Special logic method to add a value to a JSON configuration file. This + method will write a new version of the file in question, so it will + either write a new file with only the first config value, or if a JSON + body is provided, it will upsert starting from that JSON body. + """ + json_body = json_body or {} + json_body[key] = value + file_body = json.dumps(json_body, indent=4) + "\n" + try: + with open(str(filepath), 'w') as f: + f.write(file_body) + except IOError as ex: + LOG.debug("Error writing to {filepath}", exc_info=ex) + raise ex + return value diff --git a/samcli/cli/main.py b/samcli/cli/main.py index 66127aa8b9..2b83f75d91 100644 --- a/samcli/cli/main.py +++ b/samcli/cli/main.py @@ -7,17 +7,22 @@ import click from samcli import __version__ +from samcli.lib.telemetry.metrics import send_installed_metric from .options import debug_option, region_option, profile_option from .context import Context from .command import BaseCommand +from .global_config import GlobalConfig -logger = logging.getLogger(__name__) +LOG = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S') pass_context = click.make_pass_decorator(Context) +global_cfg = GlobalConfig() + + def common_options(f): """ Common CLI options used by all commands. Ex: --debug @@ -48,6 +53,17 @@ def print_info(ctx, param, value): ctx.exit() +# Keep the message to 80chars wide to it prints well on most terminals +TELEMETRY_PROMPT = """ +\tTelemetry has been enabled for SAM CLI. +\t +\tYou can OPT OUT of telemetry by setting the environment variable +\tSAM_CLI_TELEMETRY=0 in your shell. + +\tLearn More: http://docs.aws.amazon.com/serverless-application-model/latest/developerguide/telemetry-opt-out +""" + + @click.command(cls=BaseCommand) @common_options @click.version_option(version=__version__, prog_name="SAM CLI") @@ -62,4 +78,19 @@ def cli(ctx): You can find more in-depth guide about the SAM specification here: https://github.com/awslabs/serverless-application-model. """ - pass + + if global_cfg.telemetry_enabled is None: + enabled = True + + try: + global_cfg.telemetry_enabled = enabled + + if enabled: + click.secho(TELEMETRY_PROMPT, fg="yellow", err=True) + + # When the Telemetry prompt is printed, we can safely assume that this is the first time someone + # is installing SAM CLI on this computer. So go ahead and send the `installed` metric + send_installed_metric() + + except (IOError, ValueError) as ex: + LOG.debug("Unable to write telemetry flag", exc_info=ex) diff --git a/samcli/commands/build/command.py b/samcli/commands/build/command.py index f80dca0d83..9c331d2163 100644 --- a/samcli/commands/build/command.py +++ b/samcli/commands/build/command.py @@ -16,6 +16,7 @@ from samcli.lib.build.workflow_config import UnsupportedRuntimeException from samcli.local.lambdafn.exceptions import FunctionNotFound from samcli.commands._utils.template import move_template +from samcli.lib.telemetry.metrics import track_command LOG = logging.getLogger(__name__) @@ -84,6 +85,7 @@ @aws_creds_options @click.argument('function_identifier', required=False) @pass_context +@track_command def cli(ctx, function_identifier, template, diff --git a/samcli/commands/deploy/__init__.py b/samcli/commands/deploy/__init__.py index f18247ac2f..5939b2a39e 100644 --- a/samcli/commands/deploy/__init__.py +++ b/samcli/commands/deploy/__init__.py @@ -7,6 +7,7 @@ from samcli.cli.main import pass_context, common_options from samcli.lib.samlib.cloudformation_command import execute_command from samcli.commands.exceptions import UserException +from samcli.lib.telemetry.metrics import track_command SHORT_HELP = "Deploy an AWS SAM application. This is an alias for 'aws cloudformation deploy'." @@ -36,6 +37,7 @@ "If you specify a new stack, the command creates it.") @common_options @pass_context +@track_command def cli(ctx, args, template_file, stack_name): # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing diff --git a/samcli/commands/init/__init__.py b/samcli/commands/init/__init__.py index 90cf870900..7186978c0f 100644 --- a/samcli/commands/init/__init__.py +++ b/samcli/commands/init/__init__.py @@ -11,6 +11,8 @@ from samcli.local.common.runtime_template import INIT_RUNTIMES, SUPPORTED_DEP_MANAGERS from samcli.local.init import generate_project from samcli.local.init.exceptions import GenerateProjectFailedError +from samcli.lib.telemetry.metrics import track_command + LOG = logging.getLogger(__name__) @@ -27,6 +29,7 @@ help="Disable prompting and accept default values defined template config") @common_options @pass_context +@track_command def cli(ctx, location, runtime, dependency_manager, output_dir, name, no_input): """ \b Initialize a serverless application with a SAM template, folder diff --git a/samcli/commands/local/cli_common/options.py b/samcli/commands/local/cli_common/options.py index f8ad498f9b..459251acf3 100644 --- a/samcli/commands/local/cli_common/options.py +++ b/samcli/commands/local/cli_common/options.py @@ -19,6 +19,7 @@ def get_application_dir(): Path Path representing the application config directory """ + # TODO: Get the config directory directly from `GlobalConfig` return Path(click.get_app_dir('AWS SAM', force_posix=True)) diff --git a/samcli/commands/local/generate_event/event_generation.py b/samcli/commands/local/generate_event/event_generation.py index f9e7c08357..ca9180f9f0 100644 --- a/samcli/commands/local/generate_event/event_generation.py +++ b/samcli/commands/local/generate_event/event_generation.py @@ -7,6 +7,7 @@ from samcli.cli.options import debug_option import samcli.commands.local.lib.generated_sample_events.events as events +from samcli.lib.telemetry.metrics import track_command class ServiceCommand(click.MultiCommand): @@ -170,6 +171,7 @@ def list_commands(self, ctx): """ return sorted(self.subcmd_definition.keys()) + @track_command def cmd_implementation(self, events_lib, top_level_cmd_name, subcmd_name, *args, **kwargs): """ calls for value substitution in the event json and returns the diff --git a/samcli/commands/local/invoke/cli.py b/samcli/commands/local/invoke/cli.py index d8c875eee4..9272c30f30 100644 --- a/samcli/commands/local/invoke/cli.py +++ b/samcli/commands/local/invoke/cli.py @@ -15,6 +15,7 @@ from samcli.commands.local.lib.exceptions import OverridesNotWellDefinedError from samcli.local.docker.manager import DockerImagePullFailedException from samcli.local.docker.lambda_debug_entrypoint import DebuggingNotSupported +from samcli.lib.telemetry.metrics import track_command LOG = logging.getLogger(__name__) @@ -44,7 +45,8 @@ @cli_framework_options @aws_creds_options @click.argument('function_identifier', required=False) -@pass_context # pylint: disable=R0914 +@pass_context +@track_command # pylint: disable=R0914 def cli(ctx, function_identifier, template, event, no_event, env_vars, debug_port, debug_args, debugger_path, docker_volume_basedir, docker_network, log_file, layer_cache_basedir, skip_pull_image, force_image_build, parameter_overrides): diff --git a/samcli/commands/local/start_api/cli.py b/samcli/commands/local/start_api/cli.py index fb832efaaf..a2eaac6003 100644 --- a/samcli/commands/local/start_api/cli.py +++ b/samcli/commands/local/start_api/cli.py @@ -14,6 +14,8 @@ from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.commands.local.lib.exceptions import OverridesNotWellDefinedError from samcli.local.docker.lambda_debug_entrypoint import DebuggingNotSupported +from samcli.lib.telemetry.metrics import track_command + LOG = logging.getLogger(__name__) @@ -43,6 +45,7 @@ @cli_framework_options @aws_creds_options # pylint: disable=R0914 @pass_context +@track_command def cli(ctx, # start-api Specific Options host, port, static_dir, diff --git a/samcli/commands/local/start_lambda/cli.py b/samcli/commands/local/start_lambda/cli.py index 4bd7ba1129..a5807a6b50 100644 --- a/samcli/commands/local/start_lambda/cli.py +++ b/samcli/commands/local/start_lambda/cli.py @@ -14,6 +14,7 @@ from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.commands.local.lib.exceptions import OverridesNotWellDefinedError from samcli.local.docker.lambda_debug_entrypoint import DebuggingNotSupported +from samcli.lib.telemetry.metrics import track_command LOG = logging.getLogger(__name__) @@ -58,6 +59,7 @@ @cli_framework_options @aws_creds_options @pass_context +@track_command def cli(ctx, # pylint: disable=R0914 # start-lambda Specific Options host, port, diff --git a/samcli/commands/logs/command.py b/samcli/commands/logs/command.py index fa4f785219..6342fa0365 100644 --- a/samcli/commands/logs/command.py +++ b/samcli/commands/logs/command.py @@ -6,6 +6,7 @@ import click from samcli.cli.main import pass_context, common_options as cli_framework_options, aws_creds_options +from samcli.lib.telemetry.metrics import track_command from .logs_context import LogsCommandContext LOG = logging.getLogger(__name__) @@ -60,6 +61,7 @@ @cli_framework_options @aws_creds_options @pass_context +@track_command def cli(ctx, name, stack_name, diff --git a/samcli/commands/package/__init__.py b/samcli/commands/package/__init__.py index 9c45396e14..110ae5c16c 100644 --- a/samcli/commands/package/__init__.py +++ b/samcli/commands/package/__init__.py @@ -9,6 +9,7 @@ from samcli.commands._utils.options import get_or_default_template_file_name, _TEMPLATE_OPTION_DEFAULT_VALUE from samcli.lib.samlib.cloudformation_command import execute_command from samcli.commands.exceptions import UserException +from samcli.lib.telemetry.metrics import track_command SHORT_HELP = "Package an AWS SAM application. This is an alias for 'aws cloudformation package'." @@ -42,6 +43,7 @@ @click.argument("args", nargs=-1, type=click.UNPROCESSED) @common_options @pass_context +@track_command def cli(ctx, args, template_file, s3_bucket): # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing diff --git a/samcli/commands/publish/command.py b/samcli/commands/publish/command.py index 9587f0c863..9020b750b4 100644 --- a/samcli/commands/publish/command.py +++ b/samcli/commands/publish/command.py @@ -14,6 +14,7 @@ from samcli.commands._utils.options import template_common_option from samcli.commands._utils.template import get_template_data from samcli.commands.exceptions import UserException +from samcli.lib.telemetry.metrics import track_command LOG = logging.getLogger(__name__) @@ -46,6 +47,7 @@ @aws_creds_options @cli_framework_options @pass_context +@track_command def cli(ctx, template, semantic_version): # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing diff --git a/samcli/commands/validate/validate.py b/samcli/commands/validate/validate.py index 693658049b..7461ae5692 100644 --- a/samcli/commands/validate/validate.py +++ b/samcli/commands/validate/validate.py @@ -13,6 +13,7 @@ from samcli.commands._utils.options import template_option_without_build from samcli.commands.local.cli_common.user_exceptions import InvalidSamTemplateException, SamTemplateNotFoundException from samcli.yamlhelper import yaml_parse +from samcli.lib.telemetry.metrics import track_command from .lib.exceptions import InvalidSamDocumentException from .lib.sam_template_validator import SamTemplateValidator @@ -23,6 +24,7 @@ @aws_creds_options @cli_framework_options @pass_context +@track_command def cli(ctx, template): # All logic must be implemented in the ``do_cli`` method. This helps with easy unit testing diff --git a/samcli/lib/telemetry/__init__.py b/samcli/lib/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/lib/telemetry/metrics.py b/samcli/lib/telemetry/metrics.py new file mode 100644 index 0000000000..ba2cc2271f --- /dev/null +++ b/samcli/lib/telemetry/metrics.py @@ -0,0 +1,132 @@ +""" +Provides methods to generate and send metrics +""" + + +import platform +import logging + +from timeit import default_timer + +from samcli.cli.context import Context +from samcli.commands.exceptions import UserException +from samcli.cli.global_config import GlobalConfig +from .telemetry import Telemetry + + +LOG = logging.getLogger(__name__) + + +def send_installed_metric(): + + LOG.debug("Sending Installed Metric") + + telemetry = Telemetry() + telemetry.emit("installed", { + "osPlatform": platform.system(), + "telemetryEnabled": _telemetry_enabled(), + }) + + +def track_command(func): + """ + Decorator to track execution of a command. This method executes the function, gathers all relevant metrics, + reports the metrics and returns. + + If you have a Click command, you can track as follows: + + .. code:: python + @click.command(...) + @click.options(...) + @track_command + def hello_command(): + print('hello') + + """ + + def wrapped(*args, **kwargs): + + if not _telemetry_enabled(): + # When Telemetry is disabled, call the function immediately and return. + return func(*args, **kwargs) + + telemetry = Telemetry() + + exception = None + return_value = None + exit_reason = "success" + exit_code = 0 + + duration_fn = _timer() + try: + + # Execute the function and capture return value. This is returned back by the wrapper + # First argument of all commands should be the Context + return_value = func(*args, **kwargs) + + except UserException as ex: + # Capture exception information and re-raise it later so we can first send metrics. + exception = ex + exit_code = ex.exit_code + exit_reason = type(ex).__name__ + + except Exception as ex: + exception = ex + # Standard Unix practice to return exit code 255 on fatal/unhandled exit. + exit_code = 255 + exit_reason = type(ex).__name__ + + ctx = Context.get_current_context() + telemetry.emit("commandRun", { + # Metric about command's general environment + "awsProfileProvided": bool(ctx.profile), + "debugFlagProvided": bool(ctx.debug), + "region": ctx.region or "", + "commandName": ctx.command_path, # Full command path. ex: sam local start-api + + # Metric about command's execution characteristics + "duration": duration_fn(), + "exitReason": exit_reason, + "exitCode": exit_code + }) + + if exception: + raise exception # pylint: disable=raising-bad-type + + return return_value + + return wrapped + + +def _timer(): + """ + Timer to measure the elapsed time between two calls in milliseconds. When you first call this method, + we will automatically start the timer. The return value is another method that, when called, will end the timer + and return the duration between the two calls. + + ..code: + >>> import time + >>> duration_fn = _timer() + >>> time.sleep(5) # Say, you sleep for 5 seconds in between calls + >>> duration_ms = duration_fn() + >>> print(duration_ms) + 5010 + + Returns + ------- + function + Call this method to end the timer and return duration in milliseconds + + """ + start = default_timer() + + def end(): + # time might go backwards in rare scenarios, hence the 'max' + return int(max(default_timer() - start, 0) * 1000) # milliseconds + + return end + + +def _telemetry_enabled(): + gc = GlobalConfig() + return bool(gc.telemetry_enabled) diff --git a/samcli/lib/telemetry/telemetry.py b/samcli/lib/telemetry/telemetry.py new file mode 100644 index 0000000000..2daae87513 --- /dev/null +++ b/samcli/lib/telemetry/telemetry.py @@ -0,0 +1,124 @@ +""" +Class to publish metrics +""" + +import platform +import uuid +import logging +import requests + +from samcli import __version__ as samcli_version +from samcli.cli.context import Context +from samcli.cli.global_config import GlobalConfig + +# Get the preconfigured endpoint URL +from samcli.settings import telemetry_endpoint_url as DEFAULT_ENDPOINT_URL + +LOG = logging.getLogger(__name__) + + +class Telemetry(object): + + def __init__(self, url=None): + """ + Initialize the Telemetry object. + + Parameters + ---------- + url : str + Optional, URL where the metrics should be published to + """ + self._session_id = self._default_session_id() + + if not self._session_id: + raise RuntimeError("Unable to retrieve session_id from Click Context") + + self._gc = GlobalConfig() + self._url = url or DEFAULT_ENDPOINT_URL + LOG.debug("Telemetry endpoint configured to be %s", self._url) + + def emit(self, metric_name, attrs): + """ + Emits the metric with given name and the attributes and send it immediately to the HTTP backend. This method + will return immediately without waiting for response from the backend. Before sending, this method will + also update ``attrs`` with some common attributes used by all metrics. + + Parameters + ---------- + metric_name : str + Name of the metric to publish + + attrs : dict + Attributes sent along with the metric + """ + attrs = self._add_common_metric_attributes(attrs) + + self._send({metric_name: attrs}) + + def _send(self, metric, wait_for_response=False): + """ + Serializes the metric data to JSON and sends to the backend. + + Parameters + ---------- + + metric : dict + Dictionary of metric data to send to backend. + + wait_for_response : bool + If set to True, this method will wait until the HTTP server returns a response. If not, it will return + immediately after the request is sent. + """ + + if not self._url: + # Endpoint not configured. So simply return + LOG.debug("Not sending telemetry. Endpoint URL not configured") + return + + payload = {"metrics": [metric]} + LOG.debug("Sending Telemetry: %s", payload) + + timeout_ms = 2000 if wait_for_response else 1 # 2 seconds to wait for response or 1ms + + timeout = (2, # connection timeout. Always set to 2 seconds + timeout_ms / 1000.0 # Read timeout. Tweaked based on input. + ) + try: + r = requests.post(self._url, json=payload, timeout=timeout) + LOG.debug("Telemetry response: %d", r.status_code) + except requests.exceptions.Timeout as ex: + # Expected if request times out. Just print debug log and ignore the exception. + LOG.debug(str(ex)) + + def _add_common_metric_attributes(self, attrs): + attrs["requestId"] = str(uuid.uuid4()) + attrs["installationId"] = self._gc.installation_id + attrs["sessionId"] = self._session_id + attrs["executionEnvironment"] = self._get_execution_environment() + attrs["pyversion"] = platform.python_version() + attrs["samcliVersion"] = samcli_version + + return attrs + + def _default_session_id(self): + """ + Get the default SessionId from Click Context. + """ + ctx = Context.get_current_context() + if ctx: + return ctx.session_id + + def _get_execution_environment(self): + """ + Returns the environment in which SAM CLI is running. Possible options are: + + CLI (default) - SAM CLI was executed from terminal or a script. + IDEToolkit - SAM CLI was executed by IDE Toolkit + CodeBuild - SAM CLI was executed from within CodeBuild + + Returns + ------- + str + Name of the environment where SAM CLI is executed in. + """ + return "CLI" diff --git a/samcli/settings/__init__.py b/samcli/settings/__init__.py new file mode 100644 index 0000000000..266d363ff6 --- /dev/null +++ b/samcli/settings/__init__.py @@ -0,0 +1,24 @@ +# flake8: noqa +""" +Default Settings used by the CLI. + +We will checkin the development.py file into source control. So by default only the dev configs +will be available. When preparing the CLI for production release, the release process will inject +production.py file into this folder and remove development.py. When customers install SAM CLI from +PyPi or any other official installation mechanism, they will get the production settings. + +Ensure the configuration variables defined in production.py and development.py have exact same names. + + +Following variables are exported by this module: + + ``telemetry_endpoint_url``: string URL where Telemetry data should be published to + +""" + +import os + +if "__SAM_CLI_TELEMETRY_ENDPOINT_URL" not in os.environ: + telemetry_endpoint_url = "https://aws-serverless-tools-telemetry.us-west-2.amazonaws.com/metrics" +else: + telemetry_endpoint_url = os.getenv("__SAM_CLI_TELEMETRY_ENDPOINT_URL") diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..9c4c78ad8d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,6 @@ + +import os + + +if "__SAM_CLI_TELEMETRY_ENDPOINT_URL" not in os.environ: + os.environ["__SAM_CLI_TELEMETRY_ENDPOINT_URL"] = "" diff --git a/tests/functional/commands/cli/__init__.py b/tests/functional/commands/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/functional/commands/cli/test_global_config.py b/tests/functional/commands/cli/test_global_config.py new file mode 100644 index 0000000000..78b84c73b1 --- /dev/null +++ b/tests/functional/commands/cli/test_global_config.py @@ -0,0 +1,152 @@ +import json +import tempfile +import shutil + +from mock import mock_open, patch +from unittest import TestCase +from json import JSONDecodeError +from samcli.cli.global_config import GlobalConfig + +try: + from pathlib import Path +except ImportError: + from pathlib2 import Path + + +class TestGlobalConfig(TestCase): + + def setUp(self): + self._cfg_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._cfg_dir) + + def test_installation_id_with_side_effect(self): + gc = GlobalConfig(config_dir=self._cfg_dir) + installation_id = gc.installation_id + expected_path = Path(self._cfg_dir, "metadata.json") + json_body = json.loads(expected_path.read_text()) + self.assertIsNotNone(installation_id) + self.assertTrue(expected_path.exists()) + self.assertEquals(installation_id, json_body["installationId"]) + installation_id_refetch = gc.installation_id + self.assertEquals(installation_id, installation_id_refetch) + + def test_installation_id_on_existing_file(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + cfg = {"foo": "bar"} + f.write(json.dumps(cfg, indent=4) + "\n") + gc = GlobalConfig(config_dir=self._cfg_dir) + installation_id = gc.installation_id + json_body = json.loads(path.read_text()) + self.assertEquals(installation_id, json_body["installationId"]) + self.assertEquals("bar", json_body["foo"]) + + def test_installation_id_exists(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + cfg = {"installationId": "stub-uuid"} + f.write(json.dumps(cfg, indent=4) + "\n") + gc = GlobalConfig(config_dir=self._cfg_dir) + installation_id = gc.installation_id + self.assertEquals("stub-uuid", installation_id) + + def test_init_override(self): + gc = GlobalConfig(installation_id="foo") + installation_id = gc.installation_id + self.assertEquals("foo", installation_id) + + def test_invalid_json(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + f.write("NOT JSON, PROBABLY VALID YAML AM I RIGHT!?") + gc = GlobalConfig(config_dir=self._cfg_dir) + self.assertIsNone(gc.installation_id) + self.assertFalse(gc.telemetry_enabled) + + def test_telemetry_flag_provided(self): + gc = GlobalConfig(telemetry_enabled=True) + self.assertTrue(gc.telemetry_enabled) + + def test_telemetry_flag_from_cfg(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + cfg = {"telemetryEnabled": True} + f.write(json.dumps(cfg, indent=4) + "\n") + gc = GlobalConfig(config_dir=self._cfg_dir) + self.assertTrue(gc.telemetry_enabled) + + def test_telemetry_flag_no_file(self): + gc = GlobalConfig(config_dir=self._cfg_dir) + self.assertFalse(gc.telemetry_enabled) + + def test_telemetry_flag_not_in_cfg(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + cfg = {"installationId": "stub-uuid"} + f.write(json.dumps(cfg, indent=4) + "\n") + gc = GlobalConfig(config_dir=self._cfg_dir) + self.assertFalse(gc.telemetry_enabled) + + def test_set_telemetry_flag_no_file(self): + path = Path(self._cfg_dir, "metadata.json") + gc = GlobalConfig(config_dir=self._cfg_dir) + self.assertIsNone(gc.telemetry_enabled) # pre-state test + gc.telemetry_enabled = True + from_gc = gc.telemetry_enabled + json_body = json.loads(path.read_text()) + from_file = json_body["telemetryEnabled"] + self.assertTrue(from_gc) + self.assertTrue(from_file) + + def test_set_telemetry_flag_no_key(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + cfg = {"installationId": "stub-uuid"} + f.write(json.dumps(cfg, indent=4) + "\n") + gc = GlobalConfig(config_dir=self._cfg_dir) + gc.telemetry_enabled = True + json_body = json.loads(path.read_text()) + self.assertTrue(gc.telemetry_enabled) + self.assertTrue(json_body["telemetryEnabled"]) + + def test_set_telemetry_flag_overwrite(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + cfg = {"telemetryEnabled": True} + f.write(json.dumps(cfg, indent=4) + "\n") + gc = GlobalConfig(config_dir=self._cfg_dir) + self.assertTrue(gc.telemetry_enabled) + gc.telemetry_enabled = False + json_body = json.loads(path.read_text()) + self.assertFalse(gc.telemetry_enabled) + self.assertFalse(json_body["telemetryEnabled"]) + + def test_telemetry_flag_explicit_false(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + cfg = {"telemetryEnabled": True} + f.write(json.dumps(cfg, indent=4) + "\n") + gc = GlobalConfig(config_dir=self._cfg_dir, telemetry_enabled=False) + self.assertFalse(gc.telemetry_enabled) + + def test_setter_raises_on_invalid_json(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + f.write("NOT JSON, PROBABLY VALID YAML AM I RIGHT!?") + gc = GlobalConfig(config_dir=self._cfg_dir) + with self.assertRaises(JSONDecodeError): + gc.telemetry_enabled = True + + def test_setter_cannot_open_file(self): + path = Path(self._cfg_dir, "metadata.json") + with open(str(path), 'w') as f: + cfg = {"telemetryEnabled": True} + f.write(json.dumps(cfg, indent=4) + "\n") + m = mock_open() + m.side_effect = IOError("fail") + gc = GlobalConfig(config_dir=self._cfg_dir) + with patch('samcli.cli.global_config.open', m): + with self.assertRaises(IOError): + gc.telemetry_enabled = True diff --git a/tests/functional/commands/cli/test_main.py b/tests/functional/commands/cli/test_main.py new file mode 100644 index 0000000000..3abbd25f75 --- /dev/null +++ b/tests/functional/commands/cli/test_main.py @@ -0,0 +1,34 @@ +import mock +import tempfile +import shutil + +from unittest import TestCase +from click.testing import CliRunner +from samcli.cli.main import cli +from samcli.cli.global_config import GlobalConfig + + +class TestTelemetryPrompt(TestCase): + + def setUp(self): + self._cfg_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._cfg_dir) + + def test_cli_prompt(self): + gc = GlobalConfig(config_dir=self._cfg_dir) + with mock.patch('samcli.cli.main.global_cfg', gc): + self.assertIsNone(gc.telemetry_enabled) # pre-state test + runner = CliRunner() + runner.invoke(cli, ["local", "generate-event", "s3"]) + # assertFalse is not appropriate, because None would also count + self.assertEqual(False, gc.telemetry_enabled) + + def test_cli_prompt_false(self): + gc = GlobalConfig(config_dir=self._cfg_dir) + with mock.patch('samcli.cli.main.global_cfg', gc): + self.assertIsNone(gc.telemetry_enabled) # pre-state test + runner = CliRunner() + runner.invoke(cli, ["local", "generate-event", "s3"], input="Y") + self.assertEqual(True, gc.telemetry_enabled) diff --git a/tests/integration/telemetry/__init__.py b/tests/integration/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/telemetry/integ_base.py b/tests/integration/telemetry/integ_base.py new file mode 100644 index 0000000000..c7098eb18d --- /dev/null +++ b/tests/integration/telemetry/integ_base.py @@ -0,0 +1,196 @@ +import os +import shutil +import tempfile +import logging +import subprocess +import timeit +import time +import requests +import re + +from flask import Flask, request, Response +from threading import Thread +from collections import deque +from unittest import TestCase + +try: + from pathlib import Path +except ImportError: + from pathlib2 import Path + +from samcli.cli.global_config import GlobalConfig +from samcli.cli.main import TELEMETRY_PROMPT + + +LOG = logging.getLogger(__name__) +TELEMETRY_ENDPOINT_PORT = "18298" +TELEMETRY_ENDPOINT_HOST = "localhost" +TELEMETRY_ENDPOINT_URL = "http://{}:{}".format(TELEMETRY_ENDPOINT_HOST, TELEMETRY_ENDPOINT_PORT) + +# Convert line separators to work with Windows \r\n +EXPECTED_TELEMETRY_PROMPT = re.sub(r'\n', os.linesep, TELEMETRY_PROMPT) + + +class IntegBase(TestCase): + + @classmethod + def setUpClass(cls): + cls.cmd = cls.base_command() + + def setUp(self): + self.maxDiff = None # Show full JSON Diff + + self.config_dir = tempfile.mkdtemp() + self._gc = GlobalConfig(config_dir=self.config_dir) + + def tearDown(self): + self.config_dir and shutil.rmtree(self.config_dir) + + @classmethod + def base_command(cls): + command = "sam" + if os.getenv("SAM_CLI_DEV"): + command = "samdev" + + return command + + def run_cmd(self, stdin_data=""): + # Any command will work for this test suite + cmd_list = [self.cmd, "local", "generate-event", "s3", "put"] + + env = os.environ.copy() + + # remove the envvar which usually is set in Travis. This interferes with tests. + env.pop("SAM_CLI_TELEMETRY", None) + + env["__SAM_CLI_APP_DIR"] = self.config_dir + env["__SAM_CLI_TELEMETRY_ENDPOINT_URL"] = "{}/metrics".format(TELEMETRY_ENDPOINT_URL) + + process = subprocess.Popen(cmd_list, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + env=env) + return process + + def unset_config(self): + config_file = Path(self.config_dir, "metadata.json") + if config_file.exists(): + config_file.unlink() + + def set_config(self, telemetry_enabled=None): + self._gc.telemetry_enabled = telemetry_enabled + + def get_global_config(self): + return self._gc + + @staticmethod + def wait_for_process_terminate(process, timeout_seconds=5): + """ + This is needed because Python2's wait() method does *not* have a timeout + + Returns + ------- + Return code if the process exited within the timout. None, if process is still executing + """ + + start = timeit.default_timer() + retcode = None + + while (timeit.default_timer() - start) < timeout_seconds: + retcode = process.poll() + + if retcode is not None: + # Process exited + break + + time.sleep(0.1) # 100ms + + return retcode + + +class TelemetryServer(Thread): + """ + HTTP Server that can receive and store Telemetry requests. Caller can later retrieve the responses for + assertion + + Examples + -------- + >>> with TelemetryServer() as server: + >>> # Server is running now + >>> # Set the Telemetry backend endpoint to the server's URL + >>> env = os.environ.copy().setdefault("__SAM_CLI_TELEMETRY_ENDPOINT_URL", server.url) + >>> # Run SAM CLI command + >>> p = subprocess.Popen(["samdev", "local", "generate-event", "s3", "put"], env=env) + >>> p.wait() # Wait for process to complete + >>> # Get the first metrics request that was sent + >>> r = server.get_request(0) + >>> assert r.method == 'POST' + >>> assert r.body == "{...}" + """ + + def __init__(self): + super(TelemetryServer, self).__init__() + + self.flask_app = Flask(__name__) + + self.flask_app.add_url_rule("/metrics", + endpoint="/metrics", + view_func=self._request_handler, + methods=["POST"], + provide_automatic_options=False) + + self.flask_app.add_url_rule("/_shutdown", + endpoint="/_shutdown", + view_func=self._shutdown_flask, + methods=["GET"]) + + # Thread-safe data structure to record requests sent to the server + self._requests = deque() + + def run(self): + """ + Method that runs when thread starts. This starts up Flask server as well + """ + # os.environ['WERKZEUG_RUN_MAIN'] = 'true' + self.flask_app.run(port=TELEMETRY_ENDPOINT_PORT, host=TELEMETRY_ENDPOINT_HOST, threaded=True) + + def __enter__(self): + self.daemon = True # When test completes, this thread will die automatically + self.start() # Start the thread + + return self + + def __exit__(self, *args, **kwargs): + shutdown_endpoint = "{}/_shutdown".format(TELEMETRY_ENDPOINT_URL) + requests.get(shutdown_endpoint) + + # Flask will start shutting down only *after* the above request completes. + # Just give the server a little bit of time to teardown finish + time.sleep(2) + + def get_request(self, index): + return self._requests[index] + + def get_all_requests(self): + return list(self._requests) + + def _request_handler(self, **kwargs): + """ + Handles Flask requests + """ + + # `request` is a variable populated by Flask automatically when handler method is called + request_data = { + "endpoint": request.endpoint, + "method": request.method, + "data": request.get_json(), + "headers": dict(request.headers) + } + + self._requests.append(request_data) + + return Response(response={}, status=200) + + def _shutdown_flask(self): + # Based on http://flask.pocoo.org/snippets/67/ + request.environ.get('werkzeug.server.shutdown')() + print('Server shutting down...') + return '' diff --git a/tests/integration/telemetry/test_installed_metric.py b/tests/integration/telemetry/test_installed_metric.py new file mode 100644 index 0000000000..b25fb1b19e --- /dev/null +++ b/tests/integration/telemetry/test_installed_metric.py @@ -0,0 +1,117 @@ +import platform + +from mock import ANY +from .integ_base import IntegBase, TelemetryServer, EXPECTED_TELEMETRY_PROMPT +from samcli import __version__ as SAM_CLI_VERSION + + +class TestSendInstalledMetric(IntegBase): + + def test_send_installed_metric_on_first_run(self): + """ + On the first run, send the installed metric + """ + self.unset_config() + + with TelemetryServer() as server: + # Start the CLI + process = self.run_cmd() + + (_, stderrdata) = process.communicate() + + retcode = process.poll() + self.assertEquals(retcode, 0, "Command should successfully complete") + + # Make sure the prompt was printed. Otherwise this test is not valid + self.assertIn(EXPECTED_TELEMETRY_PROMPT, stderrdata.decode()) + + all_requests = server.get_all_requests() + self.assertEquals(2, len(all_requests), "There should be exactly two metrics request") + + # First one is usually the installed metric + requests = filter_installed_metric_requests(all_requests) + self.assertEquals(1, len(requests), "There should be only one 'installed' metric") + request = requests[0] + self.assertIn("Content-Type", request["headers"]) + self.assertEquals(request["headers"]["Content-Type"], "application/json") + + expected_data = { + "metrics": [{ + "installed": { + "installationId": self.get_global_config().installation_id, + "samcliVersion": SAM_CLI_VERSION, + "osPlatform": platform.system(), + + "executionEnvironment": ANY, + "pyversion": ANY, + "sessionId": ANY, + "requestId": ANY, + "telemetryEnabled": True + } + }] + } + + self.assertEquals(request["data"], expected_data) + + def test_must_not_send_installed_metric_when_prompt_is_disabled(self): + """ + If the Telemetry Prompt is not displayed, we must *not* send installed metric, even if Telemetry is enabled. + This happens on all subsequent runs. + """ + + # Enable Telemetry. This will skip the Telemetry Prompt. + self.set_config(telemetry_enabled=True) + + with TelemetryServer() as server: + # Start the CLI + process = self.run_cmd() + + (stdoutdata, stderrdata) = process.communicate() + + retcode = process.poll() + self.assertEquals(retcode, 0, "Command should successfully complete") + self.assertNotIn(EXPECTED_TELEMETRY_PROMPT, stdoutdata.decode()) + self.assertNotIn(EXPECTED_TELEMETRY_PROMPT, stderrdata.decode()) + + requests = filter_installed_metric_requests(server.get_all_requests()) + self.assertEquals(0, len(requests), "'installed' metric should NOT be sent") + + def test_must_not_send_installed_metric_on_second_run(self): + """ + On first run, send installed metric. On second run, must *not* send installed metric + """ + + # Unset config to show the prompt + self.unset_config() + + with TelemetryServer() as server: + + # First Run + process1 = self.run_cmd() + (_, stderrdata) = process1.communicate() + retcode = process1.poll() + self.assertEquals(retcode, 0, "Command should successfully complete") + self.assertIn(EXPECTED_TELEMETRY_PROMPT, stderrdata.decode()) + self.assertEquals(1, len(filter_installed_metric_requests(server.get_all_requests())), + "'installed' metric should be sent") + + # Second Run + process2 = self.run_cmd() + (stdoutdata, stderrdata) = process2.communicate() + retcode = process2.poll() + self.assertEquals(retcode, 0) + self.assertNotIn(EXPECTED_TELEMETRY_PROMPT, stdoutdata.decode()) + self.assertNotIn(EXPECTED_TELEMETRY_PROMPT, stderrdata.decode()) + self.assertEquals(1, len(filter_installed_metric_requests(server.get_all_requests())), + "Only one 'installed' metric should be sent") + + +def filter_installed_metric_requests(all_requests): + + result = [] + for r in all_requests: + data = r["data"] + if "metrics" in data and data["metrics"] and "installed" in data["metrics"][0]: + result.append(r) + + return result diff --git a/tests/integration/telemetry/test_prompt.py b/tests/integration/telemetry/test_prompt.py new file mode 100644 index 0000000000..8dedf79473 --- /dev/null +++ b/tests/integration/telemetry/test_prompt.py @@ -0,0 +1,53 @@ + +from parameterized import parameterized +from .integ_base import IntegBase, EXPECTED_TELEMETRY_PROMPT + + +class TestTelemetryPrompt(IntegBase): + + def test_must_prompt_if_config_is_not_set(self): + """ + Must print prompt if Telemetry config is not set. + """ + self.unset_config() + + process = self.run_cmd() + (stdoutdata, stderrdata) = process.communicate() + + # Telemetry prompt should be printed to the terminal + self.assertIn(EXPECTED_TELEMETRY_PROMPT, stderrdata.decode()) + + @parameterized.expand([ + (True, "Enable Telemetry"), + (False, "Disalbe Telemetry") + ]) + def test_must_not_prompt_if_config_is_set(self, telemetry_enabled, msg): + """ + If telemetry config is already set, prompt must not be displayed + """ + + # Set the telemetry config + self.set_config(telemetry_enabled=telemetry_enabled) + + process = self.run_cmd() + (stdoutdata, stderrdata) = process.communicate() + + self.assertNotIn(EXPECTED_TELEMETRY_PROMPT, stdoutdata.decode()) + self.assertNotIn(EXPECTED_TELEMETRY_PROMPT, stderrdata.decode()) + + def test_prompt_must_not_display_on_second_run(self): + """ + On first run, display the prompt. Do *not* display prompt on subsequent runs. + """ + self.unset_config() + + # First Run + process = self.run_cmd() + (stdoutdata, stderrdata) = process.communicate() + self.assertIn(EXPECTED_TELEMETRY_PROMPT, stderrdata.decode()) + + # Second Run + process = self.run_cmd() + (stdoutdata, stderrdata) = process.communicate() + self.assertNotIn(EXPECTED_TELEMETRY_PROMPT, stdoutdata.decode()) + self.assertNotIn(EXPECTED_TELEMETRY_PROMPT, stderrdata.decode()) diff --git a/tests/unit/cli/test_context.py b/tests/unit/cli/test_context.py index 6e65fa96d1..d834d187d2 100644 --- a/tests/unit/cli/test_context.py +++ b/tests/unit/cli/test_context.py @@ -58,3 +58,39 @@ def test_must_set_all_aws_session_properties(self, boto_mock): ctx.profile = profile ctx.region = region boto_mock.setup_default_session.assert_called_with(region_name=region, profile_name=profile) + + @patch("samcli.cli.context.uuid") + def test_must_set_session_id_to_uuid(self, uuid_mock): + uuid_mock.uuid4.return_value = "abcd" + ctx = Context() + + self.assertEquals(ctx.session_id, "abcd") + + @patch("samcli.cli.context.click") + def test_must_find_context(self, click_mock): + + ctx = Context() + result = ctx.get_current_context() + + self.assertEquals(click_mock.get_current_context.return_value.find_object.return_value, result) + click_mock.get_current_context.return_value.find_object.assert_called_once_with(Context) + + @patch("samcli.cli.context.click") + def test_create_new_context_if_not_found(self, click_mock): + + # Context can't be found + click_mock.get_current_context.return_value.find_object.return_value = None + + ctx = Context() + result = ctx.get_current_context() + + self.assertEquals(click_mock.get_current_context.return_value.ensure_object.return_value, result) + click_mock.get_current_context.return_value.ensure_object.assert_called_once_with(Context) + + @patch("samcli.cli.context.click") + def test_get_current_context_from_outside_of_click(self, click_mock): + click_mock.get_current_context.return_value = None + ctx = Context() + + # Context can't be found + self.assertIsNone(ctx.get_current_context()) diff --git a/tests/unit/cli/test_global_config.py b/tests/unit/cli/test_global_config.py new file mode 100644 index 0000000000..48a0193221 --- /dev/null +++ b/tests/unit/cli/test_global_config.py @@ -0,0 +1,118 @@ +from mock import mock_open, patch, Mock +from unittest import TestCase +from parameterized import parameterized +from samcli.cli.global_config import GlobalConfig + +try: + from pathlib import Path +except ImportError: + from pathlib2 import Path + + +class TestGlobalConfig(TestCase): + + def test_config_write_error(self): + m = mock_open() + m.side_effect = IOError("fail") + gc = GlobalConfig() + with patch('samcli.cli.global_config.open', m): + installation_id = gc.installation_id + self.assertIsNone(installation_id) + + def test_setter_cannot_open_path(self): + m = mock_open() + m.side_effect = IOError("fail") + gc = GlobalConfig() + with patch('samcli.cli.global_config.open', m): + with self.assertRaises(IOError): + gc.telemetry_enabled = True + + @patch('samcli.cli.global_config.click') + def test_config_dir_default(self, mock_click): + mock_click.get_app_dir.return_value = "mock/folders" + gc = GlobalConfig() + self.assertEqual(Path("mock/folders"), gc.config_dir) + mock_click.get_app_dir.assert_called_once_with('AWS SAM', force_posix=True) + + def test_explicit_installation_id(self): + gc = GlobalConfig(installation_id="foobar") + self.assertEqual("foobar", gc.installation_id) + + @patch('samcli.cli.global_config.uuid') + @patch('samcli.cli.global_config.Path') + @patch('samcli.cli.global_config.click') + def test_setting_installation_id(self, mock_click, mock_path, mock_uuid): + gc = GlobalConfig() + mock_uuid.uuid4.return_value = "SevenLayerDipMock" + path_mock = Mock() + joinpath_mock = Mock() + joinpath_mock.exists.return_value = False + path_mock.joinpath.return_value = joinpath_mock + mock_path.return_value = path_mock + mock_click.get_app_dir.return_value = "mock/folders" + mock_io = mock_open(Mock()) + with patch("samcli.cli.global_config.open", mock_io): + self.assertEquals("SevenLayerDipMock", gc.installation_id) + + def test_explicit_telemetry_enabled(self): + gc = GlobalConfig(telemetry_enabled=True) + self.assertTrue(gc.telemetry_enabled) + + @patch('samcli.cli.global_config.Path') + @patch('samcli.cli.global_config.click') + @patch('samcli.cli.global_config.os') + def test_missing_telemetry_flag(self, mock_os, mock_click, mock_path): + gc = GlobalConfig() + mock_click.get_app_dir.return_value = "mock/folders" + path_mock = Mock() + joinpath_mock = Mock() + joinpath_mock.exists.return_value = False + path_mock.joinpath.return_value = joinpath_mock + mock_path.return_value = path_mock + mock_os.environ = {} # env var is not set + self.assertIsNone(gc.telemetry_enabled) + + @patch('samcli.cli.global_config.Path') + @patch('samcli.cli.global_config.click') + @patch('samcli.cli.global_config.os') + def test_error_reading_telemetry_flag(self, mock_os, mock_click, mock_path): + gc = GlobalConfig() + mock_click.get_app_dir.return_value = "mock/folders" + path_mock = Mock() + joinpath_mock = Mock() + joinpath_mock.exists.return_value = True + path_mock.joinpath.return_value = joinpath_mock + mock_path.return_value = path_mock + mock_os.environ = {} # env var is not set + + m = mock_open() + m.side_effect = IOError("fail") + with patch('samcli.cli.global_config.open', m): + self.assertFalse(gc.telemetry_enabled) + + @parameterized.expand([ + # Only values of '1' and 1 will enable Telemetry. Everything will disable. + (1, True), + ('1', True), + + (0, False), + ('0', False), + # words true, True, False, False etc will disable telemetry + ('true', False), + ('True', False), + ('False', False) + ]) + @patch('samcli.cli.global_config.os') + @patch('samcli.cli.global_config.click') + def test_set_telemetry_through_env_variable(self, env_value, expected_result, mock_click, mock_os): + gc = GlobalConfig() + + mock_os.environ = {"SAM_CLI_TELEMETRY": env_value} + mock_os.getenv.return_value = env_value + + self.assertEquals(gc.telemetry_enabled, expected_result) + + mock_os.getenv.assert_called_once_with("SAM_CLI_TELEMETRY") + + # When environment variable is set, we shouldn't be reading the real config file at all. + mock_click.get_app_dir.assert_not_called() diff --git a/tests/unit/cli/test_main.py b/tests/unit/cli/test_main.py index 0e3568fe39..827960a0f7 100644 --- a/tests/unit/cli/test_main.py +++ b/tests/unit/cli/test_main.py @@ -1,3 +1,5 @@ +import mock + from unittest import TestCase from click.testing import CliRunner from samcli.cli.main import cli @@ -10,20 +12,52 @@ def test_cli_base(self): Just invoke the CLI without any commands and assert that help text was printed :return: """ - runner = CliRunner() - result = runner.invoke(cli, []) - self.assertEquals(result.exit_code, 0) - self.assertTrue("--help" in result.output, "Help text must be printed") - self.assertTrue("--debug" in result.output, "--debug option must be present in help text") + mock_cfg = mock.Mock() + with mock.patch('samcli.cli.main.global_cfg', mock_cfg): + runner = CliRunner() + result = runner.invoke(cli, []) + self.assertEquals(result.exit_code, 0) + self.assertTrue("--help" in result.output, "Help text must be printed") + self.assertTrue("--debug" in result.output, "--debug option must be present in help text") def test_cli_some_command(self): - runner = CliRunner() - result = runner.invoke(cli, ["local", "generate-event", "s3"]) - self.assertEquals(result.exit_code, 0) + mock_cfg = mock.Mock() + with mock.patch('samcli.cli.main.global_cfg', mock_cfg): + runner = CliRunner() + result = runner.invoke(cli, ["local", "generate-event", "s3"]) + self.assertEquals(result.exit_code, 0) def test_cli_with_debug(self): - runner = CliRunner() - result = runner.invoke(cli, ["local", "generate-event", "s3", "put", "--debug"]) - self.assertEquals(result.exit_code, 0) + mock_cfg = mock.Mock() + with mock.patch('samcli.cli.main.global_cfg', mock_cfg): + runner = CliRunner() + result = runner.invoke(cli, ["local", "generate-event", "s3", "put", "--debug"]) + self.assertEquals(result.exit_code, 0) + + @mock.patch('samcli.cli.main.send_installed_metric') + def test_cli_enable_telemetry_with_prompt(self, send_installed_metric_mock): + with mock.patch( + 'samcli.cli.global_config.GlobalConfig.telemetry_enabled', new_callable=mock.PropertyMock + ) as mock_flag: + mock_flag.return_value = None + runner = CliRunner() + runner.invoke(cli, ["local", "generate-event", "s3"]) + mock_flag.assert_called_with(True) + + # If telemetry is enabled, this should be called + send_installed_metric_mock.assert_called_once() + + @mock.patch('samcli.cli.main.send_installed_metric') + def test_prompt_skipped_when_value_set(self, send_installed_metric_mock): + with mock.patch( + 'samcli.cli.global_config.GlobalConfig.telemetry_enabled', new_callable=mock.PropertyMock + ) as mock_flag: + mock_flag.return_value = True + runner = CliRunner() + runner.invoke(cli, ["local", "generate-event", "s3"]) + mock_flag.assert_called_once_with() + + # If prompt is skipped, this should be NOT called + send_installed_metric_mock.assert_not_called() diff --git a/tests/unit/commands/local/generate_event/test_event_generation.py b/tests/unit/commands/local/generate_event/test_event_generation.py index a06074a218..a017e1ba12 100644 --- a/tests/unit/commands/local/generate_event/test_event_generation.py +++ b/tests/unit/commands/local/generate_event/test_event_generation.py @@ -1,3 +1,5 @@ +import os + from unittest import TestCase from mock import Mock from mock import patch @@ -86,6 +88,13 @@ def setUp(self): self.events_lib_mock = Mock() self.s = EventTypeSubCommand(self.events_lib_mock, self.service_cmd_name, self.all_cmds) + # Disable telemetry + self.old_environ = os.environ.copy() + os.environ["SAM_CLI_TELEMETRY"] = 0 + + def tearDown(self): + os.environ = self.old_environ + def test_subcommand_accepts_events_lib(self): events_lib = Mock() events_lib.expose_event_metadata.return_value = self.all_cmds diff --git a/tests/unit/lib/telemetry/test_metrics.py b/tests/unit/lib/telemetry/test_metrics.py new file mode 100644 index 0000000000..36365bf8ff --- /dev/null +++ b/tests/unit/lib/telemetry/test_metrics.py @@ -0,0 +1,234 @@ +import platform +import time + +from unittest import TestCase +from mock import patch, Mock, ANY, call + +from samcli.lib.telemetry.metrics import send_installed_metric, track_command +from samcli.commands.exceptions import UserException + + +class TestSendInstalledMetric(TestCase): + + def setUp(self): + self.gc_mock = Mock() + self.global_config_patcher = patch("samcli.lib.telemetry.metrics.GlobalConfig", self.gc_mock) + self.global_config_patcher.start() + + def tearDown(self): + self.global_config_patcher.stop() + + @patch("samcli.lib.telemetry.metrics.Telemetry") + def test_must_send_installed_metric_with_attributes(self, TelemetryClassMock): + telemetry_mock = TelemetryClassMock.return_value = Mock() + + self.gc_mock.return_value.telemetry_enabled = False + send_installed_metric() + + telemetry_mock.emit.assert_called_with("installed", { + "osPlatform": platform.system(), + "telemetryEnabled": False + }) + + +class TestTrackCommand(TestCase): + + def setUp(self): + TelemetryClassMock = Mock() + GlobalConfigClassMock = Mock() + self.telemetry_instance = TelemetryClassMock.return_value = Mock() + self.gc_instance_mock = GlobalConfigClassMock.return_value = Mock() + + self.telemetry_class_patcher = patch("samcli.lib.telemetry.metrics.Telemetry", TelemetryClassMock) + self.gc_patcher = patch("samcli.lib.telemetry.metrics.GlobalConfig", GlobalConfigClassMock) + self.telemetry_class_patcher.start() + self.gc_patcher.start() + + self.context_mock = Mock() + self.context_mock.profile = False + self.context_mock.debug = False + self.context_mock.region = "myregion" + self.context_mock.command_path = "fakesam local invoke" + + # Enable telemetry so we can actually run the tests + self.gc_instance_mock.telemetry_enabled = True + + def tearDown(self): + self.telemetry_class_patcher.stop() + self.gc_patcher.stop() + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_emit_one_metric(self, ContextMock): + ContextMock.get_current_context.return_value = self.context_mock + + def real_fn(): + pass + + track_command(real_fn)() + + self.assertEquals(self.telemetry_instance.emit.mock_calls, [ + call("commandRun", ANY), + ], "The one command metric must be sent") + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_emit_command_run_metric(self, ContextMock): + ContextMock.get_current_context.return_value = self.context_mock + + def real_fn(): + pass + + track_command(real_fn)() + + expected_attrs = { + "awsProfileProvided": False, + "debugFlagProvided": False, + "region": "myregion", + "commandName": "fakesam local invoke", + + "duration": ANY, + "exitReason": "success", + "exitCode": 0 + } + self.telemetry_instance.emit.assert_has_calls([ + call("commandRun", expected_attrs) + ]) + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_emit_command_run_metric_with_sanitized_profile_value(self, ContextMock): + ContextMock.get_current_context.return_value = self.context_mock + self.context_mock.profile = "myprofilename" + + def real_fn(): + pass + + track_command(real_fn)() + + expected_attrs = _cmd_run_attrs({ + "awsProfileProvided": True + }) + self.telemetry_instance.emit.assert_has_calls([ + call("commandRun", expected_attrs) + ]) + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_record_function_duration(self, ContextMock): + ContextMock.get_current_context.return_value = self.context_mock + sleep_duration = 0.001 # 1 millisecond + + def real_fn(): + time.sleep(sleep_duration) + + track_command(real_fn)() + + # commandRun metric should be the only call to emit. + # And grab the second argument passed to this call, which are the attributes + args, kwargs = self.telemetry_instance.emit.call_args_list[0] + metric_name, actual_attrs = args + self.assertEquals("commandRun", metric_name) + self.assertGreater(actual_attrs["duration"], + sleep_duration, + "Measured duration must be in milliseconds and greater than the sleep duration") + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_record_user_exception(self, ContextMock): + ContextMock.get_current_context.return_value = self.context_mock + expected_exception = UserException("Something went wrong") + expected_exception.exit_code = 1235 + + def real_fn(): + raise expected_exception + + with self.assertRaises(UserException) as context: + track_command(real_fn)() + self.assertEquals(context.exception, expected_exception, "Must re-raise the original exception object " + "without modification") + + expected_attrs = _cmd_run_attrs({ + "exitReason": "UserException", + "exitCode": 1235 + }) + self.telemetry_instance.emit.assert_has_calls([ + call("commandRun", expected_attrs) + ]) + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_record_any_exceptions(self, ContextMock): + ContextMock.get_current_context.return_value = self.context_mock + expected_exception = KeyError("IO Error test") + + def real_fn(): + raise expected_exception + + with self.assertRaises(KeyError) as context: + track_command(real_fn)() + self.assertEquals(context.exception, expected_exception, "Must re-raise the original exception object " + "without modification") + + expected_attrs = _cmd_run_attrs({ + "exitReason": "KeyError", + "exitCode": 255 # Unhandled exceptions always use exit code 255 + }) + self.telemetry_instance.emit.assert_has_calls([ + call("commandRun", expected_attrs) + ]) + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_return_value_from_decorated_function(self, ContextMock): + expected_value = "some return value" + + def real_fn(): + return expected_value + + actual = track_command(real_fn)() + self.assertEquals(actual, "some return value") + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_pass_all_arguments_to_wrapped_function(self, ContextMock): + + def real_fn(*args, **kwargs): + # simply return the arguments to be able to examine & assert + return args, kwargs + + actual_args, actual_kwargs = track_command(real_fn)(1, 2, 3, a=1, b=2, c=3) + self.assertEquals(actual_args, (1, 2, 3)) + self.assertEquals(actual_kwargs, {"a": 1, "b": 2, "c": 3}) + + @patch("samcli.lib.telemetry.metrics.Context") + def test_must_decorate_functions(self, ContextMock): + + @track_command + def real_fn(a, b=None): + return "{} {}".format(a, b) + + actual = real_fn("hello", b="world") + self.assertEquals(actual, "hello world") + + self.assertEquals(self.telemetry_instance.emit.mock_calls, [ + call("commandRun", ANY), + ], "The command metrics be emitted when used as a decorator") + + def test_must_return_immediately_if_telemetry_is_disabled(self): + + def real_fn(): + return "hello" + + # Disable telemetry first + self.gc_instance_mock.telemetry_enabled = False + result = track_command(real_fn)() + + self.assertEquals(result, "hello") + self.telemetry_instance.emit.assert_not_called() + + +def _cmd_run_attrs(data): + common_attrs = ["awsProfileProvided", "debugFlagProvided", "region", "commandName", + "duration", "exitReason", "exitCode"] + return _ignore_other_attrs(data, common_attrs) + + +def _ignore_other_attrs(data, common_attrs): + for a in common_attrs: + if a not in data: + data[a] = ANY + + return data diff --git a/tests/unit/lib/telemetry/test_telemetry.py b/tests/unit/lib/telemetry/test_telemetry.py new file mode 100644 index 0000000000..3777bd1636 --- /dev/null +++ b/tests/unit/lib/telemetry/test_telemetry.py @@ -0,0 +1,153 @@ +import platform +import requests + +from mock import patch, Mock, ANY +from unittest import TestCase + +from samcli.lib.telemetry.telemetry import Telemetry +from samcli import __version__ as samcli_version + + +class TestTelemetry(TestCase): + + def setUp(self): + self.test_session_id = "TestSessionId" + self.test_installation_id = "TestInstallationId" + self.url = "some_test_url" + + self.gc_mock = Mock() + self.context_mock = Mock() + + self.global_config_patcher = patch("samcli.lib.telemetry.telemetry.GlobalConfig", self.gc_mock) + self.context_patcher = patch("samcli.lib.telemetry.telemetry.Context", self.context_mock) + + self.global_config_patcher.start() + self.context_patcher.start() + + self.context_mock.get_current_context.return_value.session_id = self.test_session_id + self.gc_mock.return_value.installation_id = self.test_installation_id + + def tearDown(self): + self.global_config_patcher.stop() + self.context_mock.stop() + + def test_must_raise_on_invalid_session_id(self): + self.context_mock.get_current_context.return_value = None + + with self.assertRaises(RuntimeError): + Telemetry() + + @patch("samcli.lib.telemetry.telemetry.requests") + def test_must_add_metric_with_attributes_to_registry(self, requests_mock): + telemetry = Telemetry(url=self.url) + metric_name = "mymetric" + attrs = {"a": 1, "b": 2} + + telemetry.emit(metric_name, attrs) + + expected = { + "metrics": [{ + metric_name: { + "a": 1, + "b": 2, + "requestId": ANY, + "installationId": self.test_installation_id, + "sessionId": self.test_session_id, + "executionEnvironment": "CLI", + "pyversion": platform.python_version(), + "samcliVersion": samcli_version + } + }] + } + requests_mock.post.assert_called_once_with(ANY, json=expected, timeout=ANY) + + @patch("samcli.lib.telemetry.telemetry.requests") + @patch('samcli.lib.telemetry.telemetry.uuid') + def test_must_add_request_id_as_uuid_v4(self, uuid_mock, requests_mock): + fake_uuid = uuid_mock.uuid4.return_value = "fake uuid" + + telemetry = Telemetry(url=self.url) + telemetry.emit("metric_name", {}) + + expected = { + "metrics": [{ + "metric_name": _ignore_other_attrs({ + "requestId": fake_uuid, + }) + }] + } + requests_mock.post.assert_called_once_with(ANY, json=expected, timeout=ANY) + + @patch("samcli.lib.telemetry.telemetry.requests") + def test_execution_environment_should_be_identified(self, requests_mock): + telemetry = Telemetry(url=self.url) + + telemetry.emit("metric_name", {}) + + expected_execution_environment = "CLI" + + expected = { + "metrics": [{ + "metric_name": _ignore_other_attrs({ + "executionEnvironment": expected_execution_environment + }) + }] + } + requests_mock.post.assert_called_once_with(ANY, json=expected, timeout=ANY) + + @patch("samcli.lib.telemetry.telemetry.requests") + def test_default_request_should_be_fire_and_forget(self, requests_mock): + telemetry = Telemetry(url=self.url) + + telemetry.emit("metric_name", {}) + requests_mock.post.assert_called_once_with(ANY, json=ANY, timeout=(2, 0.001)) # 1ms response timeout + + @patch("samcli.lib.telemetry.telemetry.requests") + def test_request_must_wait_for_2_seconds_for_response(self, requests_mock): + telemetry = Telemetry(url=self.url) + + telemetry._send({}, wait_for_response=True) + requests_mock.post.assert_called_once_with(ANY, json=ANY, timeout=(2, 2)) + + @patch("samcli.lib.telemetry.telemetry.requests") + def test_must_swallow_timeout_exception(self, requests_mock): + telemetry = Telemetry(url=self.url) + + # If we Mock the entire requests library, this statement will run into issues + # `except requests.exceptions.Timeout` + # https://stackoverflow.com/questions/31713054/cant-catch-mocked-exception-because-it-doesnt-inherit-baseexception + # + # Hence we save the original Timeout object to the Mock, so Python won't complain. + # + + requests_mock.exceptions.Timeout = requests.exceptions.Timeout + requests_mock.post.side_effect = requests.exceptions.Timeout() + + telemetry.emit("metric_name", {}) + + @patch("samcli.lib.telemetry.telemetry.requests") + def test_must_raise_on_other_requests_exception(self, requests_mock): + telemetry = Telemetry(url=self.url) + + requests_mock.exceptions.Timeout = requests.exceptions.Timeout + requests_mock.post.side_effect = IOError() + + with self.assertRaises(IOError): + telemetry.emit("metric_name", {}) + + @patch('samcli.lib.telemetry.telemetry.DEFAULT_ENDPOINT_URL') + def test_must_use_default_endpoint_url_if_not_customized(self, default_endpoint_url_mock): + telemetry = Telemetry() + + self.assertEquals(telemetry._url, default_endpoint_url_mock) + + +def _ignore_other_attrs(data): + + common_attrs = ["requestId", "installationId", "sessionId", "executionEnvironment", "pyversion", "samcliVersion"] + + for a in common_attrs: + if a not in data: + data[a] = ANY + + return data From 83a32d4f55a808868e651b24770229df3bc51099 Mon Sep 17 00:00:00 2001 From: Sanath Kumar Ramesh Date: Wed, 24 Jul 2019 17:04:15 -0700 Subject: [PATCH 4/7] chore: Bumping to v0.19.0 and updating Telemetry Opt-Out URL (#1288) --- samcli/__init__.py | 2 +- samcli/cli/main.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/samcli/__init__.py b/samcli/__init__.py index 8232907343..03c7b75894 100644 --- a/samcli/__init__.py +++ b/samcli/__init__.py @@ -2,4 +2,4 @@ SAM CLI version """ -__version__ = '0.18.0' +__version__ = '0.19.0' diff --git a/samcli/cli/main.py b/samcli/cli/main.py index 2b83f75d91..8bc3525cc7 100644 --- a/samcli/cli/main.py +++ b/samcli/cli/main.py @@ -60,8 +60,8 @@ def print_info(ctx, param, value): \tYou can OPT OUT of telemetry by setting the environment variable \tSAM_CLI_TELEMETRY=0 in your shell. -\tLearn More: http://docs.aws.amazon.com/serverless-application-model/latest/developerguide/telemetry-opt-out -""" +\tLearn More: https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-telemetry.html +""" # noqa @click.command(cls=BaseCommand) From f5a24de53a95490bbfddcdfd13be966b9627f6fa Mon Sep 17 00:00:00 2001 From: Sanath Kumar Ramesh Date: Thu, 25 Jul 2019 09:35:58 -0700 Subject: [PATCH 5/7] fix: More robust connections to telemetry backend (#1289) - Increase read timeout to 100ms. Backend needs a bit more time than 1ms to reliably receive & process data - Don't crash if offline - Integration tests to validate offline behavior & opt-in/opt-out behavior --- samcli/lib/telemetry/telemetry.py | 7 +- tests/integration/telemetry/integ_base.py | 7 +- .../telemetry/test_telemetry_contract.py | 73 +++++++++++++++++++ tests/unit/lib/telemetry/test_telemetry.py | 14 +++- 4 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 tests/integration/telemetry/test_telemetry_contract.py diff --git a/samcli/lib/telemetry/telemetry.py b/samcli/lib/telemetry/telemetry.py index 2daae87513..95a0f70663 100644 --- a/samcli/lib/telemetry/telemetry.py +++ b/samcli/lib/telemetry/telemetry.py @@ -78,7 +78,7 @@ def _send(self, metric, wait_for_response=False): payload = {"metrics": [metric]} LOG.debug("Sending Telemetry: %s", payload) - timeout_ms = 2000 if wait_for_response else 1 # 2 seconds to wait for response or 1ms + timeout_ms = 2000 if wait_for_response else 100 # 2 seconds to wait for response or 100ms timeout = (2, # connection timeout. Always set to 2 seconds timeout_ms / 1000.0 # Read timeout. Tweaked based on input. @@ -86,8 +86,9 @@ def _send(self, metric, wait_for_response=False): try: r = requests.post(self._url, json=payload, timeout=timeout) LOG.debug("Telemetry response: %d", r.status_code) - except requests.exceptions.Timeout as ex: - # Expected if request times out. Just print debug log and ignore the exception. + except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as ex: + # Expected if request times out OR cannot connect to the backend (offline). + # Just print debug log and ignore the exception. LOG.debug(str(ex)) def _add_common_metric_attributes(self, attrs): diff --git a/tests/integration/telemetry/integ_base.py b/tests/integration/telemetry/integ_base.py index c7098eb18d..4acb1b661c 100644 --- a/tests/integration/telemetry/integ_base.py +++ b/tests/integration/telemetry/integ_base.py @@ -54,14 +54,17 @@ def base_command(cls): return command - def run_cmd(self, stdin_data=""): + def run_cmd(self, stdin_data="", optout_envvar_value=None): # Any command will work for this test suite cmd_list = [self.cmd, "local", "generate-event", "s3", "put"] env = os.environ.copy() - # remove the envvar which usually is set in Travis. This interferes with tests. + # remove the envvar which usually is set in Travis. This interferes with tests env.pop("SAM_CLI_TELEMETRY", None) + if optout_envvar_value: + # But if the caller explicitly asked us to opt-out via EnvVar, then set it here + env["SAM_CLI_TELEMETRY"] = optout_envvar_value env["__SAM_CLI_APP_DIR"] = self.config_dir env["__SAM_CLI_TELEMETRY_ENDPOINT_URL"] = "{}/metrics".format(TELEMETRY_ENDPOINT_URL) diff --git a/tests/integration/telemetry/test_telemetry_contract.py b/tests/integration/telemetry/test_telemetry_contract.py new file mode 100644 index 0000000000..f42f01effd --- /dev/null +++ b/tests/integration/telemetry/test_telemetry_contract.py @@ -0,0 +1,73 @@ + +from .integ_base import IntegBase, TelemetryServer + + +class TestTelemetryContract(IntegBase): + """ + Validates the basic tenets/contract Telemetry module needs to adhere to + """ + + def test_must_not_send_metrics_if_disabled_using_envvar(self): + """ + No metrics should be sent if "Enabled via Config file but Disabled via Envvar" + """ + # Enable it via configuration file + self.set_config(telemetry_enabled=True) + + with TelemetryServer() as server: + # Start the CLI, but opt-out of Telemetry using env var + process = self.run_cmd(optout_envvar_value="0") + (_, stderrdata) = process.communicate() + retcode = process.poll() + self.assertEquals(retcode, 0, "Command should successfully complete") + all_requests = server.get_all_requests() + self.assertEquals(0, len(all_requests), "No metrics should be sent") + + # Now run again without the Env Var Opt out + process = self.run_cmd() + (_, stderrdata) = process.communicate() + retcode = process.poll() + self.assertEquals(retcode, 0, "Command should successfully complete") + all_requests = server.get_all_requests() + self.assertEquals(1, len(all_requests), "Command run metric should be sent") + + def test_must_send_metrics_if_enabled_via_envvar(self): + """ + Metrics should be sent if "Disabled via config file but Enabled via Envvar" + """ + # Disable it via configuration file + self.set_config(telemetry_enabled=False) + + with TelemetryServer() as server: + # Run without any envvar.Should not publish metrics + process = self.run_cmd() + (_, stderrdata) = process.communicate() + retcode = process.poll() + self.assertEquals(retcode, 0, "Command should successfully complete") + all_requests = server.get_all_requests() + self.assertEquals(0, len(all_requests), "No metric should be sent") + + # Opt-in via env var + process = self.run_cmd(optout_envvar_value="1") + (_, stderrdata) = process.communicate() + retcode = process.poll() + self.assertEquals(retcode, 0, "Command should successfully complete") + all_requests = server.get_all_requests() + self.assertEquals(1, len(all_requests), "Command run metric must be sent") + + def test_must_not_crash_when_offline(self): + """ + Must not crash the process if internet is not available + """ + self.set_config(telemetry_enabled=True) + + # DO NOT START Telemetry Server here. + # Try to run the command without it. + + # Start the CLI + process = self.run_cmd() + + (_, stderrdata) = process.communicate() + + retcode = process.poll() + self.assertEquals(retcode, 0, "Command should successfully complete") diff --git a/tests/unit/lib/telemetry/test_telemetry.py b/tests/unit/lib/telemetry/test_telemetry.py index 3777bd1636..46c32ed13a 100644 --- a/tests/unit/lib/telemetry/test_telemetry.py +++ b/tests/unit/lib/telemetry/test_telemetry.py @@ -100,7 +100,7 @@ def test_default_request_should_be_fire_and_forget(self, requests_mock): telemetry = Telemetry(url=self.url) telemetry.emit("metric_name", {}) - requests_mock.post.assert_called_once_with(ANY, json=ANY, timeout=(2, 0.001)) # 1ms response timeout + requests_mock.post.assert_called_once_with(ANY, json=ANY, timeout=(2, 0.1)) # 100ms response timeout @patch("samcli.lib.telemetry.telemetry.requests") def test_request_must_wait_for_2_seconds_for_response(self, requests_mock): @@ -121,15 +121,27 @@ def test_must_swallow_timeout_exception(self, requests_mock): # requests_mock.exceptions.Timeout = requests.exceptions.Timeout + requests_mock.exceptions.ConnectionError = requests.exceptions.ConnectionError requests_mock.post.side_effect = requests.exceptions.Timeout() telemetry.emit("metric_name", {}) + @patch("samcli.lib.telemetry.telemetry.requests") + def test_must_swallow_connection_error_exception(self, requests_mock): + telemetry = Telemetry(url=self.url) + + requests_mock.exceptions.Timeout = requests.exceptions.Timeout + requests_mock.exceptions.ConnectionError = requests.exceptions.ConnectionError + requests_mock.post.side_effect = requests.exceptions.ConnectionError() + + telemetry.emit("metric_name", {}) + @patch("samcli.lib.telemetry.telemetry.requests") def test_must_raise_on_other_requests_exception(self, requests_mock): telemetry = Telemetry(url=self.url) requests_mock.exceptions.Timeout = requests.exceptions.Timeout + requests_mock.exceptions.ConnectionError = requests.exceptions.ConnectionError requests_mock.post.side_effect = IOError() with self.assertRaises(IOError): From 3370cbff3401f76a5aa23ef86f2d332e90386d7e Mon Sep 17 00:00:00 2001 From: Sanath Kumar Ramesh Date: Thu, 25 Jul 2019 09:41:54 -0700 Subject: [PATCH 6/7] fix: Update telemetry prompt wording (#1294) --- samcli/cli/main.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/samcli/cli/main.py b/samcli/cli/main.py index 8bc3525cc7..b480297e04 100644 --- a/samcli/cli/main.py +++ b/samcli/cli/main.py @@ -55,10 +55,11 @@ def print_info(ctx, param, value): # Keep the message to 80chars wide to it prints well on most terminals TELEMETRY_PROMPT = """ -\tTelemetry has been enabled for SAM CLI. -\t -\tYou can OPT OUT of telemetry by setting the environment variable -\tSAM_CLI_TELEMETRY=0 in your shell. +\tSAM CLI now collects telemetry to better understand customer needs. + +\tYou can OPT OUT and disable telemetry collection by setting the +\tenvironment variable SAM_CLI_TELEMETRY=0 in your shell. +\tThanks for your help! \tLearn More: https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-telemetry.html """ # noqa From 67edf91b49c54bd494683684e2ce78b64bb203be Mon Sep 17 00:00:00 2001 From: Sanath Kumar Ramesh Date: Thu, 25 Jul 2019 16:32:56 -0700 Subject: [PATCH 7/7] feat: Set execution environment when calling AWS CLI (#1297) --- samcli/lib/samlib/cloudformation_command.py | 11 ++++- .../lib/samlib/test_cloudformation_command.py | 44 ++++++++++++++++++- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/samcli/lib/samlib/cloudformation_command.py b/samcli/lib/samlib/cloudformation_command.py index 93914827ad..c284e05e00 100644 --- a/samcli/lib/samlib/cloudformation_command.py +++ b/samcli/lib/samlib/cloudformation_command.py @@ -2,11 +2,14 @@ Utility to call cloudformation command with args """ +import os import logging import platform import subprocess import sys +from samcli.cli.global_config import GlobalConfig + LOG = logging.getLogger(__name__) @@ -15,12 +18,18 @@ def execute_command(command, args, template_file): try: aws_cmd = find_executable("aws") + # Add SAM CLI information for AWS CLI to know about the caller. + gc = GlobalConfig() + env = os.environ.copy() + if gc.telemetry_enabled: + env["AWS_EXECUTION_ENV"] = "SAM-" + gc.installation_id + args = list(args) if template_file: # Since --template-file was parsed separately, add it here manually args.extend(["--template-file", template_file]) - subprocess.check_call([aws_cmd, 'cloudformation', command] + args) + subprocess.check_call([aws_cmd, 'cloudformation', command] + args, env=env) LOG.debug("%s command successful", command) except subprocess.CalledProcessError as e: # Underlying aws command will print the exception to the user diff --git a/tests/unit/lib/samlib/test_cloudformation_command.py b/tests/unit/lib/samlib/test_cloudformation_command.py index 7aa3248529..f153530e9c 100644 --- a/tests/unit/lib/samlib/test_cloudformation_command.py +++ b/tests/unit/lib/samlib/test_cloudformation_command.py @@ -2,10 +2,11 @@ Tests Deploy CLI """ +import os from subprocess import CalledProcessError, PIPE from unittest import TestCase -from mock import patch, call +from mock import patch, call, ANY from samcli.lib.samlib.cloudformation_command import execute_command, find_executable @@ -24,7 +25,46 @@ def test_must_add_template_file(self, find_executable_mock, check_call_mock): check_call_mock.assert_called_with(["mycmd", "cloudformation", "command"] + ["--arg1", "value1", "different args", "more", - "--template-file", "/path/to/template"]) + "--template-file", "/path/to/template"], env=ANY) + + @patch("subprocess.check_call") + @patch("samcli.lib.samlib.cloudformation_command.find_executable") + @patch("samcli.lib.samlib.cloudformation_command.GlobalConfig") + def test_must_add_sam_cli_info_to_execution_env_var_if_telemetry_is_on(self, global_config_mock, + find_executable_mock, check_call_mock): + installation_id = "testtest" + global_config_mock.return_value.installation_id = installation_id + global_config_mock.return_value.telemetry_enabled = True + + expected_env = os.environ.copy() + expected_env["AWS_EXECUTION_ENV"] = "SAM-" + installation_id + + find_executable_mock.return_value = "mycmd" + check_call_mock.return_value = True + execute_command("command", self.args, "/path/to/template") + + check_call_mock.assert_called() + kwargs = check_call_mock.call_args[1] + self.assertIn("env", kwargs) + self.assertEquals(kwargs["env"], expected_env) + + @patch("subprocess.check_call") + @patch("samcli.lib.samlib.cloudformation_command.find_executable") + @patch("samcli.lib.samlib.cloudformation_command.GlobalConfig") + def test_must_not_set_exec_env(self, global_config_mock, find_executable_mock, check_call_mock): + global_config_mock.return_value.telemetry_enabled = False + + # Expected to pass just a copy of the environment variables without modification + expected_env = os.environ.copy() + + find_executable_mock.return_value = "mycmd" + check_call_mock.return_value = True + execute_command("command", self.args, "/path/to/template") + + check_call_mock.assert_called() + kwargs = check_call_mock.call_args[1] + self.assertIn("env", kwargs) + self.assertEquals(kwargs["env"], expected_env) @patch("sys.exit") @patch("subprocess.check_call")