diff --git a/awslimitchecker/connectable.py b/awslimitchecker/connectable.py index acabd1bf..07c73194 100644 --- a/awslimitchecker/connectable.py +++ b/awslimitchecker/connectable.py @@ -38,11 +38,30 @@ """ import logging -import boto.sts +import boto # @TODO boto3 migration - remove this when done +import boto.sts # @TODO boto3 migration - remove this when done +import boto3 logger = logging.getLogger(__name__) +class ConnectableCredentials(object): + """ + boto's (2.x) :py:meth:`boto.sts.STSConnection.assume_role` returns a + :py:class:`boto.sts.credentials.Credentials` object, but boto3's + :py:meth:`boto3.STS.Client.assume_role` just returns a dict. This class + provides a compatible interface for boto3. + """ + + def __init__(self, creds_dict): + self.access_key = creds_dict['Credentials']['AccessKeyId'] + self.secret_key = creds_dict['Credentials']['SecretAccessKey'] + self.session_token = creds_dict['Credentials']['SessionToken'] + self.expiration = creds_dict['Credentials']['Expiration'] + self.assumed_role_id = creds_dict['AssumedRoleUser']['AssumedRoleId'] + self.assumed_role_arn = creds_dict['AssumedRoleUser']['Arn'] + + class Connectable(object): """ @@ -66,6 +85,7 @@ def connect_via(self, driver): :type driver: :py:obj:`function` :returns: connected boto service class instance """ + # @TODO boto3 migration - remove this when done if self.account_id is not None: if Connectable.credentials is None: logger.debug("Connecting to %s for account %s (STS; %s)", @@ -87,6 +107,43 @@ def connect_via(self, driver): logger.info("Connected to %s", self.service_name) return conn + def connect_client(self, service_name): + """ + Connect to an AWS API and return the connected boto3 client object. If + ``self.account_id`` is None, call :py:meth:`boto3.client` with + ``region_name=self.region``. Otherwise, call :py:meth:`~._get_sts_token` + to get STS token credentials using + :py:meth:`boto.sts.STSConnection.assume_role` and call + :py:meth:`boto3.client` with those credentials to use an assumed role. + + This method returns a low-level boto3 client object. + + :param service_name: name of the AWS service API to connect to (passed + to :py:meth:`boto3.client` as the ``service_name`` parameter.) + :type driver: str + :returns: connected :py:meth:`boto3.client` class instance + """ + if self.account_id is not None: + if Connectable.credentials is None: + logger.debug("Connecting to %s for account %s (STS; %s)", + service_name, self.account_id, self.region) + Connectable.credentials = self._get_sts_token_boto3() + else: + logger.debug("Reusing previous STS credentials for account %s", + self.account_id) + conn = boto3.client( + service_name, + region_name=self.region, + aws_access_key_id=Connectable.credentials.access_key, + aws_secret_access_key=Connectable.credentials.secret_key, + aws_session_token=Connectable.credentials.session_token) + else: + logger.debug("Connecting to %s (%s)", + service_name, self.region) + conn = boto3.client(service_name, region_name=self.region) + logger.info("Connected to %s", service_name) + return conn + def _get_sts_token(self): """ Assume a role via STS and return the credentials. @@ -101,6 +158,7 @@ def _get_sts_token(self): :returns: STS assumed role credentials :rtype: :py:class:`boto.sts.credentials.Credentials` """ + # @TODO boto3 migration - remove this when done logger.debug("Connecting to STS in region %s", self.region) sts = boto.sts.connect_to_region(self.region) arn = "arn:aws:iam::%s:role/%s" % (self.account_id, self.account_role) @@ -112,3 +170,31 @@ def _get_sts_token(self): logger.debug("Got STS credentials for role; access_key_id=%s", role.credentials.access_key) return role.credentials + + def _get_sts_token_boto3(self): + """ + Assume a role via STS and return the credentials. + + First connect to STS via :py:func:`boto3.client`, then + assume a role using :py:meth:`boto3.STS.Client.assume_role` + using ``self.account_id`` and ``self.account_role`` (and optionally + ``self.external_id``, ``self.mfa_serial_number``, ``self.mfa_token``). + Return the resulting :py:class:`~.ConnectableCredentials` + object. + + :returns: STS assumed role credentials + :rtype: :py:class:`~.ConnectableCredentials` + """ + logger.debug("Connecting to STS in region %s", self.region) + sts = boto3.client('sts', region_name=self.region) + arn = "arn:aws:iam::%s:role/%s" % (self.account_id, self.account_role) + logger.debug("STS assume role for %s", arn) + role = sts.assume_role(RoleArn=arn, + RoleSessionName="awslimitchecker", + ExternalId=self.external_id, + SerialNumber=self.mfa_serial_number, + TokenCode=self.mfa_token) + creds = ConnectableCredentials(role) + logger.debug("Got STS credentials for role; access_key_id=%s", + creds.access_key) + return creds diff --git a/awslimitchecker/tests/test_connectable.py b/awslimitchecker/tests/test_connectable.py index 7ea892a6..716911e0 100644 --- a/awslimitchecker/tests/test_connectable.py +++ b/awslimitchecker/tests/test_connectable.py @@ -37,7 +37,8 @@ ################################################################################ """ -from awslimitchecker.connectable import Connectable +from awslimitchecker.connectable import Connectable, ConnectableCredentials +from datetime import datetime import sys # https://code.google.com/p/mock/issues/detail?id=249 @@ -51,6 +52,10 @@ from unittest.mock import patch, call, Mock +pbm = 'awslimitchecker.connectable' +pb = '%s.Connectable' % pbm + + class ConnectableTester(Connectable): """example class to test Connectable""" @@ -81,8 +86,7 @@ def test_connect_via_no_region(self): def test_connect_via_with_region(self): cls = ConnectableTester(region='foo') mock_driver = Mock() - with patch('awslimitchecker.connectable.Connectable._get_sts_token' - '') as mock_get_sts: + with patch('%s._get_sts_token' % pb) as mock_get_sts: res = cls.connect_via(mock_driver) assert mock_get_sts.mock_calls == [] assert mock_driver.mock_calls == [ @@ -99,8 +103,7 @@ def test_connect_via_sts(self): type(mock_creds).secret_key = 'sts_sk' type(mock_creds).session_token = 'sts_token' - with patch('awslimitchecker.connectable.Connectable._get_sts_token' - '') as mock_get_sts: + with patch('%s._get_sts_token' % pb) as mock_get_sts: mock_get_sts.return_value = mock_creds Connectable.credentials = None res = cls.connect_via(mock_driver) @@ -124,8 +127,7 @@ def test_connect_via_sts_again(self): type(mock_creds).secret_key = 'sts_sk' type(mock_creds).session_token = 'sts_token' - with patch('awslimitchecker.connectable.Connectable._get_sts_token' - '') as mock_get_sts: + with patch('%s._get_sts_token' % pb) as mock_get_sts: Connectable.credentials = mock_creds res = cls.connect_via(mock_driver) assert mock_get_sts.mock_calls == [] @@ -139,11 +141,79 @@ def test_connect_via_sts_again(self): ] assert res == mock_driver.return_value + def test_connect_client_no_region(self): + cls = ConnectableTester() + with patch('%s.boto3.client' % pbm) as mock_client: + res = cls.connect_client('foo') + assert mock_client.mock_calls == [ + call('foo', region_name=None) + ] + assert res == mock_client.return_value + + def test_connect_client_with_region(self): + cls = ConnectableTester(region='myregion') + with patch('%s._get_sts_token_boto3' % pb) as mock_get_sts: + with patch('%s.boto3.client' % pbm) as mock_client: + res = cls.connect_client('foo') + assert mock_get_sts.mock_calls == [] + assert mock_client.mock_calls == [ + call('foo', region_name='myregion') + ] + assert res == mock_client.return_value + + def test_connect_client_sts(self): + cls = ConnectableTester(account_id='123', account_role='myrole', + region='myregion') + mock_creds = Mock() + type(mock_creds).access_key = 'sts_ak' + type(mock_creds).secret_key = 'sts_sk' + type(mock_creds).session_token = 'sts_token' + + with patch('%s._get_sts_token_boto3' % pb) as mock_get_sts: + mock_get_sts.return_value = mock_creds + Connectable.credentials = None + with patch('%s.boto3.client' % pbm) as mock_client: + res = cls.connect_client('foo') + assert mock_get_sts.mock_calls == [call()] + assert mock_client.mock_calls == [ + call( + 'foo', + region_name='myregion', + aws_access_key_id='sts_ak', + aws_secret_access_key='sts_sk', + aws_session_token='sts_token' + ) + ] + assert res == mock_client.return_value + + def test_connect_client_sts_again(self): + cls = ConnectableTester(account_id='123', account_role='myrole', + region='myregion') + mock_creds = Mock() + type(mock_creds).access_key = 'sts_ak' + type(mock_creds).secret_key = 'sts_sk' + type(mock_creds).session_token = 'sts_token' + + with patch('%s._get_sts_token_boto3' % pb) as mock_get_sts: + Connectable.credentials = mock_creds + with patch('%s.boto3.client' % pbm) as mock_client: + res = cls.connect_client('foo') + assert mock_get_sts.mock_calls == [] + assert mock_client.mock_calls == [ + call( + 'foo', + region_name='myregion', + aws_access_key_id='sts_ak', + aws_secret_access_key='sts_sk', + aws_session_token='sts_token' + ) + ] + assert res == mock_client.return_value + def test_get_sts_token(self): cls = ConnectableTester(account_id='789', account_role='myr', region='foobar') - with patch('awslimitchecker.connectable.boto.sts.connect_to_region' - '') as mock_connect: + with patch('%s.boto.sts.connect_to_region' % pbm) as mock_connect: res = cls._get_sts_token() arn = 'arn:aws:iam::789:role/myr' assert mock_connect.mock_calls == [ @@ -158,8 +228,7 @@ def test_get_sts_token_external_id(self): cls = ConnectableTester(account_id='789', account_role='myr', region='foobar', external_id='myextid') - with patch('awslimitchecker.connectable.boto.sts.connect_to_region' - '') as mock_connect: + with patch('%s.boto.sts.connect_to_region' % pbm) as mock_connect: res = cls._get_sts_token() arn = 'arn:aws:iam::789:role/myr' assert mock_connect.mock_calls == [ @@ -176,8 +245,7 @@ def test_get_sts_token_mfa(self): external_id='myextid', mfa_serial_number='arn:aws:iam::456:mfa/me', mfa_token='123456') - with patch('awslimitchecker.connectable.boto.sts.connect_to_region' - '') as mock_connect: + with patch('%s.boto.sts.connect_to_region' % pbm) as mock_connect: res = cls._get_sts_token() arn = 'arn:aws:iam::789:role/myr' assert mock_connect.mock_calls == [ @@ -188,3 +256,52 @@ def test_get_sts_token_mfa(self): ] assume_role_ret = mock_connect.return_value.assume_role.return_value assert res == assume_role_ret.credentials + + def test_get_sts_token_boto3(self): + ret_dict = Mock() + cls = ConnectableTester(account_id='789', + account_role='myr', region='foobar') + with patch('%s.boto3.client' % pbm) as mock_connect: + with patch('%s.ConnectableCredentials' % pbm, + create=True) as mock_creds: + mock_connect.return_value.assume_role.return_value = ret_dict + res = cls._get_sts_token_boto3() + arn = 'arn:aws:iam::789:role/myr' + assert mock_connect.mock_calls == [ + call('sts', region_name='foobar'), + call().assume_role( + RoleArn=arn, + RoleSessionName='awslimitchecker', + ExternalId=None, + SerialNumber=None, + TokenCode=None), + ] + assert mock_creds.mock_calls == [ + call(ret_dict) + ] + assert res == mock_creds.return_value + + +class TestConnectableCredentials(object): + + def test_connectable_credentials(self): + result = { + 'Credentials': { + 'AccessKeyId': 'akid', + 'SecretAccessKey': 'secret', + 'SessionToken': 'token', + 'Expiration': datetime(2015, 1, 1) + }, + 'AssumedRoleUser': { + 'AssumedRoleId': 'roleid', + 'Arn': 'arn' + }, + 'PackedPolicySize': 123 + } + c = ConnectableCredentials(result) + assert c.access_key == 'akid' + assert c.secret_key == 'secret' + assert c.session_token == 'token' + assert c.expiration == datetime(2015, 1, 1) + assert c.assumed_role_id == 'roleid' + assert c.assumed_role_arn == 'arn' diff --git a/docs/source/conf.py b/docs/source/conf.py index 35aca52f..b93269c9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -282,6 +282,7 @@ intersphinx_mapping = { 'https://docs.python.org/': None, 'boto': ('http://boto.readthedocs.org/en/latest/', None) + 'boto3': ('http://boto3.readthedocs.org/en/latest/', None) } autoclass_content = 'init' diff --git a/setup.py b/setup.py index c37a66ab..f72104b7 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ requires = [ 'boto>=2.32.0', + 'boto3>=1.2.3', 'termcolor>=1.1.0', 'python-dateutil>=2.4.2', ] diff --git a/tox.ini b/tox.ini index 0cbdf923..abd2a625 100644 --- a/tox.ini +++ b/tox.ini @@ -16,6 +16,7 @@ deps = mock freezegun boto==2.32.0 + boto3==1.2.3 pytest-blockage virtualenv