Skip to content

Commit

Permalink
Add support for assuming an ARN role
Browse files Browse the repository at this point in the history
Fixes #29
  • Loading branch information
Dan Torrey committed Aug 13, 2019
1 parent ae287ae commit aabba07
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.graylog2.plugin.system.NodeId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient;
Expand Down Expand Up @@ -46,7 +47,7 @@ public class KinesisConsumer implements Runnable {
private final Integer recordBatchSize;
private final ObjectMapper objectMapper;
private final AWSMessageType awsMessageType;
private final StaticCredentialsProvider credentialsProvider;
private final AwsCredentialsProvider credentialsProvider;
private final Consumer<byte[]> handleMessageCallback;
private Scheduler kinesisScheduler;

Expand All @@ -57,12 +58,9 @@ public class KinesisConsumer implements Runnable {
String kinesisStreamName,
AWSMessageType awsMessageType,
Region region,
String awsKey,
String awsSecret,
AwsCredentialsProvider credentialsProvider,
int recordBatchSize) {
Preconditions.checkArgument(StringUtils.isNotBlank(kinesisStreamName), "A Kinesis stream name is required.");
Preconditions.checkArgument(StringUtils.isNotBlank(awsKey), "An AWS key is required.");
Preconditions.checkArgument(StringUtils.isNotBlank(awsSecret), "An AWS secret is required.");
Preconditions.checkNotNull(region, "A Region is required.");
Preconditions.checkNotNull(awsMessageType, "A AWSMessageType is required.");

Expand All @@ -73,7 +71,7 @@ public class KinesisConsumer implements Runnable {
this.region = requireNonNull(region, "region");
this.objectMapper = objectMapper;
this.awsMessageType = awsMessageType;
this.credentialsProvider = AWSService.buildCredentialProvider(awsKey, awsSecret);
this.credentialsProvider = credentialsProvider;
this.recordBatchSize = recordBatchSize;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import com.codahale.metrics.MetricSet;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.eventbus.EventBus;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.inject.assistedinject.Assisted;
import org.apache.commons.lang.StringUtils;
import org.graylog.integrations.aws.AWSMessageType;
import org.graylog.integrations.aws.codecs.AWSCodec;
import org.graylog.integrations.aws.inputs.AWSInput;
import org.graylog.integrations.aws.service.AWSService;
import org.graylog2.plugin.LocalMetricRegistry;
import org.graylog2.plugin.configuration.Configuration;
Expand All @@ -26,7 +29,10 @@
import org.graylog2.plugin.system.NodeId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;

import javax.inject.Inject;
import java.util.Objects;
Expand Down Expand Up @@ -86,11 +92,31 @@ public void handleChangedThrottledState(boolean isThrottled) {
@Override
public void doLaunch(MessageInput input) throws MisfireException {

final Region region = Region.of(Objects.requireNonNull(configuration.getString(CK_AWS_REGION)));
final String assumeRoleArn = configuration.getString(AWSInput.CK_ASSUME_ROLE_ARN);
final String key = configuration.getString(CK_ACCESS_KEY);
final String secret = configuration.getString(CK_SECRET_KEY);
Preconditions.checkArgument(StringUtils.isNotBlank(key), "An AWS key is required.");
AwsCredentialsProvider awsCredentialsProvider = AWSService.buildCredentialProvider(key, secret);
Preconditions.checkArgument(StringUtils.isNotBlank(secret), "An AWS secret is required.");

// Assume role ARN functionality only applies to the Kinesis runtime (not to the setup flows).
if (StringUtils.isNotBlank(assumeRoleArn)) {
StsClient stsClient = StsClient.builder()
.region(region)
.credentialsProvider(awsCredentialsProvider).build();
String roleSessionName = String.format("API_KEY_%s@ACCOUNT_%s",
key,
stsClient.getCallerIdentity().account());
awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(r -> r.roleArn(assumeRoleArn)
.roleSessionName(roleSessionName)).build();
}
this.kinesisConsumer = new KinesisConsumer(
nodeId, this, objectMapper, kinesisCallback(input), configuration.getString(CK_KINESIS_STREAM_NAME),
AWSMessageType.valueOf(configuration.getString(AWSCodec.CK_AWS_MESSAGE_TYPE)), Region.of(Objects.requireNonNull(configuration.getString(CK_AWS_REGION))),
configuration.getString(CK_ACCESS_KEY),
configuration.getString(CK_SECRET_KEY),
AWSMessageType.valueOf(configuration.getString(AWSCodec.CK_AWS_MESSAGE_TYPE)), region,
awsCredentialsProvider,
configuration.getInt(CK_KINESIS_RECORD_BATCH_SIZE, DEFAULT_BATCH_SIZE)
);

Expand Down

0 comments on commit aabba07

Please sign in to comment.