Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Handle ExpiredTokenException correctly and restructure flow paths #150

Merged
merged 8 commits into from
Dec 4, 2019
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,15 @@ mozilla-aws-cli-yoyodyne/
├── mozilla_aws_cli_config
│   └── __init__.py
└── setup.py
```
```

## Other projects in this space

* https://github.com/aidan-/aws-cli-federator
* https://github.com/Nike-Inc/gimme-aws-creds
* https://github.com/sportradar/aws-azure-login
* https://github.com/oktadeveloper/okta-aws-cli-assume-role
* https://github.com/jmhale/okta-awscli
* https://github.com/prolane/samltoawsstskeys
* https://github.com/physera/onelogin-aws-cli
* https://github.com/kxseven/axe/blob/master/bin/subcommands/axe-token-krb5formauth-create
10 changes: 6 additions & 4 deletions mozilla_aws_cli/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def timestamp(dt):
return (dt - epoch).total_seconds()

# TODO: move to config
CLOCK_SKEW_ALLOWANCE = 300 # 5 minutes
GROUP_ROLE_MAP_CACHE_TIME = 3600 # 1 hour
CLOCK_SKEW_ALLOWANCE = 300 # 5 minutes
UNDOCUMENTED_AWS_LIMIT_MAX_ID_TOKEN_AGE = 86400 # 1 day
GROUP_ROLE_MAP_CACHE_TIME = 3600 # 1 hour
CREDENTIALS_TO_AWS_MAP = {
"AccessKeyId": "aws_access_key_id",
"SecretAccessKey": "aws_secret_access_key",
Expand Down Expand Up @@ -289,7 +290,8 @@ def read_id_token(issuer, client_id, key=None):
except jose.exceptions.JOSEError:
return None

if id_token_dict.get('exp') - time.time() > CLOCK_SKEW_ALLOWANCE:
if (id_token_dict.get('exp') - time.time() > CLOCK_SKEW_ALLOWANCE
and time.time() - id_token_dict.get('iat') < UNDOCUMENTED_AWS_LIMIT_MAX_ID_TOKEN_AGE):
logger.debug("Successfully read cached id token at: {}".format(path))
return token
else:
Expand Down Expand Up @@ -347,7 +349,7 @@ def read_sts_credentials(role_arn):
time.time(),
timestamp(exp) - time.time()))
if timestamp(exp) - time.time() > CLOCK_SKEW_ALLOWANCE:
logger.debug("Using STS credentials at: {}, expiring in: {}".format(path, timestamp(exp) - time.time()))
logger.debug("Using STS credentials at: {} expiring in: {}".format(path, timestamp(exp) - time.time()))
return sts
else:
logger.debug(
Expand Down
54 changes: 54 additions & 0 deletions mozilla_aws_cli/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

from .utils import exit_sigint

try:
# P3
from urllib.parse import urlencode
except ImportError:
# P2 Compat
from urllib import urlencode

# These ports must be configured in the IdP's allowed callback URL list
# TODO: Move this to the CLI / config section
Expand All @@ -24,6 +30,7 @@
"last_state_check": None,
"role_map": {},
}
STSWarning = type('STSWarning', (Warning,), dict())


def get_available_port():
Expand Down Expand Up @@ -63,6 +70,8 @@ def catch_all(filename):
@app.route("/api/roles", methods=["POST"])
def set_role():
login.role_arn = request.json.get("arn")
logger.debug('IAM Role ARN selected from role picker : {}'.format(
login.role_arn))

return jsonify({
"result": "set_role_arn",
Expand All @@ -73,6 +82,8 @@ def set_role():
@app.route("/api/roles", methods=["GET"])
def get_roles():
roles = {}
if login.role_map is None and login.token is not None:
login.get_role_map()
for arn in login.role_map["roles"]:
account_id = arn.split(":")[4]
alias = login.role_map.get("aliases", {}).get(account_id, [account_id])[0]
Expand Down Expand Up @@ -101,6 +112,11 @@ def get_roles():

@app.route("/api/state")
def get_state():
logger.debug('Call received to /api/state with id of {}. Returning state {} and web_state {}'.format(
request.args.get("id"),
login.state,
login.web_state
))
if request.args.get("id") != login.id:
return jsonify({
"result": "invalid_id",
Expand Down Expand Up @@ -144,7 +160,45 @@ def handle_oidc_redirect_callback():
})

# callback into the login.callback() function in login.py
logger.debug("redirect_callback : request is {}".format(request.json))
login.get_id_token(**request.json)
login.validate_id_token()
logger.debug("id_token_dict : {}".format(login.id_token_dict))
if login.id_token_dict is None:
logger.debug('Validation of token failed : {}'.format(login.token))
# TODO : What should we do in this case? How should the UI handle this?
return jsonify({
"result": "id_token_validation_failed",
"status_code": 400,
})

login.get_role_map()
try:
login.exchange_token_for_credentials()
except STSWarning as e:
if e.args[1] == 'ExpiredTokenException':
gene1wood marked this conversation as resolved.
Show resolved Hide resolved
logger.debug('AWS says that the ID token is expired : {}'.format(e[2]))
login.token = None
url_parameters = {
"scope": login.oidc_scope,
"response_type": "code",
"redirect_uri": login.redirect_uri,
"client_id": login.client_id,
"code_challenge": login.code_challenge,
"code_challenge_method": "S256",
"state": login.oidc_state,
}
url = "{}?{}".format(login.authorization_endpoint,
urlencode(url_parameters))
logger.debug('Setting state to restart_auth and idpUrl to {}'.format(url))
login.state = "restart_auth"
login.web_state["idpUrl"] = url
return jsonify({
"result": "restart_auth",
"status_code": 200,
})

login.print_output()

# Send the signal to kill the application
return jsonify({
Expand Down
142 changes: 102 additions & 40 deletions mozilla_aws_cli/login.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import absolute_import
from jose import jwt
from jose import jwt, JWTError
import json
import logging
import os
Expand Down Expand Up @@ -43,6 +43,8 @@
"SecretAccessKey": "AWS_SECRET_ACCESS_KEY",
"SessionToken": "AWS_SESSION_TOKEN",
}
STSWarning = type('STSWarning', (Warning,), dict())


class Login:
# Maybe this would be better to unroll from config?
Expand Down Expand Up @@ -94,6 +96,10 @@ def __init__(
# Whether or not we have opened a browser tab
self.opened_tab = False

# The ID Token returned from the identity provider
self.token = None
self.id_token_dict = None

# Whether we've gotten credentials via STS
self.credentials = None

Expand Down Expand Up @@ -145,11 +151,54 @@ def login(self):
self.state = "starting"
self.redirect_uri = "http://localhost:{}/redirect_uri".format(port)

token = read_id_token(self.openid_configuration.get("issuer"),
self.token = read_id_token(self.openid_configuration.get("issuer"),
self.client_id,
self.jwks)

if token is None or self.role_arn is None:
if self.token is not None and self.role_arn is not None:
logger.debug(
"We have a cached ID token and the role was passed as an "
"argument")
self.validate_id_token()
if self.id_token_dict is None:
# If validation failed, set token back to None
self.token = None
else:
# The ID Token verifies
self.get_role_map()
try:
self.exchange_token_for_credentials()
except STSWarning as e:
if e.args[1] == 'ExpiredTokenException':
logger.debug('Looks like that cached ID token is expired, setting self.token to None')
self.token = None
else:
raise
if self.token is not None and self.role_arn is not None:
self.print_output()

if self.token is not None and self.role_arn is None:
logger.debug(
"We have a cached ID token but the role passed on the command "
"line wasn't valid. Show the role picker")
self.state = 'redirecting'
url_parameters = {
"state": self.oidc_state,
"code": "foo"
}
webbrowser.get().open_new_tab(
"{}?{}".format(
self.redirect_uri,
urlencode(url_parameters)
))
self.opened_tab = True
logger.debug("About to start listener running on port {}".format(port))
listen(self)
elif self.token is None or self.role_arn is None:
logger.debug(
"Either the cached ID token was invalid or missing and we need "
"to get a new one, or the user passed no role_arn on the "
"command line so we need to spawn the role picker")
self.state = "redirecting"
url_parameters = {
"scope": self.oidc_scope,
Expand Down Expand Up @@ -187,10 +236,9 @@ def login(self):
# start up the listener, figuring out which port it ran on
logger.debug("About to start listener running on port {}".format(port))
listen(self)
else:
self.get_id_token(None, None, token=token)

def get_id_token(self, code, state, token=None, **kwargs):

def get_id_token(self, code=None, state=None, token=None, **kwargs):
"""
:param code: code GET paramater as sent by IdP
:param state: state GET parameter as sent by IdP
Expand All @@ -206,7 +254,7 @@ def get_id_token(self, code, state, token=None, **kwargs):
kwargs.get('error_description'))
))

if token is None: # Callback from web listener
if self.token is None and token is None: # Callback from web listener
self.state = "getting_id_token"

if code is None:
Expand All @@ -233,28 +281,37 @@ def get_id_token(self, code, state, token=None, **kwargs):
logger.debug(
"POSTing to token endpoint to exchange code for id_token: "
"{}".format(body))
token = requests.post(
self.token = requests.post(
self.token_endpoint, headers=headers, json=body).json()

# attempt to cache the id token
write_id_token(self.openid_configuration.get("issuer"),
self.client_id,
token)
self.token)
elif self.token is None:
self.token = token

def validate_id_token(self):
# decode the token for logging purposes
logger.debug("Validating response from endpoint: {}".format(token))
id_token_dict = jwt.decode(
token=token["id_token"],
key=self.jwks,
audience=self.client_id)
logger.debug("ID token dict : {}".format(id_token_dict))

logger.debug("Validating response from endpoint: {}".format(self.token))
try:
self.id_token_dict = jwt.decode(
token=self.token["id_token"],
key=self.jwks,
audience=self.client_id)
except JWTError as e:
logger.error('ID Token failed validation : {}'.format(e))
return None
logger.debug("ID token dict : {}".format(self.id_token_dict))


def get_role_map(self):
# get the role map, either from cache or from the endpoint
self.state = "getting_role_map"

self.role_map = get_roles_and_aliases(
endpoint=self.idtoken_for_roles_url,
token=token["id_token"],
token=self.token["id_token"],
key=self.jwks
)

Expand All @@ -264,6 +321,8 @@ def get_id_token(self, code, state, token=None, **kwargs):
logger.debug(
'Roles and aliases are {}'.format(self.role_map))


def exchange_token_for_credentials(self):
# TODO: Consider whether this needs to loop forever
while self.credentials is None:
# If we don't have a role ARN on the command line, we need to show
Expand All @@ -277,36 +336,39 @@ def get_id_token(self, code, state, token=None, **kwargs):

# Use the cached credentials or retrieve them from STS
self.state = "getting_sts_credentials"
self.credentials = sts_conn.get_credentials(
token["id_token"],
id_token_dict,
role_arn=self.role_arn
)

if self.credentials is None:
token_vals = ([
id_token_dict[x] for x in id_token_dict
if x in ['amr', 'iss', 'aud']]
if jwt else ['unknown'] * 3)
logger.error(
'AWS STS Call failed when attempting to assume role {} '
'with amr {} iss {} and aud {}'.format(
self.role_arn, *token_vals))
logger.error(
'Unable to assume role {}. Please select a different '
'role.'.format(self.role_arn))

if len(self.role_map.get("roles", [])) <= 1:
self.exit("Sorry, no valid roles available. Shutting down.")
else:
try:
self.credentials = sts_conn.get_credentials(
self.token["id_token"],
self.id_token_dict,
role_arn=self.role_arn
)
except STSWarning as e:
if e.args[1] == 'AccessDenied':
# Not authorized to perform sts:AssumeRoleWithWebIdentity
# Either that role doesn't exist or it exists but doesn't
# permit the user because of the conditions
logger.error(
'Unable to assume role {}. Please select a different '
'role.'.format(self.role_arn))
self.role_map.get("roles", []).remove(self.role_arn)
self.role_arn = None
if len(self.role_map.get("roles", [])) <= 1:
self.exit(
"Sorry, no valid roles available. Shutting down.")
elif e.args[1] == 'ExpiredTokenException':
# The ID token is expired
raise
else:
raise

if self.batch:
break

logger.debug(self.credentials)
logger.debug("ID token : {}".format(token["id_token"]))
logger.debug("ID token : {}".format(self.token["id_token"]))


def print_output(self):
# TODO: Create a global config object?
if self.credentials is not None:
profile_name = role_arn_to_profile_name(
Expand Down
5 changes: 4 additions & 1 deletion mozilla_aws_cli/static/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,17 @@ const pollState = setInterval(async () => {
// show the roles
const roles = await response.json();
showRoles(roles);
} else if (remoteState.state === "restart_auth") {
setMessage("Redirecting to identity provider...");
gene1wood marked this conversation as resolved.
Show resolved Hide resolved
window.location.replace(remoteState.value.idpUrl);
gene1wood marked this conversation as resolved.
Show resolved Hide resolved
} else if (remoteState.state === "aws_federate") {
setMessage("Redirecting to AWS...");
await shutdown();

// insert the image to log out of AWS and then redirect there once
// it has loaded
$("#aws-federation-logout").on("load error", () => {
document.location = remoteState.value.awsFederationUrl;
window.location.replace = remoteState.value.awsFederationUrl;
}).attr("src", "https://signin.aws.amazon.com/oauth?Action=logout");
} else if (remoteState.state === "invalid_id") {
setMessage("Another federation session has been detected. Shutting down.");
Expand Down
Loading