diff --git a/app/aws/credentials.py b/app/aws/credentials.py index 9406337..47eadf4 100644 --- a/app/aws/credentials.py +++ b/app/aws/credentials.py @@ -10,6 +10,7 @@ from app.core.config import ProfileGroup from app.core.result import Result +from app.util import util logger = logging.getLogger('logsmith') @@ -100,15 +101,16 @@ def check_access_key(access_key: str) -> Result: return result -def check_session() -> Result: +def check_session(access_key: str) -> Result: result = Result() credentials_file = _load_credentials_file() - if not credentials_file.has_section('session-token'): + session_token_profile_name = util.generate_session_name(access_key) + if not credentials_file.has_section(session_token_profile_name): logger.warning('no session token found') return result try: - client = _get_client('session-token', 'sts', timeout=2, retries=2) + client = _get_client(session_token_profile_name, 'sts', timeout=2, retries=2) client.get_caller_identity() except ClientError: # this is the normal case when the session token is not valid. Proceed then to fetch a new one @@ -126,8 +128,8 @@ def check_session() -> Result: def fetch_session_token(access_key: str, mfa_token: str) -> Result: result = Result() credentials_file = _load_credentials_file() - logger.info('fetch session-token') - profile = 'session-token' + logger.info(f'fetch session-token for {access_key}') + session_token_profile_name = util.generate_session_name(access_key) try: secrets = _get_session_token(access_key=access_key, mfa_token=mfa_token) @@ -147,9 +149,9 @@ def fetch_session_token(access_key: str, mfa_token: str) -> Result: logger.error(error_text, exc_info=True) return result - _add_profile_credentials(credentials_file, profile, secrets) + _add_profile_credentials(credentials_file, session_token_profile_name, secrets) _write_credentials_file(credentials_file) - logger.info('session-token successfully fetched') + logger.info(f'{session_token_profile_name} successfully fetched') result.set_success() return result @@ -162,7 +164,7 @@ def fetch_role_credentials(user_name: str, profile_group: ProfileGroup) -> Resul try: for profile in profile_group.profiles: logger.info(f'fetch {profile.profile}') - source_profile = profile.source or 'session-token' + source_profile = profile.source or util.generate_session_name(profile_group.get_access_key()) secrets = _assume_role(source_profile, user_name, profile.account, profile.role) _add_profile_credentials(credentials_file, profile.profile, secrets) if profile.default: @@ -183,10 +185,10 @@ def fetch_role_credentials(user_name: str, profile_group: ProfileGroup) -> Resul def _remove_unused_profiles(credentials_file, profile_group: ProfileGroup): used_profiles = profile_group.list_profile_names() - used_profiles.append('session-token') - for profile in credentials_file.sections(): - if profile not in used_profiles and not profile.startswith('access-key'): + if profile not in used_profiles and \ + not profile.startswith('access-key') and \ + not profile.startswith('session-token'): credentials_file.remove_section(profile) return credentials_file @@ -215,7 +217,6 @@ def write_profile_config(profile_group: ProfileGroup, region: str): def _remove_unused_configs(config_file: configparser, profile_group: ProfileGroup): used_profiles = profile_group.list_profile_names() - used_profiles.append('access-key') for config_name in config_file.sections(): profile = config_name.replace('profile ', '') @@ -242,9 +243,9 @@ def get_access_key_list() -> list: return access_key_list -def get_access_key_id(): +def get_access_key_id(key_name: str) -> str: credentials_file = _load_credentials_file() - return credentials_file.get('access-key', 'aws_access_key_id') + return credentials_file.get(key_name, 'aws_access_key_id') def _add_profile_credentials(credentials_file: configparser, profile: str, secrets: dict) -> None: diff --git a/app/aws/iam.py b/app/aws/iam.py index 87c0eb6..7710cc0 100644 --- a/app/aws/iam.py +++ b/app/aws/iam.py @@ -3,13 +3,14 @@ import boto3 from app.core.result import Result +from app.util import util logger = logging.getLogger('logsmith') -def create_access_key(user_name) -> Result: +def create_access_key(user_name, key_name) -> Result: result = Result() - session = boto3.Session(profile_name='session-token') + session = boto3.Session(profile_name=util.generate_session_name(key_name)) client = session.client('iam') try: @@ -34,9 +35,9 @@ def create_access_key(user_name) -> Result: return result -def delete_iam_access_key(user_name, key_id): +def delete_iam_access_key(user_name, key_name, key_id): result = Result() - session = boto3.Session(profile_name='session-token') + session = boto3.Session(profile_name=util.generate_session_name(key_name)) client = session.client('iam') try: diff --git a/app/core/config.py b/app/core/config.py index 7799534..cdf8567 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -18,7 +18,7 @@ def __init__(self): def load_from_disk(self): config = files.load_config() self.mfa_shell_command = config.get('mfa_shell_command', None) - access_key = config.get('default_access_key', None) + access_key = config.get('default_access_key', _default_access_key) accounts = files.load_accounts() self.set_accounts(accounts, access_key) @@ -87,7 +87,7 @@ def __init__(self, name, group: dict, default_access_key: str): self.default_access_key = default_access_key self.access_key: str = group.get('access_key', None) self.profiles: List[Profile] = [] - self.type: str = group.get("type", "aws") # only aws (default) & gcp as values are allowed + self.type: str = group.get('type', 'aws') # only aws (default) & gcp as values are allowed for profile in group.get('profiles', []): self.profiles.append(Profile(self, profile)) @@ -132,7 +132,7 @@ def to_dict(self): 'region': self.region, 'profiles': [profile.to_dict() for profile in self.profiles], } - if self.access_key != self.default_access_key: + if self.access_key and self.access_key != self.default_access_key: result_dict['access_key'] = self.access_key if self.type != "aws": result_dict["type"] = self.type diff --git a/app/core/core.py b/app/core/core.py index bbe8ac7..962eaf8 100644 --- a/app/core/core.py +++ b/app/core/core.py @@ -29,7 +29,7 @@ def login(self, profile_group: ProfileGroup, mfa_callback: Callable) -> Result: if not access_key_result.was_success: return access_key_result - session_result = credentials.check_session() + session_result = credentials.check_session(access_key=access_key) if session_result.was_error: return session_result if not session_result.was_success: @@ -142,19 +142,23 @@ def rotate_access_key(self, key_name: str, mfa_callback: Callable) -> Result: return access_key_result logger.info('fetch session') - renew_session_result = self._renew_session(access_key=key_name, mfa_callback=mfa_callback) - if not renew_session_result.was_success: - return renew_session_result + session_result = credentials.check_session(access_key=key_name) + if session_result.was_error: + return session_result + if not session_result.was_success: + renew_session_result = self._renew_session(access_key=key_name, mfa_callback=mfa_callback) + if not renew_session_result.was_success: + return renew_session_result logger.info('create key') user = credentials.get_user_name(key_name) - create_access_key_result = iam.create_access_key(user) + create_access_key_result = iam.create_access_key(user, key_name) if not create_access_key_result.was_success: return create_access_key_result logger.info('delete key') - previous_access_key_id = credentials.get_access_key_id() - delete_access_key_result = iam.delete_iam_access_key(user, previous_access_key_id) + previous_access_key_id = credentials.get_access_key_id(key_name) + delete_access_key_result = iam.delete_iam_access_key(user, key_name, previous_access_key_id) if not delete_access_key_result.was_success: return delete_access_key_result diff --git a/app/gui/access_key_dialog.py b/app/gui/access_key_dialog.py index 23c885c..076f275 100644 --- a/app/gui/access_key_dialog.py +++ b/app/gui/access_key_dialog.py @@ -35,10 +35,10 @@ def __init__(self, parent=None): self.key_id_input = QLineEdit(self) self.key_id_input.setStyleSheet("color: black; background-color: white;") - self.access_key_text = QLabel("Key secret:", self) - self.access_key_input = QLineEdit(self) - self.access_key_input.setStyleSheet("color: black; background-color: white;") - self.access_key_input.setEchoMode(QLineEdit.EchoMode.Password) + self.key_secret_text = QLabel("Key secret:", self) + self.key_secret_input = QLineEdit(self) + self.key_secret_input.setStyleSheet("color: black; background-color: white;") + self.key_secret_input.setEchoMode(QLineEdit.EchoMode.Password) self.ok_button = QPushButton("OK") self.ok_button.clicked.connect(self.ok) @@ -62,8 +62,8 @@ def __init__(self, parent=None): vbox.addWidget(self.key_name_input) vbox.addWidget(self.key_id_text) vbox.addWidget(self.key_id_input) - vbox.addWidget(self.access_key_text) - vbox.addWidget(self.access_key_input) + vbox.addWidget(self.key_secret_text) + vbox.addWidget(self.key_secret_input) vbox.addLayout(hbox) self.setLayout(vbox) @@ -96,8 +96,8 @@ def ok(self): key_id = self.key_id_input.text() key_id = key_id.strip() - access_key = self.access_key_input.text() - access_key = access_key.strip() + key_secret = self.key_secret_input.text() + key_secret = key_secret.strip() if not key_name: self.set_error_text('missing key name') @@ -105,14 +105,16 @@ def ok(self): if not key_id: self.set_error_text('missing key id') return - if not access_key: + if not key_secret: self.set_error_text('missing access key') return if key_name != '' and not key_name.startswith('access-key'): self.set_error_text('new key names must start with \'access-key\'') return - print(f'key_name={key_name}, key_id={key_id}, access_key={access_key}') - # self.gui.set_access_key(key_name=key_name, key_id=key_id, access_key=access_key) + if key_name != '' and key_name.startswith('session-token'): + self.set_error_text('new key names must not start with \'session-token\'') + return + self.gui.set_access_key(key_name=key_name, key_id=key_id, key_secret=key_secret) self.hide() def cancel(self): @@ -141,8 +143,8 @@ def show_dialog(self, access_key_list: List[str]): self.key_name_input.repaint() self.key_id_input.setText('') self.key_id_input.repaint() - self.access_key_input.setText('') - self.access_key_input.repaint() + self.key_secret_input.setText('') + self.key_secret_input.repaint() self.set_error_text('') self.existing_access_key_list = access_key_list diff --git a/app/gui/mfa_dialog.py b/app/gui/mfa_dialog.py index 9def104..9915b63 100644 --- a/app/gui/mfa_dialog.py +++ b/app/gui/mfa_dialog.py @@ -46,7 +46,9 @@ def cancel(self): self.hide() def get_value(self): - return self.input_field.text() + value = self.input_field.text() + value = value.strip() + return value def closeEvent(self, event): self.pressed_cancel = True diff --git a/tests/test_aws/test_iam.py b/app/util/__init__.py similarity index 100% rename from tests/test_aws/test_iam.py rename to app/util/__init__.py diff --git a/app/util/util.py b/app/util/util.py new file mode 100644 index 0000000..8c01925 --- /dev/null +++ b/app/util/util.py @@ -0,0 +1,2 @@ +def generate_session_name(key_name: str) -> str: + return f'session-token-{key_name}' diff --git a/tests/test_aws/test_credentials.py b/tests/test_aws/test_credentials.py index 226492d..d11b9cb 100644 --- a/tests/test_aws/test_credentials.py +++ b/tests/test_aws/test_credentials.py @@ -61,19 +61,25 @@ def test_has_access_key__no_access_key(self, mock_path): self.assertEqual('could not find access-key \'access-key\' in .aws/credentials', result.error_message) @mock.patch('app.aws.credentials._get_credentials_path') - def test_check_session__no_session(self, mock_path): + def test_check_session__no_session_found(self, mock_path): mock_path.return_value = self.test_credentials_file_path_without_keys - result = credentials.check_session() + + result = credentials.check_session('access-key') self.assertEqual(False, result.was_success) self.assertEqual(False, result.was_error) @mock.patch('app.aws.credentials._get_client') @mock.patch('app.aws.credentials._get_credentials_path') - def test_check_session(self, mock_path, _): + def test_check_session(self, mock_path, mock_get_client): mock_path.return_value = self.test_credentials_file_path - result = credentials.check_session() + result = credentials.check_session('access-key') + expected_calls = [ + call('session-token-access-key', 'sts', timeout=2, retries=2), + call().get_caller_identity() + ] + self.assertEqual(expected_calls, mock_get_client.mock_calls) self.assertEqual(True, result.was_success) self.assertEqual(False, result.was_error) @@ -86,7 +92,7 @@ def test_check_session__invalid_session(self, mock_path, mock_get_client): mock_client.get_caller_identity.side_effect = self.client_error mock_get_client.return_value = mock_client - result = credentials.check_session() + result = credentials.check_session('access-key') self.assertEqual(False, result.was_success) self.assertEqual(False, result.was_error) @@ -99,7 +105,7 @@ def test_check_session__connection_timeout(self, mock_path, mock_get_client): mock_client.get_caller_identity.side_effect = self.timeout_error mock_get_client.return_value = mock_client - result = credentials.check_session() + result = credentials.check_session('access-key') self.assertEqual(False, result.was_success) self.assertEqual(True, result.was_error) @@ -111,12 +117,12 @@ def test_check_session__connection_timeout(self, mock_path, mock_get_client): def test_fetch_session_token(self, mock_credentials, mock_session, mock_add_profile, mock_write): mock_session.return_value = self.test_secrets - result = credentials.fetch_session_token('some-token', 'mfa-token') + result = credentials.fetch_session_token('some-access-key', 'mfa-token') self.assertEqual(True, result.was_success) self.assertEqual(False, result.was_error) expected = call(mock_credentials.return_value, - 'session-token', + 'session-token-some-access-key', {'AccessKeyId': 'test-key-id', 'SecretAccessKey': 'test-access-key', 'SessionToken': 'test-session-token'}) @@ -131,7 +137,7 @@ def test_fetch_session_token__client_error(self, _, mock_session): mock_session.return_value = {} mock_session.side_effect = self.client_error - result = credentials.fetch_session_token('some-token', 'mfa-token') + result = credentials.fetch_session_token('some-access-key', 'mfa-token') self.assertEqual(False, result.was_success) self.assertEqual(True, result.was_error) @@ -143,7 +149,7 @@ def test_fetch_session_token__client_error(self, _, mock_session): mock_session.return_value = {} mock_session.side_effect = self.param_validation_error - result = credentials.fetch_session_token('some-token', 'mfa-token') + result = credentials.fetch_session_token('some-access-key', 'mfa-token') self.assertEqual(False, result.was_success) self.assertEqual(True, result.was_error) @@ -155,7 +161,7 @@ def test_fetch_session_token__no_credentials_error(self, _, mock_session): mock_session.return_value = {} mock_session.side_effect = self.no_credentials_error - result = credentials.fetch_session_token('some-token', 'mfa-token') + result = credentials.fetch_session_token('some-access-key', 'mfa-token') self.assertEqual(False, result.was_success) self.assertEqual(True, result.was_error) @@ -172,15 +178,58 @@ def test_fetch_role_credentials(self, mock_credentials, mock_assume, mock_add_pr mock_credentials.return_value = mock_config_parser mock_assume.return_value = self.test_secrets - profile_group = ProfileGroup('test', test_accounts.get_test_group(), 'default') + profile_group = ProfileGroup('test', test_accounts.get_test_group(), 'default-access-key') result = credentials.fetch_role_credentials('test_user', profile_group) self.assertEqual(True, result.was_success) self.assertEqual(False, result.was_error) self.assertEqual(3, mock_write_credentials.call_count) - expected_mock_assume_calls = [call('session-token', 'test_user', '123456789012', 'developer'), - call('session-token', 'test_user', '012345678901', 'readonly')] + expected_mock_assume_calls = [ + call('session-token-default-access-key', 'test_user', '123456789012', 'developer'), + call('session-token-default-access-key', 'test_user', '012345678901', 'readonly') + ] + self.assertEqual(expected_mock_assume_calls, mock_assume.call_args_list) + + expected_mock_add_profile_calls = [ + call(mock_config_parser, 'developer', {'AccessKeyId': 'test-key-id', + 'SecretAccessKey': 'test-access-key', + 'SessionToken': 'test-session-token'}), + call(mock_config_parser, 'readonly', {'AccessKeyId': 'test-key-id', + 'SecretAccessKey': 'test-access-key', + 'SessionToken': 'test-session-token'}), + call(mock_config_parser, 'default', {'AccessKeyId': 'test-key-id', + 'SecretAccessKey': 'test-access-key', + 'SessionToken': 'test-session-token'})] + self.assertEqual(expected_mock_add_profile_calls, mock_add_profile.call_args_list) + expected_mock_remove_profile_calls = [call(mock_config_parser, profile_group)] + self.assertEqual(expected_mock_remove_profile_calls, mock_remove_profile.call_args_list) + self.assertEqual(expected_mock_remove_profile_calls, mock_remove_profile.call_args_list) + + @mock.patch('app.aws.credentials._write_credentials_file') + @mock.patch('app.aws.credentials._remove_unused_profiles') + @mock.patch('app.aws.credentials._add_profile_credentials') + @mock.patch('app.aws.credentials._assume_role') + @mock.patch('app.aws.credentials._load_credentials_file') + def test_fetch_role_credentials_with_specific_access_key(self, mock_credentials, mock_assume, mock_add_profile, + mock_remove_profile, + mock_write_credentials): + mock_config_parser = Mock() + mock_credentials.return_value = mock_config_parser + mock_assume.return_value = self.test_secrets + + profile_group = ProfileGroup('test', test_accounts.get_test_group_with_specific_access_key(), + 'default-access-key') + result = credentials.fetch_role_credentials('test_user', profile_group) + self.assertEqual(True, result.was_success) + self.assertEqual(False, result.was_error) + + self.assertEqual(3, mock_write_credentials.call_count) + + expected_mock_assume_calls = [ + call('session-token-specific-access-key', 'test_user', '123456789012', 'developer'), + call('session-token-specific-access-key', 'test_user', '012345678901', 'readonly') + ] self.assertEqual(expected_mock_assume_calls, mock_assume.call_args_list) expected_mock_add_profile_calls = [ @@ -209,15 +258,17 @@ def test_fetch_role_credentials__no_default(self, mock_credentials, mock_assume, mock_credentials.return_value = mock_config_parser mock_assume.return_value = self.test_secrets - profile_group = ProfileGroup('test', test_accounts.get_test_group_no_default(), 'default') + profile_group = ProfileGroup('test', test_accounts.get_test_group_no_default(), 'default-access-key') result = credentials.fetch_role_credentials('test-user', profile_group) self.assertEqual(True, result.was_success) self.assertEqual(False, result.was_error) self.assertEqual(3, mock_write_credentials.call_count) - expected_mock_assume_calls = [call('session-token', 'test-user', '123456789012', 'developer'), - call('session-token', 'test-user', '012345678901', 'readonly')] + expected_mock_assume_calls = [ + call('session-token-default-access-key', 'test-user', '123456789012', 'developer'), + call('session-token-default-access-key', 'test-user', '012345678901', 'readonly') + ] self.assertEqual(expected_mock_assume_calls, mock_assume.call_args_list) expected_mock_add_profile_calls = [ @@ -243,15 +294,17 @@ def test_fetch_role_credentials__chain_assume(self, mock_credentials, mock_assum mock_credentials.return_value = mock_config_parser mock_assume.return_value = self.test_secrets - profile_group = ProfileGroup('test', test_accounts.get_test_group_chain_assume(), 'default') + profile_group = ProfileGroup('test', test_accounts.get_test_group_chain_assume(), 'default-access-key') result = credentials.fetch_role_credentials('test-user', profile_group) self.assertEqual(True, result.was_success) self.assertEqual(False, result.was_error) self.assertEqual(3, mock_write_credentials.call_count) - expected_mock_assume_calls = [call('session-token', 'test-user', '123456789012', 'developer'), - call('developer', 'test-user', '012345678901', 'service')] + expected_mock_assume_calls = [ + call('session-token-default-access-key', 'test-user', '123456789012', 'developer'), + call('developer', 'test-user', '012345678901', 'service') + ] self.assertEqual(expected_mock_assume_calls, mock_assume.call_args_list) expected_mock_add_profile_calls = [ @@ -267,9 +320,13 @@ def test_fetch_role_credentials__chain_assume(self, mock_credentials, mock_assum self.assertEqual(expected_mock_remove_profile_calls, mock_remove_profile.call_args_list) self.assertEqual(expected_mock_remove_profile_calls, mock_remove_profile.call_args_list) - def test___remove_unused_profiles(self): + def test_remove_unused_profiles(self): mock_config_parser = Mock() - mock_config_parser.sections.return_value = ['developer', 'unused-profile', 'access-key', 'session-token'] + mock_config_parser.sections.return_value = [ + 'developer', 'unused-profile', + 'access-key', 'session-token-access-key', + 'access-key-2', 'session-token-access-key-2' + ] mock_profile_group = Mock() mock_profile_group.list_profile_names.return_value = ['developer'] @@ -335,6 +392,7 @@ def test___remove_unused_configs(self): credentials._remove_unused_configs(mock_config_parser, mock_profile_group) expected = [call('profile unused-profile'), + call('profile access-key'), call('profile session-token')] self.assertEqual(expected, mock_config_parser.remove_section.call_args_list) diff --git a/tests/test_core/test_config.py b/tests/test_core/test_config.py index 7d04647..9ab6384 100644 --- a/tests/test_core/test_config.py +++ b/tests/test_core/test_config.py @@ -30,7 +30,7 @@ def test_load_from_disk(self, mock_load_accounts, mock_load_config): @mock.patch('app.core.config.files.save_config_file') @mock.patch('app.core.config.files.save_accounts_file') def test_save_to_disk(self, mock_save_accounts_file, mock_save_config_file): - self.config.set_accounts(get_test_accounts(), 'some-access-key') + self.config.set_accounts(get_test_accounts(), 'default-access-key') self.config.save_to_disk() expected_accounts = [call( { @@ -38,7 +38,6 @@ def test_save_to_disk(self, mock_save_accounts_file, mock_save_config_file): 'color': '#388E3C', 'team': 'awesome-team', 'region': 'us-east-1', - 'access_key': None, 'profiles': [ { 'profile': 'developer', @@ -57,13 +56,14 @@ def test_save_to_disk(self, mock_save_accounts_file, mock_save_config_file): 'color': '#388E3C', 'team': 'awesome-team', 'region': 'us-east-1', - 'access_key': None, + 'access_key': 'access-key-123', 'profiles': [ { 'profile': 'developer', 'account': '123456789012', 'role': 'developer', - 'default': True}, + 'default': True + }, { 'profile': 'readonly', 'account': '012345678901', @@ -76,7 +76,6 @@ def test_save_to_disk(self, mock_save_accounts_file, mock_save_config_file): 'team': 'another-team', 'region': 'europe-west1', 'type': 'gcp', - 'access_key': None, 'profiles': [], # this will be automatically added } } @@ -84,14 +83,14 @@ def test_save_to_disk(self, mock_save_accounts_file, mock_save_config_file): expected_config = [call({ 'mfa_shell_command': None, - 'default_access_key': 'some-access-key' + 'default_access_key': 'default-access-key' })] self.assertEqual(expected_accounts, mock_save_accounts_file.mock_calls) self.assertEqual(expected_config, mock_save_config_file.mock_calls) def test_set_accounts(self): - accounts = self.config.set_accounts(get_test_accounts(), 'some-access-key') + self.config.set_accounts(get_test_accounts(), 'default-access-key') groups = ['development', 'live', 'gcp-project-dev'] self.assertEqual(groups, list(self.config.profile_groups.keys())) @@ -102,6 +101,7 @@ def test_set_accounts(self): self.assertEqual('us-east-1', development_group.region) self.assertEqual('#388E3C', development_group.color) self.assertEqual('aws', development_group.type) + self.assertEqual('default-access-key', development_group.get_access_key()) profile = development_group.profiles[0] self.assertEqual(development_group, profile.group) @@ -117,11 +117,14 @@ def test_set_accounts(self): self.assertEqual('readonly', profile.role) self.assertEqual(False, profile.default) + live_group = self.config.get_group('live') + self.assertEqual('access-key-123', live_group.get_access_key()) + def test_validate(self): - self.config.set_accounts(get_test_accounts(), 'some-access-key') + self.config.set_accounts(get_test_accounts(), 'default-access-key') self.config.validate() - self.assertEqual(True, self.config.valid) self.assertEqual('', self.config.error) + self.assertEqual(True, self.config.valid) def test_validate_empty_config(self): self.config.validate() diff --git a/tests/test_core/test_core.py b/tests/test_core/test_core.py index d3eef0e..b3c1c4d 100644 --- a/tests/test_core/test_core.py +++ b/tests/test_core/test_core.py @@ -44,7 +44,8 @@ def test_login__session_token_error(self, mock_credentials): result = self.core.login(self.config.get_group('development'), None) - expected = [call.check_access_key(access_key='some-access-key'), call.check_session()] + expected = [call.check_access_key(access_key='some-access-key'), + call.check_session(access_key='some-access-key')] self.assertEqual(expected, mock_credentials.mock_calls) self.assertEqual(self.error_result, result) @@ -57,7 +58,8 @@ def test_login__mfa_error(self, mock_credentials): result = self.core.login(self.config.get_group('development'), None) - expected = [call.check_access_key(access_key='some-access-key'), call.check_session()] + expected = [call.check_access_key(access_key='some-access-key'), + call.check_session(access_key='some-access-key')] self.assertEqual(expected, mock_credentials.mock_calls) expected = [call(access_key='some-access-key', mfa_callback=None)] @@ -81,7 +83,7 @@ def test_login__successful_login(self, mock_credentials, _): result = self.core.login(profile_group, mock_mfa_callback) expected = [call.check_access_key(access_key='some-access-key'), - call.check_session(), + call.check_session(access_key='some-access-key'), call.get_user_name(access_key='some-access-key'), call.fetch_role_credentials('test-user', profile_group), call.write_profile_config(profile_group, 'us-east-1')] @@ -137,7 +139,8 @@ def test_rotate_access_key__no_access_key(self, mock_credentials, mock_logout): @mock.patch('app.core.core.Core.logout') @mock.patch('app.core.core.iam') @mock.patch('app.core.core.credentials') - def test_rotate_access_key__successful_rotate(self, mock_credentials, mock_iam, mock_logout, mock_renew_session): + def test_rotate_access_key__successful_rotate_with_valid_session(self, mock_credentials, mock_iam, mock_logout, + mock_renew_session): mock_credentials.check_access_key.return_value = self.success_result mock_credentials.check_session.return_value = self.success_result mock_credentials.get_user_name.return_value = 'test-user' @@ -155,17 +158,56 @@ def test_rotate_access_key__successful_rotate(self, mock_credentials, mock_iam, result = self.core.rotate_access_key('some-access-key', mock_mfa_callback) expected_credential_calls = [call.check_access_key(access_key='some-access-key'), - # call.check_session(), # TODO can't make sure if the session is valid because there is only one "session" + call.check_session(access_key='some-access-key'), + call.get_user_name('some-access-key'), + call.get_access_key_id('some-access-key'), + call.set_access_key(key_name='some-access-key', key_id=12345, key_secret=67890)] + self.assertEqual(expected_credential_calls, mock_credentials.mock_calls) + + self.assertEqual(0, mock_renew_session.call_count) + + expected_iam_calls = [call.create_access_key('test-user', 'some-access-key'), + call.delete_iam_access_key('test-user', 'some-access-key', '12345')] + self.assertEqual(expected_iam_calls, mock_iam.mock_calls) + + self.assertEqual(True, result.was_success) + self.assertEqual(False, result.was_error) + self.assertEqual(2, mock_logout.call_count) + + @mock.patch('app.core.core.Core._renew_session') + @mock.patch('app.core.core.Core.logout') + @mock.patch('app.core.core.iam') + @mock.patch('app.core.core.credentials') + def test_rotate_access_key__successful_rotate_with_new_session(self, mock_credentials, mock_iam, mock_logout, + mock_renew_session): + mock_credentials.check_access_key.return_value = self.success_result + mock_credentials.check_session.return_value = self.fail_result + mock_credentials.get_user_name.return_value = 'test-user' + mock_credentials.get_access_key_id.return_value = '12345' + mock_renew_session.return_value = self.success_result + + access_key_result = Result() + access_key_result.add_payload({'AccessKeyId': 12345, 'SecretAccessKey': 67890}) + access_key_result.set_success() + + mock_iam.create_access_key.return_value = access_key_result + mock_iam.delete_iam_access_key.return_value = self.success_result + + mock_mfa_callback = Mock() + result = self.core.rotate_access_key('some-access-key', mock_mfa_callback) + + expected_credential_calls = [call.check_access_key(access_key='some-access-key'), + call.check_session(access_key='some-access-key'), call.get_user_name('some-access-key'), - call.get_access_key_id(), + call.get_access_key_id('some-access-key'), call.set_access_key(key_name='some-access-key', key_id=12345, key_secret=67890)] self.assertEqual(expected_credential_calls, mock_credentials.mock_calls) renew_session_calls = [call(access_key='some-access-key', mfa_callback=mock_mfa_callback)] self.assertEqual(renew_session_calls, mock_renew_session.mock_calls) - expected_iam_calls = [call.create_access_key('test-user'), - call.delete_iam_access_key('test-user', '12345')] + expected_iam_calls = [call.create_access_key('test-user', 'some-access-key'), + call.delete_iam_access_key('test-user', 'some-access-key', '12345')] self.assertEqual(expected_iam_calls, mock_iam.mock_calls) self.assertEqual(True, result.was_success) diff --git a/tests/test_core/test_profile.py b/tests/test_core/test_profile.py index 8879e0e..6ec1718 100644 --- a/tests/test_core/test_profile.py +++ b/tests/test_core/test_profile.py @@ -2,7 +2,7 @@ from unittest.mock import Mock from app.core.config import Profile -from tests.test_data.test_accounts import get_test_profile, get_test_profile_no_default +from tests.test_data.test_accounts import get_test_profile, get_test_profile_no_default, get_test_profile_with_source class TestProfile(TestCase): @@ -11,40 +11,41 @@ def setUp(self): self.group_mock.name = 'test' self.profile = Profile(self.group_mock, get_test_profile()) - def test_init(self): + def test_profile(self): self.assertEqual(self.group_mock, self.profile.group) self.assertEqual('readonly', self.profile.profile) self.assertEqual('123456789012', self.profile.account) self.assertEqual('readonly-role', self.profile.role) self.assertEqual(True, self.profile.default) + self.assertEqual(None, self.profile.source) - def test_validate(self): + def test_profile_validate(self): result = self.profile.validate() expected = (True, '') self.assertEqual(expected, result) - def test_validate_no_profile(self): + def test_profile_validate__no_profile(self): self.profile.profile = None result = self.profile.validate() expected = (False, 'a profile in test has no profile') self.assertEqual(expected, result) - def test_validate_no_account(self): + def test_profile_validate__no_account(self): self.profile.account = None result = self.profile.validate() expected = (False, 'a profile in test has no account') self.assertEqual(expected, result) - def test_validate_no_role(self): + def test_profile_validate__no_role(self): self.profile.role = None result = self.profile.validate() expected = (False, 'a profile in test has no role') self.assertEqual(expected, result) - def test_to_dict(self): + def test_profile_to_dict(self): result = self.profile.to_dict() expected = {'account': '123456789012', 'default': True, @@ -52,8 +53,30 @@ def test_to_dict(self): 'role': 'readonly-role'} self.assertEqual(expected, result) - def test_to_dict_no_default(self): + def test_profile__no_default(self): + profile = Profile('test', get_test_profile_no_default()) + self.assertEqual('readonly', profile.profile) + self.assertEqual('123456789012', profile.account) + self.assertEqual('readonly-role', profile.role) + self.assertEqual(False, profile.default) + self.assertEqual(None, profile.source) + + def test_profile_to_dict__no_default(self): profile = Profile('test', get_test_profile_no_default()) result = profile.to_dict() expected = {'account': '123456789012', 'profile': 'readonly', 'role': 'readonly-role'} self.assertEqual(expected, result) + + def test_profile__with_source(self): + profile = Profile('test', get_test_profile_with_source()) + self.assertEqual('readonly', profile.profile) + self.assertEqual('123456789012', profile.account) + self.assertEqual('readonly-role', profile.role) + self.assertEqual(False, profile.default) + self.assertEqual('some-source', profile.source) + + def test_to_dict__with_source(self): + profile = Profile('test', get_test_profile_with_source()) + result = profile.to_dict() + expected = {'account': '123456789012', 'profile': 'readonly', 'role': 'readonly-role', 'source': 'some-source'} + self.assertEqual(expected, result) diff --git a/tests/test_core/test_profile_group.py b/tests/test_core/test_profile_group.py index c8609ab..31e1e47 100644 --- a/tests/test_core/test_profile_group.py +++ b/tests/test_core/test_profile_group.py @@ -2,7 +2,8 @@ from unittest.mock import Mock from app.core.config import ProfileGroup -from tests.test_data.test_accounts import get_test_group, get_test_group_no_default +from tests.test_data.test_accounts import get_test_group, get_test_group_no_default, \ + get_test_group_with_specific_access_key class TestProfileGroup(TestCase): @@ -14,36 +15,52 @@ def test_init(self): self.assertEqual('awesome-team', self.profile_group.team) self.assertEqual('us-east-1', self.profile_group.region) self.assertEqual('#388E3C', self.profile_group.color) + self.assertEqual('default', self.profile_group.default_access_key) + self.assertEqual(None, self.profile_group.access_key) + self.assertEqual('aws', self.profile_group.type) self.assertEqual(2, len(self.profile_group.profiles)) def test_validate(self): result = self.profile_group.validate() - expected = (True, '') self.assertEqual(expected, result) - def test_validate_no_team(self): + def test_validate__no_team(self): self.profile_group.team = None result = self.profile_group.validate() expected = (False, 'test has no team') self.assertEqual(expected, result) - def test_validate_no_region(self): + def test_validate__no_region(self): self.profile_group.region = None result = self.profile_group.validate() expected = (False, 'test has no region') self.assertEqual(expected, result) - def test_validate_no_color(self): + def test_validate__no_color(self): self.profile_group.color = None result = self.profile_group.validate() expected = (False, 'test has no color') self.assertEqual(expected, result) - def test_validate_calls_profile_validate(self): + def test_validate__access_key_malformed(self): + self.profile_group.access_key = 'no-key' + result = self.profile_group.validate() + + expected = (False, 'access-key no-key must have the prefix \"access-key\"') + self.assertEqual(expected, result) + + def test_validate__aws_type_must_have_profiled(self): + self.profile_group.profiles = [] + result = self.profile_group.validate() + + expected = (False, 'aws \"test\" has no profiles') + self.assertEqual(expected, result) + + def test_validate__calls_profile_validate(self): mock_profile1 = Mock() mock_profile1.validate.return_value = True, 'no error' mock_profile2 = Mock() @@ -64,7 +81,7 @@ def test_list_profile_names(self): expected = ['developer', 'readonly', 'default'] self.assertEqual(expected, self.profile_group.list_profile_names()) - def test_list_profile_names_no_default(self): + def test_list_profile_names__no_default(self): profile_group = ProfileGroup('test', get_test_group_no_default(), 'some-access-key') expected = ['developer', 'readonly'] self.assertEqual(expected, profile_group.list_profile_names()) @@ -73,21 +90,55 @@ def test_get_default_profile(self): result = self.profile_group.get_default_profile() self.assertEqual('readonly', result.profile) - def test_get_default_profile_no_default(self): + def test_get_default_profile__no_default(self): profile_group = ProfileGroup('test', get_test_group_no_default(), 'some-access-key') result = profile_group.get_default_profile() self.assertEqual(None, result) + def test_get_access_key(self): + profile_group = ProfileGroup('test', get_test_group(), 'some-access-key') + result = profile_group.get_access_key() + self.assertEqual('some-access-key', result) + + def test_get_access_key__with_specific_access_key(self): + profile_group = ProfileGroup('test', get_test_group_with_specific_access_key(), 'some-access-key') + result = profile_group.get_access_key() + self.assertEqual('specific-access-key', result) + def test_to_dict(self): + profile_group = ProfileGroup('test', get_test_group(), 'some-access-key') mock_profile1 = Mock() mock_profile1.to_dict.return_value = 'profile 1' mock_profile2 = Mock() mock_profile2.to_dict.return_value = 'profile 2' mock_profile3 = Mock() mock_profile3.to_dict.return_value = 'profile 3' - self.profile_group.profiles = [mock_profile1, mock_profile2, mock_profile3] + profile_group.profiles = [mock_profile1, mock_profile2, mock_profile3] + + result = profile_group.to_dict() + self.assertEqual(1, mock_profile1.to_dict.call_count) + self.assertEqual(1, mock_profile2.to_dict.call_count) + self.assertEqual(1, mock_profile3.to_dict.call_count) + + expected = { + 'color': '#388E3C', + 'profiles': ['profile 1', 'profile 2', 'profile 3'], + 'region': 'us-east-1', + 'team': 'awesome-team', + } + self.assertEqual(expected, result) + + def test_to_dict__with_specific_access_key(self): + profile_group = ProfileGroup('test', get_test_group_with_specific_access_key(), 'some-access-key') + mock_profile1 = Mock() + mock_profile1.to_dict.return_value = 'profile 1' + mock_profile2 = Mock() + mock_profile2.to_dict.return_value = 'profile 2' + mock_profile3 = Mock() + mock_profile3.to_dict.return_value = 'profile 3' + profile_group.profiles = [mock_profile1, mock_profile2, mock_profile3] - result = self.profile_group.to_dict() + result = profile_group.to_dict() self.assertEqual(1, mock_profile1.to_dict.call_count) self.assertEqual(1, mock_profile2.to_dict.call_count) self.assertEqual(1, mock_profile3.to_dict.call_count) @@ -97,7 +148,7 @@ def test_to_dict(self): 'profiles': ['profile 1', 'profile 2', 'profile 3'], 'region': 'us-east-1', 'team': 'awesome-team', - 'access_key': None, + 'access_key': 'specific-access-key', } self.assertEqual(expected, result) diff --git a/tests/test_data/test_accounts.py b/tests/test_data/test_accounts.py index 272d4b3..6e89c90 100644 --- a/tests/test_data/test_accounts.py +++ b/tests/test_data/test_accounts.py @@ -22,6 +22,7 @@ def get_test_accounts() -> dict: 'color': '#388E3C', 'team': 'awesome-team', 'region': 'us-east-1', + 'access_key': 'access-key-123', 'profiles': [ { 'profile': 'developer', @@ -67,6 +68,28 @@ def get_test_group(): } +def get_test_group_with_specific_access_key(): + return { + 'color': '#388E3C', + 'team': 'awesome-team', + 'region': 'us-east-1', + 'access_key': 'specific-access-key', + 'profiles': [ + { + 'profile': 'developer', + 'account': '123456789012', + 'role': 'developer', + }, + { + 'profile': 'readonly', + 'account': '012345678901', + 'role': 'readonly', + 'default': 'true', + } + ] + } + + def get_test_group_no_default(): return { 'color': '#388E3C', @@ -123,3 +146,12 @@ def get_test_profile_no_default(): 'account': '123456789012', 'role': 'readonly-role', } + + +def get_test_profile_with_source(): + return { + 'profile': 'readonly', + 'account': '123456789012', + 'role': 'readonly-role', + 'source': 'some-source' + } diff --git a/tests/test_resources/credential_file b/tests/test_resources/credential_file index fb72b26..14fa16f 100644 --- a/tests/test_resources/credential_file +++ b/tests/test_resources/credential_file @@ -2,7 +2,7 @@ aws_access_key_id = some_key_id aws_secret_access_key = some_access_key -[session-token] +[session-token-access-key] aws_access_key_id = some_key_id aws_secret_access_key = some_access_key aws_session_token = some_token diff --git a/tests/test_util/__init__.py b/tests/test_util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_util/test_util.py b/tests/test_util/test_util.py new file mode 100644 index 0000000..958211d --- /dev/null +++ b/tests/test_util/test_util.py @@ -0,0 +1,11 @@ +from unittest import TestCase + +from app.util.util import generate_session_name + + +class TestUtil(TestCase): + + def test_generate_session_name(self): + result = generate_session_name('key-name') + expected = 'session-token-key-name' + self.assertEqual(expected, result)