Skip to content

Commit

Permalink
Merge pull request #137 from Yelp/spark_run_pod_identity
Browse files Browse the repository at this point in the history
Allow get_aws_credentials to assume_role_with_web_identity
  • Loading branch information
nurdann authored Apr 19, 2024
2 parents 9512b01 + 4c27ce2 commit 4f859db
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions service_configuration_lib/spark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def get_aws_credentials(
assume_aws_role_arn: Optional[str] = None,
session_duration: int = 3600,
assume_role_user_creds_file: str = '/nail/etc/spark_role_assumer/spark_role_assumer.yaml',
use_web_identity=False,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""load aws creds using different method/file"""
if aws_credentials_yaml:
Expand All @@ -127,6 +128,26 @@ def get_aws_credentials(
log.warning(
'Tried to assume role with web identity but something went wrong ',
)
elif use_web_identity:
token_path = os.environ.get('AWS_WEB_IDENTITY_TOKEN_FILE')
role_arn = os.environ.get('AWS_ROLE_ARN')
if not token_path or not role_arn:
raise Exception('Expected AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN to be set.')
with open(token_path) as token_file:
token = token_file.read()
sts_client = boto3.client('sts')
timestamp = int(time.time())
session = sts_client.assume_role_with_web_identity(
RoleArn=role_arn,
RoleSessionName=f'{service}-session-{timestamp}',
WebIdentityToken=token,
DurationSeconds=session_duration,
)
return (
session['Credentials']['AccessKeyId'],
session['Credentials']['SecretAccessKey'],
session['Credentials']['SessionToken'],
)
elif service != DEFAULT_SPARK_SERVICE:
service_credentials_path = os.path.join(AWS_CREDENTIALS_DIR, f'{service}.yaml')
if os.path.exists(service_credentials_path):
Expand Down Expand Up @@ -160,10 +181,11 @@ def assume_aws_role(
creds_dict = yaml.load(creds_file.read(), Loader=yaml.SafeLoader)
access_key = creds_dict['AccessKeyId']
secret_key = creds_dict['SecretAccessKey']
except PermissionError:
except (PermissionError, FileNotFoundError):
log.warning(
'If using spark-run as a human, you must manually export '
'AWS session credentials first. See y/spark-run-aws-role',
f'Tried to use {key_file} but it is not available. --assume-aws-role '
'can only be used with ssh executor. If using spark-run as a human, '
'you must manually export AWS session credentials first. See y/spark-run-aws-role',
)
raise
timestamp = int(time.time())
Expand Down

0 comments on commit 4f859db

Please sign in to comment.