From 6c12571743beff6f496a1ae928b3d2ab464d5992 Mon Sep 17 00:00:00 2001 From: Raghuvansh Raj Date: Wed, 24 May 2023 20:18:20 +0530 Subject: [PATCH] Add S3 async upload utilities and models Signed-off-by: Raghuvansh Raj --- .../s3/AmazonAsyncS3Reference.java | 37 ++ .../s3/AmazonAsyncS3WithCredentials.java | 52 ++ .../repositories/s3/S3AsyncService.java | 476 ++++++++++++++++++ .../repositories/s3/S3ClientSettings.java | 123 +++++ .../repositories/s3/SocketAccess.java | 2 +- .../s3/async/AsyncExecutorBuilder.java | 46 ++ .../s3/async/AsyncUploadUtils.java | 398 +++++++++++++++ .../s3/async/OpenSearchThreadFactory.java | 37 ++ .../s3/async/TransferNIOGroup.java | 61 +++ .../repositories/s3/async/UploadRequest.java | 82 +++ .../plugin-metadata/plugin-security.policy | 7 + .../repositories/s3/S3AsyncServiceTests.java | 95 ++++ 12 files changed, 1415 insertions(+), 1 deletion(-) create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3Reference.java create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3WithCredentials.java create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3AsyncService.java create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncExecutorBuilder.java create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncUploadUtils.java create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/OpenSearchThreadFactory.java create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/TransferNIOGroup.java create mode 100644 plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java create mode 100644 plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3AsyncServiceTests.java diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3Reference.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3Reference.java new file mode 100644 index 0000000000000..c1716d2a20d18 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3Reference.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import org.opensearch.common.concurrent.RefCountedReleasable; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; + +import java.io.Closeable; +import java.io.IOException; + +/** + * Handles the shutdown of the wrapped {@link software.amazon.awssdk.services.s3.S3AsyncClient} using reference + * counting. + */ +public class AmazonAsyncS3Reference extends RefCountedReleasable { + + AmazonAsyncS3Reference(AmazonAsyncS3WithCredentials client) { + super("AWS_S3_CLIENT", client, () -> { + client.client().close(); + client.priorityClient().close(); + AwsCredentialsProvider credentials = client.credentials(); + if (credentials instanceof Closeable) { + try { + ((Closeable) credentials).close(); + } catch (IOException e) { + /* Do nothing here */ + } + } + }); + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3WithCredentials.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3WithCredentials.java new file mode 100644 index 0000000000000..15f104f51a067 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3WithCredentials.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import org.opensearch.common.Nullable; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.services.s3.S3AsyncClient; + +/** + * The holder of the AmazonS3 and AWSCredentialsProvider + */ +final class AmazonAsyncS3WithCredentials { + private final S3AsyncClient client; + private final S3AsyncClient priorityClient; + private final AwsCredentialsProvider credentials; + + private AmazonAsyncS3WithCredentials( + final S3AsyncClient client, + final S3AsyncClient priorityClient, + @Nullable final AwsCredentialsProvider credentials + ) { + this.client = client; + this.credentials = credentials; + this.priorityClient = priorityClient; + } + + S3AsyncClient client() { + return client; + } + + S3AsyncClient priorityClient() { + return priorityClient; + } + + AwsCredentialsProvider credentials() { + return credentials; + } + + static AmazonAsyncS3WithCredentials create( + final S3AsyncClient client, + final S3AsyncClient priorityClient, + @Nullable final AwsCredentialsProvider credentials + ) { + return new AmazonAsyncS3WithCredentials(client, priorityClient, credentials); + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3AsyncService.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3AsyncService.java new file mode 100644 index 0000000000000..1ac45d9b8f175 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3AsyncService.java @@ -0,0 +1,476 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import com.amazonaws.auth.AWSSessionCredentials; +import com.amazonaws.auth.AWSSessionCredentialsProvider; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; +import com.amazonaws.auth.STSAssumeRoleWithWebIdentitySessionCredentialsProvider; +import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; +import com.amazonaws.http.IdleConnectionReaper; +import com.amazonaws.services.s3.internal.Constants; +import com.amazonaws.services.securitytoken.AWSSecurityTokenService; +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.Nullable; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.collect.MapBuilder; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.repositories.s3.S3ClientSettings.IrsaCredentials; +import org.opensearch.repositories.s3.async.AsyncExecutorBuilder; +import org.opensearch.repositories.s3.async.TransferNIOGroup; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider; +import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.SdkSystemSetting; +import software.amazon.awssdk.core.client.config.ClientAsyncConfiguration; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.core.retry.backoff.BackoffStrategy; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.http.nio.netty.ProxyConfiguration; +import software.amazon.awssdk.http.nio.netty.SdkEventLoopGroup; +import software.amazon.awssdk.profiles.ProfileFileSystemSetting; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; + +import java.io.Closeable; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Path; +import java.time.Duration; +import java.util.Map; + +import static com.amazonaws.SDKGlobalConfiguration.AWS_ROLE_ARN_ENV_VAR; +import static com.amazonaws.SDKGlobalConfiguration.AWS_ROLE_SESSION_NAME_ENV_VAR; +import static com.amazonaws.SDKGlobalConfiguration.AWS_WEB_IDENTITY_ENV_VAR; +import static java.util.Collections.emptyMap; + +class S3AsyncService implements Closeable { + private static final Logger logger = LogManager.getLogger(S3AsyncService.class); + + private static final String STS_ENDPOINT_OVERRIDE_SYSTEM_PROPERTY = "com.amazonaws.sdk.stsEndpointOverride"; + + private volatile Map clientsCache = emptyMap(); + + /** + * Client settings calculated from static configuration and settings in the keystore. + */ + private volatile Map staticClientSettings; + + /** + * Client settings derived from those in {@link #staticClientSettings} by combining them with settings + * in the {@link RepositoryMetadata}. + */ + private volatile Map derivedClientSettings = emptyMap(); + + S3AsyncService(final Path configPath) { + staticClientSettings = MapBuilder.newMapBuilder() + .put("default", S3ClientSettings.getClientSettings(Settings.EMPTY, "default", configPath)) + .immutableMap(); + } + + /** + * Refreshes the settings for the AmazonS3 clients and clears the cache of + * existing clients. New clients will be build using these new settings. Old + * clients are usable until released. On release they will be destroyed instead + * of being returned to the cache. + */ + public synchronized void refreshAndClearCache(Map clientsSettings) { + // shutdown all unused clients + // others will shutdown on their respective release + releaseCachedClients(); + this.staticClientSettings = MapBuilder.newMapBuilder(clientsSettings).immutableMap(); + derivedClientSettings = emptyMap(); + assert this.staticClientSettings.containsKey("default") : "always at least have 'default'"; + // clients are built lazily by {@link client} + } + + /** + * Attempts to retrieve a client by its repository metadata and settings from the cache. + * If the client does not exist it will be created. + */ + public AmazonAsyncS3Reference client( + RepositoryMetadata repositoryMetadata, + AsyncExecutorBuilder priorityExecutorBuilder, + AsyncExecutorBuilder normalExecutorBuilder + ) { + final S3ClientSettings clientSettings = settings(repositoryMetadata); + { + final AmazonAsyncS3Reference clientReference = clientsCache.get(clientSettings); + if (clientReference != null && clientReference.tryIncRef()) { + return clientReference; + } + } + synchronized (this) { + final AmazonAsyncS3Reference existing = clientsCache.get(clientSettings); + if (existing != null && existing.tryIncRef()) { + return existing; + } + final AmazonAsyncS3Reference clientReference = new AmazonAsyncS3Reference( + buildClient(clientSettings, priorityExecutorBuilder, normalExecutorBuilder) + ); + clientReference.incRef(); + clientsCache = MapBuilder.newMapBuilder(clientsCache).put(clientSettings, clientReference).immutableMap(); + return clientReference; + } + } + + /** + * Either fetches {@link S3ClientSettings} for a given {@link RepositoryMetadata} from cached settings or creates them + * by overriding static client settings from {@link #staticClientSettings} with settings found in the repository metadata. + * @param repositoryMetadata Repository Metadata + * @return S3ClientSettings + */ + S3ClientSettings settings(RepositoryMetadata repositoryMetadata) { + final Settings settings = repositoryMetadata.settings(); + { + final S3ClientSettings existing = derivedClientSettings.get(settings); + if (existing != null) { + return existing; + } + } + final String clientName = S3Repository.CLIENT_NAME.get(settings); + final S3ClientSettings staticSettings = staticClientSettings.get(clientName); + if (staticSettings != null) { + synchronized (this) { + final S3ClientSettings existing = derivedClientSettings.get(settings); + if (existing != null) { + return existing; + } + final S3ClientSettings newSettings = staticSettings.refine(settings); + derivedClientSettings = MapBuilder.newMapBuilder(derivedClientSettings).put(settings, newSettings).immutableMap(); + return newSettings; + } + } + throw new IllegalArgumentException( + "Unknown s3 client name [" + + clientName + + "]. Existing client configs: " + + Strings.collectionToDelimitedString(staticClientSettings.keySet(), ",") + ); + } + + // proxy for testing + synchronized AmazonAsyncS3WithCredentials buildClient( + final S3ClientSettings clientSettings, + AsyncExecutorBuilder priorityExecutorBuilder, + AsyncExecutorBuilder normalExecutorBuilder + ) { + setDefaultAwsProfilePath(); + final S3AsyncClientBuilder builder = S3AsyncClient.builder(); + builder.overrideConfiguration(buildOverrideConfiguration(clientSettings)); + final AwsCredentialsProvider credentials = buildCredentials(logger, clientSettings); + builder.credentialsProvider(credentials); + + String endpoint = Strings.hasLength(clientSettings.endpoint) ? clientSettings.endpoint : Constants.S3_HOSTNAME; + if ((endpoint.startsWith("http://") || endpoint.startsWith("https://")) == false) { + // Manually add the schema to the endpoint to work around https://github.com/aws/aws-sdk-java/issues/2274 + endpoint = clientSettings.protocol.toString() + "://" + endpoint; + } + final String region = Strings.hasLength(clientSettings.region) ? clientSettings.region : null; + builder.region(region != null ? Region.of(region) : Region.US_EAST_1); + logger.debug("using endpoint [{}] and region [{}]", endpoint, region); + + // If the endpoint configuration isn't set on the builder then the default behaviour is to try + // and work out what region we are in and use an appropriate endpoint - see AwsClientBuilder#setRegion. + // In contrast, directly-constructed clients use s3.amazonaws.com unless otherwise instructed. We currently + // use a directly-constructed client, and need to keep the existing behaviour to avoid a breaking change, + // so to move to using the builder we must set it explicitly to keep the existing behaviour. + // + // We do this because directly constructing the client is deprecated (was already deprecated in 1.1.223 too) + // so this change removes that usage of a deprecated API. + builder.endpointOverride(URI.create(endpoint)); + if (clientSettings.pathStyleAccess) { + builder.forcePathStyle(true); + } + + builder.httpClient(buildHttpClient(clientSettings, priorityExecutorBuilder.getTransferNIOGroup())); + builder.asyncConfiguration( + ClientAsyncConfiguration.builder() + .advancedOption( + SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR, + priorityExecutorBuilder.getFutureCompletionExecutor() + ) + .build() + ); + final S3AsyncClient priorityClient = SocketAccess.doPrivileged(builder::build); + + builder.httpClient(buildHttpClient(clientSettings, normalExecutorBuilder.getTransferNIOGroup())); + builder.asyncConfiguration( + ClientAsyncConfiguration.builder() + .advancedOption( + SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR, + normalExecutorBuilder.getFutureCompletionExecutor() + ) + .build() + ); + final S3AsyncClient client = SocketAccess.doPrivileged(builder::build); + + return AmazonAsyncS3WithCredentials.create(client, priorityClient, credentials); + } + + static ClientOverrideConfiguration buildOverrideConfiguration(final S3ClientSettings clientSettings) { + return ClientOverrideConfiguration.builder() + .retryPolicy( + RetryPolicy.builder() + .numRetries(clientSettings.maxRetries) + .throttlingBackoffStrategy( + clientSettings.throttleRetries ? BackoffStrategy.defaultThrottlingStrategy() : BackoffStrategy.none() + ) + .build() + ) + .apiCallAttemptTimeout(Duration.ofMillis(clientSettings.requestTimeoutMillis)) + .build(); + } + + // pkg private for tests + static SdkAsyncHttpClient buildHttpClient(S3ClientSettings clientSettings, TransferNIOGroup transferNIOGroup) { + // the response metadata cache is only there for diagnostics purposes, + // but can force objects from every response to the old generation. + NettyNioAsyncHttpClient.Builder clientBuilder = NettyNioAsyncHttpClient.builder(); + + if (clientSettings.proxySettings.getType() != ProxySettings.ProxyType.DIRECT) { + ProxyConfiguration.Builder proxyConfiguration = ProxyConfiguration.builder(); + proxyConfiguration.scheme(clientSettings.proxySettings.getType().toProtocol().toString()); + proxyConfiguration.host(clientSettings.proxySettings.getHostName()); + proxyConfiguration.port(clientSettings.proxySettings.getPort()); + proxyConfiguration.username(clientSettings.proxySettings.getUsername()); + proxyConfiguration.password(clientSettings.proxySettings.getPassword()); + clientBuilder.proxyConfiguration(proxyConfiguration.build()); + } + + // TODO: add max retry and UseThrottleRetry. Replace values with settings and put these in default settings + clientBuilder.connectionTimeout(Duration.ofMillis(clientSettings.connectionTimeoutMillis)); + clientBuilder.connectionAcquisitionTimeout(Duration.ofMillis(clientSettings.connectionAcquisitionTimeoutMillis)); + clientBuilder.maxPendingConnectionAcquires(clientSettings.maxPendingConnectionAcquires); + clientBuilder.maxConcurrency(clientSettings.maxConnections); + clientBuilder.eventLoopGroup(SdkEventLoopGroup.create(transferNIOGroup.getEventLoopGroup())); + clientBuilder.tcpKeepAlive(true); + + return clientBuilder.build(); + } + + // pkg private for tests + static AwsCredentialsProvider buildCredentials(Logger logger, S3ClientSettings clientSettings) { + final S3BasicCredentials basicCredentials = clientSettings.credentials; + final IrsaCredentials irsaCredentials = buildFromEnvironment(clientSettings.irsaCredentials); + + // If IAM Roles for Service Accounts (IRSA) credentials are configured, start with them first + if (irsaCredentials != null) { + logger.debug("Using IRSA credentials"); + AWSSecurityTokenService securityTokenService = null; + final String region = Strings.hasLength(clientSettings.region) ? clientSettings.region : null; + + if (region != null || basicCredentials != null) { + securityTokenService = SocketAccess.doPrivileged(() -> { + AWSSecurityTokenServiceClientBuilder builder = AWSSecurityTokenServiceClientBuilder.standard(); + + // Use similar approach to override STS endpoint as SDKGlobalConfiguration.EC2_METADATA_SERVICE_OVERRIDE_SYSTEM_PROPERTY + final String stsEndpoint = System.getProperty(STS_ENDPOINT_OVERRIDE_SYSTEM_PROPERTY); + if (region != null && stsEndpoint != null) { + builder = builder.withEndpointConfiguration(new EndpointConfiguration(stsEndpoint, region)); + } else { + builder = builder.withRegion(region); + } + + if (basicCredentials != null) { + builder = builder.withCredentials(new AWSStaticCredentialsProvider(basicCredentials)); + } + + return builder.build(); + }); + } + + if (irsaCredentials.getIdentityTokenFile() == null) { + final STSAssumeRoleSessionCredentialsProvider.Builder stsCredentialsProviderBuilder = + new STSAssumeRoleSessionCredentialsProvider.Builder(irsaCredentials.getRoleArn(), irsaCredentials.getRoleSessionName()) + .withStsClient(securityTokenService); + + final STSAssumeRoleSessionCredentialsProvider stsCredentialsProvider = SocketAccess.doPrivileged( + stsCredentialsProviderBuilder::build + ); + + return new PrivilegedSTSAssumeRoleSessionCredentialsProvider<>( + securityTokenService, + new SessionsCredsWrapper<>(stsCredentialsProvider) + ); + } else { + final STSAssumeRoleWithWebIdentitySessionCredentialsProvider.Builder stsCredentialsProviderBuilder = + new STSAssumeRoleWithWebIdentitySessionCredentialsProvider.Builder( + irsaCredentials.getRoleArn(), + irsaCredentials.getRoleSessionName(), + irsaCredentials.getIdentityTokenFile() + ).withStsClient(securityTokenService); + + final STSAssumeRoleWithWebIdentitySessionCredentialsProvider stsCredentialsProvider = SocketAccess.doPrivileged( + stsCredentialsProviderBuilder::build + ); + + return new PrivilegedSTSAssumeRoleSessionCredentialsProvider<>( + securityTokenService, + new SessionsCredsWrapper<>(stsCredentialsProvider) + ); + } + } else if (basicCredentials != null) { + logger.debug("Using basic key/secret credentials"); + return StaticCredentialsProvider.create( + AwsBasicCredentials.create(basicCredentials.getAWSAccessKeyId(), basicCredentials.getAWSSecretKey()) + ); + } else { + logger.debug("Using instance profile credentials"); + return new PrivilegedInstanceProfileCredentialsProvider(); + } + } + + // Aws v2 sdk tries to load a default profile from home path which is restricted. Hence, setting these to random + // valid paths. + @SuppressForbidden(reason = "Need to provide this override to v2 SDK so that path does not default to home path") + private static void setDefaultAwsProfilePath() { + if (ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.getStringValue().isEmpty()) { + System.setProperty(ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.property(), System.getProperty("opensearch.path.conf")); + } + if (ProfileFileSystemSetting.AWS_CONFIG_FILE.getStringValue().isEmpty()) { + System.setProperty(ProfileFileSystemSetting.AWS_CONFIG_FILE.property(), System.getProperty("opensearch.path.conf")); + } + } + + static class SessionsCredsWrapper

implements AwsCredentialsProvider, Closeable { + + private final P sessionCredentialsProvider; + + public SessionsCredsWrapper(P sessionCredentialsProvider) { + this.sessionCredentialsProvider = sessionCredentialsProvider; + } + + @Override + public AwsCredentials resolveCredentials() { + AWSSessionCredentials sessionCredentials = sessionCredentialsProvider.getCredentials(); + return AwsSessionCredentials.create( + sessionCredentials.getAWSAccessKeyId(), + sessionCredentials.getAWSSecretKey(), + sessionCredentials.getSessionToken() + ); + } + + @Override + public void close() throws IOException { + sessionCredentialsProvider.close(); + } + } + + private static IrsaCredentials buildFromEnvironment(IrsaCredentials defaults) { + if (defaults == null) { + return null; + } + + String webIdentityTokenFile = defaults.getIdentityTokenFile(); + if (webIdentityTokenFile == null) { + webIdentityTokenFile = System.getenv(AWS_WEB_IDENTITY_ENV_VAR); + } + + String roleArn = defaults.getRoleArn(); + if (roleArn == null) { + roleArn = System.getenv(AWS_ROLE_ARN_ENV_VAR); + } + + String roleSessionName = defaults.getRoleSessionName(); + if (roleSessionName == null) { + roleSessionName = System.getenv(AWS_ROLE_SESSION_NAME_ENV_VAR); + } + + return new IrsaCredentials(webIdentityTokenFile, roleArn, roleSessionName); + } + + private synchronized void releaseCachedClients() { + // the clients will shutdown when they will not be used anymore + for (final AmazonAsyncS3Reference clientReference : clientsCache.values()) { + clientReference.decRef(); + } + + // clear previously cached clients, they will be build lazily + clientsCache = emptyMap(); + derivedClientSettings = emptyMap(); + + // shutdown IdleConnectionReaper background thread + // it will be restarted on new client usage + IdleConnectionReaper.shutdown(); + } + + static class PrivilegedInstanceProfileCredentialsProvider implements AwsCredentialsProvider { + private final AwsCredentialsProvider credentials; + + private PrivilegedInstanceProfileCredentialsProvider() { + this.credentials = initializeProvider(); + } + + private AwsCredentialsProvider initializeProvider() { + if (SdkSystemSetting.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI.getStringValue().isPresent() + || SdkSystemSetting.AWS_CONTAINER_CREDENTIALS_FULL_URI.getStringValue().isPresent()) { + + return ContainerCredentialsProvider.builder().asyncCredentialUpdateEnabled(true).build(); + } + // InstanceProfileCredentialsProvider as last item of chain + return InstanceProfileCredentialsProvider.builder().asyncCredentialUpdateEnabled(true).build(); + } + + @Override + public AwsCredentials resolveCredentials() { + return SocketAccess.doPrivileged(credentials::resolveCredentials); + } + } + + static class PrivilegedSTSAssumeRoleSessionCredentialsProvider

+ implements + AwsCredentialsProvider, + Closeable { + private final P credentials; + private final AWSSecurityTokenService securityTokenService; + + private PrivilegedSTSAssumeRoleSessionCredentialsProvider( + @Nullable final AWSSecurityTokenService securityTokenService, + final P credentials + ) { + this.securityTokenService = securityTokenService; + this.credentials = credentials; + } + + @Override + public AwsCredentials resolveCredentials() { + return SocketAccess.doPrivileged(credentials::resolveCredentials); + } + + @Override + public void close() throws IOException { + SocketAccess.doPrivilegedIOException(() -> { + credentials.close(); + if (securityTokenService != null) { + securityTokenService.shutdown(); + } + return null; + }); + } + } + + @Override + public void close() { + releaseCachedClients(); + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3ClientSettings.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3ClientSettings.java index ba7535dd78f68..fbdbe77ba29ed 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3ClientSettings.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3ClientSettings.java @@ -170,6 +170,48 @@ final class S3ClientSettings { key -> Setting.timeSetting(key, TimeValue.timeValueMillis(ClientConfiguration.DEFAULT_SOCKET_TIMEOUT), Property.NodeScope) ); + /** The request timeout for connecting to s3. */ + static final Setting.AffixSetting REQUEST_TIMEOUT_SETTING = Setting.affixKeySetting( + PREFIX, + "request_timeout", + key -> Setting.timeSetting(key, TimeValue.timeValueMinutes(2), Property.NodeScope) + ); + + /** The connection timeout for connecting to s3. */ + static final Setting.AffixSetting CONNECTION_TIMEOUT_SETTING = Setting.affixKeySetting( + PREFIX, + "connection_timeout", + key -> Setting.timeSetting(key, TimeValue.timeValueSeconds(10), Property.NodeScope) + ); + + /** The connection TTL for connecting to s3. */ + static final Setting.AffixSetting CONNECTION_TTL_SETTING = Setting.affixKeySetting( + PREFIX, + "connection_ttl", + key -> Setting.timeSetting(key, TimeValue.timeValueMillis(5000), Property.NodeScope) + ); + + /** The maximum connections to s3. */ + static final Setting.AffixSetting MAX_CONNECTIONS_SETTING = Setting.affixKeySetting( + PREFIX, + "max_connections", + key -> Setting.intSetting(key, 100, Property.NodeScope) + ); + + /** Connection acquisition timeout for new connections to S3. */ + static final Setting.AffixSetting CONNECTION_ACQUISITION_TIMEOUT = Setting.affixKeySetting( + PREFIX, + "connection_acquisition_timeout", + key -> Setting.timeSetting(key, TimeValue.timeValueMinutes(2), Property.NodeScope) + ); + + /** The maximum pending connections to S3. */ + static final Setting.AffixSetting MAX_PENDING_CONNECTION_ACQUIRES = Setting.affixKeySetting( + PREFIX, + "max_pending_connection_acquires", + key -> Setting.intSetting(key, 10_000, Property.NodeScope) + ); + /** The number of retries to use when an s3 request fails. */ static final Setting.AffixSetting MAX_RETRIES_SETTING = Setting.affixKeySetting( PREFIX, @@ -230,6 +272,24 @@ final class S3ClientSettings { /** The read timeout for the s3 client. */ final int readTimeoutMillis; + /** The request timeout for the s3 client */ + final int requestTimeoutMillis; + + /** The connection timeout for the s3 client */ + final int connectionTimeoutMillis; + + /** The connection TTL for the s3 client */ + final int connectionTTLMillis; + + /** The max number of connections for the s3 client */ + final int maxConnections; + + /** The connnection acquisition timeout for the s3 async client */ + final int connectionAcquisitionTimeoutMillis; + + /** The max number of requests pending to acquire connection for the s3 async client */ + final int maxPendingConnectionAcquires; + /** The number of retries to use for the s3 client. */ final int maxRetries; @@ -254,6 +314,12 @@ private S3ClientSettings( String endpoint, Protocol protocol, int readTimeoutMillis, + int requestTimeoutMillis, + int connectionTimeoutMillis, + int connectionTTLMillis, + int maxConnections, + int connectionAcquisitionTimeoutMillis, + int maxPendingConnectionAcquires, int maxRetries, boolean throttleRetries, boolean pathStyleAccess, @@ -267,6 +333,12 @@ private S3ClientSettings( this.endpoint = endpoint; this.protocol = protocol; this.readTimeoutMillis = readTimeoutMillis; + this.requestTimeoutMillis = requestTimeoutMillis; + this.connectionTimeoutMillis = connectionTimeoutMillis; + this.connectionTTLMillis = connectionTTLMillis; + this.maxConnections = maxConnections; + this.connectionAcquisitionTimeoutMillis = connectionAcquisitionTimeoutMillis; + this.maxPendingConnectionAcquires = maxPendingConnectionAcquires; this.maxRetries = maxRetries; this.throttleRetries = throttleRetries; this.pathStyleAccess = pathStyleAccess; @@ -298,6 +370,27 @@ S3ClientSettings refine(Settings repositorySettings) { final int newReadTimeoutMillis = Math.toIntExact( getRepoSettingOrDefault(READ_TIMEOUT_SETTING, normalizedSettings, TimeValue.timeValueMillis(readTimeoutMillis)).millis() ); + final int newRequestTimeoutMillis = Math.toIntExact( + getRepoSettingOrDefault(REQUEST_TIMEOUT_SETTING, normalizedSettings, TimeValue.timeValueMillis(requestTimeoutMillis)).millis() + ); + final int newConnectionTimeoutMillis = Math.toIntExact( + getRepoSettingOrDefault(CONNECTION_TIMEOUT_SETTING, normalizedSettings, TimeValue.timeValueMillis(connectionTimeoutMillis)) + .millis() + ); + final int newConnectionTTLMillis = Math.toIntExact( + getRepoSettingOrDefault(CONNECTION_TTL_SETTING, normalizedSettings, TimeValue.timeValueMillis(connectionTTLMillis)).millis() + ); + final int newConnectionAcquisitionTimeoutMillis = Math.toIntExact( + getRepoSettingOrDefault( + CONNECTION_ACQUISITION_TIMEOUT, + normalizedSettings, + TimeValue.timeValueMillis(connectionAcquisitionTimeoutMillis) + ).millis() + ); + final int newMaxConnections = Math.toIntExact(getRepoSettingOrDefault(MAX_CONNECTIONS_SETTING, normalizedSettings, maxConnections)); + final int newMaxPendingConnectionAcquires = Math.toIntExact( + getRepoSettingOrDefault(MAX_PENDING_CONNECTION_ACQUIRES, normalizedSettings, maxPendingConnectionAcquires) + ); final int newMaxRetries = getRepoSettingOrDefault(MAX_RETRIES_SETTING, normalizedSettings, maxRetries); final boolean newThrottleRetries = getRepoSettingOrDefault(USE_THROTTLE_RETRIES_SETTING, normalizedSettings, throttleRetries); final boolean newPathStyleAccess = getRepoSettingOrDefault(USE_PATH_STYLE_ACCESS, normalizedSettings, pathStyleAccess); @@ -319,6 +412,12 @@ S3ClientSettings refine(Settings repositorySettings) { && Objects.equals(proxySettings.getHostName(), newProxyHost) && proxySettings.getPort() == newProxyPort && newReadTimeoutMillis == readTimeoutMillis + && newRequestTimeoutMillis == requestTimeoutMillis + && newConnectionTimeoutMillis == connectionTimeoutMillis + && newConnectionTTLMillis == connectionTTLMillis + && newMaxConnections == maxConnections + && newConnectionAcquisitionTimeoutMillis == connectionAcquisitionTimeoutMillis + && newMaxPendingConnectionAcquires == maxPendingConnectionAcquires && maxRetries == newMaxRetries && newThrottleRetries == throttleRetries && Objects.equals(credentials, newCredentials) @@ -336,6 +435,12 @@ S3ClientSettings refine(Settings repositorySettings) { newEndpoint, newProtocol, newReadTimeoutMillis, + newRequestTimeoutMillis, + newConnectionTimeoutMillis, + newConnectionTTLMillis, + newMaxConnections, + newConnectionAcquisitionTimeoutMillis, + newMaxPendingConnectionAcquires, newMaxRetries, newThrottleRetries, newPathStyleAccess, @@ -461,6 +566,12 @@ static S3ClientSettings getClientSettings(final Settings settings, final String getConfigValue(settings, clientName, ENDPOINT_SETTING), awsProtocol, Math.toIntExact(getConfigValue(settings, clientName, READ_TIMEOUT_SETTING).millis()), + Math.toIntExact(getConfigValue(settings, clientName, REQUEST_TIMEOUT_SETTING).millis()), + Math.toIntExact(getConfigValue(settings, clientName, CONNECTION_TIMEOUT_SETTING).millis()), + Math.toIntExact(getConfigValue(settings, clientName, CONNECTION_TTL_SETTING).millis()), + Math.toIntExact(getConfigValue(settings, clientName, MAX_CONNECTIONS_SETTING)), + Math.toIntExact(getConfigValue(settings, clientName, CONNECTION_ACQUISITION_TIMEOUT).millis()), + Math.toIntExact(getConfigValue(settings, clientName, MAX_PENDING_CONNECTION_ACQUIRES)), getConfigValue(settings, clientName, MAX_RETRIES_SETTING), getConfigValue(settings, clientName, USE_THROTTLE_RETRIES_SETTING), getConfigValue(settings, clientName, USE_PATH_STYLE_ACCESS), @@ -530,6 +641,12 @@ public boolean equals(final Object o) { } final S3ClientSettings that = (S3ClientSettings) o; return readTimeoutMillis == that.readTimeoutMillis + && requestTimeoutMillis == that.requestTimeoutMillis + && connectionTimeoutMillis == that.connectionTimeoutMillis + && connectionTTLMillis == that.connectionTTLMillis + && maxConnections == that.maxConnections + && connectionAcquisitionTimeoutMillis == that.connectionAcquisitionTimeoutMillis + && maxPendingConnectionAcquires == that.maxPendingConnectionAcquires && maxRetries == that.maxRetries && throttleRetries == that.throttleRetries && Objects.equals(credentials, that.credentials) @@ -550,6 +667,12 @@ public int hashCode() { protocol, proxySettings, readTimeoutMillis, + requestTimeoutMillis, + connectionTimeoutMillis, + connectionTTLMillis, + maxConnections, + connectionAcquisitionTimeoutMillis, + maxPendingConnectionAcquires, maxRetries, throttleRetries, disableChunkedEncoding, diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/SocketAccess.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/SocketAccess.java index 0a6408764aeeb..4888764dbc720 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/SocketAccess.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/SocketAccess.java @@ -46,7 +46,7 @@ * {@link SocketPermission} 'connect' to establish connections. This class wraps the operations requiring access in * {@link AccessController#doPrivileged(PrivilegedAction)} blocks. */ -final class SocketAccess { +public final class SocketAccess { private SocketAccess() {} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncExecutorBuilder.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncExecutorBuilder.java new file mode 100644 index 0000000000000..0e2989ab6e747 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncExecutorBuilder.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import java.util.concurrent.ExecutorService; + +/** + * An encapsulation for the {@link TransferNIOGroup}, and the stream reader and future completion executor services + */ +public class AsyncExecutorBuilder { + + private final ExecutorService futureCompletionExecutor; + private final ExecutorService streamReader; + private final TransferNIOGroup transferNIOGroup; + + /** + * Construct a new AsyncExecutorBuilder object + * + * @param futureCompletionExecutor An {@link ExecutorService} to pass to {@link software.amazon.awssdk.services.s3.S3AsyncClient} for future completion + * @param streamReader An {@link ExecutorService} to read streams for upload + * @param transferNIOGroup A {@link TransferNIOGroup} which encapsulates the netty {@link io.netty.channel.EventLoopGroup} for async uploads + */ + public AsyncExecutorBuilder(ExecutorService futureCompletionExecutor, ExecutorService streamReader, TransferNIOGroup transferNIOGroup) { + this.transferNIOGroup = transferNIOGroup; + this.streamReader = streamReader; + this.futureCompletionExecutor = futureCompletionExecutor; + } + + public ExecutorService getFutureCompletionExecutor() { + return futureCompletionExecutor; + } + + public TransferNIOGroup getTransferNIOGroup() { + return transferNIOGroup; + } + + public ExecutorService getStreamReader() { + return streamReader; + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncUploadUtils.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncUploadUtils.java new file mode 100644 index 0000000000000..f368de08fb27b --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncUploadUtils.java @@ -0,0 +1,398 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.unit.ByteSizeUnit; +import org.opensearch.common.util.ByteUtils; +import org.opensearch.repositories.s3.SocketAccess; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +/** + * A helper class that automatically uses multipart upload based on the size of the source object + */ +public final class AsyncUploadUtils { + private static final Logger log = LogManager.getLogger(AsyncUploadUtils.class); + private final ExecutorService executorService; + private final ExecutorService priorityExecutorService; + private final long minimumPartSize; + + /** + * The max number of parts on S3 side is 10,000 + */ + private static final long MAX_UPLOAD_PARTS = 10_000; + + /** + * Construct a new object of AsyncUploadUtils + * + * @param minimumPartSize The minimum part size for parallel multipart uploads + * @param executorService The stream reader {@link ExecutorService} for normal priority uploads + * @param priorityExecutorService the stream read {@link ExecutorService} for high priority uploads + */ + public AsyncUploadUtils(long minimumPartSize, ExecutorService executorService, ExecutorService priorityExecutorService) { + this.executorService = executorService; + this.priorityExecutorService = priorityExecutorService; + this.minimumPartSize = minimumPartSize; + } + + /** + * Upload an object to S3 using the async client + * + * @param s3AsyncClient The {@link S3AsyncClient} to use for uploads + * @param uploadRequest The {@link UploadRequest} object encapsulating all relevant details for upload + * @param streamContext The {@link StreamContext} to supply streams during upload + * @return A {@link CompletableFuture} to listen for upload completion + */ + public CompletableFuture uploadObject(S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, StreamContext streamContext) { + + CompletableFuture returnFuture = new CompletableFuture<>(); + try { + if (streamContext.getNumberOfParts() == 1) { + log.debug(() -> "Starting the upload as a single upload part request"); + uploadInOneChunk(s3AsyncClient, uploadRequest, streamContext.provideStream(0), returnFuture); + } else { + log.debug(() -> "Starting the upload as multipart upload request"); + uploadInParts(s3AsyncClient, uploadRequest, streamContext, returnFuture); + } + } catch (Throwable throwable) { + returnFuture.completeExceptionally(throwable); + } + + return returnFuture; + } + + private void uploadInParts( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + StreamContext streamContext, + CompletableFuture returnFuture + ) { + + CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()); + if (uploadRequest.doRemoteDataIntegrityCheck()) { + createMultipartUploadRequestBuilder.checksumAlgorithm(ChecksumAlgorithm.CRC32); + } + CompletableFuture createMultipartUploadFuture = SocketAccess.doPrivileged( + () -> s3AsyncClient.createMultipartUpload(createMultipartUploadRequestBuilder.build()) + ); + + // Ensure cancellations are forwarded to the createMultipartUploadFuture future + CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); + + createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { + if (throwable != null) { + handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); + } else { + log.debug(() -> "Initiated new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); + doUploadInParts(s3AsyncClient, uploadRequest, streamContext, returnFuture, createMultipartUploadResponse.uploadId()); + } + }); + } + + private void doUploadInParts( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + StreamContext streamContext, + CompletableFuture returnFuture, + String uploadId + ) { + + // The list of completed parts must be sorted + AtomicReferenceArray completedParts = new AtomicReferenceArray<>(streamContext.getNumberOfParts()); + + List> futures; + try { + futures = sendUploadPartRequests(s3AsyncClient, uploadRequest, streamContext, uploadId, completedParts); + } catch (Exception ex) { + try { + cleanUpParts(s3AsyncClient, uploadRequest, uploadId); + } finally { + returnFuture.completeExceptionally(ex); + } + return; + } + + CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(CompletableFuture[]::new)).thenApply(resp -> { + uploadRequest.getUploadFinalizer().accept(true); + return resp; + }) + .thenCompose(ignore -> completeMultipartUpload(s3AsyncClient, uploadRequest, uploadId, completedParts)) + .handle(handleExceptionOrResponse(s3AsyncClient, uploadRequest, returnFuture, uploadId)) + .exceptionally(throwable -> { + handleException(returnFuture, () -> "Unexpected exception occurred", throwable); + return null; + }); + } + + private BiFunction handleExceptionOrResponse( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + CompletableFuture returnFuture, + String uploadId + ) { + + return (response, throwable) -> { + if (throwable != null) { + cleanUpParts(s3AsyncClient, uploadRequest, uploadId); + handleException(returnFuture, () -> "Failed to send multipart upload requests.", throwable); + } else { + returnFuture.complete(null); + } + + return null; + }; + } + + private CompletableFuture completeMultipartUpload( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + String uploadId, + AtomicReferenceArray completedParts + ) { + + log.debug(() -> new ParameterizedMessage("Sending completeMultipartUploadRequest, uploadId: {}", uploadId)); + CompletedPart[] parts = IntStream.range(0, completedParts.length()).mapToObj(completedParts::get).toArray(CompletedPart[]::new); + CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .uploadId(uploadId) + .multipartUpload(CompletedMultipartUpload.builder().parts(parts).build()) + .build(); + + return SocketAccess.doPrivileged(() -> s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest)); + } + + private void cleanUpParts(S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, String uploadId) { + + AbortMultipartUploadRequest abortMultipartUploadRequest = AbortMultipartUploadRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .uploadId(uploadId) + .build(); + SocketAccess.doPrivileged(() -> s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest).exceptionally(throwable -> { + log.warn( + () -> new ParameterizedMessage( + "Failed to abort previous multipart upload " + + "(id: {})" + + ". You may need to call " + + "S3AsyncClient#abortMultiPartUpload to " + + "free all storage consumed by" + + " all parts. ", + uploadId + ), + throwable + ); + return null; + })); + } + + private static void handleException(CompletableFuture returnFuture, Supplier message, Throwable throwable) { + Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; + + if (cause instanceof Error) { + returnFuture.completeExceptionally(cause); + } else { + SdkClientException exception = SdkClientException.create(message.get(), cause); + returnFuture.completeExceptionally(exception); + } + } + + private List> sendUploadPartRequests( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + StreamContext streamContext, + String uploadId, + AtomicReferenceArray completedParts + ) throws IOException { + List> futures = new ArrayList<>(); + for (int partIdx = 0; partIdx < streamContext.getNumberOfParts(); partIdx++) { + InputStreamContainer inputStreamContainer = streamContext.provideStream(partIdx); + UploadPartRequest.Builder uploadPartRequestBuilder = UploadPartRequest.builder() + .bucket(uploadRequest.getBucket()) + .partNumber(partIdx + 1) + .key(uploadRequest.getKey()) + .uploadId(uploadId) + .contentLength(inputStreamContainer.getContentLength()); + if (uploadRequest.doRemoteDataIntegrityCheck()) { + uploadPartRequestBuilder.checksumAlgorithm(ChecksumAlgorithm.CRC32); + } + sendIndividualUploadPart( + s3AsyncClient, + completedParts, + futures, + uploadPartRequestBuilder.build(), + inputStreamContainer, + uploadRequest + ); + } + + return futures; + } + + private void sendIndividualUploadPart( + S3AsyncClient s3AsyncClient, + AtomicReferenceArray completedParts, + List> futures, + UploadPartRequest uploadPartRequest, + InputStreamContainer inputStreamContainer, + UploadRequest uploadRequest + ) { + Integer partNumber = uploadPartRequest.partNumber(); + + ExecutorService streamReadExecutor = uploadRequest.getWritePriority() == WritePriority.HIGH + ? priorityExecutorService + : executorService; + CompletableFuture uploadPartResponseFuture = SocketAccess.doPrivileged( + () -> s3AsyncClient.uploadPart( + uploadPartRequest, + AsyncRequestBody.fromInputStream( + inputStreamContainer.getInputStream(), + inputStreamContainer.getContentLength(), + streamReadExecutor + ) + ) + ); + + CompletableFuture convertFuture = uploadPartResponseFuture.thenApply( + uploadPartResponse -> convertUploadPartResponse( + completedParts, + uploadPartResponse, + partNumber, + uploadRequest.doRemoteDataIntegrityCheck() + ) + ); + futures.add(convertFuture); + + CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartResponseFuture); + } + + private CompletedPart convertUploadPartResponse( + AtomicReferenceArray completedParts, + UploadPartResponse partResponse, + int partNumber, + boolean isRemoteDataIntegrityCheckEnabled + ) { + CompletedPart.Builder completedPartBuilder = CompletedPart.builder().eTag(partResponse.eTag()).partNumber(partNumber); + if (isRemoteDataIntegrityCheckEnabled) { + completedPartBuilder.checksumCRC32(partResponse.checksumCRC32()); + } + CompletedPart completedPart = completedPartBuilder.build(); + completedParts.set(partNumber - 1, completedPart); + return completedPart; + } + + /** + * Calculates the optimal part size of each part request if the upload operation is carried out as multipart upload. + */ + public long calculateOptimalPartSize(long contentLengthOfSource) { + if (contentLengthOfSource < ByteSizeUnit.MB.toBytes(5)) { + return contentLengthOfSource; + } + double optimalPartSize = contentLengthOfSource / (double) MAX_UPLOAD_PARTS; + optimalPartSize = Math.ceil(optimalPartSize); + return (long) Math.max(optimalPartSize, minimumPartSize); + } + + private void uploadInOneChunk( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + InputStreamContainer inputStreamContainer, + CompletableFuture returnFuture + ) { + PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .contentLength(uploadRequest.getContentLength()); + if (uploadRequest.doRemoteDataIntegrityCheck()) { + putObjectRequestBuilder.checksumAlgorithm(ChecksumAlgorithm.CRC32); + putObjectRequestBuilder.checksumCRC32( + Base64.getEncoder().encodeToString(Arrays.copyOfRange(ByteUtils.toByteArrayBE(uploadRequest.getExpectedChecksum()), 4, 8)) + ); + } + ExecutorService streamReadExecutor = uploadRequest.getWritePriority() == WritePriority.HIGH + ? priorityExecutorService + : executorService; + CompletableFuture putObjectFuture = SocketAccess.doPrivileged( + () -> s3AsyncClient.putObject( + putObjectRequestBuilder.build(), + AsyncRequestBody.fromInputStream( + inputStreamContainer.getInputStream(), + inputStreamContainer.getContentLength(), + streamReadExecutor + ) + ).handle((resp, throwable) -> { + if (throwable != null) { + returnFuture.completeExceptionally(throwable); + } else { + uploadRequest.getUploadFinalizer().accept(true); + returnFuture.complete(null); + } + + return null; + }).handle((resp, throwable) -> { + if (throwable != null) { + deleteUploadedObject(s3AsyncClient, uploadRequest); + returnFuture.completeExceptionally(throwable); + } + + return null; + }) + ); + + CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectFuture); + CompletableFutureUtils.forwardResultTo(putObjectFuture, returnFuture); + } + + private void deleteUploadedObject(S3AsyncClient s3AsyncClient, UploadRequest uploadRequest) { + DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .build(); + + SocketAccess.doPrivileged(() -> s3AsyncClient.deleteObject(deleteObjectRequest)).exceptionally(throwable -> { + log.warn("Failed to delete uploaded object", throwable); + return null; + }); + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/OpenSearchThreadFactory.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/OpenSearchThreadFactory.java new file mode 100644 index 0000000000000..aba8231d291cf --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/OpenSearchThreadFactory.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * OpenSearchThreadFactory is an extension of {@link ThreadFactory} + * that overrides new thread creation and sets them as daemon threads + */ +public class OpenSearchThreadFactory implements ThreadFactory { + + final ThreadGroup group; + final AtomicInteger threadNumber = new AtomicInteger(1); + final String namePrefix; + + OpenSearchThreadFactory(String namePrefix) { + this.namePrefix = namePrefix; + SecurityManager s = System.getSecurityManager(); + group = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup(); + } + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(group, r, namePrefix + "[T#" + threadNumber.getAndIncrement() + "]", 0); + t.setDaemon(true); + return t; + } + +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/TransferNIOGroup.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/TransferNIOGroup.java new file mode 100644 index 0000000000000..10859e6b155f5 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/TransferNIOGroup.java @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.util.concurrent.Future; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.repositories.s3.SocketAccess; + +import java.io.Closeable; +import java.util.concurrent.TimeUnit; + +/** + * TransferNIOGroup is an encapsulation for netty {@link EventLoopGroup} + */ +public class TransferNIOGroup implements Closeable { + private static final String THREAD_PREFIX = "aws-async-transfer-nio"; + private final Logger logger = LogManager.getLogger(TransferNIOGroup.class); + + private final EventLoopGroup eventLoopGroup; + + /** + * Construct a new TransferNIOGroup + * + * @param eventLoopThreads The number of event loop threads for this event loop group + */ + public TransferNIOGroup(int eventLoopThreads) { + // Epoll event loop incurs less GC and provides better performance than Nio loop. Therefore, + // using epoll wherever available is preferred. + this.eventLoopGroup = SocketAccess.doPrivileged( + () -> Epoll.isAvailable() + ? new EpollEventLoopGroup(eventLoopThreads, new OpenSearchThreadFactory(THREAD_PREFIX)) + : new NioEventLoopGroup(eventLoopThreads, new OpenSearchThreadFactory(THREAD_PREFIX)) + ); + } + + public EventLoopGroup getEventLoopGroup() { + return eventLoopGroup; + } + + @Override + public void close() { + Future shutdownFuture = eventLoopGroup.shutdownGracefully(0, 5, TimeUnit.SECONDS); + shutdownFuture.awaitUninterruptibly(); + if (!shutdownFuture.isSuccess()) { + logger.warn(new ParameterizedMessage("Error closing {} netty event loop group", THREAD_PREFIX), shutdownFuture.cause()); + } + } + +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java new file mode 100644 index 0000000000000..fb3001dbd78a5 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.blobstore.transfer.UploadFinalizer; + +/** + * A model encapsulating all details for an upload to S3 + */ +public class UploadRequest { + private final String bucket; + private final String key; + private final long contentLength; + private final WritePriority writePriority; + private final UploadFinalizer uploadFinalizer; + private final boolean doRemoteDataIntegrityCheck; + private final Long expectedChecksum; + + /** + * Construct a new UploadRequest object + * + * @param bucket The name of the S3 bucket + * @param key Key of the file needed to be uploaded + * @param contentLength Total content length of the file for upload + * @param writePriority The priority of this upload + * @param uploadFinalizer An upload finalizer to call once all parts are uploaded + * @param doRemoteDataIntegrityCheck A boolean to inform vendor plugins whether remote data integrity checks need to be done + * @param expectedChecksum Checksum of the file being uploaded for remote data integrity check + */ + public UploadRequest( + String bucket, + String key, + long contentLength, + WritePriority writePriority, + UploadFinalizer uploadFinalizer, + boolean doRemoteDataIntegrityCheck, + Long expectedChecksum + ) { + this.bucket = bucket; + this.key = key; + this.contentLength = contentLength; + this.writePriority = writePriority; + this.uploadFinalizer = uploadFinalizer; + this.doRemoteDataIntegrityCheck = doRemoteDataIntegrityCheck; + this.expectedChecksum = expectedChecksum; + } + + public String getBucket() { + return bucket; + } + + public String getKey() { + return key; + } + + public long getContentLength() { + return contentLength; + } + + public WritePriority getWritePriority() { + return writePriority; + } + + public UploadFinalizer getUploadFinalizer() { + return uploadFinalizer; + } + + public boolean doRemoteDataIntegrityCheck() { + return doRemoteDataIntegrityCheck; + } + + public Long getExpectedChecksum() { + return expectedChecksum; + } +} diff --git a/plugins/repository-s3/src/main/plugin-metadata/plugin-security.policy b/plugins/repository-s3/src/main/plugin-metadata/plugin-security.policy index f6c154bb3b14d..106103d45e7eb 100644 --- a/plugins/repository-s3/src/main/plugin-metadata/plugin-security.policy +++ b/plugins/repository-s3/src/main/plugin-metadata/plugin-security.policy @@ -35,6 +35,7 @@ grant { // TODO: get these fixed in aws sdk permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.lang.RuntimePermission "getClassLoader"; + permission java.lang.RuntimePermission "setContextClassLoader"; // Needed because of problems in AmazonS3Client: // When no region is set on a AmazonS3Client instance, the // AWS SDK loads all known partitions from a JSON file and @@ -56,4 +57,10 @@ grant { // only for tests : org.opensearch.repositories.s3.S3RepositoryPlugin permission java.util.PropertyPermission "opensearch.allow_insecure_settings", "read,write"; + permission java.util.PropertyPermission "aws.sharedCredentialsFile", "read,write"; + permission java.util.PropertyPermission "aws.configFile", "read,write"; + permission java.util.PropertyPermission "opensearch.path.conf", "read,write"; + permission java.io.FilePermission "config", "read"; + + permission java.lang.RuntimePermission "accessDeclaredMembers"; }; diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3AsyncServiceTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3AsyncServiceTests.java new file mode 100644 index 0000000000000..5b7362b85055a --- /dev/null +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3AsyncServiceTests.java @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import org.junit.Before; +import org.opensearch.cli.SuppressForbidden; +import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.settings.MockSecureSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.repositories.s3.async.AsyncExecutorBuilder; +import org.opensearch.repositories.s3.async.TransferNIOGroup; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; +import java.util.concurrent.Executors; + +public class S3AsyncServiceTests extends OpenSearchTestCase implements ConfigPathSupport { + + @Override + @Before + @SuppressForbidden(reason = "Need to set opensearch.path.conf for async client") + public void setUp() throws Exception { + SocketAccess.doPrivileged(() -> System.setProperty("opensearch.path.conf", configPath().toString())); + super.setUp(); + } + + public void testCachedClientsAreReleased() { + final S3AsyncService s3AsyncService = new S3AsyncService(configPath()); + final Settings settings = Settings.builder().put("endpoint", "http://first").build(); + final RepositoryMetadata metadata1 = new RepositoryMetadata("first", "s3", settings); + final RepositoryMetadata metadata2 = new RepositoryMetadata("second", "s3", settings); + final AsyncExecutorBuilder asyncExecutorBuilder = new AsyncExecutorBuilder( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + new TransferNIOGroup(1) + ); + final S3ClientSettings clientSettings = s3AsyncService.settings(metadata2); + final S3ClientSettings otherClientSettings = s3AsyncService.settings(metadata2); + assertSame(clientSettings, otherClientSettings); + final AmazonAsyncS3Reference reference = SocketAccess.doPrivileged( + () -> s3AsyncService.client(metadata1, asyncExecutorBuilder, asyncExecutorBuilder) + ); + reference.close(); + s3AsyncService.close(); + final AmazonAsyncS3Reference referenceReloaded = SocketAccess.doPrivileged( + () -> s3AsyncService.client(metadata1, asyncExecutorBuilder, asyncExecutorBuilder) + ); + assertNotSame(referenceReloaded, reference); + referenceReloaded.close(); + s3AsyncService.close(); + final S3ClientSettings clientSettingsReloaded = s3AsyncService.settings(metadata1); + assertNotSame(clientSettings, clientSettingsReloaded); + } + + public void testCachedClientsWithCredentialsAreReleased() { + final MockSecureSettings secureSettings = new MockSecureSettings(); + secureSettings.setString("s3.client.default.role_arn", "role"); + final Map defaults = S3ClientSettings.load( + Settings.builder().setSecureSettings(secureSettings).put("s3.client.default.identity_token_file", "file").build(), + configPath() + ); + final S3AsyncService s3AsyncService = new S3AsyncService(configPath()); + s3AsyncService.refreshAndClearCache(defaults); + final Settings settings = Settings.builder().put("endpoint", "http://first").put("region", "us-east-2").build(); + final RepositoryMetadata metadata1 = new RepositoryMetadata("first", "s3", settings); + final RepositoryMetadata metadata2 = new RepositoryMetadata("second", "s3", settings); + final AsyncExecutorBuilder asyncExecutorBuilder = new AsyncExecutorBuilder( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + new TransferNIOGroup(1) + ); + final S3ClientSettings clientSettings = s3AsyncService.settings(metadata2); + final S3ClientSettings otherClientSettings = s3AsyncService.settings(metadata2); + assertSame(clientSettings, otherClientSettings); + final AmazonAsyncS3Reference reference = SocketAccess.doPrivileged( + () -> s3AsyncService.client(metadata1, asyncExecutorBuilder, asyncExecutorBuilder) + ); + reference.close(); + s3AsyncService.close(); + final AmazonAsyncS3Reference referenceReloaded = SocketAccess.doPrivileged( + () -> s3AsyncService.client(metadata1, asyncExecutorBuilder, asyncExecutorBuilder) + ); + assertNotSame(referenceReloaded, reference); + referenceReloaded.close(); + s3AsyncService.close(); + final S3ClientSettings clientSettingsReloaded = s3AsyncService.settings(metadata1); + assertNotSame(clientSettings, clientSettingsReloaded); + } +}