Skip to content

Commit

Permalink
Refactor AWSSigV4 auth to support different AWSCredentialProviders
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
  • Loading branch information
vamsimanohar committed Mar 3, 2023
1 parent 8f6793b commit 93a82a8
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 18 deletions.
2 changes: 2 additions & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ configurations.all {
resolutionStrategy.force "com.squareup.okhttp3:okhttp:4.9.3"
resolutionStrategy.force "joda-time:joda-time:2.10.12"
resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36"
resolutionStrategy.force "org.apache.httpcomponents:httpcore:4.4.15"
resolutionStrategy.force "org.apache.httpcomponents:httpclient:4.5.13"
}
compileJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
Expand Down
2 changes: 2 additions & 0 deletions prometheus/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ dependencies {
implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}"
implementation group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3'
implementation 'com.github.babbel:okhttp-aws-signer:1.0.2'
implementation group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.1'
implementation group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.1'
implementation group: 'org.json', name: 'json', version: '20180813'

testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

package org.opensearch.sql.prometheus.authinterceptors;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.babbel.mobile.android.commons.okhttpawssigner.OkHttpAwsV4Signer;
import java.io.IOException;
import java.time.ZoneId;
Expand All @@ -16,29 +19,29 @@
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class AwsSigningInterceptor implements Interceptor {

private OkHttpAwsV4Signer okHttpAwsV4Signer;

private String accessKey;
private AWSCredentialsProvider awsCredentialsProvider;

private String secretKey;
private static final Logger LOG = LogManager.getLogger();

/**
* AwsSigningInterceptor which intercepts http requests
* and adds required headers for sigv4 authentication.
*
* @param accessKey accessKey.
* @param secretKey secretKey.
* @param awsCredentialsProvider awsCredentialsProvider.
* @param region region.
* @param serviceName serviceName.
*/
public AwsSigningInterceptor(@NonNull String accessKey, @NonNull String secretKey,
public AwsSigningInterceptor(@NonNull AWSCredentialsProvider awsCredentialsProvider,
@NonNull String region, @NonNull String serviceName) {
this.okHttpAwsV4Signer = new OkHttpAwsV4Signer(region, serviceName);
this.accessKey = accessKey;
this.secretKey = secretKey;
this.awsCredentialsProvider = awsCredentialsProvider;
}

@Override
Expand All @@ -48,11 +51,21 @@ public Response intercept(Interceptor.Chain chain) throws IOException {
DateTimeFormatter timestampFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'")
.withZone(ZoneId.of("GMT"));

Request newRequest = request.newBuilder()

Request.Builder newRequestBuilder = request.newBuilder()
.addHeader("x-amz-date", timestampFormat.format(ZonedDateTime.now()))
.addHeader("host", request.url().host())
.build();
Request signed = okHttpAwsV4Signer.sign(newRequest, accessKey, secretKey);
.addHeader("host", request.url().host());

AWSCredentials awsCredentials = awsCredentialsProvider.getCredentials();
if (awsCredentialsProvider instanceof STSAssumeRoleSessionCredentialsProvider) {
newRequestBuilder.addHeader("x-amz-security-token",
((STSAssumeRoleSessionCredentialsProvider) awsCredentialsProvider)
.getCredentials()
.getSessionToken());
}
Request newRequest = newRequestBuilder.build();
Request signed = okHttpAwsV4Signer.sign(newRequest,
awsCredentials.getAWSAccessKeyId(), awsCredentials.getAWSSecretKey());
return chain.proceed(signed);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.opensearch.sql.prometheus.storage;

import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashSet;
Expand Down Expand Up @@ -75,7 +77,8 @@ private OkHttpClient getHttpClient(Map<String, String> config) {
} else if (AuthenticationType.AWSSIGV4AUTH.equals(authenticationType)) {
validateFieldsInConfig(config, Set.of(REGION, ACCESS_KEY, SECRET_KEY));
okHttpClient.addInterceptor(new AwsSigningInterceptor(
config.get(ACCESS_KEY), config.get(SECRET_KEY),
new AWSStaticCredentialsProvider(
new BasicAWSCredentials(config.get(ACCESS_KEY), config.get(SECRET_KEY))),
config.get(REGION), "aps"));
} else {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import lombok.SneakyThrows;
import okhttp3.Interceptor;
import okhttp3.Request;
Expand All @@ -33,13 +36,13 @@ public class AwsSigningInterceptorTest {
@Test
void testConstructors() {
Assertions.assertThrows(NullPointerException.class, () ->
new AwsSigningInterceptor(null, "secretKey", "us-east-1", "aps"));
new AwsSigningInterceptor(null, "us-east-1", "aps"));
Assertions.assertThrows(NullPointerException.class, () ->
new AwsSigningInterceptor("accessKey", null, "us-east-1", "aps"));
new AwsSigningInterceptor(getStaticAWSCredentialsProvider("accessKey", "secretKey"), null,
"aps"));
Assertions.assertThrows(NullPointerException.class, () ->
new AwsSigningInterceptor("accessKey", "secretKey", null, "aps"));
Assertions.assertThrows(NullPointerException.class, () ->
new AwsSigningInterceptor("accessKey", "secretKey", "us-east-1", null));
new AwsSigningInterceptor(getStaticAWSCredentialsProvider("accessKey", "secretKey"),
"us-east-1", null));
}

@Test
Expand All @@ -49,7 +52,9 @@ void testIntercept() {
.url("http://localhost:9090")
.build());
AwsSigningInterceptor awsSigningInterceptor
= new AwsSigningInterceptor("testAccessKey", "testSecretKey", "us-east-1", "aps");
=
new AwsSigningInterceptor(getStaticAWSCredentialsProvider("testAccessKey", "testSecretKey"),
"us-east-1", "aps");
awsSigningInterceptor.intercept(chain);
verify(chain).proceed(requestArgumentCaptor.capture());
Request request = requestArgumentCaptor.getValue();
Expand All @@ -58,4 +63,10 @@ void testIntercept() {
Assertions.assertNotNull(request.headers("host"));
}


private AWSCredentialsProvider getStaticAWSCredentialsProvider(String accessKey,
String secretKey) {
return new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey));
}

}

0 comments on commit 93a82a8

Please sign in to comment.