Skip to content

Commit

Permalink
Add async read support for S3 plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>
  • Loading branch information
kotwanikunal committed Sep 1, 2023
1 parent f85700b commit 03ddc8a
Show file tree
Hide file tree
Showing 6 changed files with 374 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import software.amazon.awssdk.services.s3.model.Delete;
import software.amazon.awssdk.services.s3.model.DeleteObjectsRequest;
import software.amazon.awssdk.services.s3.model.DeleteObjectsResponse;
import software.amazon.awssdk.services.s3.model.GetObjectAttributesResponse;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
Expand All @@ -63,6 +64,7 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.SetOnce;
import org.opensearch.common.StreamContext;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.common.blobstore.BlobMetadata;
import org.opensearch.common.blobstore.BlobPath;
Expand All @@ -75,10 +77,12 @@
import org.opensearch.common.blobstore.support.AbstractBlobContainer;
import org.opensearch.common.blobstore.support.PlainBlobMetadata;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.unit.ByteSizeUnit;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.repositories.s3.async.AsyncTransferManager;
import org.opensearch.repositories.s3.async.UploadRequest;

import java.io.ByteArrayInputStream;
Expand All @@ -91,6 +95,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -212,9 +217,49 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp
}
}

@ExperimentalApi
@Override
public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
throw new UnsupportedOperationException();
try (AmazonAsyncS3Reference amazonS3Reference = SocketAccess.doPrivileged(blobStore::asyncClientReference)) {
final S3AsyncClient s3AsyncClient = amazonS3Reference.get().client();
final String bucketName = blobStore.bucket();
final AsyncTransferManager transferManager = blobStore.getAsyncTransferManager();

final GetObjectAttributesResponse blobMetadata = transferManager.getBlobPartMetadata(blobName, bucketName, s3AsyncClient).get();

final long blobSize = blobMetadata.objectSize();
final int numberOfParts = blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum().checksumCRC32();

final List<InputStreamContainer> blobPartStreams = new ArrayList<>();
final List<CompletableFuture<InputStreamContainer>> blobPartInputStreamFutures = new ArrayList<>();
for (int partNumber = 0; partNumber < numberOfParts; partNumber++) {
int finalPartNumber = partNumber;
CompletableFuture<InputStreamContainer> partInputStreamFuture = transferManager.getBlobPartInputStreamContainer(
s3AsyncClient,
bucketName,
blobName,
partNumber
).whenComplete((inputStreamContainer, error) -> {
if (error == null) {
blobPartStreams.add(finalPartNumber, inputStreamContainer);
}
});

blobPartInputStreamFutures.add(partInputStreamFuture);
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new)).whenComplete((unused, throwable) -> {
if (throwable == null) {
listener.onResponse(new ReadContext(blobSize, blobPartStreams, blobChecksum));
} else {
Exception ex = throwable instanceof Error ? new Exception(throwable) : (Exception) throwable;
listener.onFailure(ex);
}
});
} catch (ExecutionException | InterruptedException ex) {
listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex));
}
}

// package private for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

package org.opensearch.repositories.s3.async;

import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.http.HttpStatusCode;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm;
Expand All @@ -20,6 +23,11 @@
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectAttributesRequest;
import software.amazon.awssdk.services.s3.model.GetObjectAttributesResponse;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.ObjectAttributes;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.utils.CompletableFutureUtils;
Expand All @@ -29,13 +37,16 @@
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.ExceptionsHelper;
import org.opensearch.common.StreamContext;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.blobstore.exception.CorruptFileException;
import org.opensearch.common.blobstore.stream.write.WritePriority;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.common.util.ByteUtils;
import org.opensearch.core.common.unit.ByteSizeUnit;
import org.opensearch.repositories.s3.SocketAccess;
import org.opensearch.repositories.s3.io.CheckedContainer;
import org.opensearch.repositories.s3.utils.HttpRangeUtils;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -353,4 +364,71 @@ private void deleteUploadedObject(S3AsyncClient s3AsyncClient, UploadRequest upl
return null;
});
}

/**
* Fetches a part of the blob from the S3 bucket and transforms it to an {@link InputStreamContainer}, which holds
* the stream and its related metadata.
* @param s3AsyncClient Async client to be utilized to fetch the object part
* @param bucketName Name of the S3 bucket
* @param blobName Identifier of the blob for which the parts will be fetched
* @param partNumber Part number for the blob to be retrieved
* @return A future of {@link InputStreamContainer} containing the stream and stream metadata.
*/
@ExperimentalApi
public CompletableFuture<InputStreamContainer> getBlobPartInputStreamContainer(
S3AsyncClient s3AsyncClient,
String bucketName,
String blobName,
int partNumber
) {
final GetObjectRequest.Builder getObjectRequestBuilder = GetObjectRequest.builder()
.bucket(bucketName)
.key(blobName)
.partNumber(partNumber);

return SocketAccess.doPrivileged(
() -> s3AsyncClient.getObject(getObjectRequestBuilder.build(), AsyncResponseTransformer.toBlockingInputStream())
.thenApply(this::transformResponseToInputStreamContainer)
);
}

/**
* Transforms the stream response object from S3 into an {@link InputStreamContainer}
* @param streamResponse Response stream object from S3
* @return {@link InputStreamContainer} containing the stream and stream metadata
*/
// Package-Private for testing.
InputStreamContainer transformResponseToInputStreamContainer(ResponseInputStream<GetObjectResponse> streamResponse) {
final GetObjectResponse getObjectResponse = streamResponse.response();
final String contentRange = getObjectResponse.contentRange();
final Long contentLength = getObjectResponse.contentLength();
if (contentRange == null || contentLength == null) {
throw SdkException.builder().message("Failed to fetch required metadata for blob part").build();
}
final Tuple<Long, Long> s3ResponseRange = HttpRangeUtils.fromHttpRangeHeader(getObjectResponse.contentRange());
return new InputStreamContainer(streamResponse, getObjectResponse.contentLength(), s3ResponseRange.v1());
}

/**
* Retrieves the metadata like checksum, object size and parts for the provided blob within the S3 bucket.
* @param blobName Identifier of the blob for which the metadata will be fetched
* @param bucketName Name of the S3 bucket
* @param s3AsyncClient Async client to be utilized to fetch the metadata
* @return A future containing the metadata within {@link GetObjectAttributesResponse}
*/
@ExperimentalApi
public CompletableFuture<GetObjectAttributesResponse> getBlobPartMetadata(
String blobName,
String bucketName,
S3AsyncClient s3AsyncClient
) {
// Fetch blob metadata - part info, size, checksum
final GetObjectAttributesRequest getObjectAttributesRequest = GetObjectAttributesRequest.builder()
.bucket(bucketName)
.key(blobName)
.objectAttributes(ObjectAttributes.CHECKSUM, ObjectAttributes.OBJECT_SIZE, ObjectAttributes.OBJECT_PARTS)
.build();

return SocketAccess.doPrivileged(() -> s3AsyncClient.getObjectAttributes(getObjectAttributesRequest));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,31 @@

package org.opensearch.repositories.s3.utils;

import software.amazon.awssdk.core.exception.SdkException;

import org.opensearch.common.collect.Tuple;

import java.util.regex.Matcher;
import java.util.regex.Pattern;

public final class HttpRangeUtils {
private static final Pattern RANGE_PATTERN = Pattern.compile("^bytes\\s+(\\d+)-(\\d+)/(\\d+|.*)$");

/**
* Parses the content range header string value to calculate the start and end of the stream
* Tests against the RFC9110 specification of content range string.
* Sample values: "bytes 0-10/200", "bytes 0-10/*"
* <a href="https://www.rfc-editor.org/rfc/rfc9110.html#name-content-range">Details here</a>
* @param headerValue Header content range string value from the HTTP response
* @return Pair of values where v1 represents the lower and v2 represents the upper bound of the stream
*/
public static Tuple<Long, Long> fromHttpRangeHeader(String headerValue) {
Matcher matcher = RANGE_PATTERN.matcher(headerValue);
if (!matcher.find()) {
throw SdkException.create("Regex match for Content-Range header {" + headerValue + "} failed", new RuntimeException());
}
return new Tuple<>(Long.parseLong(matcher.group(1)), Long.parseLong(matcher.group(2)));
}

/**
* Provides a byte range string per <a href="https://www.rfc-editor.org/rfc/rfc9110.html#name-byte-ranges">RFC 9110</a>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,27 @@

package org.opensearch.repositories.s3;

import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.AbortMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.Checksum;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.DeleteObjectsRequest;
import software.amazon.awssdk.services.s3.model.DeleteObjectsResponse;
import software.amazon.awssdk.services.s3.model.GetObjectAttributesParts;
import software.amazon.awssdk.services.s3.model.GetObjectAttributesRequest;
import software.amazon.awssdk.services.s3.model.GetObjectAttributesResponse;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
Expand All @@ -61,15 +70,18 @@
import software.amazon.awssdk.services.s3.model.UploadPartResponse;
import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable;

import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.common.blobstore.BlobMetadata;
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.common.blobstore.BlobStoreException;
import org.opensearch.common.blobstore.DeleteResult;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.unit.ByteSizeUnit;
import org.opensearch.repositories.s3.async.AsyncTransferManager;
import org.opensearch.test.OpenSearchTestCase;

import java.io.ByteArrayInputStream;
Expand All @@ -86,6 +98,9 @@
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand All @@ -94,6 +109,7 @@

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -882,15 +898,60 @@ public void onFailure(Exception e) {}
}
}

public void testAsyncBlobDownload() {
public void testAsyncBlobDownload() throws InterruptedException {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);

final long objectSize = 100L;
final int objectPartCount = 10;
final int partSize = 10;

final String contentRange = "bytes 0-10/100";

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
final AmazonAsyncS3Reference amazonAsyncS3Reference = mock(AmazonAsyncS3Reference.class);
final AsyncTransferManager asyncTransferManager = new AsyncTransferManager(
10000L,
mock(ExecutorService.class),
mock(ExecutorService.class)
);
final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = mock(BlobPath.class);
final String blobName = "test-blob";
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
AmazonAsyncS3Reference reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);
when(blobStore.asyncClientReference()).thenReturn(reference);

// when(amazonAsyncS3Reference.get()).thenReturn(s3AsyncClient);
when(blobStore.getAsyncTransferManager()).thenReturn(asyncTransferManager);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
GetObjectAttributesResponse.builder()
.checksum(Checksum.builder().checksumCRC32(checksum).build())
.objectSize(objectSize)
.objectParts(GetObjectAttributesParts.builder().totalPartsCount(objectPartCount).build())
.build()
);
when(s3AsyncClient.getObjectAttributes(any(GetObjectAttributesRequest.class))).thenReturn(
getObjectAttributesResponseCompletableFuture
);

mockObjectPartResponse(s3AsyncClient, bucketName, blobName, objectPartCount, partSize, objectSize);

CountDownLatch countDownLatch = new CountDownLatch(1);
ActionListener<ReadContext> readContextActionListener = new PlainActionFuture<>();
LatchedActionListener<ReadContext> listener = new LatchedActionListener<>(readContextActionListener, countDownLatch);

final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore);
blobContainer.readBlobAsync(blobName, listener);
countDownLatch.await();

final UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, () -> {
final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore);
blobContainer.readBlobAsync(blobName, new PlainActionFuture<>());
});
}

public void testListBlobsByPrefixInLexicographicOrderWithNegativeLimit() throws IOException {
Expand All @@ -912,4 +973,33 @@ public void testListBlobsByPrefixInLexicographicOrderWithLimitGreaterThanPageSiz
public void testListBlobsByPrefixInLexicographicOrderWithLimitGreaterThanNumberOfRecords() throws IOException {
testListBlobsByPrefixInLexicographicOrder(12, 2, BlobContainer.BlobNameSortOrder.LEXICOGRAPHIC);
}

private void mockObjectPartResponse(
S3AsyncClient s3AsyncClient,
String bucketName,
String blobName,
int totalNumberOfParts,
int partSize,
long objectSize
) {
for (int partNumber = 0; partNumber < totalNumberOfParts; partNumber++) {
final int start = partNumber * partSize;
final int end = (partNumber + 1) * partSize;
final String contentRange = "bytes " + start + "-" + end + "/" + objectSize;
final InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(partSize));

GetObjectResponse getObjectResponse = GetObjectResponse.builder()
.contentLength((long) partSize)
.contentRange(contentRange)
.build();

CompletableFuture<ResponseInputStream<GetObjectResponse>> getObjectPartResponse = new CompletableFuture<>();
ResponseInputStream<GetObjectResponse> responseInputStream = new ResponseInputStream<>(getObjectResponse, inputStream);
getObjectPartResponse.complete(responseInputStream);

GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(blobName).partNumber(partNumber).build();

when(s3AsyncClient.getObject(eq(getObjectRequest), any(AsyncResponseTransformer.class))).thenReturn(getObjectPartResponse);
}
}
}
Loading

0 comments on commit 03ddc8a

Please sign in to comment.