From 1fa4f21f5781242f725d6418cae1b4de79267c64 Mon Sep 17 00:00:00 2001 From: William Usher Date: Tue, 3 Jan 2017 17:16:26 -0500 Subject: [PATCH] [WIP] GH220 --- awslimitchecker/checker.py | 104 ++++++++++++++++++---- awslimitchecker/connectable.py | 91 +------------------ awslimitchecker/services/base.py | 13 +-- awslimitchecker/services/firehose.py | 2 +- awslimitchecker/tests/test_checker.py | 20 ++--- awslimitchecker/tests/test_integration.py | 4 +- awslimitchecker/trustedadvisor.py | 20 ++--- 7 files changed, 107 insertions(+), 147 deletions(-) diff --git a/awslimitchecker/checker.py b/awslimitchecker/checker.py index 23dacb02..cb186da9 100644 --- a/awslimitchecker/checker.py +++ b/awslimitchecker/checker.py @@ -37,9 +37,11 @@ ################################################################################ """ +from .connectable import ConnectableCredentials from .services import _services -from .version import _get_version_info from .trustedadvisor import TrustedAdvisor +from .version import _get_version_info +import boto3 import sys import logging import warnings @@ -153,24 +155,54 @@ def __init__(self, warning_threshold=80, critical_threshold=99, self.mfa_serial_number = mfa_serial_number self.mfa_token = mfa_token self.region = region + self.services = {} + + boto_connection_kwargs = self._boto3_connection_kwargs for sname, cls in _services.items(): - self.services[sname] = cls(warning_threshold, critical_threshold, - profile_name, account_id, account_role, - region, external_id, mfa_serial_number, - mfa_token) - self.ta = TrustedAdvisor( - self.services, - profile_name=profile_name, - account_id=account_id, - account_role=account_role, - region=region, - external_id=external_id, - mfa_serial_number=mfa_serial_number, - mfa_token=mfa_token, - ta_refresh_mode=ta_refresh_mode, - ta_refresh_timeout=ta_refresh_timeout - ) + self.services[sname] = cls(warning_threshold, + critical_threshold, + boto_connection_kwargs) + + self.ta = TrustedAdvisor(self.services, + boto_connection_kwargs, + ta_refresh_mode=ta_refresh_mode, + ta_refresh_timeout=ta_refresh_timeout) + + @property + def _boto3_connection_kwargs(self): + """ + Generate keyword arguments for boto3 connection functions. + If ``self.account_id`` is None, this will just include + ``region_name=self.region``. Otherwise, call + :py:meth:`~._get_sts_token` to get STS token credentials using + `boto3.STS.Client.assume_role `_ and include + those credentials in the return value. + + :return: keyword arguments for boto3 connection functions + :rtype: dict + """ + kwargs = {'region_name': self.region} + if self.account_id is not None: + logger.debug("Connecting for account %s role '%s' with STS " + "(region: %s)", self.account_id, self.account_role, + self.region) + credentials = self._get_sts_token() + kwargs['aws_access_key_id'] = credentials.access_key + kwargs['aws_secret_access_key'] = credentials.secret_key + kwargs['aws_session_token'] = credentials.session_token + elif self.profile_name is not None: + # use boto3.Session to get credentials from the named profile + logger.debug("Using credentials profile: %s", self.profile_name) + session = boto3.Session(profile_name=self.profile_name) + credentials = session._session.get_credentials() + kwargs['aws_access_key_id'] = credentials.access_key + kwargs['aws_secret_access_key'] = credentials.secret_key + kwargs['aws_session_token'] = credentials.token + else: + logger.debug("Connecting to region: %s", self.region) + return kwargs def get_version(self): """ @@ -227,6 +259,44 @@ def get_service_names(self): """ return sorted(self.services.keys()) + def _get_sts_token(self): + """ + Assume a role via STS and return the credentials. + + First connect to STS via :py:func:`boto3.client`, then + assume a role using `boto3.STS.Client.assume_role `_ + using ``self.account_id`` and ``self.account_role`` (and optionally + ``self.external_id``, ``self.mfa_serial_number``, ``self.mfa_token``). + Return the resulting :py:class:`~.ConnectableCredentials` + object. + + :returns: STS assumed role credentials + :rtype: :py:class:`~.ConnectableCredentials` + """ + logger.debug("Connecting to STS in region %s", self.region) + sts = boto3.client('sts', region_name=self.region) + arn = "arn:aws:iam::%s:role/%s" % (self.account_id, self.account_role) + logger.debug("STS assume role for %s", arn) + assume_kwargs = { + 'RoleArn': arn, + 'RoleSessionName': 'awslimitchecker' + } + if self.external_id is not None: + assume_kwargs['ExternalId'] = self.external_id + if self.mfa_serial_number is not None: + assume_kwargs['SerialNumber'] = self.mfa_serial_number + if self.mfa_token is not None: + assume_kwargs['TokenCode'] = self.mfa_token + role = sts.assume_role(**assume_kwargs) + + creds = ConnectableCredentials(role) + creds.account_id = self.account_id + + logger.debug("Got STS credentials for role; access_key_id=%s " + "(account_id=%s)", creds.access_key, creds.account_id) + return creds + def find_usage(self, service=None, use_ta=True): """ For each limit in the specified service (or all services if diff --git a/awslimitchecker/connectable.py b/awslimitchecker/connectable.py index 7c5c300e..6ab29ec4 100644 --- a/awslimitchecker/connectable.py +++ b/awslimitchecker/connectable.py @@ -72,55 +72,6 @@ class Connectable(object): connecting via regions and/or STS. """ - # Class attribute to reuse credentials between calls - credentials = None - - @property - def _boto3_connection_kwargs(self): - """ - Generate keyword arguments for boto3 connection functions. - If ``self.account_id`` is None, this will just include - ``region_name=self.region``. Otherwise, call - :py:meth:`~._get_sts_token` to get STS token credentials using - `boto3.STS.Client.assume_role `_ and include - those credentials in the return value. - - :return: keyword arguments for boto3 connection functions - :rtype: dict - """ - kwargs = {'region_name': self.region} - if self.account_id is not None: - if Connectable.credentials is None: - logger.debug("Connecting for account %s role '%s' with STS " - "(region: %s)", self.account_id, self.account_role, - self.region) - Connectable.credentials = self._get_sts_token() - else: - if self.account_id == Connectable.credentials.account_id: - logger.debug("Reusing previous STS credentials for " - "account %s", self.account_id) - else: - logger.debug("Previous STS credentials are for account %s; " - "getting new credentials for current account " - "(%s)", Connectable.credentials.account_id, - self.account_id) - Connectable.credentials = self._get_sts_token() - kwargs['aws_access_key_id'] = Connectable.credentials.access_key - kwargs['aws_secret_access_key'] = Connectable.credentials.secret_key - kwargs['aws_session_token'] = Connectable.credentials.session_token - elif self.profile_name is not None: - # use boto3.Session to get credentials from the named profile - logger.debug("Using credentials profile: %s", self.profile_name) - session = boto3.Session(profile_name=self.profile_name) - credentials = session._session.get_credentials() - kwargs['aws_access_key_id'] = credentials.access_key - kwargs['aws_secret_access_key'] = credentials.secret_key - kwargs['aws_session_token'] = credentials.token - else: - logger.debug("Connecting to region %s", self.region) - return kwargs - def connect(self): """ Connect to an AWS API via boto3 low-level client and set ``self.conn`` @@ -135,8 +86,8 @@ def connect(self): return kwargs = self._boto3_connection_kwargs self.conn = boto3.client(self.api_name, **kwargs) - logger.info("Connected to %s in region %s", self.api_name, - self.conn._client_config.region_name) + logger.info("Connected to %s in region %s", + self.api_name, self.conn._client_config.region_name) def connect_resource(self): """ @@ -155,41 +106,3 @@ def connect_resource(self): self.resource_conn = boto3.resource(self.api_name, **kwargs) logger.info("Connected to %s (resource) in region %s", self.api_name, self.resource_conn.meta.client._client_config.region_name) - - def _get_sts_token(self): - """ - Assume a role via STS and return the credentials. - - First connect to STS via :py:func:`boto3.client`, then - assume a role using `boto3.STS.Client.assume_role `_ - using ``self.account_id`` and ``self.account_role`` (and optionally - ``self.external_id``, ``self.mfa_serial_number``, ``self.mfa_token``). - Return the resulting :py:class:`~.ConnectableCredentials` - object. - - :returns: STS assumed role credentials - :rtype: :py:class:`~.ConnectableCredentials` - """ - logger.debug("Connecting to STS in region %s", self.region) - sts = boto3.client('sts', region_name=self.region) - arn = "arn:aws:iam::%s:role/%s" % (self.account_id, self.account_role) - logger.debug("STS assume role for %s", arn) - assume_kwargs = { - 'RoleArn': arn, - 'RoleSessionName': 'awslimitchecker' - } - if self.external_id is not None: - assume_kwargs['ExternalId'] = self.external_id - if self.mfa_serial_number is not None: - assume_kwargs['SerialNumber'] = self.mfa_serial_number - if self.mfa_token is not None: - assume_kwargs['TokenCode'] = self.mfa_token - role = sts.assume_role(**assume_kwargs) - - creds = ConnectableCredentials(role) - creds.account_id = self.account_id - - logger.debug("Got STS credentials for role; access_key_id=%s " - "(account_id=%s)", creds.access_key, creds.account_id) - return creds diff --git a/awslimitchecker/services/base.py b/awslimitchecker/services/base.py index b375f1c1..5ca98d6f 100644 --- a/awslimitchecker/services/base.py +++ b/awslimitchecker/services/base.py @@ -51,9 +51,7 @@ class _AwsService(Connectable): api_name = 'baseclass' def __init__(self, warning_threshold, critical_threshold, - profile_name=None, account_id=None, account_role=None, - region=None, external_id=None, mfa_serial_number=None, - mfa_token=None): + boto_connection_kwargs): """ Describes an AWS service and its limits, and provides methods to query current utilization. @@ -101,14 +99,7 @@ def __init__(self, warning_threshold, critical_threshold, """ self.warning_threshold = warning_threshold self.critical_threshold = critical_threshold - self.profile_name = profile_name - self.account_id = account_id - self.account_role = account_role - self.region = region - self.external_id = external_id - self.mfa_serial_number = mfa_serial_number - self.mfa_token = mfa_token - + self._boto3_connection_kwargs = boto_connection_kwargs self.limits = {} self.limits = self.get_limits() self.conn = None diff --git a/awslimitchecker/services/firehose.py b/awslimitchecker/services/firehose.py index c9ca6d38..7de98a18 100644 --- a/awslimitchecker/services/firehose.py +++ b/awslimitchecker/services/firehose.py @@ -84,7 +84,7 @@ def _find_delivery_streams(self): usage += len(streams['DeliveryStreamNames']) self.limits['Delivery streams per region']._add_current_usage( usage, - resource_id=self.region, + resource_id=self._boto3_connection_kwargs['region_name'], aws_type='AWS::KinesisFirehose::DeliveryStream', ) diff --git a/awslimitchecker/tests/test_checker.py b/awslimitchecker/tests/test_checker.py index 0b8e2cf6..b1a7f88f 100644 --- a/awslimitchecker/tests/test_checker.py +++ b/awslimitchecker/tests/test_checker.py @@ -112,16 +112,14 @@ def test_init(self): assert self.cls.services == services # _AwsService instances should exist, but have no other calls assert self.mock_foo.mock_calls == [ - call(80, 99, None, None, None, None, None, None, None) + call(80, 99, {'region_name': None}) ] assert self.mock_bar.mock_calls == [ - call(80, 99, None, None, None, None, None, None, None) + call(80, 99, {'region_name': None}) ] assert self.mock_ta_constr.mock_calls == [ - call(services, account_id=None, account_role=None, region=None, - external_id=None, mfa_serial_number=None, mfa_token=None, - profile_name=None, ta_refresh_mode=None, - ta_refresh_timeout=None) + call(services, {'region_name': None}, + ta_refresh_mode=None, ta_refresh_timeout=None) ] assert self.mock_svc1.mock_calls == [] assert self.mock_svc2.mock_calls == [] @@ -180,16 +178,14 @@ def test_init_thresholds(self): assert cls.services == services # _AwsService instances should exist, but have no other calls assert mock_foo.mock_calls == [ - call(5, 22, None, None, None, None, None, None, None) + call(5, 22, {'region_name': None}) ] assert mock_bar.mock_calls == [ - call(5, 22, None, None, None, None, None, None, None) + call(5, 22, {'region_name': None}) ] assert mock_ta_constr.mock_calls == [ - call(services, account_id=None, account_role=None, region=None, - external_id=None, mfa_serial_number=None, mfa_token=None, - profile_name=None, ta_refresh_mode=None, - ta_refresh_timeout=None) + call(services, {'region_name': None}, + ta_refresh_mode=None, ta_refresh_timeout=None) ] assert mock_svc1.mock_calls == [] assert mock_svc2.mock_calls == [] diff --git a/awslimitchecker/tests/test_integration.py b/awslimitchecker/tests/test_integration.py index 2f15684b..2ebd04a6 100644 --- a/awslimitchecker/tests/test_integration.py +++ b/awslimitchecker/tests/test_integration.py @@ -125,7 +125,7 @@ def test_verify_limits(self, checker_args, creds_type, service_name, use_ta, :type allow_endpoint_error: bool """ # clear the Connectable credentials - Connectable.credentials = None + Connectable._boto3_connection_kwargs = None # destroy boto3's session, so it creates a new one boto3.DEFAULT_SESSION = None # set the env vars to the creds we want @@ -215,8 +215,6 @@ def test_verify_usage(self, checker_args, creds_type, service_name, :py:meth:`~.support.LogRecordHelper.unexpected_logs` :type allow_endpoint_error: bool """ - # clear the Connectable credentials - Connectable.credentials = None # destroy boto3's session, so it creates a new one boto3.DEFAULT_SESSION = None # set the env vars to the creds we want diff --git a/awslimitchecker/trustedadvisor.py b/awslimitchecker/trustedadvisor.py index f0a43210..7eb1f3cc 100644 --- a/awslimitchecker/trustedadvisor.py +++ b/awslimitchecker/trustedadvisor.py @@ -49,7 +49,6 @@ class TrustedAdvisor(Connectable): - """ Class to handle interaction with TrustedAdvisor API, polling TA and updating limits from TA information. @@ -58,10 +57,8 @@ class TrustedAdvisor(Connectable): service_name = 'TrustedAdvisor' api_name = 'support' - def __init__(self, all_services, profile_name=None, account_id=None, - account_role=None, region=None, external_id=None, - mfa_serial_number=None, mfa_token=None, ta_refresh_mode=None, - ta_refresh_timeout=None): + def __init__(self, all_services, boto_connection_kwargs, + ta_refresh_mode=None, ta_refresh_timeout=None): """ Class to contain all TrustedAdvisor-related logic. @@ -115,14 +112,9 @@ def __init__(self, all_services, profile_name=None, account_id=None, """ self.conn = None self.have_ta = True - self.profile_name = profile_name - self.account_id = account_id - self.account_role = account_role self.region = 'us-east-1' - self.ta_region = region - self.external_id = external_id - self.mfa_serial_number = mfa_serial_number - self.mfa_token = mfa_token + self.ta_region = boto_connection_kwargs['region_name'] + self._boto3_connection_kwargs = boto_connection_kwargs self.refresh_mode = ta_refresh_mode self.refresh_timeout = ta_refresh_timeout self.all_services = all_services @@ -226,8 +218,8 @@ def _get_limit_check_id(self): raise ex for check in checks: if ( - check['category'] == 'performance' and - check['name'] == 'Service Limits' + check['category'] == 'performance' and + check['name'] == 'Service Limits' ): logger.debug("Found TA check; id=%s", check['id']) return (