Skip to content

Commit

Permalink
Enable support for STS regional endpoints when configured to assume a…
Browse files Browse the repository at this point in the history
… role (#37)

In AWS VPCs without internet access which have the STS VPC endpoint created the region needs to be explicitly set, along with the environment variable `AWS_STS_REGIONAL_ENDPOINTS=regional` otherwise the AWS Java SDK will use the global STS endpoint.
  • Loading branch information
dvulpe authored Aug 10, 2021
1 parent 1b6a469 commit b318611
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.amazonaws.auth.SystemPropertiesCredentialsProvider;
import com.amazonaws.auth.WebIdentityTokenCredentialsProvider;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -49,6 +51,7 @@ public class MSKCredentialProvider implements AWSCredentialsProvider, AutoClosea
private static final String AWS_PROFILE_NAME_KEY = "awsProfileName";
private static final String AWS_ROLE_ARN_KEY = "awsRoleArn";
private static final String AWS_ROLE_SESSION_KEY = "awsRoleSessionName";
private static final String AWS_STS_REGION = "awsStsRegion";

private final List<AutoCloseable> closeableProviders;
private final AWSCredentialsProvider compositeDelegate;
Expand Down Expand Up @@ -136,13 +139,20 @@ private Optional<STSAssumeRoleSessionCredentialsProvider> getStsRoleProvider() {
}
String sessionName = Optional.ofNullable((String) optionsMap.get(AWS_ROLE_SESSION_KEY))
.orElse("aws-msk-iam-auth");
return createSTSRoleCredentialProvider((String) p, sessionName);
String stsRegion = Optional.ofNullable((String)optionsMap.get(AWS_STS_REGION))
.orElse("aws-global");
return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion);
});
}

STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName) {
return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName).build();
String sessionName, String stsRegion) {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.build();
return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
.withStsClient(stsClient)
.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public void testAwsRoleArn() {

MSKCredentialProvider.ProviderBuilder providerBuilder = new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName) {
String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals("aws-msk-iam-auth", sessionName);
return mockStsRoleProvider;
Expand Down Expand Up @@ -183,7 +183,7 @@ public void testAwsRoleArnAndSessionName() {

MSKCredentialProvider.ProviderBuilder providerBuilder = new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName) {
String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
return mockStsRoleProvider;
Expand All @@ -202,6 +202,40 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
Mockito.verify(mockStsRoleProvider, times(1)).close();
}

@Test
public void testAwsRoleArnSessionNameAndStsRegion() {
STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider = Mockito
.mock(STSAssumeRoleSessionCredentialsProvider.class);
Mockito.when(mockStsRoleProvider.getCredentials())
.thenReturn(new BasicSessionCredentials(ACCESS_KEY_VALUE, SECRET_KEY_VALUE, SESSION_TOKEN));

Map<String, String> optionsMap = new HashMap<>();
optionsMap.put(AWS_ROLE_ARN, TEST_ROLE_ARN);
optionsMap.put("awsRoleSessionName", TEST_ROLE_SESSION_NAME);
optionsMap.put("awsStsRegion", "eu-west-1");

MSKCredentialProvider.ProviderBuilder providerBuilder = new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
return mockStsRoleProvider;
}
};
MSKCredentialProvider provider = new MSKCredentialProvider(providerBuilder);

AWSCredentials credentials = provider.getCredentials();
assertTrue(credentials instanceof BasicSessionCredentials);
BasicSessionCredentials sessionCredentials = (BasicSessionCredentials) credentials;
assertEquals(ACCESS_KEY_VALUE, sessionCredentials.getAWSAccessKeyId());
assertEquals(SECRET_KEY_VALUE, sessionCredentials.getAWSSecretKey());
assertEquals(SESSION_TOKEN, sessionCredentials.getSessionToken());

provider.close();
Mockito.verify(mockStsRoleProvider, times(1)).close();
}

@Test
public void testProfileNameAndRoleArn() {
ProfileFile profileFile = getProfileFile();
Expand All @@ -219,7 +253,7 @@ EnhancedProfileCredentialsProvider createEnhancedProfileCredentialsProvider(Stri
return new EnhancedProfileCredentialsProvider(profileFile, TEST_PROFILE_NAME);
}
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName) {
String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals("aws-msk-iam-auth", sessionName);
return mockStsRoleProvider;
Expand Down

0 comments on commit b318611

Please sign in to comment.