Skip to content

Commit

Permalink
each access-key get their own session-token to make switching profile…
Browse files Browse the repository at this point in the history
…s and rotating keys faster an easier
  • Loading branch information
redvox committed Apr 5, 2024
1 parent 2763912 commit db89ea4
Show file tree
Hide file tree
Showing 17 changed files with 333 additions and 101 deletions.
29 changes: 15 additions & 14 deletions app/aws/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from app.core.config import ProfileGroup
from app.core.result import Result
from app.util import util

logger = logging.getLogger('logsmith')

Expand Down Expand Up @@ -100,15 +101,16 @@ def check_access_key(access_key: str) -> Result:
return result


def check_session() -> Result:
def check_session(access_key: str) -> Result:
result = Result()
credentials_file = _load_credentials_file()
if not credentials_file.has_section('session-token'):
session_token_profile_name = util.generate_session_name(access_key)
if not credentials_file.has_section(session_token_profile_name):
logger.warning('no session token found')
return result

try:
client = _get_client('session-token', 'sts', timeout=2, retries=2)
client = _get_client(session_token_profile_name, 'sts', timeout=2, retries=2)
client.get_caller_identity()
except ClientError:
# this is the normal case when the session token is not valid. Proceed then to fetch a new one
Expand All @@ -126,8 +128,8 @@ def check_session() -> Result:
def fetch_session_token(access_key: str, mfa_token: str) -> Result:
result = Result()
credentials_file = _load_credentials_file()
logger.info('fetch session-token')
profile = 'session-token'
logger.info(f'fetch session-token for {access_key}')
session_token_profile_name = util.generate_session_name(access_key)

try:
secrets = _get_session_token(access_key=access_key, mfa_token=mfa_token)
Expand All @@ -147,9 +149,9 @@ def fetch_session_token(access_key: str, mfa_token: str) -> Result:
logger.error(error_text, exc_info=True)
return result

_add_profile_credentials(credentials_file, profile, secrets)
_add_profile_credentials(credentials_file, session_token_profile_name, secrets)
_write_credentials_file(credentials_file)
logger.info('session-token successfully fetched')
logger.info(f'{session_token_profile_name} successfully fetched')
result.set_success()
return result

Expand All @@ -162,7 +164,7 @@ def fetch_role_credentials(user_name: str, profile_group: ProfileGroup) -> Resul
try:
for profile in profile_group.profiles:
logger.info(f'fetch {profile.profile}')
source_profile = profile.source or 'session-token'
source_profile = profile.source or util.generate_session_name(profile_group.get_access_key())
secrets = _assume_role(source_profile, user_name, profile.account, profile.role)
_add_profile_credentials(credentials_file, profile.profile, secrets)
if profile.default:
Expand All @@ -183,10 +185,10 @@ def fetch_role_credentials(user_name: str, profile_group: ProfileGroup) -> Resul

def _remove_unused_profiles(credentials_file, profile_group: ProfileGroup):
used_profiles = profile_group.list_profile_names()
used_profiles.append('session-token')

for profile in credentials_file.sections():
if profile not in used_profiles and not profile.startswith('access-key'):
if profile not in used_profiles and \
not profile.startswith('access-key') and \
not profile.startswith('session-token'):
credentials_file.remove_section(profile)
return credentials_file

Expand Down Expand Up @@ -215,7 +217,6 @@ def write_profile_config(profile_group: ProfileGroup, region: str):

def _remove_unused_configs(config_file: configparser, profile_group: ProfileGroup):
used_profiles = profile_group.list_profile_names()
used_profiles.append('access-key')

for config_name in config_file.sections():
profile = config_name.replace('profile ', '')
Expand All @@ -242,9 +243,9 @@ def get_access_key_list() -> list:
return access_key_list


def get_access_key_id():
def get_access_key_id(key_name: str) -> str:
credentials_file = _load_credentials_file()
return credentials_file.get('access-key', 'aws_access_key_id')
return credentials_file.get(key_name, 'aws_access_key_id')


def _add_profile_credentials(credentials_file: configparser, profile: str, secrets: dict) -> None:
Expand Down
9 changes: 5 additions & 4 deletions app/aws/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import boto3

from app.core.result import Result
from app.util import util

logger = logging.getLogger('logsmith')


def create_access_key(user_name) -> Result:
def create_access_key(user_name, key_name) -> Result:
result = Result()
session = boto3.Session(profile_name='session-token')
session = boto3.Session(profile_name=util.generate_session_name(key_name))
client = session.client('iam')

try:
Expand All @@ -34,9 +35,9 @@ def create_access_key(user_name) -> Result:
return result


def delete_iam_access_key(user_name, key_id):
def delete_iam_access_key(user_name, key_name, key_id):
result = Result()
session = boto3.Session(profile_name='session-token')
session = boto3.Session(profile_name=util.generate_session_name(key_name))
client = session.client('iam')

try:
Expand Down
6 changes: 3 additions & 3 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
def load_from_disk(self):
config = files.load_config()
self.mfa_shell_command = config.get('mfa_shell_command', None)
access_key = config.get('default_access_key', None)
access_key = config.get('default_access_key', _default_access_key)

accounts = files.load_accounts()
self.set_accounts(accounts, access_key)
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, name, group: dict, default_access_key: str):
self.default_access_key = default_access_key
self.access_key: str = group.get('access_key', None)
self.profiles: List[Profile] = []
self.type: str = group.get("type", "aws") # only aws (default) & gcp as values are allowed
self.type: str = group.get('type', 'aws') # only aws (default) & gcp as values are allowed

for profile in group.get('profiles', []):
self.profiles.append(Profile(self, profile))
Expand Down Expand Up @@ -132,7 +132,7 @@ def to_dict(self):
'region': self.region,
'profiles': [profile.to_dict() for profile in self.profiles],
}
if self.access_key != self.default_access_key:
if self.access_key and self.access_key != self.default_access_key:
result_dict['access_key'] = self.access_key
if self.type != "aws":
result_dict["type"] = self.type
Expand Down
18 changes: 11 additions & 7 deletions app/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def login(self, profile_group: ProfileGroup, mfa_callback: Callable) -> Result:
if not access_key_result.was_success:
return access_key_result

session_result = credentials.check_session()
session_result = credentials.check_session(access_key=access_key)
if session_result.was_error:
return session_result
if not session_result.was_success:
Expand Down Expand Up @@ -142,19 +142,23 @@ def rotate_access_key(self, key_name: str, mfa_callback: Callable) -> Result:
return access_key_result

logger.info('fetch session')
renew_session_result = self._renew_session(access_key=key_name, mfa_callback=mfa_callback)
if not renew_session_result.was_success:
return renew_session_result
session_result = credentials.check_session(access_key=key_name)
if session_result.was_error:
return session_result
if not session_result.was_success:
renew_session_result = self._renew_session(access_key=key_name, mfa_callback=mfa_callback)
if not renew_session_result.was_success:
return renew_session_result

logger.info('create key')
user = credentials.get_user_name(key_name)
create_access_key_result = iam.create_access_key(user)
create_access_key_result = iam.create_access_key(user, key_name)
if not create_access_key_result.was_success:
return create_access_key_result

logger.info('delete key')
previous_access_key_id = credentials.get_access_key_id()
delete_access_key_result = iam.delete_iam_access_key(user, previous_access_key_id)
previous_access_key_id = credentials.get_access_key_id(key_name)
delete_access_key_result = iam.delete_iam_access_key(user, key_name, previous_access_key_id)
if not delete_access_key_result.was_success:
return delete_access_key_result

Expand Down
28 changes: 15 additions & 13 deletions app/gui/access_key_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def __init__(self, parent=None):
self.key_id_input = QLineEdit(self)
self.key_id_input.setStyleSheet("color: black; background-color: white;")

self.access_key_text = QLabel("Key secret:", self)
self.access_key_input = QLineEdit(self)
self.access_key_input.setStyleSheet("color: black; background-color: white;")
self.access_key_input.setEchoMode(QLineEdit.EchoMode.Password)
self.key_secret_text = QLabel("Key secret:", self)
self.key_secret_input = QLineEdit(self)
self.key_secret_input.setStyleSheet("color: black; background-color: white;")
self.key_secret_input.setEchoMode(QLineEdit.EchoMode.Password)

self.ok_button = QPushButton("OK")
self.ok_button.clicked.connect(self.ok)
Expand All @@ -62,8 +62,8 @@ def __init__(self, parent=None):
vbox.addWidget(self.key_name_input)
vbox.addWidget(self.key_id_text)
vbox.addWidget(self.key_id_input)
vbox.addWidget(self.access_key_text)
vbox.addWidget(self.access_key_input)
vbox.addWidget(self.key_secret_text)
vbox.addWidget(self.key_secret_input)
vbox.addLayout(hbox)

self.setLayout(vbox)
Expand Down Expand Up @@ -96,23 +96,25 @@ def ok(self):

key_id = self.key_id_input.text()
key_id = key_id.strip()
access_key = self.access_key_input.text()
access_key = access_key.strip()
key_secret = self.key_secret_input.text()
key_secret = key_secret.strip()

if not key_name:
self.set_error_text('missing key name')
return
if not key_id:
self.set_error_text('missing key id')
return
if not access_key:
if not key_secret:
self.set_error_text('missing access key')
return
if key_name != '' and not key_name.startswith('access-key'):
self.set_error_text('new key names must start with \'access-key\'')
return
print(f'key_name={key_name}, key_id={key_id}, access_key={access_key}')
# self.gui.set_access_key(key_name=key_name, key_id=key_id, access_key=access_key)
if key_name != '' and key_name.startswith('session-token'):
self.set_error_text('new key names must not start with \'session-token\'')
return
self.gui.set_access_key(key_name=key_name, key_id=key_id, key_secret=key_secret)
self.hide()

def cancel(self):
Expand Down Expand Up @@ -141,8 +143,8 @@ def show_dialog(self, access_key_list: List[str]):
self.key_name_input.repaint()
self.key_id_input.setText('')
self.key_id_input.repaint()
self.access_key_input.setText('')
self.access_key_input.repaint()
self.key_secret_input.setText('')
self.key_secret_input.repaint()
self.set_error_text('')

self.existing_access_key_list = access_key_list
Expand Down
4 changes: 3 additions & 1 deletion app/gui/mfa_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def cancel(self):
self.hide()

def get_value(self):
return self.input_field.text()
value = self.input_field.text()
value = value.strip()
return value

def closeEvent(self, event):
self.pressed_cancel = True
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions app/util/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def generate_session_name(key_name: str) -> str:
return f'session-token-{key_name}'
Loading

0 comments on commit db89ea4

Please sign in to comment.