diff --git a/.gitignore b/.gitignore index fe24b660..23b09927 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ test/*.zip *.pyc *.zip* *.log +.settings diff --git a/Dockerfile b/Dockerfile index 91eba90f..93810b0a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,13 +2,15 @@ FROM python:3.7.0-stretch MAINTAINER @iMerica -RUN apt-get update && apt-get install -y zip +RUN apt-get update && apt-get install -y \ + zip \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* RUN addgroup --system scar && adduser --system --group scar RUN git clone --branch master --depth 1 https://github.com/grycap/scar.git /usr/bin/scar && \ - pip install -r /usr/bin/scar/requirements.txt && \ - pip install pyyaml + pip install -r /usr/bin/scar/requirements.txt RUN touch /scar.log && chown scar /scar.log diff --git a/docs/source/installation.rst b/docs/source/installation.rst index aff72691..e7913eec 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -1,28 +1,23 @@ Installation ============ -1) SCAR requires python3, first make sure you have python3 available in your system +1) SCAR requires python3, pip3 and a configured AWS credentials file in your system. + More info about the AWS credentials file can be found `here `_. 2) Clone the GitHub repository:: git clone https://github.com/grycap/scar.git + cd scar -3) Install the required dependencies: +3) Install the python required dependencies automatically with the command:: - * `zip `_ (linux package) - * `Boto3 `_ (v1.4.4+ is required) - * `Tabulate `_ - * `Requests `_ + sudo pip3 install -r requirements.txt - You can automatically install the python dependencies by issuing the following command:: + The last dependency needs to be installed using the apt manager:: + + sudo apt install zip - sudo pip install -r requirements.txt - The zip package can be installed using apt:: - - sudo apt install zip - - -4) (Optional) Define an alias for increased usability:: +4) (Optional) Define an alias for easier usability:: alias scar='python3 `pwd`/scar.py' diff --git a/examples/cowsay/README.md b/examples/cowsay/README.md index 2f62bed9..0742c144 100644 --- a/examples/cowsay/README.md +++ b/examples/cowsay/README.md @@ -53,13 +53,13 @@ docker save grycap/minicow > minicow.tar.gz 2. Create the Lambda function using the 'scar-minicow.yaml' configuration file: ```sh -scar init -f scar-cowsay.yaml +scar init -f scar-minicow.yaml ``` 3. Execute the Lambda function ```sh -scar run -f scar-cowsay.yaml +scar run -f scar-minicow.yaml ``` From the user perspective nothing changed in comparison with the previous execution, but the main difference with the 'standard' lambda deployment is that the container is already available when the function is launched for the first time. Moreover, the function doesn't need to connect to any external repository to download the container, so this is also useful to execute small binaries or containers that you don't want to upload to a public repository. \ No newline at end of file diff --git a/examples/darknet/README.md b/examples/darknet/README.md index c7dfadf2..5e3a4c6d 100644 --- a/examples/darknet/README.md +++ b/examples/darknet/README.md @@ -75,15 +75,8 @@ darknet/output/68f5c9d5-5826-44gr-basc-8f8b23f44cdf/result.out The files are created in the output folder following the `s3://scar-darknet-bucket/scar-darknet-s3/output/$REQUEST_ID/*.*` structure. -To download the created files you can also use SCAR: -Download an specific file with : - -```sh -scar get -b scar-darknet-bucket -bf scar-darknet-s3/output/68f5c9d5-5826-44gr-basc-8f8b23f44cdf/image-result.png -p /tmp/result.png -``` - -Download a folder with: +To download the created files you can also use SCAR. Download a folder with: ```sh scar get -b scar-darknet-bucket -bf scar-darknet-s3/output -p /tmp/lambda/ diff --git a/requirements.txt b/requirements.txt index 8f447fa9..d9f032ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ boto3 tabulate configparser requests +pyyaml diff --git a/scar.py b/scar.py index ff7b91f4..30c0d1d6 100755 --- a/scar.py +++ b/scar.py @@ -18,8 +18,11 @@ from src.providers.aws.controller import AWS from src.parser.cli import CommandParser from src.parser.yaml import YamlParser +from src.parser.cfgfile import ConfigFileParser from src.cmdtemplate import Commands -import src.logger as logger +import src.logger as logger +import src.exceptions as excp +import src.utils as utils class Scar(Commands): @@ -52,24 +55,30 @@ def put(self): def get(self): self.cloud_provider.get() + + @excp.exception(logger) + def parse_arguments(self): + ''' + Merge the scar.conf parameters, the cmd parameters and the yaml file parameters in a single dictionary. - def parse_command_arguments(self): - args = CommandParser(self).parse_arguments() - if hasattr(args, 'func'): - if hasattr(args, 'conf_file') and args.conf_file: - # Update the arguments with the values extracted from the configuration file - args.__dict__.update(YamlParser(args).parse_arguments()) - self.cloud_provider.parse_command_arguments(args) - args.func() - else: - logger.error("Incorrect arguments: use scar -h to see the options available") + The precedence of parameters is CMD >> YAML >> SCAR.CONF + That is, the CMD parameter will override any other configuration, + and the YAML parameters will override the SCAR.CONF settings + ''' + merged_args = ConfigFileParser().get_properties() + cmd_args = CommandParser(self).parse_arguments() + if 'conf_file' in cmd_args['scar'] and cmd_args['scar']['conf_file']: + yaml_args = YamlParser(cmd_args['scar']).parse_arguments() + merged_args = utils.merge_dicts(yaml_args, merged_args) + merged_args = utils.merge_dicts(cmd_args, merged_args) + self.cloud_provider.parse_arguments(**merged_args) + merged_args['scar']['func']() if __name__ == "__main__": logger.init_execution_trace() try: - Scar().parse_command_arguments() + Scar().parse_arguments() logger.end_execution_trace() - except Exception as ex: - logger.exception(ex) + except: logger.end_execution_trace_with_errors() diff --git a/src/cmdtemplate.py b/src/cmdtemplate.py index 486bd9ec..6e8c1dcd 100644 --- a/src/cmdtemplate.py +++ b/src/cmdtemplate.py @@ -15,6 +15,18 @@ # along with this program. If not, see . import abc +from enum import Enum + +class CallType(Enum): + INIT = "init" + INVOKE = "invoke" + RUN = "run" + UPDATE = "update" + LS = "ls" + RM = "rm" + LOG = "log" + PUT = "put" + GET = "get" class Commands(metaclass=abc.ABCMeta): ''' All the different cloud provider controllers must inherit @@ -57,5 +69,5 @@ def get(self): pass @abc.abstractmethod - def parse_command_arguments(self, args): + def parse_arguments(self, args): pass diff --git a/src/exceptions.py b/src/exceptions.py index f544ab0b..c1158398 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -13,6 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +from botocore.exceptions import ClientError +import sys + +def exception(logger): + ''' + A decorator that wraps the passed in function and logs exceptions + @param logger: The logging object + ''' + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except ClientError as ce: + print("There was an exception in {0}".format(func.__name__)) + print(ce.response['Error']['Message']) + logger.exception(ce) + sys.exit(1) + except ScarError as se: + print(se.args[0]) + logger.exception(se) + # Finish the execution if it's an error + if 'Error' in se.__class__.__name__: + sys.exit(1) + except Exception as ex: + print("There was an unmanaged exception in {0}".format(func.__name__)) + logger.exception(ex) + sys.exit(1) + return wrapper + return decorator + class ScarError(Exception): """ The base exception class for ScarError exceptions. @@ -26,11 +58,77 @@ def __init__(self, **kwargs): Exception.__init__(self, msg) self.kwargs = kwargs +################################################ +## GENERAL EXCEPTIONS ## +################################################ +class MissingCommandError(ScarError): + """ + SCAR was launched without a command + + """ + fmt = "Please use one of the scar available commands (init,invoke,run,update,rm,ls,log,put,get)" + +class ScarConfigFileError(ScarError): + """ + The SCAR configuration file does not exist and it has been created + + :ivar file_path: Path of the file + """ + fmt = "Config file '{file_path}' created.\n" + fmt += "Please, set a valid iam role in the file field 'role' before the first execution." + +class YamlFileNotFoundError(ScarError): + """ + The yaml configuration file does not exist + + :ivar file_path: Path of the file + """ + fmt = "Unable to find the yaml file '{file_path}'" + +class ValidatorError(ScarError): + """ + An error occurred when validating a parameter + + :ivar parameter: Name of the parameter evaluated + :ivar parameter_value: Current value of the validated parameter + :ivar error_msg: General error message + """ + fmt = "Error validating '{parameter}'.\nValue '{parameter_value}' incorrect.\n{error_msg}" + +class ScarFunctionNotFoundError(ScarError): + """ + The called function was not found + + :ivar func_name: Name of the function called + """ + fmt = "Unable to find the function '{func_name}'" + +class FunctionCodeSizeError(ScarError): + """ + Function code size exceeds AWS limits + + :ivar code_size: Name of the parameter evaluated + """ + fmt = "Payload size greater than {code_size}.\nPlease reduce the payload size or use an S3 bucket and try again." + +class S3CodeSizeError(ScarError): + """ + Function code uploaded to S3 exceeds AWS limits + + :ivar code_size: Name of the parameter evaluated + """ + + fmt = "Uncompressed image size greater than {code_size}.\nPlease reduce the uncompressed image and try again." + +################################################ +## LAMBDA EXCEPTIONS ## +################################################ class FunctionCreationError(ScarError): """ An error occurred when creating the lambda function. - :ivar name: Name of the function + :ivar function_name: Name of the function + :ivar error_msg: General error message """ fmt = "Unable to create the function '{function_name}' : {error_msg}" @@ -38,8 +136,94 @@ class FunctionNotFoundError(ScarError): """ The requested function does not exist. - :ivar name: Name of the function + :ivar function_name: Name of the function """ - fmt = "Unable to find the function '{function_name}' : {error_msg}" + fmt = "Unable to find the function '{function_name}'" +class FunctionExistsError(ScarError): + """ + The requested function exists. + + :ivar function_name: Name of the function + """ + fmt = "Function '{function_name}' already exists" + +################################################ +## S3 EXCEPTIONS ## +################################################ +class BucketNotFoundError(ScarError): + """ + The requested bucket does not exist. + + :ivar bucket_name: Name of the bucket + """ + fmt = "Unable to find the bucket '{bucket_name}'." + +class ExistentBucketWarning(ScarError): + """ + The bucket already exists + + :ivar bucket_name: Name of the bucket + """ + fmt = "Using existent bucket '{bucket_name}'." + +################################################ +## CLOUDWATCH LOGS EXCEPTIONS ## +################################################ +class ExistentLogGroupWarning(ScarError): + """ + The requested log group already exists + + :ivar log_group_name: Name of the log group + """ + fmt = "Using existent log group '{logGroupName}'." + +class NotExistentLogGroupWarning(ScarError): + """ + The requested log group does not exists + + :ivar log_group_name: Name of the log group + """ + fmt = "The requested log group '{logGroupName}' does not exist." + +################################################ +## API GATEWAY EXCEPTIONS ## +################################################ +class ApiEndpointNotFoundError(ScarError): + """ + The requested function does not have an associated API. + + :ivar function_name: Name of the function + """ + fmt = "Error retrieving API ID for lambda function '{function_name}'\n" + fmt += "Looks like he requested function does not have an associated API." + +class ApiCreationError(ScarError): + """ + Error creating the API endpoint. + + :ivar api_name: Name of the api + """ + fmt = "Error creating the API '{api_name}'" + +class InvocationPayloadError(ScarError): + """ + Error invocating the API endpoint. + + :ivar file_size: Size of the passed file + :ivar max_size: Max size allowd of the file + """ + fmt = "Invalid request: Payload size {file_size} greater than {max_size}\n" + fmt += "Check AWS Lambda invocation limits in : https://docs.aws.amazon.com/lambda/latest/dg/limits.html" + +################################################ +## IAM EXCEPTIONS ## +################################################ +class GetUserInfoError(ScarError): + """ + There was an error gettting the IAM user info + + :ivar error_msg: General error message + """ + fmt = "Error getting the AWS user information.\n{error_msg}." \ No newline at end of file diff --git a/src/http/__init__.py b/src/http/__init__.py index d2e6d896..528ca22c 100644 --- a/src/http/__init__.py +++ b/src/http/__init__.py @@ -15,4 +15,4 @@ # along with this program. If not, see . -__all__ = ['invoke'] \ No newline at end of file +__all__ = ['request'] \ No newline at end of file diff --git a/src/http/invoke.py b/src/http/request.py similarity index 51% rename from src/http/invoke.py rename to src/http/request.py index 40a3db44..ffdbde73 100644 --- a/src/http/invoke.py +++ b/src/http/request.py @@ -15,9 +15,17 @@ import requests -def invoke_function(url, parameters=None, data=None, headers=None): - if data is None: - response = requests.get(url, headers=headers, params=parameters) +def invoke_http_endpoint(url, **kwargs): + """ + Does a 'GET' or 'PUT' request if the parameter 'data' exists or not respectively + + :param url: URL for the request. + :param data: (optional) Dictionary (will be form-encoded), bytes, or file-like object to send in the body of the request. + :param headers: (optional) Dictionary of HTTP Headers to send with the request. + :param parameters: (optional) Dictionary or bytes to be sent in the query string. + """ + if 'data' in kwargs and kwargs['data']: + response = requests.post(url, **kwargs) else: - response = requests.post(url, headers=headers, data=data, params=parameters) + response = requests.get(url, **kwargs) return response \ No newline at end of file diff --git a/src/logger.py b/src/logger.py index cca4d6bc..3c51a667 100644 --- a/src/logger.py +++ b/src/logger.py @@ -15,10 +15,16 @@ import logging import json +import os + +log_folder_name = ".scar" +log_file_folder = os.path.join(os.path.expanduser("~"), log_folder_name) +log_file_name = "scar.log" +log_file_path = os.path.join(log_file_folder, log_file_name) loglevel = logging.INFO FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' -logging.basicConfig(filename='scar.log', level=loglevel, format=FORMAT) +logging.basicConfig(filename=log_file_path, level=loglevel, format=FORMAT) def init_execution_trace(): logging.info('----------------------------------------------------') diff --git a/src/parser/cfgfile.py b/src/parser/cfgfile.py index 2535be63..928526aa 100644 --- a/src/parser/cfgfile.py +++ b/src/parser/cfgfile.py @@ -14,35 +14,77 @@ # limitations under the License. import os -import src.logger as logger import shutil import json +import src.utils as utils +import src.exceptions as excp +import src.logger as logger -config_file_folder = os.path.expanduser("~") + "/.scar" -config_file_name = "scar.cfg" -config_file_path = config_file_folder + '/' + config_file_name -default_file_path = os.path.dirname(os.path.realpath(__file__)) +default_cfg = { + "aws" : { + "boto_profile" : "default", + "region" : "us-east-1", + "iam" : {"role" : ""}, + "lambda" : { + "time" : 300, + "memory" : 512, + "description" : "Automatically generated lambda function", + "timeout_threshold" : 10 + }, + "cloudwatch" : { "log_retention_policy_in_days" : 30 } + } +} -class ConfigFile(object): +class ConfigFileParser(object): + + config_file_name = "scar.cfg" + backup_config_file_name = "scar.cfg_old" + config_folder_name = ".scar" + config_file_folder = utils.join_paths(os.path.expanduser("~"), config_folder_name) + config_file_path = utils.join_paths(config_file_folder, config_file_name) + backup_file_path = utils.join_paths(config_file_folder, backup_config_file_name) + + @excp.exception(logger) def __init__(self): # Check if the config file exists - if os.path.isfile(config_file_path): - with open(config_file_path) as cfg_file: - self.__setattr__("cfg_data", json.load(cfg_file)) + if os.path.isfile(self.config_file_path): + with open(self.config_file_path) as cfg_file: + self.__setattr__("cfg_data", json.load(cfg_file)) + if 'region' not in self.cfg_data['aws'] or 'boto_profile' not in self.cfg_data['aws']: + self.add_missing_attributes() else: # Create scar config dir - os.makedirs(config_file_folder, exist_ok=True) + os.makedirs(self.config_file_folder, exist_ok=True) self.create_default_config_file() + raise excp.ScarConfigFileError(file_path=self.config_file_path) def create_default_config_file(self): - shutil.copy(default_file_path + "/default_config_file.json", config_file_path) - message = "Config file '%s' created.\n" % config_file_path - message += "Please, set a valid iam role in the file field 'role' before the first execution." - logger.warning(message) - - def get_aws_props(self): - return self.cfg_data['aws'] + with open(self.config_file_path, mode='w') as cfg_file: + cfg_file.write(json.dumps(default_cfg, indent=2)) + def get_properties(self): + return self.cfg_data - \ No newline at end of file + def add_missing_attributes(self): + logger.info("Updating old scar config file '{0}'.\n".format(self.config_file_path)) + shutil.copy(self.config_file_path, self.backup_file_path) + logger.info("Old scar config file saved in '{0}'.\n".format(self.backup_file_path)) + self.merge_files(self.cfg_data, default_cfg) + self.delete_unused_data() + with open(self.config_file_path, mode='w') as cfg_file: + cfg_file.write(json.dumps(self.cfg_data, indent=2)) + + def merge_files(self, cfg_data, default_data): + for k, v in default_data.items(): + if k not in cfg_data: + cfg_data[k] = v + elif type(cfg_data[k]) is dict: + self.merge_files(cfg_data[k], default_data[k]) + + def delete_unused_data(self): + if 'region' in self.cfg_data['aws']['lambda']: + region = self.cfg_data['aws']['lambda'].pop('region', None) + if region: + self.cfg_data['aws']['region'] = region + diff --git a/src/parser/cli.py b/src/parser/cli.py index 15abbd90..3cb42834 100644 --- a/src/parser/cli.py +++ b/src/parser/cli.py @@ -16,6 +16,8 @@ import argparse import src.logger as logger +import src.utils as utils +import src.exceptions as excp class CommandParser(object): @@ -45,30 +47,36 @@ def create_init_parser(self): parser_init = self.subparsers.add_parser('init', help="Create lambda function") # Set default function parser_init.set_defaults(func=self.scar.init) + # Lambda conf group = parser_init.add_mutually_exclusive_group(required=True) - group.add_argument("-i", "--image_id", help="Container image id (i.e. centos:7)") + group.add_argument("-i", "--image", help="Container image id (i.e. centos:7)") group.add_argument("-if", "--image_file", help="Container image file created with 'docker save' (i.e. centos.tar.gz)") group.add_argument("-f", "--conf_file", help="Yaml file with the function configuration") parser_init.add_argument("-d", "--description", help="Lambda function description.") - parser_init.add_argument("-db", "--deployment_bucket", help="Bucket where the deployment package is going to be uploaded.") - parser_init.add_argument("-ib", "--input_bucket", help="Bucket name where the input files will be stored.") - parser_init.add_argument("-inf", "--input_folder", help="Folder name where the input files will be stored (Only works when an input bucket is defined).") - parser_init.add_argument("-ob", "--output_bucket", help="Bucket name where the output files are saved.") - parser_init.add_argument("-outf", "--output_folder", help="Folder name where the output files are saved (Only works when an input bucket is defined).") - # parser_init.add_argument("-out-func", "--output_function", help="Function name where the output will be redirected") parser_init.add_argument("-n", "--name", help="Lambda function name") - parser_init.add_argument("-e", "--environment_variables", action='append', help="Pass environment variable to the container (VAR=val). Can be defined multiple times.") + parser_init.add_argument("-e", "--environment", action='append', help="Pass environment variable to the container (VAR=val). Can be defined multiple times.") parser_init.add_argument("-m", "--memory", type=int, help="Lambda function memory in megabytes. Range from 128 to 1536 in increments of 64") parser_init.add_argument("-t", "--time", type=int, help="Lambda function maximum execution time in seconds. Max 300.") parser_init.add_argument("-tt", "--timeout_threshold", type=int, help="Extra time used to postprocess the data. This time is extracted from the total time of the lambda function.") - parser_init.add_argument("-j", "--json", help="Return data in JSON format", action="store_true") - parser_init.add_argument("-v", "--verbose", help="Show the complete aws output in json format", action="store_true") parser_init.add_argument("-s", "--init_script", help="Path to the input file passed to the function") - parser_init.add_argument("-lr", "--lambda_role", help="Lambda role used in the management of the functions") parser_init.add_argument("-p", "--preheat", help="Preheats the function running it once and downloading the necessary container", action="store_true") parser_init.add_argument("-ep", "--extra_payload", help="Folder containing files that are going to be added to the lambda function") parser_init.add_argument("-ll", "--log_level", help="Set the log level of the lambda function. Accepted values are: 'CRITICAL','ERROR','WARNING','INFO','DEBUG'", default="INFO") + # S3 conf + parser_init.add_argument("-db", "--deployment_bucket", help="Bucket where the deployment package is going to be uploaded.") + parser_init.add_argument("-ib", "--input_bucket", help="Bucket name where the input files will be stored.") + parser_init.add_argument("-inf", "--input_folder", help="Folder name where the input files will be stored (Only works when an input bucket is defined).") + parser_init.add_argument("-ob", "--output_bucket", help="Bucket name where the output files are saved.") + parser_init.add_argument("-outf", "--output_folder", help="Folder name where the output files are saved (Only works when an input bucket is defined).") + # IAM conf + parser_init.add_argument("-r", "--iam_role", help="IAM role used in the management of the functions") + # SCAR conf + parser_init.add_argument("-j", "--json", help="Return data in JSON format", action="store_true") + parser_init.add_argument("-v", "--verbose", help="Show the complete aws output in json format", action="store_true") + # API Gateway conf parser_init.add_argument("-api", "--api_gateway_name", help="API Gateway name created to launch the lambda function") + # General AWS conf + parser_init.add_argument("-pf", "--profile", help="AWS profile to use") def create_invoke_parser(self): parser_invoke = self.subparsers.add_parser('invoke', help="Call a lambda function using an HTTP request") @@ -79,7 +87,9 @@ def create_invoke_parser(self): group.add_argument("-f", "--conf_file", help="Yaml file with the function configuration") parser_invoke.add_argument("-db", "--data_binary", help="File path of the HTTP data to POST.") parser_invoke.add_argument("-a", "--asynchronous", help="Launch an asynchronous function.", action="store_true") - parser_invoke.add_argument("-p", "--parameters", help="In addition to passing the parameters in the URL, you can pass the parameters here (i.e. '{\"key1\": \"value1\", \"key2\": [\"value2\", \"value3\"]}').") + parser_invoke.add_argument("-p", "--parameters", help="In addition to passing the parameters in the URL, you can pass the parameters here (i.e. '{\"key1\": \"value1\", \"key2\": [\"value2\", \"value3\"]}').") + # General AWS conf + parser_invoke.add_argument("-pf", "--profile", help="AWS profile to use") def create_update_parser(self): parser_update = self.subparsers.add_parser('update', help="Update function properties") @@ -89,14 +99,12 @@ def create_update_parser(self): group.add_argument("-f", "--conf_file", help="Yaml file with the function configuration") parser_update.add_argument("-m", "--memory", type=int, help="Lambda function memory in megabytes. Range from 128 to 1536 in increments of 64") parser_update.add_argument("-t", "--time", type=int, help="Lambda function maximum execution time in seconds. Max 300.") - parser_update.add_argument("-e", "--environment_variables", action='append', help="Pass environment variable to the container (VAR=val). Can be defined multiple times.") + parser_update.add_argument("-e", "--environment", action='append', help="Pass environment variable to the container (VAR=val). Can be defined multiple times.") parser_update.add_argument("-tt", "--timeout_threshold", type=int, help="Extra time used to postprocess the data. This time is extracted from the total time of the lambda function.") - parser_update.add_argument("-ll", "--log_level", help="Set the log level of the lambda function. Accepted values are: 'CRITICAL','ERROR','WARNING','INFO','DEBUG'", default="INFO") - #parser_update.add_argument("-s", "--script", nargs='?', type=argparse.FileType('r'), help="Path to the input file passed to the function") - #parser_update.add_argument("-j", "--json", help="Return data in JSON format", action="store_true") - #parser_update.add_argument("-v", "--verbose", help="Show the complete aws output in json format", action="store_true") - #parser_update.add_argument("-es", "--event_source", help="Name specifying the source of the events that will launch the lambda function. Only supporting buckets right now.") - + parser_update.add_argument("-ll", "--log_level", help="Set the log level of the lambda function. Accepted values are: 'CRITICAL','ERROR','WARNING','INFO','DEBUG'", default="INFO") + # General AWS conf + parser_update.add_argument("-pf", "--profile", help="AWS profile to use") + def create_run_parser(self): parser_run = self.subparsers.add_parser('run', help="Deploy function") parser_run.set_defaults(func=self.scar.run) @@ -108,6 +116,8 @@ def create_run_parser(self): parser_run.add_argument("-j", "--json", help="Return data in JSON format", action="store_true") parser_run.add_argument("-v", "--verbose", help="Show the complete aws output in json format", action="store_true") parser_run.add_argument('c_args', nargs=argparse.REMAINDER, help="Arguments passed to the container.") + # General AWS conf + parser_run.add_argument("-pf", "--profile", help="AWS profile to use") def create_rm_parser(self): parser_rm = self.subparsers.add_parser('rm', help="Delete function") @@ -117,44 +127,115 @@ def create_rm_parser(self): group.add_argument("-a", "--all", help="Delete all lambda functions", action="store_true") group.add_argument("-f", "--conf_file", help="Yaml file with the function configuration") parser_rm.add_argument("-j", "--json", help="Return data in JSON format", action="store_true") - parser_rm.add_argument("-v", "--verbose", help="Show the complete aws output in json format", action="store_true") - - def create_ls_parser(self): - parser_ls = self.subparsers.add_parser('ls', help="List lambda functions") - parser_ls.set_defaults(func=self.scar.ls) - parser_ls.add_argument("-j", "--json", help="Return data in JSON format", action="store_true") - parser_ls.add_argument("-v", "--verbose", help="Show the complete aws output in json format", action="store_true") - parser_ls.add_argument("-b", "--bucket", help="Show bucket files") - parser_ls.add_argument("-bf", "--bucket_folder", help="Show bucket files") - + parser_rm.add_argument("-v", "--verbose", help="Show the complete aws output in json format", action="store_true") + # General AWS conf + parser_rm.add_argument("-pf", "--profile", help="AWS profile to use") + def create_log_parser(self): parser_log = self.subparsers.add_parser('log', help="Show the logs for the lambda function") parser_log.set_defaults(func=self.scar.log) group = parser_log.add_mutually_exclusive_group(required=True) group.add_argument("-n", "--name", help="Lambda function name") - group.add_argument("-f", "--conf_file", help="Yaml file with the function configuration") + group.add_argument("-f", "--conf_file", help="Yaml file with the function configuration") + # CloudWatch args parser_log.add_argument("-ls", "--log_stream_name", help="Return the output for the log stream specified.") - parser_log.add_argument("-ri", "--request_id", help="Return the output for the request id specified.") + parser_log.add_argument("-ri", "--request_id", help="Return the output for the request id specified.") + # General AWS conf + parser_log.add_argument("-pf", "--profile", help="AWS profile to use") + + def create_ls_parser(self): + parser_ls = self.subparsers.add_parser('ls', help="List lambda functions") + parser_ls.set_defaults(func=self.scar.ls) + parser_ls.add_argument("-j", "--json", help="Return data in JSON format", action="store_true") + parser_ls.add_argument("-v", "--verbose", help="Show the complete aws output in json format", action="store_true") + # S3 args + parser_ls.add_argument("-b", "--bucket", help="Show bucket files") + parser_ls.add_argument("-bf", "--bucket_folder", help="Show bucket files") + # General AWS conf + parser_ls.add_argument("-pf", "--profile", help="AWS profile to use") def create_put_parser(self): parser_put = self.subparsers.add_parser('put', help="Upload file(s) to bucket") parser_put.set_defaults(func=self.scar.put) + # S3 args parser_put.add_argument("-b", "--bucket", help="Bucket to use as storage", required=True) - parser_put.add_argument("-bf", "--bucket_folder", help="Folder used to store the file(s) in the bucket", default="") + parser_put.add_argument("-bf", "--bucket_folder", help="Folder used to store the file(s) in the bucket") + # Local info args parser_put.add_argument("-p", "--path", help="Path of the file or folder to upload", required=True) + # General AWS conf + parser_put.add_argument("-pf", "--profile", help="AWS profile to use") def create_get_parser(self): parser_get = self.subparsers.add_parser('get', help="Download file(s) from bucket") parser_get.set_defaults(func=self.scar.get) + # S3 args parser_get.add_argument("-b", "--bucket", help="Bucket to use as storage", required=True) - parser_get.add_argument("-bf", "--bucket_folder", help="Path of the file or folder to download", required=True) + parser_get.add_argument("-bf", "--bucket_folder", help="Path of the file or folder to download") + # Local info args parser_get.add_argument("-p", "--path", help="Path to store the downloaded file or folder") + # General AWS conf + parser_get.add_argument("-pf", "--profile", help="AWS profile to use") + @excp.exception(logger) def parse_arguments(self): '''Command parsing and selection''' try: - return self.parser.parse_args() + cmd_args = vars(self.parser.parse_args()) + if 'func' not in cmd_args: + raise excp.MissingCommandError() + scar_args = self.parse_scar_args(cmd_args) + aws_args = self.parse_aws_args(cmd_args) + return utils.merge_dicts(scar_args, aws_args) except AttributeError as ae: logger.error("Incorrect arguments: use scar -h to see the options available", "Error parsing arguments: %s" % ae) + else: raise + + def set_args(self, args, key, val): + if key and val: + args[key] = val + + def parse_aws_args(self, cmd_args): + aws_args = {} + other_args = [('profile','boto_profile'),'region'] + self.set_args(aws_args, 'iam', self.parse_iam_args(cmd_args)) + self.set_args(aws_args, 'lambda', self.parse_lambda_args(cmd_args)) + self.set_args(aws_args, 'cloudwatch', self.parse_cloudwatchlogs_args(cmd_args)) + self.set_args(aws_args, 's3', self.parse_s3_args(cmd_args)) + self.set_args(aws_args, 'api_gateway', self.parse_api_gateway_args(cmd_args)) + aws_args.update(utils.parse_arg_list(other_args, cmd_args)) + return {'aws' : aws_args } + + def parse_scar_args(self, cmd_args): + scar_args = ['func', 'conf_file', 'json', 'verbose', 'path', ('all', 'delete_all'), 'preheat'] + return {'scar' : utils.parse_arg_list(scar_args, cmd_args)} + + def parse_lambda_args(self, cmd_args): + lambda_args = ['name', 'asynchronous', 'init_script', 'run_script', 'c_args', 'memory', 'time', + 'timeout_threshold', 'log_level', 'image', 'image_file', 'description', + 'lambda_role', 'extra_payload', ('environment', 'environment_variables')] + return utils.parse_arg_list(lambda_args, cmd_args) + + def parse_iam_args(self, cmd_args): + iam_args = [('iam_role', 'role')] + return utils.parse_arg_list(iam_args, cmd_args) + + def parse_cloudwatchlogs_args(self, cmd_args): + cw_log_args = ['log_stream_name', 'request_id'] + return utils.parse_arg_list(cw_log_args, cmd_args) + + def parse_api_gateway_args(self, cmd_args): + api_gtw_args = [('api_gateway_name', 'name'), 'parameters', 'data_binary'] + return utils.parse_arg_list(api_gtw_args, cmd_args) + + def parse_s3_args(self, cmd_args): + s3_args = ['deployment_bucket', + 'input_bucket', + 'input_folder', + 'output_bucket', + 'output_folder', + ('bucket', 'input_bucket'), + ('bucket_folder', 'input_folder')] + return utils.parse_arg_list(s3_args, cmd_args) + \ No newline at end of file diff --git a/src/parser/default_config_file.json b/src/parser/default_config_file.json deleted file mode 100644 index 14245d22..00000000 --- a/src/parser/default_config_file.json +++ /dev/null @@ -1,10 +0,0 @@ -{ "aws" : { - "iam" : {"role" : ""}, - "lambda" : { - "region" : "us-east-1", - "time" : 300, - "memory" : 512, - "description" : "Automatically generated lambda function", - "timeout_threshold" : 10 }, - "cloudwatch" : { "log_retention_policy_in_days" : 30 }} -} \ No newline at end of file diff --git a/src/parser/yaml.py b/src/parser/yaml.py index bb47ebf8..30c869f9 100644 --- a/src/parser/yaml.py +++ b/src/parser/yaml.py @@ -15,85 +15,47 @@ import yaml import os - -class Function: - def __init__(self, name, image): - self.name = name - self.image_id = image +from src.exceptions import YamlFileNotFoundError +import src.utils as utils class YamlParser(object): - def __init__(self, args): - file_path = args.conf_file - self.func = args.func + def __init__(self, scar_args): + file_path = scar_args['conf_file'] if os.path.isfile(file_path): with open(file_path) as cfg_file: self.__setattr__("yaml_data", yaml.safe_load(cfg_file)) + else: + raise YamlFileNotFoundError(file_path=file_path) def parse_arguments(self): - functions = [] + functions = [] for function in self.yaml_data['functions']: - functions.append(self.parse_function(function, self.yaml_data['functions'][function])) + functions.append(self.parse_aws_function(function, self.yaml_data['functions'][function])) return functions[0] - def parse_function(self, function_name, function_data): - args = {'func' : self.func } + def parse_aws_function(self, function_name, function_data): + aws_args = {} # Get function name - args['name'] = function_name - # Parse function information - if 'image' in function_data: - args['image_id'] = function_data['image'] - if 'image_file' in function_data: - args['image_file'] = function_data['image_file'] - if 'time' in function_data: - args['time'] = function_data['time'] - if 'memory' in function_data: - args['memory'] = function_data['memory'] - if 'timeout_threshold' in function_data: - args['timeout_threshold'] = function_data['timeout_threshold'] - if 'lambda_role' in function_data: - args['lambda_role'] = function_data['lambda_role'] - if 'description' in function_data: - args['description'] = function_data['description'] - if 'init_script' in function_data: - args['init_script'] = function_data['init_script'] - if 'run_script' in function_data: - args['run_script'] = function_data['run_script'] - if 'extra_payload' in function_data: - args['extra_payload'] = function_data['extra_payload'] - if 'log_level' in function_data: - args['log_level'] = function_data['log_level'] - if 'environment' in function_data: - variables = [] - for k,v in function_data['environment'].items(): - variables.append(str(k) + '=' + str(v)) - args['environment_variables'] = variables - # LOG COMMANDS - if 'log_stream_name' in function_data: - args['log_stream_name'] = function_data['log_stream_name'] - if 'request_id' in function_data: - args['request_id'] = function_data['request_id'] - - if 'data_binary' in function_data: - args['data_binary'] = function_data['data_binary'] - - if 's3' in function_data: - s3_data = function_data['s3'] - if 'deployment_bucket' in s3_data: - args['deployment_bucket'] = s3_data['deployment_bucket'] - if 'input_bucket' in s3_data: - args['input_bucket'] = s3_data['input_bucket'] - if 'input_folder' in s3_data: - args['input_folder'] = s3_data['input_folder'] - if 'output_bucket' in s3_data: - args['output_bucket'] = s3_data['output_bucket'] - if 'output_folder' in s3_data: - args['output_folder'] = s3_data['output_folder'] - if 'api_gateway' in function_data: - api_data = function_data['api_gateway'] - if 'name' in api_data: - args['api_gateway_name'] = api_data['name'] - if 'parameters' in api_data: - args['parameters'] = api_data['parameters'] - return args + aws_args['lambda'] = self.parse_lambda_args(function_data) + aws_args['lambda']['name'] = function_name + if 'iam' in function_data: + aws_args['iam'] = function_data['iam'] + if 'cloudwatch' in function_data: + aws_args['cloudwatch'] = function_data['cloudwatch'] + if 's3' in function_data: + aws_args['s3'] = function_data['s3'] + if 'api_gateway' in function_data: + aws_args['api_gateway'] = function_data['api_gateway'] + other_args = [('profile','boto_profile'),'region'] + aws_args.update(utils.parse_arg_list(other_args, function_data)) + aws = {} + aws['aws'] = aws_args + return aws + + def parse_lambda_args(self, cmd_args): + lambda_args = ['asynchronous', 'init_script', 'run_script', 'c_args', 'memory', 'time', + 'timeout_threshold', 'log_level', 'image', 'image_file', 'description', + 'lambda_role', 'extra_payload', ('environment', 'environment_variables')] + return utils.parse_arg_list(lambda_args, cmd_args) \ No newline at end of file diff --git a/src/providers/aws/apigateway.py b/src/providers/aws/apigateway.py index dc96be39..7e002c06 100644 --- a/src/providers/aws/apigateway.py +++ b/src/providers/aws/apigateway.py @@ -15,32 +15,38 @@ # along with this program. If not, see . import src.logger as logger -import src.utils as utils -import src.providers.aws.response as response_parser -from src.providers.aws.clientfactory import GenericClient +from src.providers.aws.botoclientfactory import GenericClient class APIGateway(GenericClient): - def __init__(self, aws_lambda): - # Get all the log related attributes - self.function_name = aws_lambda.get_property("name") - self.api_gateway_name = aws_lambda.get_property("api_gateway_name") - self.lambda_role = aws_lambda.get_property("iam","role") - # ANY, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT - self.default_http_method = "ANY" - # NONE, AWS_IAM, CUSTOM, COGNITO_USER_POOLS - self.default_authorization_type = "NONE" - # 'HTTP'|'AWS'|'MOCK'|'HTTP_PROXY'|'AWS_PROXY' - self.default_type = "AWS_PROXY" + # ANY, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT + default_http_method = "ANY" + # NONE, AWS_IAM, CUSTOM, COGNITO_USER_POOLS + default_authorization_type = "NONE" + # 'HTTP'|'AWS'|'MOCK'|'HTTP_PROXY'|'AWS_PROXY' + default_type = "AWS_PROXY" + # Used in the lambda-proxy integration + default_request_parameters = { 'integration.request.header.X-Amz-Invocation-Type' : 'method.request.header.X-Amz-Invocation-Type' } + # {0}: api_region + generic_api_gateway_uri = 'arn:aws:apigateway:{0}:lambda:path/2015-03-31/functions/' + # {0}: lambda function region, {1}: aws account id, {1}: lambda function name + generic_lambda_uri = 'arn:aws:lambda:{0}:{1}:function:{2}/invocations' + # {0}: api_id, {1}: api_region + generic_endpoint = 'https://{0}.execute-api.{1}.amazonaws.com/scar/launch' - def get_api_lambda_uri(self): - self.aws_acc_id = utils.find_expression(self.lambda_role, '\d{12}') - api_gateway_uri = 'arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/' - lambda_uri = 'arn:aws:lambda:us-east-1:{0}:function:{1}/invocations'.format(self.aws_acc_id, self.function_name) - return api_gateway_uri + lambda_uri + def __init__(self, aws_properties): + GenericClient.__init__(self, aws_properties) + self.properties = aws_properties['api_gateway'] + self.set_api_lambda_uri() + + def set_api_lambda_uri(self): + self.properties['lambda_uri'] = self.generic_lambda_uri.format(self.aws_properties['region'], + self.aws_properties['account_id'], + self.aws_properties['lambda']['name']) + self.properties['uri'] = self.generic_api_gateway_uri.format(self.aws_properties['region']) + self.properties['lambda_uri'] def get_common_args(self, resource_info): - return {'restApiId' : self.api_id, + return {'restApiId' : self.properties['id'], 'resourceId' : resource_info['id'], 'httpMethod' : self.default_http_method} @@ -54,33 +60,30 @@ def get_method_args(self, resource_info): def get_integration_args(self, resource_info): args = {'type' : self.default_type, 'integrationHttpMethod' : 'POST', - 'uri' : self.get_api_lambda_uri(), - 'requestParameters' : - { 'integration.request.header.X-Amz-Invocation-Type' : 'method.request.header.X-Amz-Invocation-Type' } + 'uri' : self.properties['uri'], + 'requestParameters' : self.default_request_parameters } integration = self.get_common_args(resource_info) integration.update(args) return integration def create_api_gateway(self): - api_info = self.client.create_rest_api(self.api_gateway_name) - self.set_api_resources(api_info) - resource_info = self.client.create_resource(self.api_id, self.root_resource_id, "{proxy+}") + api_info = self.client.create_rest_api(self.properties['name']) + self.set_api_ids(api_info) + resource_info = self.client.create_resource(self.properties['id'], self.properties['root_resource_id'], "{proxy+}") self.client.create_method(**self.get_method_args(resource_info)) self.client.set_integration(**self.get_integration_args(resource_info)) - self.client.create_deployment(self.api_id, 'scar') - self.endpoint = 'https://{0}.execute-api.{1}.amazonaws.com/scar/launch'.format(self.api_id, 'us-east-1') + self.client.create_deployment(self.properties['id'], 'scar') + self.endpoint = self.generic_endpoint.format(self.properties['id'], self.aws_properties['region']) logger.info('API Gateway endpoint: {0}'.format(self.endpoint)) - return self.api_id, self.aws_acc_id - def delete_api_gateway(self, api_id, output_type): - response = self.client.delete_rest_api(api_id) - response_parser.parse_delete_api_response(response, api_id, output_type) + def delete_api_gateway(self): + return self.client.delete_rest_api(self.properties['id']) - def set_api_resources(self, api_info): - self.api_id = api_info['id'] + def set_api_ids(self, api_info): + self.properties['id'] = api_info['id'] resources_info = self.client.get_resources(api_info['id']) for resource in resources_info['items']: if resource['path'] == '/': - self.root_resource_id = resource['id'] + self.properties['root_resource_id'] = resource['id'] diff --git a/src/providers/aws/clientfactory.py b/src/providers/aws/botoclientfactory.py similarity index 79% rename from src/providers/aws/clientfactory.py rename to src/providers/aws/botoclientfactory.py index bde5300f..e1f9031a 100644 --- a/src/providers/aws/clientfactory.py +++ b/src/providers/aws/botoclientfactory.py @@ -19,18 +19,21 @@ from src.providers.aws.clients.apigateway import APIGatewayClient from src.providers.aws.clients.cloudwatchlogs import CloudWatchLogsClient from src.providers.aws.clients.iam import IAMClient -from src.providers.aws.clients.lambdafunction import LambdaClient from src.providers.aws.clients.resourcegroups import ResourceGroupsClient from src.providers.aws.clients.s3 import S3Client class GenericClient(object): + + def __init__(self, aws_properties): + self.aws_properties = aws_properties + + def get_client_args(self): + return {'client' : {'region_name' : self.aws_properties['region'] } , + 'session' : { 'profile_name' : self.aws_properties['boto_profile'] }} @utils.lazy_property def client(self): client_name = self.__class__.__name__ + 'Client' - if hasattr(self, 'region'): - client = globals()[client_name](self.region) - else: - client = globals()[client_name]() + client = globals()[client_name](**self.get_client_args()) return client diff --git a/src/providers/aws/clients/apigateway.py b/src/providers/aws/clients/apigateway.py index 8a62f9bd..a4445aba 100644 --- a/src/providers/aws/clients/apigateway.py +++ b/src/providers/aws/clients/apigateway.py @@ -18,56 +18,61 @@ import src.logger as logger from botocore.exceptions import ClientError import time -import src.utils as utils - -API_DESCRIPTION="API created automatically with SCAR" -MAX_NUMBER_OF_RETRIES = 5 -WAIT_BETWEEN_RETIRES = 5 +import src.exceptions as excp class APIGatewayClient(BotoClient): '''A low-level client representing Amazon API Gateway. https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html''' + # Parameter used by the parent to create the appropriate boto3 client boto_client_name = 'apigateway' + endpoint_configuration = {'types': ['REGIONAL']} + API_DESCRIPTION="API created automatically with SCAR" + MAX_NUMBER_OF_RETRIES = 5 + WAIT_BETWEEN_RETIRES = 5 - @utils.exception(logger) - def create_rest_api(self, api_name, count=MAX_NUMBER_OF_RETRIES): - ''' Creates a new RestApi resource. - https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html#APIGateway.Client.create_rest_api - ''' + @excp.exception(logger) + def create_rest_api(self, name, count=MAX_NUMBER_OF_RETRIES): + """ + Creates a new RestApi resource. + https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html#APIGateway.Client.create_rest_api + + :param str name: The name of the RestApi. + :param int count: (Optional) The maximum number of retries to create the API + """ try: - return self.client.create_rest_api(name=api_name, - description=API_DESCRIPTION, - endpointConfiguration={'types': ['REGIONAL']}) + return self.client.create_rest_api(name=name, + description=self.API_DESCRIPTION, + endpointConfiguration=self.endpoint_configuration) except ClientError as ce: - if (ce.response['Error']['Code'] == 'TooManyRequestsException'): - time.sleep(WAIT_BETWEEN_RETIRES) - return self.create_rest_api(api_name, count-1) - else: - raise + if (ce.response['Error']['Code'] == 'TooManyRequestsException') and (self.MAX_NUMBER_OF_RETRIES > 0): + time.sleep(self.WAIT_BETWEEN_RETIRES) + return self.create_rest_api(name, count-1) + except: + raise excp.ApiCreationError(api_name=name) - @utils.exception(logger) + @excp.exception(logger) def get_resources(self, api_id): ''' Lists information about a collection of Resource resources. https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html#APIGateway.Client.get_resources ''' return self.client.get_resources(restApiId=api_id) - @utils.exception(logger) + @excp.exception(logger) def create_resource(self, api_id, parent_id, path_part): ''' Creates a new RestApi resource. https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html#APIGateway.Client.create_rest_api ''' return self.client.create_resource(restApiId=api_id, parentId=parent_id, pathPart=path_part) - @utils.exception(logger) + @excp.exception(logger) def create_method(self, **kwargs): ''' Add a method to an existing Resource resource. https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html#APIGateway.Client.put_method ''' return self.client.put_method(**kwargs) - @utils.exception(logger) + @excp.exception(logger) def set_integration(self, **kwargs): ''' Sets up a method's integration. https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html#APIGateway.Client.put_integration @@ -75,14 +80,14 @@ def set_integration(self, **kwargs): ''' return self.client.put_integration(**kwargs) - @utils.exception(logger) + @excp.exception(logger) def create_deployment(self, api_id, stage_name): ''' Creates a Deployment resource, which makes a specified RestApi callable over the internet. https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html#APIGateway.Client.create_deployment ''' return self.client.create_deployment(restApiId=api_id, stageName=stage_name) - @utils.exception(logger) + @excp.exception(logger) def delete_rest_api(self, api_id, count=MAX_NUMBER_OF_RETRIES): ''' Deletes the specified API. https://boto3.readthedocs.io/en/latest/reference/services/apigateway.html#APIGateway.Client.delete_rest_api @@ -90,8 +95,8 @@ def delete_rest_api(self, api_id, count=MAX_NUMBER_OF_RETRIES): try: return self.client.delete_rest_api(restApiId=api_id) except ClientError as ce: - if (ce.response['Error']['Code'] == 'TooManyRequestsException'): - time.sleep(WAIT_BETWEEN_RETIRES) + if (ce.response['Error']['Code'] == 'TooManyRequestsException') and (self.MAX_NUMBER_OF_RETRIES > 0): + time.sleep(self.WAIT_BETWEEN_RETIRES) return self.delete_rest_api(api_id, count-1) else: raise diff --git a/src/providers/aws/clients/boto.py b/src/providers/aws/clients/boto.py index f1beeb9f..f8ab82c9 100644 --- a/src/providers/aws/clients/boto.py +++ b/src/providers/aws/clients/boto.py @@ -20,22 +20,20 @@ # Default values botocore_client_read_timeout = 360 -default_aws_region = "us-east-1" class BotoClient(object): - def __init__(self, region=None): - self.region = region + def __init__(self, **kwargs): + self.session_args = kwargs['session'] + self.client_args = kwargs['client'] @utils.lazy_property def client(self): - if self.region is None: - self.region = default_aws_region - boto_config = botocore.config.Config(read_timeout=botocore_client_read_timeout) - client = boto3.client(self.boto_client_name, region_name=self.region, config=boto_config) - return client + session = boto3.Session(**self.session_args) + self.client_args['config'] = botocore.config.Config(read_timeout=botocore_client_read_timeout) + return session.client(self.boto_client_name, **self.client_args) def get_access_key(self): - session = boto3.Session() + session = boto3.Session(**self.session_args) credentials = session.get_credentials() return credentials.access_key diff --git a/src/providers/aws/clients/cloudwatchlogs.py b/src/providers/aws/clients/cloudwatchlogs.py index 50198df3..dcc4d0ff 100644 --- a/src/providers/aws/clients/cloudwatchlogs.py +++ b/src/providers/aws/clients/cloudwatchlogs.py @@ -17,24 +17,22 @@ from src.providers.aws.clients.boto import BotoClient import src.logger as logger from botocore.exceptions import ClientError -import src.utils as utils +import src.exceptions as excp class CloudWatchLogsClient(BotoClient): '''A low-level client representing Amazon CloudWatch Logs. https://boto3.readthedocs.io/en/latest/reference/services/logs.html''' + # Parameter used by the parent to create the appropriate boto3 client boto_client_name = 'logs' - @utils.exception(logger) - def get_log_events(self, log_group_name, log_stream_name=None): + @excp.exception(logger) + def get_log_events(self, **kwargs): ''' Lists log events from the specified log group. https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.filter_log_events ''' logs = [] - kwargs = {"logGroupName" : log_group_name} - if log_stream_name: - kwargs["logStreamNames"] = [log_stream_name] response = self.client.filter_log_events(**kwargs) logs.append(response) while ('nextToken' in response) and (response['nextToken']): @@ -43,39 +41,38 @@ def get_log_events(self, log_group_name, log_stream_name=None): logs.append(response) return logs - @utils.exception(logger) - def create_log_group(self, log_group_name, tags): + @excp.exception(logger) + def create_log_group(self, **kwargs): ''' Creates a log group with the specified name. https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.create_log_group ''' try: - return self.client.create_log_group(logGroupName=log_group_name, tags=tags) + return self.client.create_log_group(**kwargs) except ClientError as ce: if ce.response['Error']['Code'] == 'ResourceAlreadyExistsException': - logger.warning("Using existent log group '{0}'".format(log_group_name)) - pass + raise excp.ExistentLogGroupWarning(logGroupName=kwargs['logGroupName']) else: raise - @utils.exception(logger) - def set_log_retention_policy(self, log_group_name, log_retention_policy_in_days): + @excp.exception(logger) + def set_log_retention_policy(self, **kwargs): ''' Sets the retention of the specified log group. https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.put_retention_policy ''' - return self.client.put_retention_policy(logGroupName=log_group_name, retentionInDays=log_retention_policy_in_days) + return self.client.put_retention_policy(**kwargs) - @utils.exception(logger) - def delete_log_group(self, log_group_name): + @excp.exception(logger) + def delete_log_group(self, **kwargs): ''' Deletes the specified log group and permanently deletes all the archived log events associated with the log group. https://boto3.readthedocs.io/en/latest/reference/services/logs.html#CloudWatchLogs.Client.delete_log_group ''' try: - return self.client.delete_log_group(logGroupName=log_group_name) + return self.client.delete_log_group(**kwargs) except ClientError as ce: if ce.response['Error']['Code'] == 'ResourceNotFoundException': - logger.warning("Cannot delete log group '%s'. Group not found." % log_group_name) + raise excp.NotExistentLogGroupWarning(**kwargs) else: raise diff --git a/src/providers/aws/clients/iam.py b/src/providers/aws/clients/iam.py index f347d834..94ae12ec 100644 --- a/src/providers/aws/clients/iam.py +++ b/src/providers/aws/clients/iam.py @@ -18,14 +18,16 @@ from botocore.exceptions import ClientError import src.logger as logger import src.utils as utils +import src.exceptions as excp class IAMClient(BotoClient): '''A low-level client representing aws Identity and Access Management (IAMClient). https://boto3.readthedocs.io/en/latest/reference/services/iam.html''' + # Parameter used by the parent to create the appropriate boto3 client boto_client_name = 'iam' - @utils.exception(logger) + @excp.exception(logger) def get_user_info(self): ''' Retrieves information about the specified IAM user, including the user's creation date, path, unique ID, and ARN. @@ -39,4 +41,8 @@ def get_user_info(self): # we can find the user name in the error response user_name = utils.find_expression(str(ce), '(?<=user\/)(\S+)') return {'UserName' : user_name, - 'User' : {'UserName' : user_name, 'UserId' : ''}} + 'User' : {'UserName' : user_name, 'UserId' : ''}} + else: + raise + except Exception as ex: + raise excp.GetUserInfoError(error_msg=ex) diff --git a/src/providers/aws/clients/lambdafunction.py b/src/providers/aws/clients/lambdafunction.py index 3eec59b6..4d29e068 100644 --- a/src/providers/aws/clients/lambdafunction.py +++ b/src/providers/aws/clients/lambdafunction.py @@ -15,14 +15,15 @@ # along with this program. If not, see . from src.providers.aws.clients.boto import BotoClient -from botocore.exceptions import ClientError import src.logger as logger import src.utils as utils +import src.exceptions as excp class LambdaClient(BotoClient): '''A low-level client representing aws LambdaClient. https://boto3.readthedocs.io/en/latest/reference/services/lambda.htmll''' + # Parameter used by the parent to create the appropriate boto3 client boto_client_name = 'lambda' def create_function(self, **kwargs): @@ -38,19 +39,9 @@ def get_function_info(self, function_name_or_arn): Returns the configuration information of the Lambda function. http://boto3.readthedocs.io/en/latest/reference/services/lambda.html#Lambda.Client.get_function_configuration ''' - try: - return self.client.get_function_configuration(FunctionName=function_name_or_arn) - except ClientError as ce: - if ce.response['Error']['Code'] == 'ResourceNotFoundException': - raise ce - else: - error_msg = "Error getting function data" - logger.error(error_msg, error_msg + ": %s" % ce) + return self.client.get_function_configuration(FunctionName=function_name_or_arn) - def get_function_environment_variables(self, function_name): - return self.get_function_info(function_name)['Environment'] - - @utils.exception(logger) + @excp.exception(logger) def update_function(self, **kwargs): ''' Updates the configuration parameters for the specified Lambda function by using the values provided in the request. @@ -59,7 +50,7 @@ def update_function(self, **kwargs): # Retrieve the global variables already defined return self.client.update_function_configuration(**kwargs) - @utils.exception(logger) + @excp.exception(logger) def list_functions(self): ''' Returns a list of your Lambda functions. @@ -75,7 +66,7 @@ def list_functions(self): functions.extend(result['Functions']) return functions - @utils.exception(logger) + @excp.exception(logger) def delete_function(self, function_name): ''' Deletes the specified Lambda function code and configuration. @@ -84,7 +75,7 @@ def delete_function(self, function_name): # Delete the lambda function return self.client.delete_function(FunctionName=function_name) - @utils.exception(logger) + @excp.exception(logger) def invoke_function(self, **kwargs): ''' Invokes a specific Lambda function. @@ -93,7 +84,7 @@ def invoke_function(self, **kwargs): response = self.client.invoke(**kwargs) return response - @utils.exception(logger) + @excp.exception(logger) def add_invocation_permission(self, **kwargs): ''' Adds a permission to the resource policy associated with the specified AWS Lambda function. @@ -103,4 +94,4 @@ def add_invocation_permission(self, **kwargs): kwargs['Action'] = "lambda:InvokeFunction" return self.client.add_permission(**kwargs) - \ No newline at end of file + diff --git a/src/providers/aws/clients/resourcegroups.py b/src/providers/aws/clients/resourcegroups.py index 2ff9f8f5..58ff3378 100644 --- a/src/providers/aws/clients/resourcegroups.py +++ b/src/providers/aws/clients/resourcegroups.py @@ -16,15 +16,16 @@ from src.providers.aws.clients.boto import BotoClient import src.logger as logger -import src.utils as utils +import src.exceptions as ex class ResourceGroupsClient(BotoClient): '''A low-level client representing aws Resource Groups Tagging API. https://boto3.readthedocs.io/en/latest/reference/services/resourcegroupstaggingapi.html''' + # Parameter used by the parent to create the appropriate boto3 client boto_client_name = 'resourcegroupstaggingapi' - @utils.exception(logger) + @ex.exception(logger) def get_tagged_resources(self, tag_filters, resource_type_filters): '''Returns all the tagged resources that are associated with the specified tags (keys and values) located in the specified region for the AWS account. diff --git a/src/providers/aws/clients/s3.py b/src/providers/aws/clients/s3.py index 54d8b282..0f44f1ad 100644 --- a/src/providers/aws/clients/s3.py +++ b/src/providers/aws/clients/s3.py @@ -17,21 +17,22 @@ from src.providers.aws.clients.boto import BotoClient from botocore.exceptions import ClientError import src.logger as logger -import src.utils as utils +import src.exceptions as excp class S3Client(BotoClient): '''A low-level client representing Amazon Simple Storage Service (S3Client). https://boto3.readthedocs.io/en/latest/reference/services/s3.html''' + # Parameter used by the parent to create the appropriate boto3 client boto_client_name = 's3' - @utils.exception(logger) + @excp.exception(logger) def create_bucket(self, bucket_name): '''Creates a new S3 bucket. https://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.ServiceResource.create_bucket''' self.client.create_bucket(ACL='private', Bucket=bucket_name) - @utils.exception(logger) + @excp.exception(logger) def find_bucket(self, bucket_name): '''Checks bucket existence. https://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.get_bucket_location''' @@ -46,46 +47,42 @@ def find_bucket(self, bucket_name): else: raise - @utils.exception(logger) + @excp.exception(logger) def put_bucket_notification_configuration(self, bucket_name, notification): '''Enables notifications of specified events for a bucket. https://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.put_bucket_notification_configuration''' return self.client.put_bucket_notification_configuration(Bucket=bucket_name, NotificationConfiguration=notification) - @utils.exception(logger) + @excp.exception(logger) def get_bucket_notification_configuration(self, bucket_name): '''Returns the notification configuration of a bucket. https://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.get_bucket_notification_configuration''' return self.client.get_bucket_notification_configuration(Bucket=bucket_name) - @utils.exception(logger) - def upload_file(self, bucket_name, file_key, file_data=None): + @excp.exception(logger) + def upload_file(self, **kwargs): '''Adds an object to a bucket. https://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.put_object''' - kwargs = {'Bucket' : bucket_name, 'Key' : file_key} - if file_data: - kwargs['Body'] = file_data return self.client.put_object(**kwargs) - @utils.exception(logger) - def download_file(self, bucket_name, file_key, file): + @excp.exception(logger) + def download_file(self, **kwargs): '''Download an object from S3 to a file-like object. https://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.download_fileobj''' - return self.client.download_fileobj(Bucket=bucket_name, Key=file_key, Fileobj=file) + return self.client.download_fileobj(**kwargs) - @utils.exception(logger) - def list_files(self, bucket_name, key=''): + @excp.exception(logger) + def list_files(self, **kwargs): '''Returns all of the objects in a bucket. https://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.list_objects_v2''' file_list = [] - kwargs = {"Bucket" : bucket_name, "Prefix" : key} response = self.client.list_objects_v2(**kwargs) file_list.append(response) while ('IsTruncated' in response) and (response['IsTruncated']): kwargs['ContinuationToken'] = response['NextContinuationToken'] response = self.client.list_objects_v2(**kwargs) file_list.append(response) - return response + return file_list \ No newline at end of file diff --git a/src/providers/aws/cloud/lambda/__init__.py b/src/providers/aws/cloud/lambda/__init__.py index 977f3700..13f60e3f 100644 --- a/src/providers/aws/cloud/lambda/__init__.py +++ b/src/providers/aws/cloud/lambda/__init__.py @@ -15,4 +15,4 @@ # along with this program. If not, see . -__all__ = ['utils','logger'] \ No newline at end of file +__all__ = ['utils','logger','exceptions'] \ No newline at end of file diff --git a/src/providers/aws/cloud/lambda/scarsupervisor.py b/src/providers/aws/cloud/lambda/scarsupervisor.py index 10e8d337..30a4d51e 100644 --- a/src/providers/aws/cloud/lambda/scarsupervisor.py +++ b/src/providers/aws/cloud/lambda/scarsupervisor.py @@ -32,8 +32,8 @@ import src.utils as utils logger = logging.getLogger() -if os.environ['LOG_LEVEL']: - logger.setLevel(os.environ['LOG_LEVEL']) +if utils.is_variable_in_environment('LOG_LEVEL'): + logger.setLevel(utils.get_environment_variable('LOG_LEVEL')) else: logger.setLevel('INFO') logger.info('SCAR: Loading lambda function') @@ -50,7 +50,7 @@ def client(self): return client def __init__(self): - if utils.check_key_in_dictionary('Records', lambda_instance.event): + if utils.is_value_in_dict(lambda_instance.event, 'Records'): self.record = self.get_s3_record() self.input_bucket = self.record['bucket']['name'] self.file_key = unquote_plus(self.record['object']['key']) @@ -62,7 +62,7 @@ def get_s3_record(self): logger.warning("Multiple records detected. Only processing the first one.") record = lambda_instance.event['Records'][0] - if utils.check_key_in_dictionary('s3', record): + if utils.is_value_in_dict(record, 's3'): return record['s3'] def download_input(self): @@ -238,10 +238,10 @@ def create_command(self): self.add_container_volumes() self.add_container_environment_variables() # Container running script - if utils.check_key_in_dictionary('script', lambda_instance.event): + if utils.is_value_in_dict(lambda_instance.event, 'script'): self.add_script_as_entrypoint() # Container with args - elif utils.check_key_in_dictionary('cmd_args', lambda_instance.event): + elif utils.is_value_in_dict(lambda_instance.event,'cmd_args'): self.add_args() # Script to be executed every time (if defined) elif utils.is_variable_in_environment('INIT_SCRIPT_PATH'): @@ -384,7 +384,7 @@ def __init__(self): self.create_event_file() def is_s3_event(self): - if utils.check_key_in_dictionary('Records', lambda_instance.event): + if utils.is_value_in_dict(lambda_instance.event, 'Records'): # Check if the event is an S3 event return lambda_instance.event['Records'][0]['eventSource'] == "aws:s3" return False diff --git a/src/providers/aws/cloudwatchlogs.py b/src/providers/aws/cloudwatchlogs.py index 75cf3d16..2e64f9c2 100644 --- a/src/providers/aws/cloudwatchlogs.py +++ b/src/providers/aws/cloudwatchlogs.py @@ -15,79 +15,76 @@ # along with this program. If not, see . from botocore.exceptions import ClientError -import src.providers.aws.response as response_parser -from src.providers.aws.clientfactory import GenericClient +from src.providers.aws.botoclientfactory import GenericClient class CloudWatchLogs(GenericClient): - def __init__(self, aws_lambda): - # Get all the log related attributes - self.log_group_name = aws_lambda.get_property("log_group_name") - self.tags = aws_lambda.get_property("tags") - self.output_type = aws_lambda.get_property("output") - self.log_retention_policy_in_days = aws_lambda.get_property("cloudwatch", "log_retention_policy_in_days") - self.log_stream_name = aws_lambda.get_property("log_stream_name") - self.request_id = aws_lambda.get_property("request_id") + def __init__(self, aws_properties): + GenericClient.__init__(self, aws_properties) + self.properties = self.aws_properties['cloudwatch'] + + def get_log_group_name(self): + return '/aws/lambda/{0}'.format(self.aws_properties['lambda']['name']) + + def get_log_group_name_arg(self): + return { 'logGroupName' : self.get_log_group_name() } def create_log_group(self): - # lambda_validator.validate_log_creation_values(self.aws_lambda) - response = self.client.create_log_group(self.log_group_name, self.tags) - response_parser.parse_log_group_creation_response(response, - self.log_group_name, - self.output_type) + creation_args = self.get_log_group_name_arg() + creation_args['tags'] = self.aws_properties['tags'] + response = self.client.create_log_group(**creation_args) # Set retention policy into the log group - self.client.set_log_retention_policy(self.log_group_name, - self.log_retention_policy_in_days) - - def set_log_group_name(self, function_name=None): - self.log_group_name = '/aws/lambda/' + function_name - - def delete_log_group(self, func_name=None): - if func_name: - self.set_log_group_name(func_name) - cw_response = self.client.delete_log_group(self.log_group_name) - response_parser.parse_delete_log_response(cw_response, self.log_group_name, self.output_type) + retention_args = self.get_log_group_name_arg() + retention_args['retentionInDays'] = self.properties['log_retention_policy_in_days'] + self.client.set_log_retention_policy(**retention_args) + return response + def delete_log_group(self): + return self.client.delete_log_group(**self.get_log_group_name_arg()) + def get_aws_log(self): - function_log = "" + function_logs = "" try: - full_msg = "" - result = self.client.get_log_events(self.log_group_name, self.log_stream_name) - data = [] - for response in result: - for event in response['events']: - data.append((event['message'], event['timestamp'])) - sorted_data = sorted(data, key=lambda time: time[1]) - for sdata in sorted_data: - full_msg += sdata[0] - response['completeMessage'] = full_msg - if self.request_id: - function_log = self.parse_aws_logs(full_msg) - else: - function_log = full_msg + kwargs = self.get_log_group_name_arg() + if 'log_stream_name' in self.properties: + kwargs["logStreamNames"] = [self.properties['log_stream_name']] + response = self.client.get_log_events(**kwargs) + function_logs = self.sort_events_in_message(response) + if 'request_id' in self.properties and self.properties['request_id']: + function_logs = self.parse_logs_with_requestid(function_logs) except ClientError as ce: print ("Error getting the function logs: %s" % ce) - - return function_log + return function_logs + + def sort_events_in_message(self, response): + sorted_msg = "" + data = [] + for elem in response: + for event in elem['events']: + data.append((event['message'], event['timestamp'])) + sorted_data = sorted(data, key=lambda time: time[1]) + for sdata in sorted_data: + sorted_msg += sdata[0] + return sorted_msg def is_end_line(self, line): - return line.startswith('REPORT') and self.request_id in line + return line.startswith('REPORT') and self.properties['request_id'] in line def is_start_line(self, line): - return line.startswith('START') and self.request_id in line + return line.startswith('START') and self.properties['request_id'] in line - def parse_aws_logs(self, logs): - if logs and self.request_id: - full_msg = "" - logging = False - for line in logs.split('\n'): + def parse_logs_with_requestid(self, function_logs): + if function_logs: + parsed_msg = "" + in_reqid_logs = False + for line in function_logs.split('\n'): if self.is_start_line(line): - full_msg += line + '\n' - logging = True + parsed_msg += line + '\n' + in_reqid_logs = True elif self.is_end_line(line): - full_msg += line + '\n' - return full_msg - elif logging: - full_msg += line + '\n' - + parsed_msg += line + break + elif in_reqid_logs: + parsed_msg += line + '\n' + return parsed_msg \ No newline at end of file diff --git a/src/providers/aws/controller.py b/src/providers/aws/controller.py index d46de74b..8b1d0221 100644 --- a/src/providers/aws/controller.py +++ b/src/providers/aws/controller.py @@ -19,194 +19,270 @@ from src.providers.aws.s3 import S3 from src.providers.aws.iam import IAM from src.providers.aws.resourcegroups import ResourceGroups -from botocore.exceptions import ClientError -from src.cmdtemplate import Commands +from src.cmdtemplate import Commands, CallType +from src.providers.aws.validators import AWSValidator import src.logger as logger import src.providers.aws.response as response_parser import src.utils as utils import os +import src.exceptions as excp class AWS(Commands): + properties = {} + @utils.lazy_property def _lambda(self): '''It's called _lambda because 'lambda' it's a restricted word in python''' - _lambda = Lambda() - return _lambda + _lambda = Lambda(self.properties) + return _lambda @utils.lazy_property def cloudwatch_logs(self): - cloudwatch_logs = CloudWatchLogs(self._lambda) + cloudwatch_logs = CloudWatchLogs(self.properties) return cloudwatch_logs @utils.lazy_property def api_gateway(self): - api_gateway = APIGateway(self._lambda) + api_gateway = APIGateway(self.properties) return api_gateway @utils.lazy_property def s3(self): - s3 = S3(self._lambda) + s3 = S3(self.properties) return s3 @utils.lazy_property def resource_groups(self): - resource_groups = ResourceGroups() + resource_groups = ResourceGroups(self.properties) return resource_groups @utils.lazy_property def iam(self): - iam = IAM() + iam = IAM(self.properties) return iam - @utils.exception(logger) + @excp.exception(logger) def init(self): - if self._lambda.has_api_defined(): - api_id, aws_acc_id = self.api_gateway.create_api_gateway() - self._lambda.set_api_gateway_id(api_id, aws_acc_id) - - # Call the aws services - self._lambda.create_function() - self.cloudwatch_logs.create_log_group() + if self._lambda.find_function(): + raise excp.FunctionExistsError(function_name=self._lambda.properties['name']) - if self._lambda.has_input_bucket(): - self.create_input_source() - - if self._lambda.has_output_bucket(): - self.s3.create_bucket(self._lambda.get_output_bucket()) - - if self._lambda.has_api_defined(): + if 'api_gateway' in self.properties: + self.api_gateway.create_api_gateway() + + response = self._lambda.create_function() + if response: + response_parser.parse_lambda_function_creation_response(response, + self._lambda.properties['name'], + self._lambda.client.get_access_key(), + self.properties['output']) + response = self.cloudwatch_logs.create_log_group() + if response: + response_parser.parse_log_group_creation_response(response, + self.cloudwatch_logs.get_log_group_name(), + self.properties['output']) + + if 's3' in self.properties: + self.manage_s3_init() + + if 'api_gateway' in self.properties: self._lambda.add_invocation_permission_from_api_gateway() # If preheat is activated, the function is launched at the init step - if self._lambda.need_preheat(): + if 'preheat' in self.scar_properties: self._lambda.preheat_function() + @excp.exception(logger) def invoke(self): - function_name = self._lambda.get_function_name() - response = self._lambda.invoke_function_http(function_name) + response = self._lambda.invoke_http_endpoint() response_parser.parse_http_response(response, - function_name, - self._lambda.get_property("asynchronous")) + self._lambda.properties['name'], + self._lambda.is_asynchronous()) + @excp.exception(logger) def run(self): - if self._lambda.has_input_bucket(): + if 's3' in self.properties and 'input_bucket' in self.properties['s3']: self.process_input_bucket_calls() else: if self._lambda.is_asynchronous(): self._lambda.set_asynchronous_call_parameters() self._lambda.launch_lambda_instance() - + + @excp.exception(logger) def update(self): self._lambda.update_function_attributes() + @excp.exception(logger) def ls(self): - bucket_name = self._lambda.get_property("bucket") - bucket_folder = self._lambda.get_property("bucket_folder") - if bucket_name: - file_list = self.s3.get_bucket_files(bucket_name, bucket_folder) + if 's3' in self.properties: + file_list = self.s3.get_bucket_file_list() for file_info in file_list: print(file_info) else: lambda_functions = self.get_all_functions() response_parser.parse_ls_response(lambda_functions, - self._lambda.get_output_type()) + self.properties['output']) + @excp.exception(logger) def rm(self): - if self._lambda.delete_all(): - self.delete_all_resources(self.get_all_functions()) + if 'delete_all' in self.scar_properties and self.scar_properties['delete_all']: + self.delete_all_resources(self.get_all_functions()) else: - self.delete_resources(self._lambda.get_function_name()) - + self.delete_resources() + + @excp.exception(logger) def log(self): aws_log = self.cloudwatch_logs.get_aws_log() print(aws_log) + @excp.exception(logger) def put(self): - bucket_name = self._lambda.get_property("bucket") - bucket_folder = self._lambda.get_property("bucket_folder") - path_to_upload = self._lambda.get_property("path") - self.upload_to_s3(bucket_name, bucket_folder, path_to_upload) - + self.upload_file_or_folder_to_s3() + + @excp.exception(logger) def get(self): - bucket_name = self._lambda.get_property("bucket") - file_prefix = self._lambda.get_property("bucket_folder") - output_path = self._lambda.get_property("path") - self.s3.download_bucket_files(bucket_name, file_prefix, output_path) + self.download_file_or_folder_from_s3() - def parse_command_arguments(self, args): - self._lambda.set_properties(args) + @AWSValidator.validate() + @excp.exception(logger) + def parse_arguments(self, **kwargs): + self.properties = kwargs['aws'] + self.scar_properties = kwargs['scar'] + self.add_extra_aws_properties() + def add_extra_aws_properties(self): + self.add_tags() + self.add_output() +# self.add_call_type() + self.add_account_id() + + def add_tags(self): + self.properties["tags"] = {} + self.properties["tags"]['createdby'] = 'scar' + self.properties["tags"]['owner'] = self.iam.get_user_name_or_id() + + def add_output(self): + self.properties["output"] = response_parser.OutputType.PLAIN_TEXT + if 'json' in self.properties and self.properties['json']: + self.properties["output"] = response_parser.OutputType.JSON + # Override json ouput if both of them are defined + if 'verbose' in self.properties and self.properties['verbose']: + self.properties["output"] = response_parser.OutputType.VERBOSE + + def add_account_id(self): + self.properties['account_id'] = utils.find_expression(self.properties['iam']['role'], '\d{12}') + def get_all_functions(self): - functions_arn_list = self.get_functions_arn_list() + user_id = self.iam.get_user_name_or_id() + functions_arn_list = self.resource_groups.get_lambda_functions_arn_list(user_id) return self._lambda.get_all_functions(functions_arn_list) - def get_functions_arn_list(self): - user_id = self.iam.get_user_name_or_id() - return self.resource_groups.get_lambda_functions_arn_list(user_id) + def manage_s3_init(self): + if 'input_bucket' in self.properties['s3']: + self.create_s3_source() + if 'output_bucket' in self.properties['s3']: + self.s3.create_output_bucket() + + @excp.exception(logger) + def create_s3_source(self): + self.s3.create_input_bucket(create_input_folder=True) + self._lambda.link_function_and_input_bucket() + self.s3.set_input_bucket_notification() def process_input_bucket_calls(self): - s3_file_list = self.s3.get_processed_bucket_file_list() - logger.info("Files found: '%s'" % s3_file_list) + s3_file_list = self.s3.get_bucket_file_list() + logger.info("Files found: '{0}'".format(s3_file_list)) # First do a request response invocation to prepare the lambda environment if s3_file_list: - s3_file = s3_file_list.pop(0) - self._lambda.launch_request_response_event(s3_file) + s3_event = self.s3.get_s3_event(s3_file_list.pop(0)) + self._lambda.launch_request_response_event(s3_event) # If the list has more elements, invoke functions asynchronously if s3_file_list: - self._lambda.process_asynchronous_lambda_invocations(s3_file_list) + s3_event_list = self.s3.get_s3_event_list(s3_file_list) + self._lambda.process_asynchronous_lambda_invocations(s3_event_list) - def upload_to_s3(self, bucket_name, bucket_folder, path_to_upload): - self.s3.create_bucket(bucket_name) - if(os.path.isdir(path_to_upload)): - files = utils.get_all_files_in_directory(path_to_upload) - else: - files = [path_to_upload] - for file in files: - self.upload_file_to_s3(bucket_name, bucket_folder, file) - - def upload_file_to_s3(self, bucket_name, bucket_folder, file_path): - file_data = utils.read_file(file_path, 'rb') - file_name = os.path.basename(file_path) - file_key = "{0}".format(file_name) - if bucket_folder and bucket_folder != "" and bucket_folder.endswith("/"): - file_key = "{0}{1}".format(bucket_folder, file_name) - else: - file_key = "{0}/{1}".format(bucket_folder, file_name) - logger.info("Uploading file '{0}' to bucket '{1}' with key '{2}'".format(file_path, bucket_name, file_key)) - self.s3.upload_file(bucket_name, file_key, file_data) + def upload_file_or_folder_to_s3(self): + path_to_upload = self.scar_properties['path'] + bucket_folder = self.s3.properties['input_folder'] + self.s3.create_input_bucket() + files = utils.get_all_files_in_directory(path_to_upload) if os.path.isdir(path_to_upload) else [path_to_upload] + for file_path in files: + self.s3.upload_file(folder_name=bucket_folder, file_path=file_path) - def create_input_source(self): - try: - self.s3.create_input_bucket() - self._lambda.link_function_and_input_bucket() - self.s3.set_input_bucket_notification() - except ClientError as ce: - error_msg = "Error creating the event source" - logger.error(error_msg, error_msg + ": %s" % ce) + def get_download_file_path(self, s3_file_key, file_prefix): + file_path = s3_file_key + # Parse file path + if file_prefix: + # Get folder name + dir_name_to_add = os.path.basename(os.path.dirname(file_prefix)) + # Don't replace last '/' + file_path = s3_file_key.replace(file_prefix[:-1], dir_name_to_add) + if 'path' in self.scar_properties and self.scar_properties['path']: + path_to_download = self.scar_properties['path'] + file_path = utils.join_paths(path_to_download, file_path) + return file_path + def download_file_or_folder_from_s3(self): + bucket_name = self.s3.properties['input_bucket'] + file_prefix = self.s3.properties['input_folder'] + s3_file_list = self.s3.get_bucket_file_list() + for s3_file in s3_file_list: + # Avoid download s3 'folders' + if not s3_file.endswith('/'): + file_path = self.get_download_file_path(s3_file, file_prefix) + # make sure the path folders are created + dir_path = os.path.dirname(file_path) + if dir_path and not os.path.isdir(dir_path): + os.makedirs(dir_path, exist_ok=True) + self.s3.download_file(bucket_name, s3_file, file_path) + def delete_all_resources(self, lambda_functions): for function in lambda_functions: - self.delete_resources(function['FunctionName']) + self._lambda.properties['name'] = function['FunctionName'] + self.delete_resources() - def delete_resources(self, function_name): + def delete_resources(self): + if not self._lambda.find_function(): + raise excp.FunctionNotFoundError(self.properties['lambda']['name']) # Delete associated api - api_id = self._lambda.get_api_gateway_id(function_name) - output_type = self._lambda.get_output_type() - if api_id: - self.api_gateway.delete_api_gateway(api_id, output_type) + self.delete_api_gateway() # Delete associated log - self.cloudwatch_logs.delete_log_group(function_name) + self.delete_logs() # Delete associated notifications - func_info = self._lambda.get_function_info(function_name) - function_arn = func_info['FunctionArn'] - variables = func_info['Environment']['Variables'] - if 'INPUT_BUCKET' in variables: - bucket_name = variables['INPUT_BUCKET'] - self.s3.delete_bucket_notification(bucket_name, function_arn) + self.delete_bucket_notifications() # Delete function - self._lambda.delete_function(function_name) + self.delete_lambda_function() + + def delete_lambda_function(self): + response = self._lambda.delete_function() + if response: + response_parser.parse_delete_function_response(response, + self.properties['lambda']['name'], + self.properties['output']) + + def delete_bucket_notifications(self): + func_info = self._lambda.get_function_info() + self.properties['lambda']['arn'] = func_info['FunctionArn'] + self.properties['lambda']['environment'] = {'Variables' : func_info['Environment']['Variables']} + if 'INPUT_BUCKET' in self.properties['lambda']['environment']['Variables']: + self.properties['s3'] = {'input_bucket' : self.properties['lambda']['environment']['Variables']['INPUT_BUCKET']} + self.s3.delete_bucket_notification() + + def delete_logs(self): + response = self.cloudwatch_logs.delete_log_group() + if response: + response_parser.parse_delete_log_response(response, + self.cloudwatch_logs.get_log_group_name(), + self.properties['output']) + def delete_api_gateway(self): + self.properties['api_gateway'] = {'id' : self._lambda.get_api_gateway_id() } + if self.properties['api_gateway']['id']: + response = self.api_gateway.delete_api_gateway() + if response: + response_parser.parse_delete_api_response(response, + self.properties['api_gateway']['id'], + self.properties['output']) diff --git a/src/providers/aws/functioncode.py b/src/providers/aws/functioncode.py new file mode 100644 index 00000000..893ad176 --- /dev/null +++ b/src/providers/aws/functioncode.py @@ -0,0 +1,189 @@ +# SCAR - Serverless Container-aware ARchitectures +# Copyright (C) 2011 - GRyCAP - Universitat Politecnica de Valencia +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import shutil +import src.logger as logger +import src.utils as utils +import subprocess +import tempfile +from distutils import dir_util +import src.exceptions as excp +from src.providers.aws.validators import AWSValidator + +MAX_PAYLOAD_SIZE = 50 * 1024 * 1024 +MAX_S3_PAYLOAD_SIZE = 250 * 1024 * 1024 + +def udocker_env(func): + ''' + Decorator used to avoid losing the definition of the udocker + environment variables (if any) + ''' + def wrapper(*args, **kwargs): + FunctionPackageCreator.save_tmp_udocker_env() + func(*args, **kwargs) + FunctionPackageCreator.restore_udocker_env() + return wrapper + +class FunctionPackageCreator(): + + src_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + aws_src_path = os.path.dirname(os.path.abspath(__file__)) + lambda_code_files_path = utils.join_paths(aws_src_path, "cloud/lambda/") + os_tmp_folder = tempfile.gettempdir() + scar_temporal_folder = utils.join_paths(os_tmp_folder, "scar/") + + supervisor_source = utils.join_paths(lambda_code_files_path, "scarsupervisor.py") + + udocker_source = utils.join_paths(lambda_code_files_path, "udockerb") + udocker_dest = utils.join_paths(scar_temporal_folder, "udockerb") + + udocker_exec = ["python3", udocker_dest] + udocker_tarball = "" + udocker_dir = "" + init_script_name = "init_script.sh" + init_script_path = "/var/task/{0}".format(init_script_name) + extra_payload_path = "/var/task" + + def __init__(self, package_props): + self.properties = package_props + self.lambda_code_name = "{0}.py".format(self.properties['FunctionName']) + self.supervisor_dest = utils.join_paths(self.scar_temporal_folder, self.lambda_code_name) + + @excp.exception(logger) + def prepare_lambda_code(self): + self.clean_tmp_folders() + self.add_mandatory_files() + + if 'DeploymentBucket' in self.properties and 'ImageId' in self.properties: + self.download_udocker_image() + if 'ImageFile' in self.properties: + self.prepare_udocker_image() + + self.add_init_script() + self.add_extra_payload() + self.zip_scar_folder() + self.check_code_size() + + def add_mandatory_files(self): + os.makedirs(self.scar_temporal_folder, exist_ok=True) + shutil.copy(self.supervisor_source, self.supervisor_dest) + shutil.copy(self.udocker_source, self.udocker_dest) + + os.makedirs(utils.join_paths(self.scar_temporal_folder, "src"), exist_ok=True) + shutil.copy(utils.join_paths(self.lambda_code_files_path, "__init__.py"), + utils.join_paths(self.scar_temporal_folder, "src/__init__.py")) + shutil.copy(utils.join_paths(self.src_path, "utils.py"), + utils.join_paths(self.scar_temporal_folder, "src/utils.py")) + shutil.copy(utils.join_paths(self.src_path, "exceptions.py"), + utils.join_paths(self.scar_temporal_folder, "src/exceptions.py")) + + self.set_environment_variable('UDOCKER_DIR', "/tmp/home/udocker") + self.set_environment_variable('UDOCKER_LIB', "/var/task/udocker/lib/") + self.set_environment_variable('UDOCKER_BIN', "/var/task/udocker/bin/") + self.create_udocker_files() + + @udocker_env + def create_udocker_files(self): + self.execute_command(self.udocker_exec + ["help"], cli_msg="Packing udocker files") + + def add_init_script(self): + if 'Script' in self.properties: + shutil.copy(self.properties['Script'], utils.join_paths(self.scar_temporal_folder, self.init_script_name)) + self.properties['EnvironmentVariables']['INIT_SCRIPT_PATH'] = self.init_script_path + + def add_extra_payload(self): + if 'ExtraPayload' in self.properties: + logger.info("Adding extra payload from {0}".format(self.properties['ExtraPayload'])) + dir_util.copy_tree(self.properties['ExtraPayload'], self.scar_temporal_folder) + self.set_environment_variable('EXTRA_PAYLOAD', self.extra_payload_path) + + def check_code_size(self): + # Check if the code size fits within the aws limits + if 'DeploymentBucket' in self.properties: + AWSValidator.validate_s3_code_size(self.scar_temporal_folder, MAX_S3_PAYLOAD_SIZE) + else: + AWSValidator.validate_function_code_size(self.scar_temporal_folder, MAX_PAYLOAD_SIZE) + + def clean_tmp_folders(self): + if os.path.isfile(self.properties['ZipFilePath']): + utils.delete_file(self.properties['ZipFilePath']) + # Delete created temporal files + if os.path.isdir(self.scar_temporal_folder): + shutil.rmtree(self.scar_temporal_folder, ignore_errors=True) + + def zip_scar_folder(self): + self.execute_command(["zip", "-r9y", self.properties['ZipFilePath'], "."], + cmd_wd=self.scar_temporal_folder, + cli_msg="Creating function package") + + @classmethod + def save_tmp_udocker_env(cls): + #Avoid override global variables + if utils.is_value_in_dict(os.environ, 'UDOCKER_TARBALL'): + cls.udocker_tarball = os.environ['UDOCKER_TARBALL'] + if utils.is_value_in_dict(os.environ, 'UDOCKER_DIR'): + cls.udocker_dir = os.environ['UDOCKER_DIR'] + # Set temporal global vars + utils.set_environment_variable('UDOCKER_TARBALL', utils.join_paths(cls.lambda_code_files_path, "udocker-1.1.0-RC2.tar.gz")) + utils.set_environment_variable('UDOCKER_DIR', utils.join_paths(cls.scar_temporal_folder, "udocker")) + + @classmethod + def restore_udocker_env(cls): + cls.restore_environ_var('UDOCKER_TARBALL', cls.udocker_tarball) + cls.restore_environ_var('UDOCKER_DIR', cls.udocker_dir) + + @classmethod + def restore_environ_var(cls, key, var): + if var: + utils.set_environment_variable(key, var) + else: + del os.environ[key] + + def execute_command(self, command, cmd_wd=None, cli_msg=None): + cmd_out = subprocess.check_output(command, cwd=cmd_wd).decode("utf-8") + logger.info(cli_msg, cmd_out) + return cmd_out[:-1] + + @udocker_env + def prepare_udocker_image(self): + image_path = utils.join_paths(self.os_tmp_folder, "udocker_image.tar.gz") + shutil.copy(self.properties['ImageFile'], image_path) + cmd_out = self.execute_command(self.udocker_exec + ["load", "-i", image_path], cli_msg="Loading image file") + self.create_udocker_container(cmd_out) + self.set_environment_variable('IMAGE_ID', cmd_out) + self.set_udocker_local_registry() + + @udocker_env + def download_udocker_image(self): + self.execute_command(self.udocker_exec + ["pull", self.properties['ImageId']], cli_msg="Downloading container image") + self.create_udocker_container(self.properties['ImageId']) + self.set_udocker_local_registry() + + def create_udocker_container(self, image_id): + if(utils.get_tree_size(self.scar_temporal_folder) < MAX_S3_PAYLOAD_SIZE/2): + self.execute_command(self.udocker_exec + ["create", "--name=lambda_cont", image_id], cli_msg="Creating container structure") + if(utils.get_tree_size(self.scar_temporal_folder) > MAX_S3_PAYLOAD_SIZE): + shutil.rmtree(utils.join_paths(self.scar_temporal_folder, "udocker/containers/")) + + def set_udocker_local_registry(self): + self.set_environment_variable('UDOCKER_REPOS', '/var/task/udocker/repos/') + self.set_environment_variable('UDOCKER_LAYERS', '/var/task/udocker/layers/') + + def set_environment_variable(self, key, val): + if key and val: + self.properties['EnvironmentVariables'][key] = val + \ No newline at end of file diff --git a/src/providers/aws/iam.py b/src/providers/aws/iam.py index d108c647..2296f793 100644 --- a/src/providers/aws/iam.py +++ b/src/providers/aws/iam.py @@ -14,11 +14,12 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from src.providers.aws.clientfactory import GenericClient +from src.providers.aws.botoclientfactory import GenericClient class IAM(GenericClient): def get_user_name_or_id(self): user = self.client.get_user_info() - return user.get('UserName', user['User']['UserId']) + if user: + return user.get('UserName', user['User']['UserId']) \ No newline at end of file diff --git a/src/providers/aws/lambdafunction.py b/src/providers/aws/lambdafunction.py index ddd3f33d..d8dee64c 100644 --- a/src/providers/aws/lambdafunction.py +++ b/src/providers/aws/lambdafunction.py @@ -14,179 +14,141 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from src.providers.aws.iam import IAM from botocore.exceptions import ClientError -from enum import Enum from multiprocessing.pool import ThreadPool -from src.parser.cfgfile import ConfigFile -from src.providers.aws.response import OutputType +from src.providers.aws.botoclientfactory import GenericClient +from src.providers.aws.functioncode import FunctionPackageCreator +from src.providers.aws.s3 import S3 +import base64 import json -import os -import src.http.invoke as invoke +import src.exceptions as excp +import src.http.request as request import src.logger as logger -import src.providers.aws.payload as codezip -import src.providers.aws.validators as validators import src.providers.aws.response as response_parser import src.utils as utils -import tempfile -import base64 -from src.providers.aws.clientfactory import GenericClient -import src.exceptions as scar_excp MAX_CONCURRENT_INVOCATIONS = 1000 -MAX_POST_BODY_SIZE = 1024*1024*6 -MAX_POST_BODY_SIZE_ASYNC = 1024*95 - -class CallType(Enum): - INIT = "init" - RUN = "run" - UPDATE = "update" - LS = "ls" - RM = "rm" - LOG = "log" - INVOKE = "invoke" - PUT = "put" - GET = "get" - -def get_call_type(value): - for call_type in CallType: - if call_type.value == value: - return call_type +MB = 1024*1024 +KB = 1024 +MAX_POST_BODY_SIZE = MB*6 +MAX_POST_BODY_SIZE_ASYNC = KB*95 class Lambda(GenericClient): - properties = { - "runtime" : "python3.6", - "invocation_type" : "RequestResponse", - "log_type" : "Tail", - "output" : OutputType.PLAIN_TEXT, - "payload" : {}, - "tags" : {}, - "environment" : { 'Variables' : {} }, - "environment_variables" : {}, - "name_regex" : "(arn:(aws[a-zA-Z-]*)?:lambda:)?([a-z]{2}(-gov)?-[a-z]+-\d{1}:)?(\d{12}:)?(function:)?([a-zA-Z0-9-_]+)(:(\$LATEST|[a-zA-Z0-9-_]+))?", - "s3_event" : { "Records" : [ - {"eventSource" : "aws:s3", - "s3" : {"bucket" : { "name" : "" }, - "object" : { "key" : "" } } - }]}, - "zip_file_path" : os.path.join(tempfile.gettempdir(), 'function.zip') - } - - def __init__(self): - self.set_config_file_properties() - validators.validate_iam_role(self.properties["iam"]) - - def set_config_file_properties(self): - config_file_props = ConfigFile().get_aws_props() - self.properties = utils.merge_dicts(self.properties, config_file_props['lambda']) - self.properties['iam'] = config_file_props['iam'] - self.properties['cloudwatch'] = config_file_props['cloudwatch'] - - def get_property(self, value, nested_value=None): - if value in self.properties: - if nested_value and nested_value in self.properties[value]: - return self.properties[value][nested_value] - else: - return self.properties[value] - - def set_property(self, key, value): - self.properties[key] = value + def __init__(self, aws_properties): + GenericClient.__init__(self, aws_properties) + self.properties = aws_properties['lambda'] + self.properties['environment'] = {'Variables' : {}} + self.properties['zip_file_path'] = utils.join_paths(utils.get_temp_dir(), 'function.zip') + self.properties['invocation_type'] = 'RequestResponse' + self.properties['log_type'] = 'Tail' + if 'name' in self.properties: + self.properties['handler'] = "{0}.lambda_handler".format(self.properties['name']) + if 'asynchronous' not in self.properties: + self.properties['asynchronous'] = False - def delete_all(self): - return self.get_property("all") - - def get_output_type(self): - return self.get_property("output") - - def get_function_name(self): - return self.get_property("name") - - def get_function_arn(self): - return self.get_property("function_arn") - - def need_preheat(self): - return self.get_property("preheat") - - def get_input_bucket(self): - return self.get_property("input_bucket") - - def get_output_bucket(self): - return self.get_property("output_bucket") - def get_creations_args(self): - return {'FunctionName' : self.get_property("name"), - 'Runtime' : self.get_property("runtime"), - 'Role' : self.get_property("iam", "role"), - 'Handler' : self.get_property("handler"), - 'Code' : self.get_property("code"), - 'Environment' : self.get_property("environment"), - 'Description':self.get_property("description"), - 'Timeout': self.get_property("time"), - 'MemorySize':self.get_property("memory"), - 'Tags':self.get_property("tags") } + return {'FunctionName' : self.properties['name'], + 'Runtime' : self.properties['runtime'], + 'Role' : self.aws_properties['iam']['role'], + 'Handler' : self.properties['handler'], + 'Code' : self.properties['code'], + 'Environment' : self.properties['environment'], + 'Description': self.properties['description'], + 'Timeout': self.properties['time'], + 'MemorySize': self.properties['memory'], + 'Tags': self.aws_properties['tags'], + } + @excp.exception(logger) def create_function(self): - try: - response = self.client.create_function(**self.get_creations_args()) - if response and 'FunctionArn' in response: - self.properties["function_arn"] = response['FunctionArn'] - response_parser.parse_lambda_function_creation_response(response, - self.get_function_name(), - self.client.get_access_key(), - self.get_output_type()) - except ClientError as ce: - error_msg = "Error initializing lambda function." - logger.error(error_msg, error_msg + ": %s" % ce) - finally: - # Remove the files created in the operation - utils.delete_file(self.properties["zip_file_path"]) + self.set_environment_variables() + self.set_function_code() + creation_args = self.get_creations_args() + response = self.client.create_function(**creation_args) + if response and 'FunctionArn' in response: + self.properties["function_arn"] = response['FunctionArn'] + return response - def delete_function(self, func_name=None): - if func_name: - function_name = func_name - else: - function_name = self.get_function_name() - self.check_function_name(function_name) - # Delete lambda function - response = self.client.delete_function(function_name) - response_parser.parse_delete_function_response(response, - function_name, - self.get_output_type()) + def set_environment_variables(self): + # Add required variables + self.set_required_environment_variables() + # Add explicitly user defined variables + if 'environment_variables' in self.properties: + if type(self.properties['environment_variables']) is dict: + for key, val in self.properties['environment_variables'].items(): + # Add an specific prefix to be able to find the variables defined by the user + self.add_lambda_environment_variable('CONT_VAR_{0}'.format(key), val) + else: + for env_var in self.properties['environment_variables']: + key_val = env_var.split("=") + # Add an specific prefix to be able to find the variables defined by the user + self.add_lambda_environment_variable('CONT_VAR_{0}'.format(key_val[0]), key_val[1]) + + def set_required_environment_variables(self): + self.add_lambda_environment_variable('TIMEOUT_THRESHOLD', str(self.properties['timeout_threshold'])) + self.add_lambda_environment_variable('LOG_LEVEL', self.properties['log_level']) + if utils.is_value_in_dict(self.properties, 'image'): + self.add_lambda_environment_variable('IMAGE_ID', self.properties['image']) + self.add_s3_environment_vars() + if 'api_gateway' in self.aws_properties: + self.add_lambda_environment_variable('API_GATEWAY_ID', self.aws_properties['api_gateway']['id']) + + def add_s3_environment_vars(self): + if utils.is_value_in_dict(self.aws_properties, 's3'): + s3_props = self.aws_properties['s3'] + if utils.is_value_in_dict(self.aws_properties, 'input_bucket'): + self.add_lambda_environment_variable('INPUT_BUCKET', s3_props['input_bucket']) + if utils.is_value_in_dict(self.aws_properties, 'output_bucket'): + self.add_lambda_environment_variable('OUTPUT_BUCKET', s3_props['output_bucket']) + if utils.is_value_in_dict(self.aws_properties, 'output_folder'): + self.add_lambda_environment_variable('OUTPUT_FOLDER', s3_props['output_folder']) - def create_function_name(self, image_id_or_path): - parsed_id_or_path = image_id_or_path.replace('/', ',,,').replace(':', ',,,').replace('.', ',,,').split(',,,') - name = "scar-{0}".format('-'.join(parsed_id_or_path)) - i = 1 - while self.find_function(name): - name = "scar-{0}-{1}".format('-'.join(parsed_id_or_path), str(i)) - i += 1 - return name + + def add_lambda_environment_variable(self, key, value): + if key and value: + self.properties['environment']['Variables'][key] = value - @utils.exception(logger) - def check_function_name(self, func_name=None): - call_type = self.get_property("call_type") - if func_name: - function_name = func_name + @excp.exception(logger) + def set_function_code(self): + package_props = self.get_function_payload_props() + # Zip all the files and folders needed + FunctionPackageCreator(package_props).prepare_lambda_code() + if 'DeploymentBucket' in package_props: + self.aws_properties['s3']['input_bucket'] = package_props['DeploymentBucket'] + S3(self.aws_properties).upload_file(file_path=package_props['ZipFilePath'], file_key=package_props['FileKey']) + self.properties['code'] = {"S3Bucket": package_props['DeploymentBucket'], + "S3Key" : package_props['FileKey'],} else: - function_name = self.get_property("name") - function_found = self.find_function(function_name) - error_msg = None - if function_found and (call_type == CallType.INIT): - error_msg = "Function name '{0}' already used.".format(function_name) - raise scar_excp.FunctionCreationError(function_name=function_name, error_msg=error_msg) - elif (not function_found) and ((call_type == CallType.RM) or - (call_type == CallType.RUN) or - (call_type == CallType.INVOKE)): - error_msg = "Function '{0}' doesn't exist.".format(function_name) - raise scar_excp.FunctionNotFoundError(function_name=function_name, error_msg=error_msg) - if error_msg: - logger.error(error_msg) + self.properties['code'] = {"ZipFile": utils.read_file(self.properties['zip_file_path'], mode="rb")} + + def get_function_payload_props(self): + package_args = {'FunctionName' : self.properties['name'], + 'EnvironmentVariables' : self.properties['environment']['Variables'], + 'ZipFilePath' : self.properties['zip_file_path'], + } + if 'init_script' in self.properties: + package_args['Script'] = self.properties['init_script'] + if 'extra_payload' in self.properties: + package_args['ExtraPayload'] = self.properties['extra_payload'] + if 'image_id' in self.properties: + package_args['ImageId'] = self.properties['image_id'] + if 'image_file' in self.properties: + package_args['ImageFile'] = self.properties['image_file'] + if 's3' in self.aws_properties: + if 'deployment_bucket' in self.aws_properties['s3']: + package_args['DeploymentBucket'] = self.aws_properties['s3']['deployment_bucket'] + if 'DeploymentBucket' in package_args: + package_args['FileKey'] = 'lambda/{0}.zip'.format(self.properties['name']) + return package_args + + def delete_function(self): + return self.client.delete_function(self.properties['name']) def link_function_and_input_bucket(self): - kwargs = {'FunctionName' : self.get_function_name(), + kwargs = {'FunctionName' : self.properties['name'], 'Principal' : "s3.amazonaws.com", - 'SourceArn' : 'arn:aws:s3:::{0}'.format(self.get_input_bucket())} + 'SourceArn' : 'arn:aws:s3:::{0}'.format(self.aws_properties['s3']['input_bucket'])} self.client.add_invocation_permission(**kwargs) def preheat_function(self): @@ -194,219 +156,94 @@ def preheat_function(self): self.set_request_response_call_parameters() return self.launch_lambda_instance() - def launch_async_event(self, s3_file): + def launch_async_event(self, s3_event): self.set_asynchronous_call_parameters() - return self.launch_s3_event(s3_file) + return self.launch_s3_event(s3_event) - def launch_request_response_event(self, s3_file): + def launch_request_response_event(self, s3_event): self.set_request_response_call_parameters() - return self.launch_s3_event(s3_file) + return self.launch_s3_event(s3_event) - def launch_s3_event(self, s3_file): - self.set_s3_event_source(s3_file) - self.set_property('payload', self.get_property("s3_event")) - logger.info("Sending event for file '%s'" % s3_file) + def launch_s3_event(self, s3_event): + self.properties['payload'] = s3_event + logger.info("Sending event for file '{0}'".format(s3_event['Records'][0]['s3']['object']['key'])) return self.launch_lambda_instance() - def process_asynchronous_lambda_invocations(self, s3_file_list): - if (len(s3_file_list) > MAX_CONCURRENT_INVOCATIONS): - s3_file_chunk_list = utils.divide_list_in_chunks(s3_file_list, MAX_CONCURRENT_INVOCATIONS) + def process_asynchronous_lambda_invocations(self, s3_event_list): + if (len(s3_event_list) > MAX_CONCURRENT_INVOCATIONS): + s3_file_chunk_list = utils.divide_list_in_chunks(s3_event_list, MAX_CONCURRENT_INVOCATIONS) for s3_file_chunk in s3_file_chunk_list: self.launch_concurrent_lambda_invocations(s3_file_chunk) else: - self.launch_concurrent_lambda_invocations(s3_file_list) + self.launch_concurrent_lambda_invocations(s3_event_list) - def launch_concurrent_lambda_invocations(self, s3_file_list): - pool = ThreadPool(processes=len(s3_file_list)) - pool.map( - lambda s3_file: self.launch_async_event(s3_file), s3_file_list - ) + def launch_concurrent_lambda_invocations(self, s3_event_list): + pool = ThreadPool(processes=len(s3_event_list)) + pool.map(lambda s3_event: self.launch_async_event(s3_event), s3_event_list) pool.close() def launch_lambda_instance(self): response = self.invoke_lambda_function() response_args = {'Response' : response, - 'FunctionName' : self.get_function_name(), - 'OutputType' : self.get_property("output"), - 'IsAsynchronous' : self.is_asynchronous()} + 'FunctionName' : self.properties['name'], + 'OutputType' : self.aws_properties['output'], + 'IsAsynchronous' : self.properties['asynchronous']} response_parser.parse_invocation_response(**response_args) + def get_payload(self): + # Default payload + payload = {} + + if 'run_script' in self.properties: + file_content = utils.read_file(self.properties['run_script'], 'rb') + # We first code to base64 in bytes and then decode those bytes to allow the json lib to parse the data + # https://stackoverflow.com/questions/37225035/serialize-in-json-a-base64-encoded-data#37239382 + payload = { "script" : utils.utf8_to_base64_string(file_content) } + + if 'c_args' in self.properties: + payload = { "cmd_args" : json.dumps(self.properties['c_args']) } + + return json.dumps(payload) + def invoke_lambda_function(self): - invoke_args = {'FunctionName' : self.get_function_name(), - 'InvocationType' : self.get_property("invocation_type"), - 'LogType' : self.get_property("log_type"), - 'Payload' : json.dumps(self.get_property("payload"))} + invoke_args = {'FunctionName' : self.properties['name'], + 'InvocationType' : self.properties['invocation_type'], + 'LogType' : self.properties['log_type'], + 'Payload' : self.get_payload() } return self.client.invoke_function(**invoke_args) - def is_asynchronous(self): - return self.get_property('asynchronous') - def set_asynchronous_call_parameters(self): - self.set_property('invocation_type', "Event") - self.set_property('log_type', 'None') - self.set_property('asynchronous', 'True') + self.properties['invocation_type'] = "Event" + self.properties['log_type'] = "None" + self.properties['asynchronous'] = "True" - def set_api_gateway_id(self, api_id, acc_id): - self.set_property('api_gateway_id', api_id) - self.set_property('aws_acc_id', acc_id) - self.add_lambda_environment_variable('API_GATEWAY_ID', api_id) - def set_request_response_call_parameters(self): - self.set_property('invocation_type', "RequestResponse") - self.set_property('log_type', "Tail") - self.set_property('asynchronous', 'False') - - def set_s3_event_source(self, file_name): - self.properties['s3_event']['Records'][0]['s3']['bucket']['name'] = self.get_property('input_bucket') - self.properties['s3_event']['Records'][0]['s3']['object']['key'] = file_name - - def set_property_if_has_value(self, dictio, key, prop): - prop_val = self.get_property(prop) - if prop_val and prop_val != "": - dictio[key] = prop_val + self.properties['invocation_type'] = "RequestResponse" + self.properties['log_type'] = "Tail" + self.properties['asynchronous'] = "False" - def get_function_code_args(self): - package_args = {'FunctionName' : self.get_property("name"), - 'EnvironmentVariables' : self.get_property("environment", "Variables")} - self.set_property_if_has_value(package_args, 'Script', "init_script") - self.set_property_if_has_value(package_args, 'ExtraPayload', "extra_payload") - self.set_property_if_has_value(package_args, 'ImageId', "image_id") - self.set_property_if_has_value(package_args, 'ImageFile', "image_file") - self.set_property_if_has_value(package_args, 'DeploymentBucket', "deployment_bucket") - if 'DeploymentBucket' in package_args: - package_args['FileKey'] = 'lambda/' + self.get_property("name") + '.zip' - return package_args - - def set_function_code(self): - package_args = self.get_function_code_args() - # Zip all the files and folders needed - codezip.prepare_lambda_payload(**package_args) - - if 'DeploymentBucket' in package_args: - self.properties['code'] = { "S3Bucket": package_args['DeploymentBucket'], "S3Key" : package_args['FileKey'] } - else: - self.properties['code'] = { "ZipFile": utils.read_file(self.get_property("zip_file_path"), mode="rb")} - - def has_image_file(self): - return utils.has_dict_prop_value(self.properties, 'image_file') - - def has_api_defined(self): - return utils.has_dict_prop_value(self.properties, 'api_gateway_name') - - def has_deployment_bucket(self): - return utils.has_dict_prop_value(self.properties, 'deployment_bucket') - - def has_input_bucket(self): - return utils.has_dict_prop_value(self.properties, 'input_bucket') - - def has_output_bucket(self): - return utils.has_dict_prop_value(self.properties, 'output_bucket') - - def has_output_folder(self): - return utils.has_dict_prop_value(self.properties, 'output_folder') - - def set_required_environment_variables(self): - self.add_lambda_environment_variable('TIMEOUT_THRESHOLD', str(self.get_property("timeout_threshold"))) - self.add_lambda_environment_variable('LOG_LEVEL', self.get_property("log_level")) - self.add_lambda_environment_variable('IMAGE_ID', self.get_property("image_id")) - if self.has_input_bucket(): - self.add_lambda_environment_variable('INPUT_BUCKET', self.get_property("input_bucket")) - if self.has_output_bucket(): - self.add_lambda_environment_variable('OUTPUT_BUCKET', self.get_property("output_bucket")) - if self.has_output_folder(): - self.add_lambda_environment_variable('OUTPUT_FOLDER', self.get_property("output_folder")) - - def add_lambda_environment_variable(self, key, value): - if (key is not None or key != "") and (value is not None): - self.get_property("environment", "Variables")[key] = value - - def set_environment_variables(self): - if isinstance(self.get_property("environment_variables"), list): - for env_var in self.get_property("environment_variables"): - parsed_env_var = env_var.split("=") - # Add an specific prefix to be able to find the variables defined by the user - key = 'CONT_VAR_' + parsed_env_var[0] - self.add_lambda_environment_variable(key, parsed_env_var[1]) - if (self.get_property("call_type") == CallType.INIT): - self.set_required_environment_variables() - - def set_tags(self): - self.properties["tags"]['createdby'] = 'scar' - self.properties["tags"]['owner'] = IAM().get_user_name_or_id() - - def get_argument_value(self, args, attr): - if attr in args.__dict__.keys(): - return args.__dict__[attr] - def update_function_attributes(self): - update_args = {'FunctionName' : self.get_property("name") } - self.set_property_if_has_value(update_args, 'MemorySize', "memory") - self.set_property_if_has_value(update_args, 'Timeout', "time") + update_args = {'FunctionName' : self.properties['name'] } + if "memory" in self.properties and self.properties['memory']: + update_args['MemorySize'] = self.properties['memory'] + if "time" in self.properties and self.properties['time']: + update_args['Timeout'] = self.properties['time'] # To update the environment variables we need to retrieve the # variables defined in lambda and update them with the new values - env_vars = self.get_property("environment") - if self.get_property('timeout_threshold'): - env_vars['Variables']['TIMEOUT_THRESHOLD'] = str(self.get_property('timeout_threshold')) - if self.get_property('log_level'): - env_vars['Variables']['LOG_LEVEL'] = self.get_property('log_level') - defined_lambda_env_variables = self.client.get_function_environment_variables(self.get_property("name")) + env_vars = self.properties['environment'] + if "timeout_threshold" in self.properties and self.properties['timeout_threshold']: + env_vars['Variables']['TIMEOUT_THRESHOLD'] = str(self.properties['timeout_threshold']) + if "log_level" in self.properties and self.properties['log_level']: + env_vars['Variables']['LOG_LEVEL'] = self.properties['log_level'] + defined_lambda_env_variables = self.get_function_environment_variables() defined_lambda_env_variables['Variables'].update(env_vars['Variables']) update_args['Environment'] = defined_lambda_env_variables - validators.validate(**update_args) self.client.update_function(**update_args) + logger.info("Function updated successfully.") - def set_call_type(self, call_type): - self.set_property("call_type", get_call_type(call_type)) - return self.properties["call_type"] - - def set_output_type(self): - if self.get_property("json"): - self.set_property("output", OutputType.JSON) - elif self.get_property("verbose"): - self.set_property("output", OutputType.VERBOSE) - - def set_properties(self, args): - # Set the command line parsed properties - self.properties = utils.merge_dicts(self.properties, vars(args)) - call_type = self.set_call_type(args.func.__name__) - self.set_output_type() - if ((call_type != CallType.LS) and - (not self.delete_all()) and - (call_type != CallType.PUT) and - (call_type != CallType.GET)): - if (call_type == CallType.INIT): - if (not self.get_property("name")) or (self.get_property("name") == ""): - func_name = "function" - if self.get_property("image_id") != "": - func_name = self.get_property("image_id") - elif self.get_property("image_file") != "": - func_name = self.get_property("image_file").split('.')[0] - self.properties["name"] = self.create_function_name(func_name) - self.set_tags() - - self.check_function_name() - function_name = self.get_property("name") - validators.validate_function_name(function_name, self.get_property("name_regex")) - - self.set_environment_variables() - self.properties["handler"] = function_name + ".lambda_handler" - self.properties["log_group_name"] = '/aws/lambda/' + function_name - - if (call_type == CallType.INIT): - self.set_function_code() - - if (call_type == CallType.RUN): - if self.get_argument_value(args, 'run_script'): - file_content = utils.read_file(self.get_property("run_script"), 'rb') - # We first code to base64 in bytes and then decode those bytes to allow json to work - # https://stackoverflow.com/questions/37225035/serialize-in-json-a-base64-encoded-data#37239382 - parsed_script = utils.utf8_to_base64_string(file_content) - self.set_property('payload', { "script" : parsed_script }) - - if self.get_argument_value(args, 'c_args'): - parsed_cont_args = json.dumps(self.get_property("c_args")) - self.set_property('payload', { "cmd_args" : parsed_cont_args }) + def get_function_environment_variables(self): + return self.get_function_info()['Environment'] def get_all_functions(self, arn_list): function_info_list = [] @@ -417,92 +254,89 @@ def get_all_functions(self, arn_list): print ("Error getting function info by arn: %s" % ce) return function_info_list - def get_function_info(self, function_name_or_arn): - try: - # If this call works the function exists - return self.client.get_function_info(function_name_or_arn) - except ClientError as ce: - error_msg = "Error while looking for the lambda function" - logger.error(error_msg, error_msg + ": %s" % ce) + def get_function_info(self): + return self.client.get_function_info(self.properties['name']) - def find_function(self, function_name_or_arn): - validators.validate_function_name(function_name_or_arn, self.get_property("name_regex")) + @excp.exception(logger) + def find_function(self, function_name_or_arn=None): try: # If this call works the function exists - self.client.get_function_info(function_name_or_arn) + if function_name_or_arn: + name_arn = function_name_or_arn + else: + name_arn = self.properties['name'] + self.client.get_function_info(name_arn) return True except ClientError as ce: # Function not found if ce.response['Error']['Code'] == 'ResourceNotFoundException': return False else: - error_msg = "Error while looking for the lambda function" - logger.error(error_msg, error_msg + ": %s" % ce) + raise def add_invocation_permission_from_api_gateway(self): - api_gateway_id = self.get_property('api_gateway_id') - aws_acc_id = self.get_property('aws_acc_id') - kwargs = {'FunctionName' : self.get_function_name(), + api_gateway_id = self.aws_properties['api_gateway']['id'] + aws_acc_id = self.aws_properties['account_id'] + aws_region = self.aws_properties['region'] + kwargs = {'FunctionName' : self.properties['name'], 'Principal' : 'apigateway.amazonaws.com', - 'SourceArn' : 'arn:aws:execute-api:us-east-1:{0}:{1}/*'.format(aws_acc_id, api_gateway_id)} + 'SourceArn' : 'arn:aws:execute-api:{0}:{1}:{2}/*'.format(aws_region, aws_acc_id, api_gateway_id), + } # Testing permission self.client.add_invocation_permission(**kwargs) # Invocation permission - kwargs['SourceArn'] = 'arn:aws:execute-api:us-east-1:{0}:{1}/scar/ANY'.format(aws_acc_id, api_gateway_id) + kwargs['SourceArn'] = 'arn:aws:execute-api:{0}:{1}:{2}/scar/ANY'.format(aws_region, aws_acc_id, api_gateway_id) self.client.add_invocation_permission(**kwargs) - def get_api_gateway_id(self, function_name): - self.check_function_name(function_name) - env_vars = self.client.get_function_environment_variables(function_name) + def get_api_gateway_id(self): + env_vars = self.get_function_environment_variables() if ('API_GATEWAY_ID' in env_vars['Variables']): return env_vars['Variables']['API_GATEWAY_ID'] - def get_api_gateway_url(self, function_name): - api_id = self.get_api_gateway_id(function_name) - if api_id is None or api_id == "": - error_msg = "Error retrieving API ID for lambda function {0}".format(function_name) - logger.error(error_msg) - return 'https://{0}.execute-api.{1}.amazonaws.com/scar/launch'.format(api_id, self.get_property("region")) + def get_api_gateway_url(self): + api_id = self.get_api_gateway_id() + if not api_id: + raise excp.ApiEndpointNotFoundError(self.properties['name']) + return 'https://{0}.execute-api.{1}.amazonaws.com/scar/launch'.format(api_id, self.aws_properties["region"]) def get_http_invocation_headers(self): - asynch = self.get_property("asynchronous") - if asynch: + if self.is_asynchronous(): return {'X-Amz-Invocation-Type':'Event'} - def get_encoded_binary_data(self): - data = self.get_property("data_binary") - if data: - self.check_file_size(data) - with open(data, 'rb') as f: + def parse_http_parameters(self, parameters): + if type(parameters) is dict: + return parameters + return json.loads(parameters) + + def get_encoded_binary_data(self, data_path): + if data_path: + self.check_file_size(data_path) + with open(data_path, 'rb') as f: data = f.read() return base64.b64encode(data) - def get_http_parameters(self): - params = self.get_property("parameters") - if params: - if type(params) is dict: - return params - return json.loads(params) - - def invoke_function_http(self, function_name): - function_url = self.get_api_gateway_url(function_name) - headers = self.get_http_invocation_headers() - params = self.get_http_parameters() - data = self.get_encoded_binary_data() - - return invoke.invoke_function(function_url, - parameters=params, - data=data, - headers=headers) + def invoke_http_endpoint(self): + invoke_args = {'headers' : self.get_http_invocation_headers()} + print(invoke_args) + if 'api_gateway' in self.aws_properties: + api_props = self.aws_properties['api_gateway'] + if 'data_binary' in api_props and api_props['data_binary']: + invoke_args['data'] = self.get_encoded_binary_data(api_props['data_binary']) + if 'parameters' in api_props and api_props['parameters']: + invoke_args['parameters'] = self.parse_http_parameters(api_props['parameters']) + return request.invoke_http_endpoint(self.get_api_gateway_url(), **invoke_args) + @excp.exception(logger) def check_file_size(self, file_path): - asynch = self.get_property("asynchronous") file_size = utils.get_file_size(file_path) - error_msg = None if file_size > MAX_POST_BODY_SIZE: - error_msg = "Invalid request: Payload size {0:.2f} MB greater than 6 MB".format((file_size/(1024*1024))) - elif asynch and file_size > MAX_POST_BODY_SIZE_ASYNC: - error_msg = "Invalid request: Payload size {0:.2f} KB greater than 128 KB".format((file_size/(1024))) - if error_msg: - error_msg += "\nCheck AWS Lambda invocation limits in : https://docs.aws.amazon.com/lambda/latest/dg/limits.html" - logger.error(error_msg) + filesize = '{0:.2f}MB'.format(file_size/MB) + maxsize = '{0:.2f}MB'.format(MAX_POST_BODY_SIZE_ASYNC/MB) + raise excp.InvocationPayloadError(file_size= filesize, max_size=maxsize) + elif self.is_asynchronous() and file_size > MAX_POST_BODY_SIZE_ASYNC: + filesize = '{0:.2f}KB'.format(file_size/KB) + maxsize = '{0:.2f}KB'.format(MAX_POST_BODY_SIZE_ASYNC/KB) + raise excp.InvocationPayloadError(file_size=filesize, max_size=maxsize) + + def is_asynchronous(self): + return "asynchronous" in self.properties and self.properties['asynchronous'] diff --git a/src/providers/aws/payload.py b/src/providers/aws/payload.py deleted file mode 100644 index c8133c18..00000000 --- a/src/providers/aws/payload.py +++ /dev/null @@ -1,157 +0,0 @@ -# SCAR - Serverless Container-aware ARchitectures -# Copyright (C) 2011 - GRyCAP - Universitat Politecnica de Valencia -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -from .s3 import S3 -import os -import shutil -import src.logger as logger -import src.utils as utils -import subprocess -import tempfile -from distutils import dir_util - -MAX_PAYLOAD_SIZE = 50 * 1024 * 1024 -MAX_S3_PAYLOAD_SIZE = 250 * 1024 * 1024 -src_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -aws_src_path = os.path.dirname(os.path.abspath(__file__)) -lambda_code_files_path = aws_src_path + "/cloud/lambda/" -os_tmp_folder = tempfile.gettempdir() -scar_temporal_folder = os_tmp_folder + "/scar" -udocker_exec = scar_temporal_folder +"/udockerb" -udocker_tarball = "" -udocker_dir = "" -zip_file_path = os_tmp_folder +"/function.zip" - -def add_mandatory_files(function_name, env_vars): - os.makedirs(scar_temporal_folder, exist_ok=True) - shutil.copy(lambda_code_files_path + 'scarsupervisor.py', "{0}/{1}.py".format(scar_temporal_folder, function_name)) - shutil.copy(lambda_code_files_path + 'udockerb', udocker_exec) - - os.makedirs(scar_temporal_folder + "/src", exist_ok=True) - shutil.copy(lambda_code_files_path + '__init__.py', scar_temporal_folder + '/src/__init__.py') - shutil.copy(src_path + '/utils.py', scar_temporal_folder + '/src/utils.py') - shutil.copy(src_path + '/exceptions.py', scar_temporal_folder + '/src/exceptions.py') - - env_vars['UDOCKER_DIR'] = "/tmp/home/udocker" - env_vars['UDOCKER_LIB'] = "/var/task/udocker/lib/" - env_vars['UDOCKER_BIN'] = "/var/task/udocker/bin/" - create_udocker_files() - -def prepare_lambda_payload(**kwargs): - clean_tmp_folders() - add_mandatory_files(kwargs['FunctionName'], kwargs['EnvironmentVariables']) - - if 'DeploymentBucket' in kwargs: - if 'ImageId' in kwargs: - download_udocker_image(kwargs['ImageId'], kwargs['EnvironmentVariables']) - - if 'ImageFile' in kwargs: - prepare_udocker_image(kwargs['ImageFile'], kwargs['EnvironmentVariables']) - - if 'Script' in kwargs: - shutil.copy(kwargs['Script'], "{0}/init_script.sh".format(scar_temporal_folder)) - kwargs['EnvironmentVariables']['INIT_SCRIPT_PATH'] = "/var/task/init_script.sh" - - if 'ExtraPayload' in kwargs: - logger.info("Adding extra payload from %s" % kwargs['ExtraPayload']) - kwargs['EnvironmentVariables']['EXTRA_PAYLOAD'] = "/var/task" - dir_util.copy_tree(kwargs['ExtraPayload'], scar_temporal_folder) - - zip_scar_folder() - - # Check if the payload size fits within the aws limits - if((not ('DeploymentBucket' in kwargs)) and (os.path.getsize(zip_file_path) > MAX_PAYLOAD_SIZE)): - error_msg = "Payload size greater than 50MB.\nPlease reduce the payload size or use an S3 bucket and try again." - payload_size_error(zip_file_path, error_msg) - - if 'DeploymentBucket' in kwargs: - upload_file_to_S3_bucket(zip_file_path, kwargs['DeploymentBucket'], kwargs['FileKey']) - - clean_tmp_folders() - -def clean_tmp_folders(): - # Delete created temporal files - if os.path.isdir(scar_temporal_folder): - shutil.rmtree(scar_temporal_folder, ignore_errors=True) - -def zip_scar_folder(): - execute_command(["zip", "-r9y", zip_file_path, "."], cmd_wd=scar_temporal_folder, cli_msg="Creating function package") - -def set_tmp_udocker_env(): - #Avoid override global variables - if utils.has_dict_prop_value(os.environ, 'UDOCKER_TARBALL'): - udocker_tarball = os.environ['UDOCKER_TARBALL'] - if utils.has_dict_prop_value(os.environ, 'UDOCKER_DIR'): - udocker_dir = os.environ['UDOCKER_DIR'] - # Set temporal global vars - os.environ['UDOCKER_TARBALL'] = lambda_code_files_path + "udocker-1.1.0-RC2.tar.gz" - os.environ['UDOCKER_DIR'] = scar_temporal_folder + "/udocker" - -def restore_udocker_env(): - if udocker_tarball != "": - os.environ['UDOCKER_TARBALL'] = udocker_tarball - if udocker_dir != "": - os.environ['UDOCKER_DIR'] = udocker_dir - -def execute_command(command, cmd_wd=None, cli_msg=None): - cmd_out = subprocess.check_output(command, cwd=cmd_wd).decode("utf-8") - logger.info(cli_msg, cmd_out) - return cmd_out[:-1] - -def create_udocker_files(): - set_tmp_udocker_env() - execute_command(["python3", udocker_exec, "help"], cli_msg="Packing udocker files") - restore_udocker_env() - -def prepare_udocker_image(image_file, env_vars): - set_tmp_udocker_env() - shutil.copy(image_file, os_tmp_folder + "/udocker_image.tar.gz") - cmd_out = execute_command(["python3", udocker_exec, "load", "-i", os_tmp_folder + "/udocker_image.tar.gz"], cli_msg="Loading image file") - create_udocker_container(cmd_out) - env_vars['IMAGE_ID'] = cmd_out - env_vars['UDOCKER_REPOS'] = "/var/task/udocker/repos/" - env_vars['UDOCKER_LAYERS'] = "/var/task/udocker/layers/" - restore_udocker_env() - -def create_udocker_container(image_id): - if(utils.get_tree_size(scar_temporal_folder) < MAX_S3_PAYLOAD_SIZE/2): - execute_command(["python3", udocker_exec, "create", "--name=lambda_cont", image_id], cli_msg="Creating container structure") - if(utils.get_tree_size(scar_temporal_folder) > MAX_S3_PAYLOAD_SIZE): - shutil.rmtree(scar_temporal_folder + "/udocker/containers/") - -def download_udocker_image(image_id, env_vars): - set_tmp_udocker_env() - execute_command(["python3", udocker_exec, '--debug', "pull", image_id], cli_msg="Downloading container image") - create_udocker_container(image_id) - env_vars['UDOCKER_REPOS'] = "/var/task/udocker/repos/" - env_vars['UDOCKER_LAYERS'] = "/var/task/udocker/layers/" - restore_udocker_env() - -def upload_file_to_S3_bucket(image_file, deployment_bucket, file_key): - if(utils.get_tree_size(scar_temporal_folder) > MAX_S3_PAYLOAD_SIZE): - error_msg = "Uncompressed image size greater than 250MB.\nPlease reduce the uncompressed image and try again." - logger.error(error_msg) - utils.delete_file(zip_file_path) - exit(1) - - logger.info("Uploading '%s' to the '%s' S3 bucket" % (image_file, deployment_bucket)) - file_data = utils.read_file(image_file, 'rb') - S3().upload_file(deployment_bucket, file_key, file_data) - -def payload_size_error(zip_file_path, message): - logger.error(message) - utils.delete_file(zip_file_path) - exit(1) diff --git a/src/providers/aws/resourcegroups.py b/src/providers/aws/resourcegroups.py index 0fcc86b7..45c7507c 100644 --- a/src/providers/aws/resourcegroups.py +++ b/src/providers/aws/resourcegroups.py @@ -16,7 +16,7 @@ from botocore.exceptions import ClientError import src.logger as logger -from src.providers.aws.clientfactory import GenericClient +from src.providers.aws.botoclientfactory import GenericClient class ResourceGroups(GenericClient): diff --git a/src/providers/aws/response.py b/src/providers/aws/response.py index 731814a5..d6b8ac30 100644 --- a/src/providers/aws/response.py +++ b/src/providers/aws/response.py @@ -23,7 +23,7 @@ class OutputType(Enum): PLAIN_TEXT = 1 JSON = 2 VERBOSE = 3 - + def parse_http_response(response, function_name, asynch): if response.ok: text_message = "Request Id: {0}".format(response.headers['amz-lambda-request-id']) @@ -86,7 +86,7 @@ def parse_delete_log_response(response, log_group_name, output_type): print_generic_response(response, output_type, 'CloudWatchOutput', text_message) def parse_delete_api_response(response, api_id, output_type): - text_message = "REST API '%s' successfully deleted." % api_id + text_message = "API Endpoint '%s' successfully deleted." % api_id print_generic_response(response, output_type, 'APIGateway', text_message) def parse_ls_response(lambda_functions, output_type): diff --git a/src/providers/aws/s3.py b/src/providers/aws/s3.py index ce175cac..48a0b8c5 100644 --- a/src/providers/aws/s3.py +++ b/src/providers/aws/s3.py @@ -14,134 +14,127 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from botocore.exceptions import ClientError import src.logger as logger import os -from src.providers.aws.clientfactory import GenericClient +from src.providers.aws.botoclientfactory import GenericClient +import src.exceptions as excp +import src.utils as utils class S3(GenericClient): - - def __init__(self, aws_lambda=None): - if aws_lambda: - self.input_bucket = aws_lambda.get_property("input_bucket") - self.input_folder = aws_lambda.get_property("input_folder") - if self.input_folder and not self.input_folder.endswith("/"): - self.input_folder = self.input_folder + "/" - if self.input_folder is None: - self.input_folder = "{0}/input/".format(aws_lambda.get_function_name()) - self.function_arn = aws_lambda.get_property("function_arn") - self.region = aws_lambda.get_property("region") - self.trigger_configuration = { "LambdaFunctionArn": "", - "Events": [ "s3:ObjectCreated:*" ], - "Filter": { - "Key": { - "FilterRules": [{ "Name": "prefix", - "Value": "" }] - } - } - } + def __init__(self, aws_properties): + GenericClient.__init__(self, aws_properties) + self.properties = aws_properties['s3'] + self.parse_input_folder() + + def parse_input_folder(self): + if not 'input_folder' in self.properties: + if 'name' in self.aws_properties['lambda']: + self.properties['input_folder'] = "{0}/input/".format(self.aws_properties['lambda']['name']) + else: + self.properties['input_folder'] = '' + elif not self.properties['input_folder'].endswith("/"): + self.properties['input_folder'] = "{0}/".format(self.properties['input_folder']) + + @excp.exception(logger) def create_bucket(self, bucket_name): - try: - if not self.client.find_bucket(bucket_name): - # Create the bucket if not found - self.client.create_bucket(bucket_name) - except ClientError as ce: - error_msg = "Error creating the bucket '{0}'".format(self.input_bucket) - logger.log_exception(error_msg, ce) + if not self.client.find_bucket(bucket_name): +# raise excp.ExistentBucketWarning(bucket_name=bucket_name) +# else: + self.client.create_bucket(bucket_name) + def create_output_bucket(self): + self.create_bucket(self.properties['output_bucket']) + + @excp.exception(logger) def add_bucket_folder(self): - try: - self.client.upload_file(self.input_bucket, self.input_folder) - except ClientError as ce: - error_msg = "Error creating the folder '{0}' in the bucket '{1}'".format(self.input_bucket, self.input_folder) - logger.log_exception(error_msg, ce) + if self.properties['input_folder']: + self.upload_file(folder_name=self.properties['input_folder']) - def create_input_bucket(self): - self.create_bucket(self.input_bucket) - self.add_bucket_folder() + def create_input_bucket(self, create_input_folder=False): + self.create_bucket(self.properties['input_bucket']) + if create_input_folder: + self.add_bucket_folder() def set_input_bucket_notification(self): # First check that the function doesn't have other configurations - bucket_conf = self.client.get_bucket_notification_configuration(self.input_bucket) - trigger_conf = self.get_trigger_configuration(self.function_arn, self.input_folder) + input_bucket = self.properties['input_bucket'] + bucket_conf = self.client.get_bucket_notification_configuration(input_bucket) + trigger_conf = self.get_trigger_configuration() lambda_conf = [trigger_conf] if "LambdaFunctionConfigurations" in bucket_conf: lambda_conf = bucket_conf["LambdaFunctionConfigurations"] lambda_conf.append(trigger_conf) notification = { "LambdaFunctionConfigurations": lambda_conf } - self.client.put_bucket_notification_configuration(self.input_bucket, notification) + self.client.put_bucket_notification_configuration(input_bucket, notification) - def delete_bucket_notification(self, bucket_name, function_arn): - bucket_conf = self.client.get_bucket_notification_configuration(bucket_name) + def delete_bucket_notification(self): + bucket_conf = self.client.get_bucket_notification_configuration(self.properties['input_bucket']) if bucket_conf and "LambdaFunctionConfigurations" in bucket_conf: lambda_conf = bucket_conf["LambdaFunctionConfigurations"] - filter_conf = [x for x in lambda_conf if x['LambdaFunctionArn'] != function_arn] + filter_conf = [x for x in lambda_conf if x['LambdaFunctionArn'] != self.aws_properties['lambda']['arn']] notification = { "LambdaFunctionConfigurations": filter_conf } - self.client.put_bucket_notification_configuration(bucket_name, notification) - - def get_trigger_configuration(self, function_arn, folder_name): - self.trigger_configuration["LambdaFunctionArn"] = function_arn - self.trigger_configuration["Filter"]["Key"]["FilterRules"][0]["Value"] = folder_name - return self.trigger_configuration + self.client.put_bucket_notification_configuration(self.properties['input_bucket'], notification) - def get_processed_bucket_file_list(self): - file_list = [] - result = self.client.list_files(self.input_bucket, self.input_folder) - if 'Contents' in result: - for content in result['Contents']: - if content['Key'] and content['Key'] != self.input_folder: - file_list.append(content['Key']) - return file_list - - def upload_file(self, bucket_name, file_key, file_data): - try: - self.client.upload_file(bucket_name, file_key, file_data) - except ClientError as ce: - error_msg = "Error uploading the file '{0}' to the S3 bucket '{1}'".format(file_key, bucket_name) - logger.log_exception(error_msg, ce) + def get_trigger_configuration(self): + return {"LambdaFunctionArn": self.aws_properties['lambda']['function_arn'], + "Events": [ "s3:ObjectCreated:*" ], + "Filter": { "Key": { "FilterRules": [{ "Name": "prefix", "Value": self.properties['input_folder'] }]}} + } - def get_bucket_files(self, bucket_name, prefix_key): - file_list = [] + def get_file_key(self, folder_name=None, file_path=None, file_key=None): + if file_key: + return file_key + file_key = '' + if file_path: + file_key = os.path.basename(file_path) + if folder_name: + file_key = utils.join_paths(folder_name, file_key) + elif folder_name: + file_key = folder_name if folder_name.endswith('/') else '{0}/'.format(folder_name) + return file_key + + @excp.exception(logger) + def upload_file(self, folder_name=None, file_path=None, file_key=None): + kwargs = {'Bucket' : self.properties['input_bucket']} + kwargs['Key'] = self.get_file_key(folder_name, file_path, file_key) + if file_path: + kwargs['Body'] = utils.read_file(file_path, 'rb') + if folder_name and not file_path: + logger.info("Folder '{0}' created in bucket '{1}'".format(kwargs['Key'], kwargs['Bucket'])) + else: + logger.info("Uploading file '{0}' to bucket '{1}' from '{2}'".format(kwargs['Key'], kwargs['Bucket'], file_path)) + self.client.upload_file(**kwargs) + + @excp.exception(logger) + def get_bucket_file_list(self): + bucket_name = self.properties['input_bucket'] if self.client.find_bucket(bucket_name): - if prefix_key is None: - prefix_key = '' - result = self.client.list_files(bucket_name, key=prefix_key) - if 'Contents' in result: - for info in result['Contents']: - file_list += [info['Key']] + kwargs = {"Bucket" : bucket_name} + if ('input_folder' in self.properties) and self.properties['input_folder']: + kwargs["Prefix"] = self.properties['input_folder'] + response = self.client.list_files(**kwargs) + return self.parse_file_keys(response) else: - logger.warning("Bucket '{0}' not found".format(bucket_name)) - return file_list + raise excp.BucketNotFoundError(bucket_name=bucket_name) - def download_bucket_files(self, bucket_name, file_prefix, output): - file_key_list = self.get_bucket_files(bucket_name, file_prefix) - for file_key in file_key_list: - # Avoid download s3 'folders' - if not file_key.endswith('/'): - # Parse file path - file_name = os.path.basename(file_key) - file_dir = file_key.replace(file_name, "") - dir_name = os.path.dirname(file_prefix) - if dir_name != '': - local_path = file_dir.replace(os.path.dirname(file_prefix)+"/", "") - else: - local_path = file_prefix + "/" - # Modify file path if there is an output defined - if output: - if not output.endswith('/') and len(file_key_list) == 1: - file_path = output - else: - local_path = output + local_path - file_path = local_path + file_name - else: - file_path = local_path + file_name - # make sure the folders are created - if os.path.dirname(local_path) != '' and not os.path.isdir(local_path): - os.makedirs(local_path, exist_ok=True) - self.download_file(bucket_name, file_key, file_path) + def parse_file_keys(self, response): + return [info['Key'] for elem in response if 'Contents' in elem for info in elem['Contents'] if not info['Key'].endswith('/')] + + def get_s3_event(self, s3_file_key): + return { "Records" : [ {"eventSource" : "aws:s3", + "s3" : {"bucket" : { "name" : self.properties['input_bucket'] }, + "object" : { "key" : s3_file_key } } + }]} + + def get_s3_event_list(self, s3_file_keys): + s3_events = [] + for s3_key in s3_file_keys: + s3_events.append(self.get_s3_event(s3_key)) def download_file(self, bucket_name, file_key, file_path): + kwargs = {'Bucket' : bucket_name, 'Key' : file_key} logger.info("Downloading file '{0}' from bucket '{1}' in path '{2}'".format(file_key, bucket_name, file_path)) - with open(file_path, 'wb') as f: - self.client.download_file(bucket_name, file_key, f) + with open(file_path, 'wb') as file: + kwargs['Fileobj'] = file + self.client.download_file(**kwargs) diff --git a/src/providers/aws/validators.py b/src/providers/aws/validators.py index e936d55e..038ac707 100644 --- a/src/providers/aws/validators.py +++ b/src/providers/aws/validators.py @@ -15,35 +15,62 @@ # along with this program. If not, see . import src.utils as utils -from botocore.exceptions import ClientError +from src.exceptions import ValidatorError, S3CodeSizeError, FunctionCodeSizeError +from src.validator import GenericValidator +import os -def create_clienterror(error_msg, operation_name): - error = {'Error' : {'Message' : error_msg}} - return ClientError(error, operation_name) +valid_lambda_name_regex = "(arn:(aws[a-zA-Z-]*)?:lambda:)?([a-z]{2}(-gov)?-[a-z]+-\d{1}:)?(\d{12}:)?(function:)?([a-zA-Z0-9-_]+)(:(\$LATEST|[a-zA-Z0-9-_]+))?" -def validate_iam_role(iam_props): - if (("role" not in iam_props) or (iam_props["role"] == "")): - error_msg = "Please, specify a valid iam role in the configuration file (usually located in ~/.scar/scar.cfg)." - raise create_clienterror(error_msg, 'validate_iam_role') - -def validate_time(lambda_time): - if (lambda_time <= 0) or (lambda_time > 300): - error_msg = 'Incorrect time specified\nPlease, set a value between 0 and 300.' - raise create_clienterror(error_msg, 'validate_time') - -def validate_memory(lambda_memory): - if (lambda_memory < 128) or (lambda_memory > 3008): - error_msg = 'Incorrect memory size specified\nPlease, set a value between 128 and 3008.' - raise create_clienterror(error_msg, 'validate_memory') - -def validate_function_name(function_name, name_regex): - if not utils.find_expression(function_name, name_regex): - raise Exception("'{0}' is an invalid lambda function name.".format(function_name)) +class AWSValidator(GenericValidator): + + @classmethod + def validate_kwargs(cls, **kwargs): + prov_args = kwargs['aws'] + if 'iam' in prov_args: + cls.validate_iam(prov_args['iam']) + if 'lambda' in prov_args: + cls.validate_lambda(prov_args['lambda']) + + @staticmethod + def validate_iam(iam_properties): + if ("role" not in iam_properties) or (iam_properties["role"] == ""): + error_msg="Please, specify a valid iam role in the configuration file (usually located in ~/.scar/scar.cfg)." + raise ValidatorError(parameter='iam_role', parameter_value=iam_properties, error_msg=error_msg) + + @classmethod + def validate_lambda(cls, lambda_properties): + if 'name' in lambda_properties: + cls.validate_function_name(lambda_properties['name']) + if 'memory' in lambda_properties: + cls.validate_memory(lambda_properties['memory']) + if 'time' in lambda_properties: + cls.validate_time(lambda_properties['time']) + + @staticmethod + def validate_time(lambda_time): + if (lambda_time <= 0) or (lambda_time > 300): + error_msg = 'Please, set a value between 0 and 300.' + raise ValidatorError(parameter='lambda_time', parameter_value=lambda_time, error_msg=error_msg) + + @staticmethod + def validate_memory(lambda_memory): + if (lambda_memory < 128) or (lambda_memory > 3008): + error_msg = 'Please, set a value between 128 and 3008.' + raise ValidatorError(parameter='lambda_memory', parameter_value=lambda_memory, error_msg=error_msg) -def validate(**kwargs): - if 'MemorySize' in kwargs: - validate_memory(kwargs['MemorySize']) - if 'Timeout' in kwargs: - validate_time(kwargs['Timeout']) + @staticmethod + def validate_function_name(function_name): + if not utils.find_expression(function_name, valid_lambda_name_regex): + error_msg = 'Find name restrictions in: https://docs.aws.amazon.com/lambda/latest/dg/API_CreateFunction.html#SSS-CreateFunction-request-FunctionName' + raise ValidatorError(parameter='function_name', parameter_value=function_name, error_msg=error_msg) + @staticmethod + def validate_function_code_size(code_file_path, MAX_PAYLOAD_SIZE): + if os.path.getsize(code_file_path) > MAX_PAYLOAD_SIZE: + raise FunctionCodeSizeError(code_size='50MB') + + @staticmethod + def validate_s3_code_size(scar_folder, MAX_S3_PAYLOAD_SIZE): + if utils.get_tree_size(scar_folder) > MAX_S3_PAYLOAD_SIZE: + raise S3CodeSizeError(code_size='250MB') \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 9b6a0d14..fba8931d 100644 --- a/src/utils.py +++ b/src/utils.py @@ -18,12 +18,16 @@ import json import os import re -import uuid -import functools import subprocess import tarfile -from botocore.exceptions import ClientError -import src.exceptions as scar_excp +import tempfile +import uuid + +def join_paths(*paths): + return os.path.join(*paths) + +def get_temp_dir(): + return tempfile.gettempdir() def lazy_property(func): ''' A decorator that makes a property lazy-evaluated.''' @@ -36,33 +40,6 @@ def _lazy_property(self): return getattr(self, attr_name) return _lazy_property -def exception(logger): - ''' - A decorator that wraps the passed in function and logs exceptions - @param logger: The logging object - ''' - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except ClientError as ce: - print("There was an exception in {0}".format(func.__name__)) - print(ce.response['Error']['Message']) - logger.exception(ce) - except scar_excp.ScarError as se: - #print("There was an exception in {0}".format(func.__name__)) - print(se.args[0]) - logger.exception(se) - raise - except Exception as ex: - print("There was an unmanaged exception in {0}".format(func.__name__)) - logger.exception(ex) - # re-raise the exception - raise - return wrapper - return decorator - def find_expression(string_to_search, rgx_pattern): '''Returns the first group that matches the rgx_pattern in the string_to_search''' if string_to_search: @@ -80,9 +57,6 @@ def utf8_to_base64_string(value): def dict_to_base64_string(value): return base64.b64encode(json.dumps(value)).decode("utf-8") -def print_json(value): - print(json.dumps(value)) - def divide_list_in_chunks(elements, chunk_size): """Yield successive n-sized chunks from th elements list.""" if len(elements) == 0: @@ -93,22 +67,23 @@ def divide_list_in_chunks(elements, chunk_size): def get_random_uuid4_str(): return str(uuid.uuid4()) -def has_dict_prop_value(dictionary, value): - return (value in dictionary) and dictionary[value] and (dictionary[value] != "") - -def load_json_file(file_path): - if os.path.isfile(file_path): - with open(file_path) as f: - return json.load(f) - def merge_dicts(d1, d2): + ''' + Merge 'd1' and 'd2' dicts into 'd1'. + 'd2' has precedence over 'd1' + ''' for k,v in d2.items(): if v: - d1[k] = v + if k not in d1: + d1[k] = v + elif type(v) is dict: + d1[k] = merge_dicts(d1[k], v) + elif type(v) is list: + d1[k] += v return d1 -def check_key_in_dictionary(key, dictionary): - return (key in dictionary) and dictionary[key] and dictionary[key] != "" +def is_value_in_dict(dictionary, value): + return value in dictionary and dictionary[value] def get_tree_size(path): """Return total size of files in given path and subdirs.""" @@ -144,14 +119,15 @@ def read_file(file_path, mode="r"): return content_file.read() def delete_file(path): - os.remove(path) + if os.path.isfile(path): + os.remove(path) def create_tar_gz(files_to_archive, destination_tar_path): with tarfile.open(destination_tar_path, "w:gz") as tar: for file_path in files_to_archive: tar.add(file_path, arcname=os.path.basename(file_path)) return destination_tar_path - + def extract_tar_gz(tar_path, destination_path): with tarfile.open(tar_path, "r:gz") as tar: tar.extractall(path=destination_path) @@ -167,13 +143,23 @@ def execute_command_and_return_output(command): return subprocess.check_output(command).decode("utf-8") def is_variable_in_environment(variable): - return check_key_in_dictionary(variable, os.environ) + return is_value_in_dict(os.environ, variable) def set_environment_variable(key, variable): - if key and variable and key != "" and variable != "": + if key and variable: os.environ[key] = variable def get_environment_variable(variable): - if check_key_in_dictionary(variable, os.environ): + if is_variable_in_environment(variable): return os.environ[variable] +def parse_arg_list(arg_keys, cmd_args): + result = {} + for key in arg_keys: + if type(key) is tuple: + if key[0] in cmd_args and cmd_args[key[0]]: + result[key[1]] = cmd_args[key[0]] + else: + if key in cmd_args and cmd_args[key]: + result[key] = cmd_args[key] + return result diff --git a/src/validator.py b/src/validator.py new file mode 100644 index 00000000..fc2e652a --- /dev/null +++ b/src/validator.py @@ -0,0 +1,38 @@ +# SCAR - Serverless Container-aware ARchitectures +# Copyright (C) 2011 - GRyCAP - Universitat Politecnica de Valencia +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import abc + +class GenericValidator(metaclass=abc.ABCMeta): + ''' All the different cloud provider validators must inherit + from this class to ensure that the commands are defined consistently''' + + @classmethod + def validate(cls): + ''' + A decorator that wraps the passed in function and validates the dictionary parameters passed + ''' + def decorator(func): + def wrapper(*args, **kwargs): + cls.validate_kwargs(**kwargs) + return func(*args, **kwargs) + return wrapper + return decorator + + @classmethod + @abc.abstractmethod + def validate_kwargs(**kwargs): + pass