diff --git a/src/main/java/org/graylog/integrations/aws/transports/KinesisConsumer.java b/src/main/java/org/graylog/integrations/aws/transports/KinesisConsumer.java index 4ee4eea39..0ade3a3dc 100644 --- a/src/main/java/org/graylog/integrations/aws/transports/KinesisConsumer.java +++ b/src/main/java/org/graylog/integrations/aws/transports/KinesisConsumer.java @@ -9,6 +9,7 @@ import org.graylog2.plugin.system.NodeId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; @@ -46,7 +47,7 @@ public class KinesisConsumer implements Runnable { private final Integer recordBatchSize; private final ObjectMapper objectMapper; private final AWSMessageType awsMessageType; - private final StaticCredentialsProvider credentialsProvider; + private final AwsCredentialsProvider credentialsProvider; private final Consumer handleMessageCallback; private Scheduler kinesisScheduler; @@ -57,12 +58,9 @@ public class KinesisConsumer implements Runnable { String kinesisStreamName, AWSMessageType awsMessageType, Region region, - String awsKey, - String awsSecret, + AwsCredentialsProvider credentialsProvider, int recordBatchSize) { Preconditions.checkArgument(StringUtils.isNotBlank(kinesisStreamName), "A Kinesis stream name is required."); - Preconditions.checkArgument(StringUtils.isNotBlank(awsKey), "An AWS key is required."); - Preconditions.checkArgument(StringUtils.isNotBlank(awsSecret), "An AWS secret is required."); Preconditions.checkNotNull(region, "A Region is required."); Preconditions.checkNotNull(awsMessageType, "A AWSMessageType is required."); @@ -73,7 +71,7 @@ public class KinesisConsumer implements Runnable { this.region = requireNonNull(region, "region"); this.objectMapper = objectMapper; this.awsMessageType = awsMessageType; - this.credentialsProvider = AWSService.buildCredentialProvider(awsKey, awsSecret); + this.credentialsProvider = credentialsProvider; this.recordBatchSize = recordBatchSize; } diff --git a/src/main/java/org/graylog/integrations/aws/transports/KinesisTransport.java b/src/main/java/org/graylog/integrations/aws/transports/KinesisTransport.java index e21f7a17c..c262e9eb9 100644 --- a/src/main/java/org/graylog/integrations/aws/transports/KinesisTransport.java +++ b/src/main/java/org/graylog/integrations/aws/transports/KinesisTransport.java @@ -2,11 +2,14 @@ import com.codahale.metrics.MetricSet; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Preconditions; import com.google.common.eventbus.EventBus; import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.inject.assistedinject.Assisted; +import org.apache.commons.lang.StringUtils; import org.graylog.integrations.aws.AWSMessageType; import org.graylog.integrations.aws.codecs.AWSCodec; +import org.graylog.integrations.aws.inputs.AWSInput; import org.graylog.integrations.aws.service.AWSService; import org.graylog2.plugin.LocalMetricRegistry; import org.graylog2.plugin.configuration.Configuration; @@ -26,7 +29,10 @@ import org.graylog2.plugin.system.NodeId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; import javax.inject.Inject; import java.util.Objects; @@ -86,11 +92,31 @@ public void handleChangedThrottledState(boolean isThrottled) { @Override public void doLaunch(MessageInput input) throws MisfireException { + final Region region = Region.of(Objects.requireNonNull(configuration.getString(CK_AWS_REGION))); + final String assumeRoleArn = configuration.getString(AWSInput.CK_ASSUME_ROLE_ARN); + final String key = configuration.getString(CK_ACCESS_KEY); + final String secret = configuration.getString(CK_SECRET_KEY); + Preconditions.checkArgument(StringUtils.isNotBlank(key), "An AWS key is required."); + AwsCredentialsProvider awsCredentialsProvider = AWSService.buildCredentialProvider(key, secret); + Preconditions.checkArgument(StringUtils.isNotBlank(secret), "An AWS secret is required."); + + // Assume role ARN functionality only applies to the Kinesis runtime (not to the setup flows). + if (StringUtils.isNotBlank(assumeRoleArn)) { + StsClient stsClient = StsClient.builder() + .region(region) + .credentialsProvider(awsCredentialsProvider).build(); + String roleSessionName = String.format("API_KEY_%s@ACCOUNT_%s", + key, + stsClient.getCallerIdentity().account()); + awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder() + .stsClient(stsClient) + .refreshRequest(r -> r.roleArn(assumeRoleArn) + .roleSessionName(roleSessionName)).build(); + } this.kinesisConsumer = new KinesisConsumer( nodeId, this, objectMapper, kinesisCallback(input), configuration.getString(CK_KINESIS_STREAM_NAME), - AWSMessageType.valueOf(configuration.getString(AWSCodec.CK_AWS_MESSAGE_TYPE)), Region.of(Objects.requireNonNull(configuration.getString(CK_AWS_REGION))), - configuration.getString(CK_ACCESS_KEY), - configuration.getString(CK_SECRET_KEY), + AWSMessageType.valueOf(configuration.getString(AWSCodec.CK_AWS_MESSAGE_TYPE)), region, + awsCredentialsProvider, configuration.getInt(CK_KINESIS_RECORD_BATCH_SIZE, DEFAULT_BATCH_SIZE) );