Skip to content

Commit

Permalink
fix mfa dialog issue where the mfa dialog is opened from another thre…
Browse files Browse the repository at this point in the history
…ad than the main one. It will first try to log in without a token and only ask if that fails.
  • Loading branch information
redvox committed Jan 17, 2025
1 parent e65d347 commit 301402b
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 211 deletions.
22 changes: 18 additions & 4 deletions app/aws/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,26 @@ def check_access_key(access_key: str) -> Result:
return result


def check_session(access_key: str) -> Result:
def has_session(session_profile_name: str) -> Result:
logger.info(f'has session {session_profile_name}')
result = Result()
credentials_file = _load_credentials_file()

if credentials_file.has_section(session_profile_name):
result.set_success()
else:
logger.warning('no session found')
return result


def check_session(access_key: str) -> Result:
session_token_profile_name = util.generate_session_name(access_key)
if not credentials_file.has_section(session_token_profile_name):
logger.warning('no session token found')
return result
session_result = has_session(session_profile_name=session_token_profile_name)
if not session_result.was_success:
return session_result

logger.info(f'check session {session_token_profile_name}')
result = Result()
try:
client = _get_client(session_token_profile_name, 'sts', timeout=2, retries=2)
client.get_caller_identity()
Expand All @@ -121,6 +133,7 @@ def check_session(access_key: str) -> Result:
logger.error(error_text, exc_info=True)
return result

logger.info('check session - valid')
result.set_success()
return result

Expand Down Expand Up @@ -265,6 +278,7 @@ def _add_profile_config(option_file: configparser, profile: str, region: str) ->


def get_user_name(access_key) -> str:
logger.info('fetch user name')
client = _get_client(access_key, 'sts')
identity = client.get_caller_identity()
return _extract_user_from_identity(identity)
Expand Down
33 changes: 19 additions & 14 deletions app/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ def login(self, profile_group_name, region, oneshot=False):
self._print_regions()
sys.exit(1)

login_result = self.core.login(profile_group=profile_group,
mfa_callback=self.ask_for_mfa_token)
login_result = self.core.login(profile_group=profile_group, mfa_token=None)
self._check_and_signal_error(login_result)
if not login_result.was_success:
mfa_token = self.ask_for_mfa_token()
login_with_mfa_result = self.core.login(profile_group=profile_group, mfa_token=mfa_token)
self._check_and_signal_error(login_with_mfa_result)

if region:
region_result = self.core.set_region(region=region)
Expand All @@ -62,9 +65,13 @@ def logout(self):
self._check_and_signal_error(logout_result)

def rotate_access_key(self, key_name):
rotate_result = self.core.rotate_access_key(key_name=key_name, mfa_callback=self.ask_for_mfa_token)
if not self._check_and_signal_error(rotate_result):
return
rotate_result = self.core.rotate_access_key(access_key=key_name, mfa_token=None)
self._check_and_signal_error(rotate_result)

if not rotate_result.was_success:
mfa_token = self.ask_for_mfa_token()
rotate_with_mfa_result = self.core.rotate_access_key(access_key=key_name, mfa_token=mfa_token)
self._check_and_signal_error(rotate_with_mfa_result)
self._info('key was successfully rotated')

def set_access_key(self):
Expand Down Expand Up @@ -110,13 +117,11 @@ def _info(s):
def _error(s):
print(f'{CR}{s}{CC}')

@staticmethod
def _warning(s):
print(f'{CY}{s}{CC}')

@staticmethod
def _warning(s):
print(f'{CY}{s}{CC}')


@staticmethod
def _print_regions():
for region in region_list:
print(region)
@staticmethod
def _print_regions():
for region in region_list:
print(region)
54 changes: 33 additions & 21 deletions app/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self):
self.empty_profile_group: ProfileGroup = ProfileGroup('logout', {}, '')
self.region_override: str = None

def login(self, profile_group: ProfileGroup, mfa_token: str) -> Result:
def login(self, profile_group: ProfileGroup, mfa_token: Optional[str]) -> Result:
result = Result()
logger.info(f'start login {profile_group.name} with token {mfa_token}')
self.active_profile_group = profile_group
Expand All @@ -29,13 +29,9 @@ def login(self, profile_group: ProfileGroup, mfa_token: str) -> Result:
if not access_key_result.was_success:
return access_key_result

session_check_result = credentials.check_session(access_key=access_key)
if not session_check_result.was_success and not mfa_token:
return session_check_result
if not session_check_result.was_success and mfa_token:
session_fetch_result = credentials.fetch_session_token(access_key=access_key, mfa_token=mfa_token)
if not session_fetch_result.was_success:
return session_fetch_result
session_result = self._ensure_session(access_key=access_key, mfa_token=mfa_token)
if not session_result.was_success:
return session_result

user_name = credentials.get_user_name(access_key=access_key)
role_result = credentials.fetch_role_credentials(user_name, profile_group)
Expand Down Expand Up @@ -104,6 +100,7 @@ def login_gcp(self, profile_group: ProfileGroup) -> Result:
def logout(self):
result = Result()
logger.info(f'start logout')
self.active_profile_group = None

role_result = credentials.fetch_role_credentials(user_name='none', profile_group=self.empty_profile_group)
if not role_result.was_success:
Expand All @@ -117,7 +114,7 @@ def logout(self):
result.set_success()
return result

def set_region(self, region: str) -> Result:
def set_region(self, region: Optional[str]) -> Result:
self.region_override = region
if not self.active_profile_group:
result = Result()
Expand All @@ -138,7 +135,7 @@ def get_profile_group_list(self):
def get_active_profile_color(self):
return self.active_profile_group.color

def rotate_access_key(self, access_key: str, mfa_token: str) -> Result:
def rotate_access_key(self, access_key: str, mfa_token: Optional[str]) -> Result:
result = Result()
logger.info(f'initiate key rotation for {access_key} with token {mfa_token}')

Expand All @@ -149,13 +146,9 @@ def rotate_access_key(self, access_key: str, mfa_token: str) -> Result:
if not access_key_result.was_success:
return access_key_result

session_check_result = credentials.check_session(access_key=access_key)
if not session_check_result.was_success and not mfa_token:
return session_check_result
if not session_check_result.was_success and mfa_token:
session_fetch_result = credentials.fetch_session_token(access_key=access_key, mfa_token=mfa_token)
if not session_fetch_result.was_success:
return session_fetch_result
session_result = self._ensure_session(access_key=access_key, mfa_token=mfa_token)
if not session_result.was_success:
return session_result

logger.info('create key')
user = credentials.get_user_name(access_key)
Expand Down Expand Up @@ -195,22 +188,26 @@ def edit_config(self, new_config: Config) -> Result:
def set_service_role(self, profile_name: str, role_name: str) -> Result:
result = Result()
logger.info('set service role')
self.config.save_selected_service_role(group_name=self.active_profile_group.name, profile_name=profile_name,
self.config.save_selected_service_role(group_name=self.active_profile_group.name,
profile_name=profile_name,
role_name=role_name)
self.active_profile_group.set_service_role_profile(source_profile_name=profile_name, role_name=role_name)
self.active_profile_group.set_service_role_profile(source_profile_name=profile_name,
role_name=role_name)

result.set_success()
return result

def set_available_service_roles(self, profile, role_list: List[str]):
result = Result()
logger.info('set available service roles')
self.config.save_available_service_roles(group_name=self.active_profile_group.name, profile_name=profile,
self.config.save_available_service_roles(group_name=self.active_profile_group.name,
profile_name=profile,
role_list=role_list)
result.set_success()
return result

def run_script(self, profile_group: ProfileGroup):
@staticmethod
def run_script(profile_group: ProfileGroup):
result = Result()
if not profile_group or not profile_group.script:
result.set_success()
Expand All @@ -230,6 +227,21 @@ def run_script(self, profile_group: ProfileGroup):
result.set_success()
return result

@staticmethod
def _ensure_session(access_key: str, mfa_token: Optional[str]) -> Result:
result = Result()

session_check_result = credentials.check_session(access_key=access_key)
if not session_check_result.was_success and not mfa_token:
return session_check_result
if not session_check_result.was_success and mfa_token:
session_fetch_result = credentials.fetch_session_token(access_key=access_key, mfa_token=mfa_token)
if not session_fetch_result.was_success:
return session_fetch_result

result.set_success()
return result

@staticmethod
def _handle_support_files(profile_group: ProfileGroup):
logger.info('handle support files')
Expand Down
46 changes: 31 additions & 15 deletions app/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,12 @@ def __init__(self, app):
self.tray_icon.show()

def login(self, profile_group: ProfileGroup, mfa_token: Optional[str] = None):
# check access key
# check session
# if not: ask for mfa token
# if not: get mfa from shell
# proceed with login

self._to_busy_state()
self.task = BackgroundTask(
func=self.core.login,
func_kwargs={'profile_group': profile_group, 'mfa_token': mfa_token},
on_success=self._on_login_success,
on_failure=partial(self._on_login_failure, profile_group=profile_group),
# on_failure=self._on_login_failure,
on_error=self._on_error
)
self.task.start()
Expand All @@ -84,7 +77,7 @@ def _on_login_success(self):
self._to_login_state()

def _on_login_failure(self, profile_group: ProfileGroup):
logger.info(f'login failure')
logger.info('login failure')

mfa_token = mfa.fetch_mfa_token_from_shell(self.core.config.mfa_shell_command)
if not mfa_token:
Expand Down Expand Up @@ -115,6 +108,7 @@ def _on_login_gcp_success(self):
self._to_login_state()

def logout(self):
self._to_busy_state()
self.task = BackgroundTask(
func=self.core.logout,
func_kwargs={},
Expand All @@ -130,6 +124,7 @@ def _on_logout_success(self):
self.tray_icon.reset_copy_menus()

def set_region(self, region: str) -> None:
self._to_busy_state()
self.task = BackgroundTask(
func=self.core.set_region,
func_kwargs={'region': region},
Expand All @@ -144,6 +139,7 @@ def _on_set_region_success(self) -> None:
if not region:
region = 'not logged in'
self.tray_icon.update_region_text(region)
self._to_login_state()

def edit_config(self, config: Config):
self._to_busy_state()
Expand All @@ -161,6 +157,7 @@ def _on_edit_config_success(self):
self._to_reset_state()

def set_access_key(self, key_name, key_id, key_secret):
self._to_busy_state()
self.task = BackgroundTask(
func=self.core.edit_config,
func_kwargs={'key_name': key_name, 'key_id': key_id, 'key_secret': key_secret},
Expand All @@ -173,21 +170,37 @@ def set_access_key(self, key_name, key_id, key_secret):
def _on_set_access_key_success(self):
logger.info('access key set')
self._signal('Success', 'access key set')
self._to_login_state()

def rotate_access_key(self, key_name: str):
def rotate_access_key(self, key_name: str, mfa_token: Optional[str] = None):
self._to_busy_state()
logger.info('initiate key rotation')
self.task = BackgroundTask(
func=self.core.rotate_access_key,
func_kwargs={'key_name': key_name, 'mfa_callback': self.show_mfa_token_fetch_dialog},
func_kwargs={'access_key': key_name, 'mfa_token': mfa_token},
on_success=self._on_rotate_access_key_success,
on_failure=self._on_error,
on_failure=partial(self._on_rotate_access_key_failure, key_name=key_name),
on_error=self._on_error
)
self.task.start()

def _on_rotate_access_key_success(self):
logger.info('key was rotated')
self._signal('Success', 'key was rotated')
self._to_login_state()

def _on_rotate_access_key_failure(self, key_name: str):
logger.info('rotation failure')

mfa_token = mfa.fetch_mfa_token_from_shell(self.core.config.mfa_shell_command)
if not mfa_token:
mfa_token = self.show_mfa_token_fetch_dialog()
if not mfa_token:
logger.warning('no mfa token provided')
self._to_error_state()
return

self.rotate_access_key(key_name=key_name, mfa_token=mfa_token)

def set_service_role(self, profile: str, role: str):
self._to_busy_state()
Expand Down Expand Up @@ -244,10 +257,13 @@ def show_logs(self):
self.log_dialog.show_dialog(logs_as_text)

def _to_login_state(self):
style = ICON_STYLE_FULL if self.core.active_profile_group.type == "aws" else ICON_STYLE_GCP
self.tray_icon.setIcon(self.assets.get_icon(style=style, color_code=self.core.get_active_profile_color()))
self.tray_icon.disable_actions(False)
self.tray_icon.update_last_login(self.get_timestamp())
if self.core.active_profile_group:
style = ICON_STYLE_FULL if self.core.active_profile_group.type == "aws" else ICON_STYLE_GCP
self.tray_icon.setIcon(self.assets.get_icon(style=style, color_code=self.core.get_active_profile_color()))
self.tray_icon.disable_actions(False)
self.tray_icon.update_last_login(self.get_timestamp())
else:
self._to_reset_state()

def _to_busy_state(self):
self.tray_icon.setIcon(self.assets.get_icon(ICON_STYLE_BUSY))
Expand Down
Loading

0 comments on commit 301402b

Please sign in to comment.