diff --git a/service_configuration_lib/spark_config.py b/service_configuration_lib/spark_config.py index 5a27769..746cdfc 100644 --- a/service_configuration_lib/spark_config.py +++ b/service_configuration_lib/spark_config.py @@ -114,6 +114,7 @@ def get_aws_credentials( assume_aws_role_arn: Optional[str] = None, session_duration: int = 3600, assume_role_user_creds_file: str = '/nail/etc/spark_role_assumer/spark_role_assumer.yaml', + use_web_identity=False, ) -> Tuple[Optional[str], Optional[str], Optional[str]]: """load aws creds using different method/file""" if no_aws_credentials: @@ -132,6 +133,26 @@ def get_aws_credentials( log.warning( 'Tried to assume role with web identity but something went wrong ', ) + elif use_web_identity: + token_path = os.environ.get('AWS_WEB_IDENTITY_TOKEN_FILE') + role_arn = os.environ.get('AWS_ROLE_ARN') + if not token_path or not role_arn: + raise Exception('Expected AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN to be set.') + with open(token_path) as token_file: + token = token_file.read() + sts_client = boto3.client('sts') + timestamp = int(time.time()) + session = sts_client.assume_role_with_web_identity( + RoleArn=role_arn, + RoleSessionName=f'{service}-session-{timestamp}', + WebIdentityToken=token, + DurationSeconds=session_duration, + ) + return ( + session['Credentials']['AccessKeyId'], + session['Credentials']['SecretAccessKey'], + session['Credentials']['SessionToken'], + ) elif service != DEFAULT_SPARK_SERVICE: service_credentials_path = os.path.join(AWS_CREDENTIALS_DIR, f'{service}.yaml') if os.path.exists(service_credentials_path): @@ -165,10 +186,11 @@ def assume_aws_role( creds_dict = yaml.load(creds_file.read(), Loader=yaml.SafeLoader) access_key = creds_dict['AccessKeyId'] secret_key = creds_dict['SecretAccessKey'] - except PermissionError: + except (PermissionError, FileNotFoundError): log.warning( - 'If using spark-run as a human, you must manually export ' - 'AWS session credentials first. See y/spark-run-aws-role', + f'Tried to use {key_file} but it is not available. --assume-aws-role ' + 'can only be used with ssh executor. If using spark-run as a human, ' + 'you must manually export AWS session credentials first. See y/spark-run-aws-role', ) raise timestamp = int(time.time())