Skip to content

Commit

Permalink
[WIP] GH220
Browse files Browse the repository at this point in the history
  • Loading branch information
William Usher committed Jan 3, 2017
1 parent d6ebdc2 commit 1fa4f21
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 147 deletions.
104 changes: 87 additions & 17 deletions awslimitchecker/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@
################################################################################
"""

from .connectable import ConnectableCredentials
from .services import _services
from .version import _get_version_info
from .trustedadvisor import TrustedAdvisor
from .version import _get_version_info
import boto3
import sys
import logging
import warnings
Expand Down Expand Up @@ -153,24 +155,54 @@ def __init__(self, warning_threshold=80, critical_threshold=99,
self.mfa_serial_number = mfa_serial_number
self.mfa_token = mfa_token
self.region = region

self.services = {}

boto_connection_kwargs = self._boto3_connection_kwargs
for sname, cls in _services.items():
self.services[sname] = cls(warning_threshold, critical_threshold,
profile_name, account_id, account_role,
region, external_id, mfa_serial_number,
mfa_token)
self.ta = TrustedAdvisor(
self.services,
profile_name=profile_name,
account_id=account_id,
account_role=account_role,
region=region,
external_id=external_id,
mfa_serial_number=mfa_serial_number,
mfa_token=mfa_token,
ta_refresh_mode=ta_refresh_mode,
ta_refresh_timeout=ta_refresh_timeout
)
self.services[sname] = cls(warning_threshold,
critical_threshold,
boto_connection_kwargs)

self.ta = TrustedAdvisor(self.services,
boto_connection_kwargs,
ta_refresh_mode=ta_refresh_mode,
ta_refresh_timeout=ta_refresh_timeout)

@property
def _boto3_connection_kwargs(self):
"""
Generate keyword arguments for boto3 connection functions.
If ``self.account_id`` is None, this will just include
``region_name=self.region``. Otherwise, call
:py:meth:`~._get_sts_token` to get STS token credentials using
`boto3.STS.Client.assume_role <https://boto3.readthedocs.org/en/
latest/reference/services/sts.html#STS.Client.assume_role>`_ and include
those credentials in the return value.
:return: keyword arguments for boto3 connection functions
:rtype: dict
"""
kwargs = {'region_name': self.region}
if self.account_id is not None:
logger.debug("Connecting for account %s role '%s' with STS "
"(region: %s)", self.account_id, self.account_role,
self.region)
credentials = self._get_sts_token()
kwargs['aws_access_key_id'] = credentials.access_key
kwargs['aws_secret_access_key'] = credentials.secret_key
kwargs['aws_session_token'] = credentials.session_token
elif self.profile_name is not None:
# use boto3.Session to get credentials from the named profile
logger.debug("Using credentials profile: %s", self.profile_name)
session = boto3.Session(profile_name=self.profile_name)
credentials = session._session.get_credentials()
kwargs['aws_access_key_id'] = credentials.access_key
kwargs['aws_secret_access_key'] = credentials.secret_key
kwargs['aws_session_token'] = credentials.token
else:
logger.debug("Connecting to region: %s", self.region)
return kwargs

def get_version(self):
"""
Expand Down Expand Up @@ -227,6 +259,44 @@ def get_service_names(self):
"""
return sorted(self.services.keys())

def _get_sts_token(self):
"""
Assume a role via STS and return the credentials.
First connect to STS via :py:func:`boto3.client`, then
assume a role using `boto3.STS.Client.assume_role <https://boto3.readthe
docs.org/en/latest/reference/services/sts.html#STS.Client.assume_role>`_
using ``self.account_id`` and ``self.account_role`` (and optionally
``self.external_id``, ``self.mfa_serial_number``, ``self.mfa_token``).
Return the resulting :py:class:`~.ConnectableCredentials`
object.
:returns: STS assumed role credentials
:rtype: :py:class:`~.ConnectableCredentials`
"""
logger.debug("Connecting to STS in region %s", self.region)
sts = boto3.client('sts', region_name=self.region)
arn = "arn:aws:iam::%s:role/%s" % (self.account_id, self.account_role)
logger.debug("STS assume role for %s", arn)
assume_kwargs = {
'RoleArn': arn,
'RoleSessionName': 'awslimitchecker'
}
if self.external_id is not None:
assume_kwargs['ExternalId'] = self.external_id
if self.mfa_serial_number is not None:
assume_kwargs['SerialNumber'] = self.mfa_serial_number
if self.mfa_token is not None:
assume_kwargs['TokenCode'] = self.mfa_token
role = sts.assume_role(**assume_kwargs)

creds = ConnectableCredentials(role)
creds.account_id = self.account_id

logger.debug("Got STS credentials for role; access_key_id=%s "
"(account_id=%s)", creds.access_key, creds.account_id)
return creds

def find_usage(self, service=None, use_ta=True):
"""
For each limit in the specified service (or all services if
Expand Down
91 changes: 2 additions & 89 deletions awslimitchecker/connectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,55 +72,6 @@ class Connectable(object):
connecting via regions and/or STS.
"""

# Class attribute to reuse credentials between calls
credentials = None

@property
def _boto3_connection_kwargs(self):
"""
Generate keyword arguments for boto3 connection functions.
If ``self.account_id`` is None, this will just include
``region_name=self.region``. Otherwise, call
:py:meth:`~._get_sts_token` to get STS token credentials using
`boto3.STS.Client.assume_role <https://boto3.readthedocs.org/en/
latest/reference/services/sts.html#STS.Client.assume_role>`_ and include
those credentials in the return value.
:return: keyword arguments for boto3 connection functions
:rtype: dict
"""
kwargs = {'region_name': self.region}
if self.account_id is not None:
if Connectable.credentials is None:
logger.debug("Connecting for account %s role '%s' with STS "
"(region: %s)", self.account_id, self.account_role,
self.region)
Connectable.credentials = self._get_sts_token()
else:
if self.account_id == Connectable.credentials.account_id:
logger.debug("Reusing previous STS credentials for "
"account %s", self.account_id)
else:
logger.debug("Previous STS credentials are for account %s; "
"getting new credentials for current account "
"(%s)", Connectable.credentials.account_id,
self.account_id)
Connectable.credentials = self._get_sts_token()
kwargs['aws_access_key_id'] = Connectable.credentials.access_key
kwargs['aws_secret_access_key'] = Connectable.credentials.secret_key
kwargs['aws_session_token'] = Connectable.credentials.session_token
elif self.profile_name is not None:
# use boto3.Session to get credentials from the named profile
logger.debug("Using credentials profile: %s", self.profile_name)
session = boto3.Session(profile_name=self.profile_name)
credentials = session._session.get_credentials()
kwargs['aws_access_key_id'] = credentials.access_key
kwargs['aws_secret_access_key'] = credentials.secret_key
kwargs['aws_session_token'] = credentials.token
else:
logger.debug("Connecting to region %s", self.region)
return kwargs

def connect(self):
"""
Connect to an AWS API via boto3 low-level client and set ``self.conn``
Expand All @@ -135,8 +86,8 @@ def connect(self):
return
kwargs = self._boto3_connection_kwargs
self.conn = boto3.client(self.api_name, **kwargs)
logger.info("Connected to %s in region %s", self.api_name,
self.conn._client_config.region_name)
logger.info("Connected to %s in region %s",
self.api_name, self.conn._client_config.region_name)

def connect_resource(self):
"""
Expand All @@ -155,41 +106,3 @@ def connect_resource(self):
self.resource_conn = boto3.resource(self.api_name, **kwargs)
logger.info("Connected to %s (resource) in region %s", self.api_name,
self.resource_conn.meta.client._client_config.region_name)

def _get_sts_token(self):
"""
Assume a role via STS and return the credentials.
First connect to STS via :py:func:`boto3.client`, then
assume a role using `boto3.STS.Client.assume_role <https://boto3.readthe
docs.org/en/latest/reference/services/sts.html#STS.Client.assume_role>`_
using ``self.account_id`` and ``self.account_role`` (and optionally
``self.external_id``, ``self.mfa_serial_number``, ``self.mfa_token``).
Return the resulting :py:class:`~.ConnectableCredentials`
object.
:returns: STS assumed role credentials
:rtype: :py:class:`~.ConnectableCredentials`
"""
logger.debug("Connecting to STS in region %s", self.region)
sts = boto3.client('sts', region_name=self.region)
arn = "arn:aws:iam::%s:role/%s" % (self.account_id, self.account_role)
logger.debug("STS assume role for %s", arn)
assume_kwargs = {
'RoleArn': arn,
'RoleSessionName': 'awslimitchecker'
}
if self.external_id is not None:
assume_kwargs['ExternalId'] = self.external_id
if self.mfa_serial_number is not None:
assume_kwargs['SerialNumber'] = self.mfa_serial_number
if self.mfa_token is not None:
assume_kwargs['TokenCode'] = self.mfa_token
role = sts.assume_role(**assume_kwargs)

creds = ConnectableCredentials(role)
creds.account_id = self.account_id

logger.debug("Got STS credentials for role; access_key_id=%s "
"(account_id=%s)", creds.access_key, creds.account_id)
return creds
13 changes: 2 additions & 11 deletions awslimitchecker/services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ class _AwsService(Connectable):
api_name = 'baseclass'

def __init__(self, warning_threshold, critical_threshold,
profile_name=None, account_id=None, account_role=None,
region=None, external_id=None, mfa_serial_number=None,
mfa_token=None):
boto_connection_kwargs):
"""
Describes an AWS service and its limits, and provides methods to
query current utilization.
Expand Down Expand Up @@ -101,14 +99,7 @@ def __init__(self, warning_threshold, critical_threshold,
"""
self.warning_threshold = warning_threshold
self.critical_threshold = critical_threshold
self.profile_name = profile_name
self.account_id = account_id
self.account_role = account_role
self.region = region
self.external_id = external_id
self.mfa_serial_number = mfa_serial_number
self.mfa_token = mfa_token

self._boto3_connection_kwargs = boto_connection_kwargs
self.limits = {}
self.limits = self.get_limits()
self.conn = None
Expand Down
2 changes: 1 addition & 1 deletion awslimitchecker/services/firehose.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _find_delivery_streams(self):
usage += len(streams['DeliveryStreamNames'])
self.limits['Delivery streams per region']._add_current_usage(
usage,
resource_id=self.region,
resource_id=self._boto3_connection_kwargs['region_name'],
aws_type='AWS::KinesisFirehose::DeliveryStream',
)

Expand Down
20 changes: 8 additions & 12 deletions awslimitchecker/tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,14 @@ def test_init(self):
assert self.cls.services == services
# _AwsService instances should exist, but have no other calls
assert self.mock_foo.mock_calls == [
call(80, 99, None, None, None, None, None, None, None)
call(80, 99, {'region_name': None})
]
assert self.mock_bar.mock_calls == [
call(80, 99, None, None, None, None, None, None, None)
call(80, 99, {'region_name': None})
]
assert self.mock_ta_constr.mock_calls == [
call(services, account_id=None, account_role=None, region=None,
external_id=None, mfa_serial_number=None, mfa_token=None,
profile_name=None, ta_refresh_mode=None,
ta_refresh_timeout=None)
call(services, {'region_name': None},
ta_refresh_mode=None, ta_refresh_timeout=None)
]
assert self.mock_svc1.mock_calls == []
assert self.mock_svc2.mock_calls == []
Expand Down Expand Up @@ -180,16 +178,14 @@ def test_init_thresholds(self):
assert cls.services == services
# _AwsService instances should exist, but have no other calls
assert mock_foo.mock_calls == [
call(5, 22, None, None, None, None, None, None, None)
call(5, 22, {'region_name': None})
]
assert mock_bar.mock_calls == [
call(5, 22, None, None, None, None, None, None, None)
call(5, 22, {'region_name': None})
]
assert mock_ta_constr.mock_calls == [
call(services, account_id=None, account_role=None, region=None,
external_id=None, mfa_serial_number=None, mfa_token=None,
profile_name=None, ta_refresh_mode=None,
ta_refresh_timeout=None)
call(services, {'region_name': None},
ta_refresh_mode=None, ta_refresh_timeout=None)
]
assert mock_svc1.mock_calls == []
assert mock_svc2.mock_calls == []
Expand Down
4 changes: 1 addition & 3 deletions awslimitchecker/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_verify_limits(self, checker_args, creds_type, service_name, use_ta,
:type allow_endpoint_error: bool
"""
# clear the Connectable credentials
Connectable.credentials = None
Connectable._boto3_connection_kwargs = None
# destroy boto3's session, so it creates a new one
boto3.DEFAULT_SESSION = None
# set the env vars to the creds we want
Expand Down Expand Up @@ -215,8 +215,6 @@ def test_verify_usage(self, checker_args, creds_type, service_name,
:py:meth:`~.support.LogRecordHelper.unexpected_logs`
:type allow_endpoint_error: bool
"""
# clear the Connectable credentials
Connectable.credentials = None
# destroy boto3's session, so it creates a new one
boto3.DEFAULT_SESSION = None
# set the env vars to the creds we want
Expand Down
20 changes: 6 additions & 14 deletions awslimitchecker/trustedadvisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@


class TrustedAdvisor(Connectable):

"""
Class to handle interaction with TrustedAdvisor API, polling TA and updating
limits from TA information.
Expand All @@ -58,10 +57,8 @@ class TrustedAdvisor(Connectable):
service_name = 'TrustedAdvisor'
api_name = 'support'

def __init__(self, all_services, profile_name=None, account_id=None,
account_role=None, region=None, external_id=None,
mfa_serial_number=None, mfa_token=None, ta_refresh_mode=None,
ta_refresh_timeout=None):
def __init__(self, all_services, boto_connection_kwargs,
ta_refresh_mode=None, ta_refresh_timeout=None):
"""
Class to contain all TrustedAdvisor-related logic.
Expand Down Expand Up @@ -115,14 +112,9 @@ def __init__(self, all_services, profile_name=None, account_id=None,
"""
self.conn = None
self.have_ta = True
self.profile_name = profile_name
self.account_id = account_id
self.account_role = account_role
self.region = 'us-east-1'
self.ta_region = region
self.external_id = external_id
self.mfa_serial_number = mfa_serial_number
self.mfa_token = mfa_token
self.ta_region = boto_connection_kwargs['region_name']
self._boto3_connection_kwargs = boto_connection_kwargs
self.refresh_mode = ta_refresh_mode
self.refresh_timeout = ta_refresh_timeout
self.all_services = all_services
Expand Down Expand Up @@ -226,8 +218,8 @@ def _get_limit_check_id(self):
raise ex
for check in checks:
if (
check['category'] == 'performance' and
check['name'] == 'Service Limits'
check['category'] == 'performance' and
check['name'] == 'Service Limits'
):
logger.debug("Found TA check; id=%s", check['id'])
return (
Expand Down

0 comments on commit 1fa4f21

Please sign in to comment.