Skip to content

Commit

Permalink
Add default role and region configuration to the data-prepper-config.…
Browse files Browse the repository at this point in the history
…yaml via extensions (#4559)

Signed-off-by: Taylor Gray <tylgry@amazon.com>
  • Loading branch information
graytaylor0 authored May 22, 2024
1 parent c44dfbe commit 81e4058
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
package org.opensearch.dataprepper.aws.api;

import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;

import java.util.Optional;

/**
* An interface available to plugins via the AWS Plugin Extension which supplies
Expand All @@ -19,4 +22,10 @@ public interface AwsCredentialsSupplier {
* @return An {@link AwsCredentialsProvider} to use.
*/
AwsCredentialsProvider getProvider(AwsCredentialsOptions options);

/**
* Gets the default region if it is configured. Otherwise returns null
* @return Default {@link Region}
*/
Optional<Region> getDefaultRegion();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.dataprepper.plugins.aws;

import org.opensearch.dataprepper.model.annotations.DataPrepperExtensionPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.plugin.ExtensionPlugin;
import org.opensearch.dataprepper.model.plugin.ExtensionPoints;
Expand All @@ -13,12 +14,18 @@
* The {@link ExtensionPlugin} class which adds the AWS Plugin to
* Data Prepper as an extension plugin. Everything starts from here.
*/
@DataPrepperExtensionPlugin(modelType = AwsPluginConfig.class, rootKeyJsonPath = "/aws/configurations")
public class AwsPlugin implements ExtensionPlugin {
private final DefaultAwsCredentialsSupplier defaultAwsCredentialsSupplier;

private final AwsPluginConfig awsPluginConfig;

@DataPrepperPluginConstructor
public AwsPlugin() {
final CredentialsProviderFactory credentialsProviderFactory = new CredentialsProviderFactory();
public AwsPlugin(final AwsPluginConfig awsPluginConfig) {

this.awsPluginConfig = awsPluginConfig;

final CredentialsProviderFactory credentialsProviderFactory = new CredentialsProviderFactory(awsPluginConfig != null ? awsPluginConfig.getDefaultStsConfiguration() : new AwsStsConfiguration());
final CredentialsCache credentialsCache = new CredentialsCache();
defaultAwsCredentialsSupplier = new DefaultAwsCredentialsSupplier(credentialsProviderFactory, credentialsCache);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.aws;

import com.fasterxml.jackson.annotation.JsonProperty;

public class AwsPluginConfig {

@JsonProperty("default")
private AwsStsConfiguration defaultStsConfiguration = new AwsStsConfiguration();

public AwsStsConfiguration getDefaultStsConfiguration() {
return defaultStsConfiguration;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

package org.opensearch.dataprepper.plugins.aws;

import org.opensearch.dataprepper.model.plugin.PluginConfigPublisher;
import org.opensearch.dataprepper.model.plugin.PluginConfigObservable;
import org.opensearch.dataprepper.model.plugin.PluginConfigPublisher;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.aws;

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.Size;
import software.amazon.awssdk.regions.Region;

public class AwsStsConfiguration {

@JsonProperty("region")
@Size(min = 1, message = "Region cannot be empty string")
private String awsRegion;

@JsonProperty("sts_role_arn")
@Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters")
private String awsStsRoleArn;

public Region getAwsRegion() {
return awsRegion != null ? Region.of(awsRegion) : null;
}

public String getAwsStsRoleArn() {
return awsStsRoleArn;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,21 @@ class CredentialsProviderFactory {
static final long STS_CLIENT_BASE_BACKOFF_MILLIS = 1000L;
static final long STS_CLIENT_MAX_BACKOFF_MILLIS = 60000L;

private final AwsStsConfiguration defaultStsConfiguration;

public CredentialsProviderFactory(final AwsStsConfiguration defaultStsConfiguration) {
Objects.requireNonNull(defaultStsConfiguration);
this.defaultStsConfiguration = defaultStsConfiguration;
}

Region getDefaultRegion() {
return defaultStsConfiguration.getAwsRegion();
}

AwsCredentialsProvider providerFromOptions(final AwsCredentialsOptions credentialsOptions) {
Objects.requireNonNull(credentialsOptions);

if(credentialsOptions.getStsRoleArn() != null) {
if(credentialsOptions.getStsRoleArn() != null || defaultStsConfiguration.getAwsStsRoleArn() != null) {
return createStsCredentials(credentialsOptions);
}

Expand All @@ -48,13 +59,15 @@ AwsCredentialsProvider providerFromOptions(final AwsCredentialsOptions credentia

private AwsCredentialsProvider createStsCredentials(final AwsCredentialsOptions credentialsOptions) {

final String stsRoleArn = credentialsOptions.getStsRoleArn();
final String stsRoleArn = credentialsOptions.getStsRoleArn() == null ? defaultStsConfiguration.getAwsStsRoleArn() : credentialsOptions.getStsRoleArn();

validateStsRoleArn(stsRoleArn);

LOG.debug("Creating new AwsCredentialsProvider with role {}.", stsRoleArn);

final StsClient stsClient = createStsClient(credentialsOptions.getRegion());
final Region region = credentialsOptions.getRegion() == null ? defaultStsConfiguration.getAwsRegion() : credentialsOptions.getRegion();

final StsClient stsClient = createStsClient(region);

AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder()
.roleSessionName("Data-Prepper-" + UUID.randomUUID())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;

import java.util.Optional;

class DefaultAwsCredentialsSupplier implements AwsCredentialsSupplier {
private final CredentialsProviderFactory credentialsProviderFactory;
Expand All @@ -22,4 +25,9 @@ class DefaultAwsCredentialsSupplier implements AwsCredentialsSupplier {
public AwsCredentialsProvider getProvider(final AwsCredentialsOptions options) {
return credentialsCache.getOrCreate(options, () -> credentialsProviderFactory.providerFromOptions(options));
}

@Override
public Optional<Region> getDefaultRegion() {
return Optional.ofNullable(credentialsProviderFactory.getDefaultRegion());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.aws;

import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

public class AwsPluginConfigTest {

@Test
void testDefault() {
final AwsPluginConfig objectUnderTest = new AwsPluginConfig();

assertThat(objectUnderTest, notNullValue());
assertThat(objectUnderTest.getDefaultStsConfiguration(), notNullValue());
assertThat(objectUnderTest.getDefaultStsConfiguration().getAwsRegion(), nullValue());
assertThat(objectUnderTest.getDefaultStsConfiguration().getAwsStsRoleArn(), nullValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.dataprepper.plugins.aws;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
Expand All @@ -28,17 +29,29 @@
import static org.hamcrest.CoreMatchers.sameInstance;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class AwsPluginIT {
@Mock
private AwsPluginConfig awsPluginConfig;

@Mock
private ExtensionPoints extensionPoints;

@Mock
private ExtensionProvider.Context context;

@Mock
private AwsStsConfiguration awsDefaultStsConfiguration;

@BeforeEach
void setup() {
when(awsPluginConfig.getDefaultStsConfiguration()).thenReturn(awsDefaultStsConfiguration);
}

private AwsPlugin createObjectUnderTest() {
return new AwsPlugin();
return new AwsPlugin(awsPluginConfig);
}

@Test
Expand Down Expand Up @@ -78,6 +91,8 @@ void test_AwsPlugin_with_STS_role() {

@Test
void test_AwsPlugin_without_STS_role() {
when(awsDefaultStsConfiguration.getAwsStsRoleArn()).thenReturn(null);

createObjectUnderTest().apply(extensionPoints);

final ArgumentCaptor<ExtensionProvider<AwsCredentialsSupplier>> extensionProviderArgumentCaptor = ArgumentCaptor.forClass(ExtensionProvider.class);
Expand Down Expand Up @@ -108,6 +123,40 @@ void test_AwsPlugin_without_STS_role() {
assertThat(awsCredentialsProvider2, sameInstance(awsCredentialsProvider1));
}

@Test
void test_AwsPlugin_without_STS_role_and_with_default_role_uses_default_role() {
when(awsDefaultStsConfiguration.getAwsStsRoleArn()).thenReturn(createStsRole());

createObjectUnderTest().apply(extensionPoints);

final ArgumentCaptor<ExtensionProvider<AwsCredentialsSupplier>> extensionProviderArgumentCaptor = ArgumentCaptor.forClass(ExtensionProvider.class);
verify(extensionPoints).addExtensionProvider(extensionProviderArgumentCaptor.capture());

final ExtensionProvider<AwsCredentialsSupplier> extensionProvider = extensionProviderArgumentCaptor.getValue();

final Optional<AwsCredentialsSupplier> optionalSupplier = extensionProvider.provideInstance(context);
assertThat(optionalSupplier, notNullValue());
assertThat(optionalSupplier.isPresent(), equalTo(true));

final AwsCredentialsSupplier awsCredentialsSupplier = optionalSupplier.get();

final AwsCredentialsOptions awsCredentialsOptions1 = AwsCredentialsOptions.builder()
.withRegion(Region.US_EAST_1)
.build();

final AwsCredentialsProvider awsCredentialsProvider1 = awsCredentialsSupplier.getProvider(awsCredentialsOptions1);

assertThat(awsCredentialsProvider1, instanceOf(StsAssumeRoleCredentialsProvider.class));

final AwsCredentialsOptions awsCredentialsOptions2 = AwsCredentialsOptions.builder()
.withRegion(Region.US_EAST_1)
.build();

final AwsCredentialsProvider awsCredentialsProvider2 = awsCredentialsSupplier.getProvider(awsCredentialsOptions2);

assertThat(awsCredentialsProvider2, sameInstance(awsCredentialsProvider1));
}

private String createStsRole() {
return String.format("arn:aws:iam::123456789012:role/%s", UUID.randomUUID());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class AwsPluginTest {

@Mock
private AwsPluginConfig awsPluginConfig;

@Mock
private ExtensionPoints extensionPoints;

private AwsPlugin createObjectUnderTest() {
return new AwsPlugin();
return new AwsPlugin(awsPluginConfig);
}

@Test
void apply_should_addExtensionProvider() {
when(awsPluginConfig.getDefaultStsConfiguration()).thenReturn(new AwsStsConfiguration());

createObjectUnderTest().apply(extensionPoints);

final ArgumentCaptor<ExtensionProvider> extensionProviderArgumentCaptor =
Expand All @@ -39,4 +46,20 @@ void apply_should_addExtensionProvider() {

assertThat(actualExtensionProvider, instanceOf(AwsExtensionProvider.class));
}

@Test
void null_aws_plugin_config_applies_extensions_correctly() {
final AwsPlugin objectUnderTest = new AwsPlugin(null);

objectUnderTest.apply(extensionPoints);

final ArgumentCaptor<ExtensionProvider> extensionProviderArgumentCaptor =
ArgumentCaptor.forClass(ExtensionProvider.class);

verify(extensionPoints).addExtensionProvider(extensionProviderArgumentCaptor.capture());

final ExtensionProvider actualExtensionProvider = extensionProviderArgumentCaptor.getValue();

assertThat(actualExtensionProvider, instanceOf(AwsExtensionProvider.class));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.aws;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.regions.Region;

import java.util.List;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;

public class AwsStsConfigurationTest {

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

@ParameterizedTest
@MethodSource("getRegions")
void testStsConfiguration(final Region region) throws JsonProcessingException {

final String defaultConfigurationAsString = "{\"region\": \"" + region.toString() + "\", \"sts_role_arn\": \"arn:aws:iam::123456789012:role/test-role\"}";

final AwsStsConfiguration objectUnderTest = OBJECT_MAPPER.readValue(defaultConfigurationAsString, AwsStsConfiguration.class);

assertThat(objectUnderTest, notNullValue());
assertThat(objectUnderTest.getAwsStsRoleArn(), equalTo("arn:aws:iam::123456789012:role/test-role"));
assertThat(objectUnderTest.getAwsRegion(), equalTo(region));
}

private static List<Region> getRegions() {
return Region.regions();
}
}
Loading

0 comments on commit 81e4058

Please sign in to comment.