diff --git a/CHANGELOG.md b/CHANGELOG.md index d4a060cc16504..80ab0fb87e609 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -97,6 +97,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add concurrent segment search related metrics to node and index stats ([#9622](https://github.com/opensearch-project/OpenSearch/issues/9622)) - Decouple replication lag from logic to fail stale replicas ([#9507](https://github.com/opensearch-project/OpenSearch/pull/9507)) - Expose DelimitedTermFrequencyTokenFilter to allow providing term frequencies along with terms ([#9479](https://github.com/opensearch-project/OpenSearch/pull/9479)) +- APIs for performing async blob reads and async downloads from the repository using multiple streams ([#9592](https://github.com/opensearch-project/OpenSearch/issues/9592)) ### Dependencies - Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307)) diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index a97a509adce47..183b5f8fe7ac1 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -69,6 +69,7 @@ import org.opensearch.common.blobstore.BlobStoreException; import org.opensearch.common.blobstore.DeleteResult; import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.support.AbstractBlobContainer; @@ -211,6 +212,11 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp } } + @Override + public void readBlobAsync(String blobName, ActionListener listener) { + throw new UnsupportedOperationException(); + } + // package private for testing long getLargeBlobThresholdInBytes() { return blobStore.bufferSizeInBytes(); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index 2438acaf7c1f2..1c4936cae7eba 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -61,6 +61,7 @@ import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.blobstore.BlobContainer; import org.opensearch.common.blobstore.BlobMetadata; import org.opensearch.common.blobstore.BlobPath; @@ -881,6 +882,17 @@ public void onFailure(Exception e) {} } } + public void testAsyncBlobDownload() { + final S3BlobStore blobStore = mock(S3BlobStore.class); + final BlobPath blobPath = mock(BlobPath.class); + final String blobName = "test-blob"; + + final UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, () -> { + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + blobContainer.readBlobAsync(blobName, new PlainActionFuture<>()); + }); + } + public void testListBlobsByPrefixInLexicographicOrderWithNegativeLimit() throws IOException { testListBlobsByPrefixInLexicographicOrder(-5, 0, BlobContainer.BlobNameSortOrder.LEXICOGRAPHIC); } diff --git a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java index d882220c9f4d7..887a4cc6ba9a8 100644 --- a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java +++ b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java @@ -14,6 +14,7 @@ import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; import org.opensearch.common.blobstore.fs.FsBlobContainer; import org.opensearch.common.blobstore.fs.FsBlobStore; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.core.action.ActionListener; @@ -24,6 +25,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -114,6 +117,27 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp } + @Override + public void readBlobAsync(String blobName, ActionListener listener) { + new Thread(() -> { + try { + long contentLength = listBlobs().get(blobName).length(); + long partSize = contentLength / 10; + int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); + List blobPartStreams = new ArrayList<>(); + for (int partNumber = 0; partNumber < numberOfParts; partNumber++) { + long offset = partNumber * partSize; + InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset); + blobPartStreams.add(blobPartStream); + } + ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null); + listener.onResponse(blobReadContext); + } catch (Exception e) { + listener.onFailure(e); + } + }).start(); + } + private boolean isSegmentFile(String filename) { return !filename.endsWith(".tlog") && !filename.endsWith(".ckp"); } diff --git a/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java index d10445ba14d76..1764c9e634781 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java @@ -8,10 +8,15 @@ package org.opensearch.common.blobstore; +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.nio.file.Path; /** * An extension of {@link BlobContainer} that adds {@link VerifyingMultiStreamBlobContainer#asyncBlobUpload} to allow @@ -31,4 +36,25 @@ public interface VerifyingMultiStreamBlobContainer extends BlobContainer { * @throws IOException if any of the input streams could not be read, or the target blob could not be written to */ void asyncBlobUpload(WriteContext writeContext, ActionListener completionListener) throws IOException; + + /** + * Creates an async callback of a {@link ReadContext} containing the multipart streams for a specified blob within the container. + * @param blobName The name of the blob for which the {@link ReadContext} needs to be fetched. + * @param listener Async listener for {@link ReadContext} object which serves the input streams and other metadata for the blob + */ + @ExperimentalApi + void readBlobAsync(String blobName, ActionListener listener); + + /** + * Asynchronously downloads the blob to the specified location using an executor from the thread pool. + * @param blobName The name of the blob for which needs to be downloaded. + * @param fileLocation The path on local disk where the blob needs to be downloaded. + * @param threadPool The threadpool instance which will provide the executor for performing a multipart download. + * @param completionListener Listener which will be notified when the download is complete. + */ + @ExperimentalApi + default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { + ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener); + readBlobAsync(blobName, readContextListener); + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java new file mode 100644 index 0000000000000..4ba17959f8040 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read; + +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.io.InputStreamContainer; + +import java.util.List; + +/** + * ReadContext is used to encapsulate all data needed by BlobContainer#readBlobAsync + */ +@ExperimentalApi +public class ReadContext { + private final long blobSize; + private final List partStreams; + private final String blobChecksum; + + public ReadContext(long blobSize, List partStreams, String blobChecksum) { + this.blobSize = blobSize; + this.partStreams = partStreams; + this.blobChecksum = blobChecksum; + } + + public String getBlobChecksum() { + return blobChecksum; + } + + public int getNumberOfParts() { + return partStreams.size(); + } + + public long getBlobSize() { + return blobSize; + } + + public List getPartStreams() { + return partStreams; + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java new file mode 100644 index 0000000000000..aadd6e2ab304e --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.common.annotation.InternalApi; +import org.opensearch.core.action.ActionListener; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * FileCompletionListener listens for completion of fetch on all the streams for a file, where + * individual streams are handled using {@link FilePartWriter}. The {@link FilePartWriter}(s) + * hold a reference to the file completion listener to be notified. + */ +@InternalApi +class FileCompletionListener implements ActionListener { + + private final int numberOfParts; + private final String fileName; + private final AtomicInteger completedPartsCount; + private final ActionListener completionListener; + + public FileCompletionListener(int numberOfParts, String fileName, ActionListener completionListener) { + this.completedPartsCount = new AtomicInteger(); + this.numberOfParts = numberOfParts; + this.fileName = fileName; + this.completionListener = completionListener; + } + + @Override + public void onResponse(Integer unused) { + if (completedPartsCount.incrementAndGet() == numberOfParts) { + completionListener.onResponse(fileName); + } + } + + @Override + public void onFailure(Exception e) { + completionListener.onFailure(e); + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java new file mode 100644 index 0000000000000..84fd7ed9ffebf --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.annotation.InternalApi; +import org.opensearch.common.io.Channels; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel} + * instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion. + */ +@InternalApi +class FilePartWriter implements Runnable { + + private final int partNumber; + private final InputStreamContainer blobPartStreamContainer; + private final Path fileLocation; + private final AtomicBoolean anyPartStreamFailed; + private final ActionListener fileCompletionListener; + private static final Logger logger = LogManager.getLogger(FilePartWriter.class); + + // 8 MB buffer for transfer + private static final int BUFFER_SIZE = 8 * 1024 * 2024; + + public FilePartWriter( + int partNumber, + InputStreamContainer blobPartStreamContainer, + Path fileLocation, + AtomicBoolean anyPartStreamFailed, + ActionListener fileCompletionListener + ) { + this.partNumber = partNumber; + this.blobPartStreamContainer = blobPartStreamContainer; + this.fileLocation = fileLocation; + this.anyPartStreamFailed = anyPartStreamFailed; + this.fileCompletionListener = fileCompletionListener; + } + + @Override + public void run() { + // Ensures no writes to the file if any stream fails. + if (anyPartStreamFailed.get() == false) { + try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { + try (InputStream inputStream = blobPartStreamContainer.getInputStream()) { + long streamOffset = blobPartStreamContainer.getOffset(); + final byte[] buffer = new byte[BUFFER_SIZE]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset); + streamOffset += bytesRead; + } + } + } catch (IOException e) { + processFailure(e); + return; + } + fileCompletionListener.onResponse(partNumber); + } + } + + void processFailure(Exception e) { + try { + Files.deleteIfExists(fileLocation); + } catch (IOException ex) { + // Die silently + logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex); + } + if (anyPartStreamFailed.getAndSet(true) == false) { + fileCompletionListener.onFailure(e); + } + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java new file mode 100644 index 0000000000000..4338bddb3fbe7 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.annotation.InternalApi; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; + +import java.nio.file.Path; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer} + * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams which are + * spread across a {@link ThreadPool} executor. + */ +@InternalApi +public class ReadContextListener implements ActionListener { + + private final String fileName; + private final Path fileLocation; + private final ThreadPool threadPool; + private final ActionListener completionListener; + private static final Logger logger = LogManager.getLogger(ReadContextListener.class); + + public ReadContextListener(String fileName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { + this.fileName = fileName; + this.fileLocation = fileLocation; + this.threadPool = threadPool; + this.completionListener = completionListener; + } + + @Override + public void onResponse(ReadContext readContext) { + logger.trace("Streams received for blob {}", fileName); + final int numParts = readContext.getNumberOfParts(); + final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener); + + for (int partNumber = 0; partNumber < numParts; partNumber++) { + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + readContext.getPartStreams().get(partNumber), + fileLocation, + anyPartStreamFailed, + fileCompletionListener + ); + threadPool.executor(ThreadPool.Names.GENERIC).submit(filePartWriter); + } + } + + @Override + public void onFailure(Exception e) { + completionListener.onFailure(e); + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java new file mode 100644 index 0000000000000..fe670fe3eb25c --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java @@ -0,0 +1,14 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Abstractions for stream based file reads from the blob store. + * Provides listeners for performing the necessary async read operations to perform + * multi stream reads for blobs from the container. + * */ +package org.opensearch.common.blobstore.stream.read.listener; diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java new file mode 100644 index 0000000000000..a9e2ca35c1fa6 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java @@ -0,0 +1,13 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Abstractions for stream based file reads from the blob store. + * Provides support for async reads from the blob container. + * */ +package org.opensearch.common.blobstore.stream.read; diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java new file mode 100644 index 0000000000000..fa13d90f42fa6 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; + +public class FileCompletionListenerTests extends OpenSearchTestCase { + + public void testFileCompletionListener() { + int numStreams = 10; + String fileName = "test_segment_file"; + CountingCompletionListener completionListener = new CountingCompletionListener(); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); + + for (int stream = 0; stream < numStreams; stream++) { + // Ensure completion listener called only when all streams are completed + assertEquals(0, completionListener.getResponseCount()); + fileCompletionListener.onResponse(null); + } + + assertEquals(1, completionListener.getResponseCount()); + assertEquals(fileName, completionListener.getResponse()); + } + + public void testFileCompletionListenerFailure() { + int numStreams = 10; + String fileName = "test_segment_file"; + CountingCompletionListener completionListener = new CountingCompletionListener(); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); + + // Fail the listener initially + IOException exception = new IOException(); + fileCompletionListener.onFailure(exception); + + for (int stream = 0; stream < numStreams - 1; stream++) { + assertEquals(0, completionListener.getResponseCount()); + fileCompletionListener.onResponse(null); + } + + assertEquals(1, completionListener.getFailureCount()); + assertEquals(exception, completionListener.getException()); + assertEquals(0, completionListener.getResponseCount()); + + fileCompletionListener.onFailure(exception); + assertEquals(2, completionListener.getFailureCount()); + assertEquals(exception, completionListener.getException()); + } +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java new file mode 100644 index 0000000000000..811566eb5767b --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java @@ -0,0 +1,163 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.Before; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; + +public class FilePartWriterTests extends OpenSearchTestCase { + + private Path path; + + @Before + public void init() throws Exception { + path = createTempDir("FilePartWriterTests"); + } + + public void testFilePartWriter() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + filePartWriter.run(); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength, Files.size(segmentFilePath)); + assertEquals(1, fileCompletionListener.getResponseCount()); + assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); + } + + public void testFilePartWriterWithOffset() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int offset = 10; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), offset); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + filePartWriter.run(); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength + offset, Files.size(segmentFilePath)); + assertEquals(1, fileCompletionListener.getResponseCount()); + assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); + } + + public void testFilePartWriterLargeInput() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 20 * 1024 * 1024; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + filePartWriter.run(); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength, Files.size(segmentFilePath)); + + assertEquals(1, fileCompletionListener.getResponseCount()); + assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); + } + + public void testFilePartWriterException() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + IOException ioException = new IOException(); + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + assertFalse(anyStreamFailed.get()); + filePartWriter.processFailure(ioException); + + assertTrue(anyStreamFailed.get()); + assertFalse(Files.exists(segmentFilePath)); + + // Fail stream again to simulate another stream failure for same file + filePartWriter.processFailure(ioException); + + assertTrue(anyStreamFailed.get()); + assertFalse(Files.exists(segmentFilePath)); + + assertEquals(0, fileCompletionListener.getResponseCount()); + assertEquals(1, fileCompletionListener.getFailureCount()); + assertEquals(ioException, fileCompletionListener.getException()); + } + + public void testFilePartWriterStreamFailed() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(true); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + filePartWriter.run(); + + assertFalse(Files.exists(segmentFilePath)); + assertEquals(0, fileCompletionListener.getResponseCount()); + } +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java new file mode 100644 index 0000000000000..1e9450c83e3ab --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.core.action.ActionListener; + +/** + * Utility class containing common functionality for read listener based tests + */ +public class ListenerTestUtils { + + /** + * CountingCompletionListener acts as a verification instance for wrapping listener based calls. + * Keeps track of the last response, failure and count of response and failure invocations. + */ + static class CountingCompletionListener implements ActionListener { + private int responseCount; + private int failureCount; + private T response; + private Exception exception; + + @Override + public void onResponse(T response) { + this.response = response; + responseCount++; + } + + @Override + public void onFailure(Exception e) { + exception = e; + failureCount++; + } + + public int getResponseCount() { + return responseCount; + } + + public int getFailureCount() { + return failureCount; + } + + public T getResponse() { + return response; + } + + public Exception getException() { + return exception; + } + } +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java new file mode 100644 index 0000000000000..f785b5f1191b4 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; + +import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; + +public class ReadContextListenerTests extends OpenSearchTestCase { + + private Path path; + private static ThreadPool threadPool; + private static final int NUMBER_OF_PARTS = 5; + private static final int PART_SIZE = 10; + private static final String TEST_SEGMENT_FILE = "test_segment_file"; + + @BeforeClass + public static void setup() { + threadPool = new TestThreadPool(ReadContextListenerTests.class.getName()); + } + + @AfterClass + public static void cleanup() { + threadPool.shutdown(); + } + + @Before + public void init() throws Exception { + path = createTempDir("ReadContextListenerTests"); + } + + public void testReadContextListener() throws InterruptedException, IOException { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContext readContext = new ReadContext((long) PART_SIZE * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + + assertTrue(Files.exists(fileLocation)); + assertEquals(NUMBER_OF_PARTS * PART_SIZE, Files.size(fileLocation)); + } + + public void testReadContextListenerFailure() throws InterruptedException { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + InputStream badInputStream = new InputStream() { + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return read(); + } + + @Override + public int read() throws IOException { + throw new IOException(); + } + + @Override + public int available() { + return PART_SIZE; + } + }; + + blobPartStreams.add(NUMBER_OF_PARTS, new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS)); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + + assertFalse(Files.exists(fileLocation)); + } + + public void testReadContextListenerException() { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + CountingCompletionListener listener = new CountingCompletionListener(); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, listener); + IOException exception = new IOException(); + readContextListener.onFailure(exception); + assertEquals(1, listener.getFailureCount()); + assertEquals(exception, listener.getException()); + } + + private List initializeBlobPartStreams() { + List blobPartStreams = new ArrayList<>(); + for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) { + InputStream testStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)); + blobPartStreams.add(new InputStreamContainer(testStream, PART_SIZE, (long) partNumber * PART_SIZE)); + } + return blobPartStreams; + } +}