Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MFA device for cross account role switching #100

Merged
merged 5 commits into from
Dec 16, 2015
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions awslimitchecker/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class AwsLimitChecker(object):

def __init__(self, warning_threshold=80, critical_threshold=99,
account_id=None, account_role=None, region=None,
external_id=None):
external_id=None, mfa_serial_number=None, mfa_token=None):
"""
Main AwsLimitChecker class - this should be the only externally-used
portion of awslimitchecker.
Expand Down Expand Up @@ -83,6 +83,12 @@ def __init__(self, warning_threshold=80, critical_threshold=99,
com/IAM/latest/UserGuide/id_roles_create_for-user_externalid.html>`_
string to use when assuming a role via STS.
:type external_id: str
:param mfa_serial_number: (optional) the `MFA Serial Number` string to_
use when assuming a role via STS.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The continuation line here and on line 89 needs to be indented two spaces. That's why the Travis doc test is failing.

:type mfa_serial_number: str
:param mfa_token: (optional) the `MFA Token` string to use when_
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The continuation line here and on line 87 needs to be indented two spaces. That's why the Travis doc test is failing.

assuming a role via STS.
:type mfa_token: str
"""
# ###### IMPORTANT license notice ##########
# Pursuant to Sections 5(b) and 13 of the GNU Affero General Public
Expand Down Expand Up @@ -112,18 +118,23 @@ def __init__(self, warning_threshold=80, critical_threshold=99,
self.account_id = account_id
self.account_role = account_role
self.external_id = external_id
self.mfa_serial_number = mfa_serial_number
self.mfa_token = mfa_token
self.region = region
self.services = {}
self.ta = TrustedAdvisor(
account_id=account_id,
account_role=account_role,
region=region,
external_id=external_id
external_id=external_id,
mfa_serial_number=mfa_serial_number,
mfa_token=mfa_token
)
for sname, cls in _services.items():
self.services[sname] = cls(warning_threshold, critical_threshold,
account_id, account_role, region,
external_id)
external_id, mfa_serial_number,
mfa_token)

def get_version(self):
"""
Expand Down
29 changes: 20 additions & 9 deletions awslimitchecker/connectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Connectable(object):
connecting via regions and/or STS.
"""

# Class attribute to reuse credentials between calls
credentials = None

def connect_via(self, driver):
"""
Connect to an AWS API and return the connection object. If
Expand All @@ -64,14 +67,19 @@ def connect_via(self, driver):
:returns: connected boto service class instance
"""
if self.account_id is not None:
logger.debug("Connecting to %s for account %s (STS; %s)",
self.service_name, self.account_id, self.region)
self.credentials = self._get_sts_token()
if Connectable.credentials is None:
logger.debug("Connecting to %s for account %s (STS; %s)",
self.service_name, self.account_id, self.region)
Connectable.credentials = self._get_sts_token()
else:
logger.debug("Reusing previous STS credentials for account %s",
self.account_id)

conn = driver(
self.region,
aws_access_key_id=self.credentials.access_key,
aws_secret_access_key=self.credentials.secret_key,
security_token=self.credentials.session_token)
aws_access_key_id=Connectable.credentials.access_key,
aws_secret_access_key=Connectable.credentials.secret_key,
security_token=Connectable.credentials.session_token)
else:
logger.debug("Connecting to %s (%s)",
self.service_name, self.region)
Expand All @@ -86,8 +94,9 @@ def _get_sts_token(self):
First connect to STS via :py:func:`boto.sts.connect_to_region`, then
assume a role using :py:meth:`boto.sts.STSConnection.assume_role`
using ``self.account_id`` and ``self.account_role`` (and optionally
``self.external_id``). Return the resulting
:py:class:`boto.sts.credentials.Credentials` object.
``self.external_id``, ``self.mfa_serial_number``, ``self.mfa_token``).
Return the resulting :py:class:`boto.sts.credentials.Credentials`
object.

:returns: STS assumed role credentials
:rtype: :py:class:`boto.sts.credentials.Credentials`
Expand All @@ -97,7 +106,9 @@ def _get_sts_token(self):
arn = "arn:aws:iam::%s:role/%s" % (self.account_id, self.account_role)
logger.debug("STS assume role for %s", arn)
role = sts.assume_role(arn, "awslimitchecker",
external_id=self.external_id)
external_id=self.external_id,
mfa_serial_number=self.mfa_serial_number,
mfa_token=self.mfa_token)
logger.debug("Got STS credentials for role; access_key_id=%s",
role.credentials.access_key)
return role.credentials
10 changes: 9 additions & 1 deletion awslimitchecker/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def parse_args(self, argv):
p.add_argument('-E', '--external-id', action='store', type=str,
default=None, help='External ID to use when assuming '
'a role via STS')
p.add_argument('-M', '--mfa-serial-number', action='store', type=str,
default=None, help='MFA Serial Number to use when '
'assuming a role via STS')
p.add_argument('-T', '--mfa-token', action='store', type=str,
default=None, help='MFA Token to use when assuming '
'a role via STS')
p.add_argument('-r', '--region', action='store',
type=str, default=None,
help='AWS region name to connect to; required for STS')
Expand Down Expand Up @@ -301,7 +307,9 @@ def console_entry_point(self):
account_id=args.sts_account_id,
account_role=args.sts_account_role,
region=args.region,
external_id=args.external_id
external_id=args.external_id,
mfa_serial_number=args.mfa_serial_number,
mfa_token=args.mfa_token
)

if args.version:
Expand Down
11 changes: 10 additions & 1 deletion awslimitchecker/services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class _AwsService(Connectable):
service_name = 'baseclass'

def __init__(self, warning_threshold, critical_threshold, account_id=None,
account_role=None, region=None, external_id=None):
account_role=None, region=None, external_id=None,
mfa_serial_number=None, mfa_token=None):
"""
Describes an AWS service and its limits, and provides methods to
query current utilization.
Expand Down Expand Up @@ -84,13 +85,21 @@ def __init__(self, warning_threshold, critical_threshold, account_id=None,
com/IAM/latest/UserGuide/id_roles_create_for-user_externalid.html>`_
string to use when assuming a role via STS.
:type external_id: str
:param mfa_serial_number: (optional) the `MFA Serial Number` string to_
use when assuming a role via STS.
:type mfa_serial_number: str
:param mfa_token: (optional) the `MFA Token` string to use when_
assuming a role via STS.
:type mfa_token: str
"""
self.warning_threshold = warning_threshold
self.critical_threshold = critical_threshold
self.account_id = account_id
self.account_role = account_role
self.region = region
self.external_id = external_id
self.mfa_serial_number = mfa_serial_number
self.mfa_token = mfa_token

self.limits = {}
self.limits = self.get_limits()
Expand Down
46 changes: 30 additions & 16 deletions awslimitchecker/tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def test_init(self):
}
# _AwsService instances should exist, but have no other calls
assert self.mock_foo.mock_calls == [
call(80, 99, None, None, None, None)
call(80, 99, None, None, None, None, None, None)
]
assert self.mock_bar.mock_calls == [
call(80, 99, None, None, None, None)
call(80, 99, None, None, None, None, None, None)
]
assert self.mock_ta_constr.mock_calls == [
call(account_id=None, account_role=None, region=None,
external_id=None)
external_id=None, mfa_serial_number=None, mfa_token=None)
]
assert self.mock_svc1.mock_calls == []
assert self.mock_svc2.mock_calls == []
Expand Down Expand Up @@ -174,11 +174,15 @@ def test_init_thresholds(self):
'SvcBar': mock_svc2
}
# _AwsService instances should exist, but have no other calls
assert mock_foo.mock_calls == [call(5, 22, None, None, None, None)]
assert mock_bar.mock_calls == [call(5, 22, None, None, None, None)]
assert mock_foo.mock_calls == [
call(5, 22, None, None, None, None, None, None)
]
assert mock_bar.mock_calls == [
call(5, 22, None, None, None, None, None, None)
]
assert mock_ta_constr.mock_calls == [
call(account_id=None, account_role=None, region=None,
external_id=None)
external_id=None, mfa_serial_number=None, mfa_token=None)
]
assert mock_svc1.mock_calls == []
assert mock_svc2.mock_calls == []
Expand Down Expand Up @@ -215,14 +219,14 @@ def test_init_region(self):
}
# _AwsService instances should exist, but have no other calls
assert mock_foo.mock_calls == [
call(80, 99, None, None, 'myregion', None)
call(80, 99, None, None, 'myregion', None, None, None)
]
assert mock_bar.mock_calls == [
call(80, 99, None, None, 'myregion', None)
call(80, 99, None, None, 'myregion', None, None, None)
]
assert mock_ta_constr.mock_calls == [
call(account_id=None, account_role=None, region='myregion',
external_id=None)
external_id=None, mfa_serial_number=None, mfa_token=None)
]
assert mock_svc1.mock_calls == []
assert mock_svc2.mock_calls == []
Expand Down Expand Up @@ -263,17 +267,21 @@ def test_init_sts(self):
}
# _AwsService instances should exist, but have no other calls
assert mock_foo.mock_calls == [
call(80, 99, '123456789012', 'myrole', 'myregion', None)
call(80, 99, '123456789012', 'myrole', 'myregion', None,
None, None)
]
assert mock_bar.mock_calls == [
call(80, 99, '123456789012', 'myrole', 'myregion', None)
call(80, 99, '123456789012', 'myrole', 'myregion', None,
None, None)
]
assert mock_ta_constr.mock_calls == [
call(
account_id='123456789012',
account_role='myrole',
region='myregion',
external_id=None
external_id=None,
mfa_serial_number=None,
mfa_token=None
)
]
assert mock_svc1.mock_calls == []
Expand Down Expand Up @@ -307,7 +315,9 @@ def test_init_sts_external_id(self):
account_id='123456789012',
account_role='myrole',
region='myregion',
external_id='myextid'
external_id='myextid',
mfa_serial_number=None,
mfa_token=None
)
# dict should be of _AwsService instances
assert cls.services == {
Expand All @@ -316,17 +326,21 @@ def test_init_sts_external_id(self):
}
# _AwsService instances should exist, but have no other calls
assert mock_foo.mock_calls == [
call(80, 99, '123456789012', 'myrole', 'myregion', 'myextid')
call(80, 99, '123456789012', 'myrole', 'myregion', 'myextid',
None, None)
]
assert mock_bar.mock_calls == [
call(80, 99, '123456789012', 'myrole', 'myregion', 'myextid')
call(80, 99, '123456789012', 'myrole', 'myregion', 'myextid',
None, None)
]
assert mock_ta_constr.mock_calls == [
call(
account_id='123456789012',
account_role='myrole',
region='myregion',
external_id='myextid'
external_id='myextid',
mfa_serial_number=None,
mfa_token=None
)
]
assert mock_svc1.mock_calls == []
Expand Down
29 changes: 26 additions & 3 deletions awslimitchecker/tests/test_connectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ class ConnectableTester(Connectable):
service_name = 'connectable_tester'

def __init__(self, account_id=None, account_role=None, region=None,
external_id=None):
external_id=None, mfa_serial_number=None, mfa_token=None):
self.account_id = account_id
self.account_role = account_role
self.region = region
self.conn = None
self.external_id = external_id
self.mfa_serial_number = mfa_serial_number
self.mfa_token = mfa_token


class Test_Connectable(object):
Expand Down Expand Up @@ -121,7 +123,8 @@ def test_get_sts_token(self):
arn = 'arn:aws:iam::789:role/myr'
assert mock_connect.mock_calls == [
call('foobar'),
call().assume_role(arn, 'awslimitchecker', external_id=None),
call().assume_role(arn, 'awslimitchecker', external_id=None,
mfa_serial_number=None, mfa_token=None),
]
assume_role_ret = mock_connect.return_value.assume_role.return_value
assert res == assume_role_ret.credentials
Expand All @@ -136,7 +139,27 @@ def test_get_sts_token_external_id(self):
arn = 'arn:aws:iam::789:role/myr'
assert mock_connect.mock_calls == [
call('foobar'),
call().assume_role(arn, 'awslimitchecker', external_id='myextid'),
call().assume_role(arn, 'awslimitchecker', external_id='myextid',
mfa_serial_number=None, mfa_token=None),
]
assume_role_ret = mock_connect.return_value.assume_role.return_value
assert res == assume_role_ret.credentials

def test_get_sts_token_mfa(self):
cls = ConnectableTester(account_id='789',
account_role='myr', region='foobar',
external_id='myextid',
mfa_serial_number='arn:aws:iam::456:mfa/me',
mfa_token='123456')
with patch('awslimitchecker.connectable.boto.sts.connect_to_region'
'') as mock_connect:
res = cls._get_sts_token()
arn = 'arn:aws:iam::789:role/myr'
assert mock_connect.mock_calls == [
call('foobar'),
call().assume_role(arn, 'awslimitchecker', external_id='myextid',
mfa_serial_number='arn:aws:iam::456:mfa/me',
mfa_token='123456'),
]
assume_role_ret = mock_connect.return_value.assume_role.return_value
assert res == assume_role_ret.credentials
Loading