Skip to content

Commit

Permalink
Replace multipart download with parallel file download
Browse files Browse the repository at this point in the history
There are a few open issues with the multi-stream download approach:
 - Recovery stats are not being reported correctly
 - It is incompatible (short of reopening and re-reading the entire
   file) with the existing Lucene checksum validation logic
 - There are some issues with integrating it with the pending client
   side encryption work

Given this, I attempted an experiment where I replaced with
multi-stream-within-a-single-file approach with simply parallelizing
downloads across files (this is how snapshot restore works). I actually
got better results with this approach: recovering a ~52GiB shard took
about 4.7 minutes with the multi-stream code versus 3.9 minutes with the
parallel file approach (r7g.4xlarge EC2 instance, 500MiB/s EBS volume,
S3 as remote repository).

I think this is the right approach as it leverages the more
battle-tested code path and addresses the three issues listed above. The
multi-stream approach still has promise as it will allow us to download
very large files faster (whereas this approach they can be the long poll
on the transfer operation). However, given that 5GB segments (made up of
multiple files in practice) are the norm, we generally aren't dealing with
huge files.

Signed-off-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
andrross committed Oct 10, 2023
1 parent 8bb11a6 commit 5959138
Show file tree
Hide file tree
Showing 34 changed files with 330 additions and 1,497 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,12 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.SetOnce;
import org.opensearch.common.StreamContext;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer;
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.common.blobstore.BlobMetadata;
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.common.blobstore.BlobStoreException;
import org.opensearch.common.blobstore.DeleteResult;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.common.blobstore.stream.write.WritePriority;
import org.opensearch.common.blobstore.support.AbstractBlobContainer;
Expand Down Expand Up @@ -222,52 +220,6 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp
}
}

@ExperimentalApi
@Override
public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
try (AmazonAsyncS3Reference amazonS3Reference = SocketAccess.doPrivileged(blobStore::asyncClientReference)) {
final S3AsyncClient s3AsyncClient = amazonS3Reference.get().client();
final String bucketName = blobStore.bucket();
final String blobKey = buildKey(blobName);

final CompletableFuture<GetObjectAttributesResponse> blobMetadataFuture = getBlobMetadata(s3AsyncClient, bucketName, blobKey);

blobMetadataFuture.whenComplete((blobMetadata, throwable) -> {
if (throwable != null) {
Exception ex = throwable.getCause() instanceof Exception
? (Exception) throwable.getCause()
: new Exception(throwable.getCause());
listener.onFailure(ex);
return;
}

try {
final List<ReadContext.StreamPartCreator> blobPartInputStreamFutures = new ArrayList<>();
final long blobSize = blobMetadata.objectSize();
final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum() == null ? null : blobMetadata.checksum().checksumCRC32();

if (numberOfParts == null) {
blobPartInputStreamFutures.add(() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
} else {
// S3 multipart files use 1 to n indexing
for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) {
final int innerPartNumber = partNumber;
blobPartInputStreamFutures.add(
() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, innerPartNumber)
);
}
}
listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum));
} catch (Exception ex) {
listener.onFailure(ex);
}
});
} catch (Exception ex) {
listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex));
}
}

public boolean remoteIntegrityCheckSupported() {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,11 @@
import software.amazon.awssdk.services.s3.model.UploadPartResponse;
import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable;

import org.opensearch.action.LatchedActionListener;
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.common.blobstore.BlobMetadata;
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.common.blobstore.BlobStoreException;
import org.opensearch.common.blobstore.DeleteResult;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
Expand All @@ -98,7 +96,6 @@
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -917,208 +914,6 @@ public void testListBlobsByPrefixInLexicographicOrderWithLimitGreaterThanNumberO
testListBlobsByPrefixInLexicographicOrder(12, 2, BlobContainer.BlobNameSortOrder.LEXICOGRAPHIC);
}

public void testReadBlobAsyncMultiPart() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);

final long objectSize = 100L;
final int objectPartCount = 10;
final int partSize = 10;

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);

final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
GetObjectAttributesResponse.builder()
.checksum(Checksum.builder().checksumCRC32(checksum).build())
.objectSize(objectSize)
.objectParts(GetObjectAttributesParts.builder().totalPartsCount(objectPartCount).build())
.build()
);
when(s3AsyncClient.getObjectAttributes(any(GetObjectAttributesRequest.class))).thenReturn(
getObjectAttributesResponseCompletableFuture
);

mockObjectPartResponse(s3AsyncClient, bucketName, blobName, objectPartCount, partSize, objectSize);

CountDownLatch countDownLatch = new CountDownLatch(1);
CountingCompletionListener<ReadContext> readContextActionListener = new CountingCompletionListener<>();
LatchedActionListener<ReadContext> listener = new LatchedActionListener<>(readContextActionListener, countDownLatch);

final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore);
blobContainer.readBlobAsync(blobName, listener);
countDownLatch.await();

assertEquals(1, readContextActionListener.getResponseCount());
assertEquals(0, readContextActionListener.getFailureCount());
ReadContext readContext = readContextActionListener.getResponse();
assertEquals(objectPartCount, readContext.getNumberOfParts());
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

for (int partNumber = 1; partNumber < objectPartCount; partNumber++) {
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get().join();
final int offset = partNumber * partSize;
assertEquals(partSize, inputStreamContainer.getContentLength());
assertEquals(offset, inputStreamContainer.getOffset());
assertEquals(partSize, inputStreamContainer.getInputStream().readAllBytes().length);
}
}

public void testReadBlobAsyncSinglePart() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);

final int objectSize = 100;

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);
final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
GetObjectAttributesResponse.builder()
.checksum(Checksum.builder().checksumCRC32(checksum).build())
.objectSize((long) objectSize)
.build()
);
when(s3AsyncClient.getObjectAttributes(any(GetObjectAttributesRequest.class))).thenReturn(
getObjectAttributesResponseCompletableFuture
);

mockObjectResponse(s3AsyncClient, bucketName, blobName, objectSize);

CountDownLatch countDownLatch = new CountDownLatch(1);
CountingCompletionListener<ReadContext> readContextActionListener = new CountingCompletionListener<>();
LatchedActionListener<ReadContext> listener = new LatchedActionListener<>(readContextActionListener, countDownLatch);

final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore);
blobContainer.readBlobAsync(blobName, listener);
countDownLatch.await();

assertEquals(1, readContextActionListener.getResponseCount());
assertEquals(0, readContextActionListener.getFailureCount());
ReadContext readContext = readContextActionListener.getResponse();
assertEquals(1, readContext.getNumberOfParts());
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get().join();
assertEquals(objectSize, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length);

}

public void testReadBlobAsyncFailure() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);

final long objectSize = 100L;
final int objectPartCount = 10;

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);

final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
GetObjectAttributesResponse.builder()
.checksum(Checksum.builder().checksumCRC32(checksum).build())
.objectSize(objectSize)
.objectParts(GetObjectAttributesParts.builder().totalPartsCount(objectPartCount).build())
.build()
);
when(s3AsyncClient.getObjectAttributes(any(GetObjectAttributesRequest.class))).thenThrow(new RuntimeException());

CountDownLatch countDownLatch = new CountDownLatch(1);
CountingCompletionListener<ReadContext> readContextActionListener = new CountingCompletionListener<>();
LatchedActionListener<ReadContext> listener = new LatchedActionListener<>(readContextActionListener, countDownLatch);

final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore);
blobContainer.readBlobAsync(blobName, listener);
countDownLatch.await();

assertEquals(0, readContextActionListener.getResponseCount());
assertEquals(1, readContextActionListener.getFailureCount());
}

public void testReadBlobAsyncOnCompleteFailureMissingData() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);

final long objectSize = 100L;
final int objectPartCount = 10;

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);

final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
GetObjectAttributesResponse.builder()
.checksum(Checksum.builder().build())
.objectSize(null)
.objectParts(GetObjectAttributesParts.builder().totalPartsCount(objectPartCount).build())
.build()
);
when(s3AsyncClient.getObjectAttributes(any(GetObjectAttributesRequest.class))).thenReturn(
getObjectAttributesResponseCompletableFuture
);

CountDownLatch countDownLatch = new CountDownLatch(1);
CountingCompletionListener<ReadContext> readContextActionListener = new CountingCompletionListener<>();
LatchedActionListener<ReadContext> listener = new LatchedActionListener<>(readContextActionListener, countDownLatch);

final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore);
blobContainer.readBlobAsync(blobName, listener);
countDownLatch.await();

assertEquals(0, readContextActionListener.getResponseCount());
assertEquals(1, readContextActionListener.getFailureCount());
}

public void testGetBlobMetadata() throws Exception {
final String checksum = randomAlphaOfLengthBetween(1, 10);
final long objectSize = 100L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.common.blobstore.fs.FsBlobContainer;
import org.opensearch.common.blobstore.fs.FsBlobStore;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
Expand All @@ -25,9 +24,6 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -118,27 +114,6 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp

}

@Override
public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
new Thread(() -> {
try {
long contentLength = listBlobs().get(blobName).length();
long partSize = contentLength / 10;
int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1);
List<ReadContext.StreamPartCreator> blobPartStreams = new ArrayList<>();
for (int partNumber = 0; partNumber < numberOfParts; partNumber++) {
long offset = partNumber * partSize;
InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset);
blobPartStreams.add(() -> CompletableFuture.completedFuture(blobPartStream));
}
ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null);
listener.onResponse(blobReadContext);
} catch (Exception e) {
listener.onFailure(e);
}
}).start();
}

public boolean remoteIntegrityCheckSupported() {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

package org.opensearch.common.blobstore;

import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.core.action.ActionListener;

Expand All @@ -34,14 +32,6 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer {
*/
void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> completionListener) throws IOException;

/**
* Creates an async callback of a {@link ReadContext} containing the multipart streams for a specified blob within the container.
* @param blobName The name of the blob for which the {@link ReadContext} needs to be fetched.
* @param listener Async listener for {@link ReadContext} object which serves the input streams and other metadata for the blob
*/
@ExperimentalApi
void readBlobAsync(String blobName, ActionListener<ReadContext> listener);

/*
* Wether underlying blobContainer can verify integrity of data after transfer. If true and if expected
* checksum is provided in WriteContext, then the checksum of transferred data is compared with expected checksum
Expand Down
Loading

0 comments on commit 5959138

Please sign in to comment.