Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce CopyOnReadInputStream with CachedBlobContainer #59872

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.blobstore.cache;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.blobstore.BlobContainer;
import org.elasticsearch.common.blobstore.support.FilterBlobContainer;
import org.elasticsearch.common.bytes.PagedBytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.util.ByteArray;

import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;

public class CachedBlobContainer extends FilterBlobContainer {

protected static final int DEFAULT_BYTE_ARRAY_SIZE = 1 << 14;

public CachedBlobContainer(BlobContainer delegate) {
super(delegate);
}

@Override
protected BlobContainer wrapChild(BlobContainer child) {
return new CachedBlobContainer(child);
}

/**
* A {@link FilterInputStream} that copies over all the bytes read from the original input stream to a given {@link ByteArray}. The
* number of bytes copied cannot exceed the size of the {@link ByteArray}.
*/
static class CopyOnReadInputStream extends FilterInputStream {

private final ActionListener<ReleasableBytesReference> listener;
private final AtomicBoolean closed;
private final ByteArray bytes;

private IOException failure;
private long count;
private long mark;

protected CopyOnReadInputStream(InputStream in, ByteArray byteArray, ActionListener<ReleasableBytesReference> listener) {
super(in);
this.listener = Objects.requireNonNull(listener);
this.bytes = Objects.requireNonNull(byteArray);
this.closed = new AtomicBoolean(false);
}

private <T> T handleFailure(CheckedSupplier<T, IOException> supplier) throws IOException {
try {
return supplier.get();
} catch (IOException e) {
assert failure == null;
failure = e;
throw e;
}
}

public int read() throws IOException {
final int result = handleFailure(super::read);
if (result != -1) {
if (count < bytes.size()) {
bytes.set(count, (byte) result);
}
count++;
}
return result;
}

public int read(byte[] b, int off, int len) throws IOException {
final int result = handleFailure(() -> super.read(b, off, len));
if (result != -1) {
if (count < bytes.size()) {
bytes.set(count, b, off, Math.toIntExact(Math.min(bytes.size() - count, result)));
}
count += result;
}
return result;
}

@Override
public long skip(long n) throws IOException {
final long skip = handleFailure(() -> super.skip(n));
if (skip > 0L) {
count += skip;
}
return skip;
}

@Override
public synchronized void mark(int readlimit) {
super.mark(readlimit);
mark = count;
}

@Override
public synchronized void reset() throws IOException {
handleFailure(() -> {
super.reset();
return null;
});
count = mark;
}

@Override
public final void close() throws IOException {
if (closed.compareAndSet(false, true)) {
boolean success = false;
try {
super.close();
if (failure == null || bytes.size() <= count) {
PagedBytesReference reference = new PagedBytesReference(bytes, Math.toIntExact(Math.min(count, bytes.size())));
listener.onResponse(new ReleasableBytesReference(reference, bytes));
success = true;
} else {
listener.onFailure(failure);
}
} finally {
if (success == false) {
bytes.close();
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.blobstore.cache;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.blobstore.cache.CachedBlobContainer.CopyOnReadInputStream;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.ByteArray;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.test.ESTestCase;

import java.io.ByteArrayInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;

import static org.elasticsearch.blobstore.cache.CachedBlobContainer.DEFAULT_BYTE_ARRAY_SIZE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

public class CachedBlobContainerTests extends ESTestCase {

private final MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());

public void testCopyOnReadInputStreamDoesNotCopyMoreThanByteArraySize() throws Exception {
final SetOnce<ReleasableBytesReference> onSuccess = new SetOnce<>();
final SetOnce<Exception> onFailure = new SetOnce<>();
final ActionListener<ReleasableBytesReference> listener = ActionListener.wrap(onSuccess::set, onFailure::set);

final byte[] blobContent = randomByteArray();

final ByteArray byteArray = bigArrays.newByteArray(randomIntBetween(0, DEFAULT_BYTE_ARRAY_SIZE));
final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(blobContent), byteArray, listener);
randomReads(stream, blobContent.length);
stream.close();

final ReleasableBytesReference releasable = onSuccess.get();
assertThat(releasable, notNullValue());
assertThat(releasable.length(), equalTo(Math.toIntExact(Math.min(blobContent.length, byteArray.size()))));
assertArrayEquals(Arrays.copyOfRange(blobContent, 0, releasable.length()), BytesReference.toBytes(releasable));
releasable.close();

final Exception failure = onFailure.get();
assertThat(failure, nullValue());
}

public void testCopyOnReadInputStream() throws Exception {
final SetOnce<ReleasableBytesReference> onSuccess = new SetOnce<>();
final SetOnce<Exception> onFailure = new SetOnce<>();
final ActionListener<ReleasableBytesReference> listener = ActionListener.wrap(onSuccess::set, onFailure::set);

final byte[] blobContent = randomByteArray();
final ByteArray byteArray = bigArrays.newByteArray(DEFAULT_BYTE_ARRAY_SIZE);

final int maxBytesToRead = randomIntBetween(0, blobContent.length);
final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(blobContent), byteArray, listener);
randomReads(stream, maxBytesToRead);
stream.close();

final ReleasableBytesReference releasable = onSuccess.get();
assertThat(releasable, notNullValue());
assertThat(releasable.length(), equalTo((int) Math.min(maxBytesToRead, byteArray.size())));
assertArrayEquals(Arrays.copyOfRange(blobContent, 0, releasable.length()), BytesReference.toBytes(releasable));
releasable.close();

final Exception failure = onFailure.get();
assertThat(failure, nullValue());
}

public void testCopyOnReadWithFailure() throws Exception {
final SetOnce<ReleasableBytesReference> onSuccess = new SetOnce<>();
final SetOnce<Exception> onFailure = new SetOnce<>();
final ActionListener<ReleasableBytesReference> listener = ActionListener.wrap(onSuccess::set, onFailure::set);

final byte[] blobContent = new byte[0];
randomByteArray();

// InputStream that throws an IOException once byte at position N is read/skipped
final int failAfterNBytesRead = randomIntBetween(0, Math.max(0, blobContent.length - 1));
final InputStream erroneousStream = new FilterInputStream(new ByteArrayInputStream(blobContent)) {

long bytesRead;
long mark;

void canReadMoreBytes() throws IOException {
if (failAfterNBytesRead <= bytesRead) {
throw new IOException("Cannot read more bytes");
}
}

@Override
public int read() throws IOException {
canReadMoreBytes();
final int read = super.read();
if (read != -1) {
bytesRead++;
}
return read;
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
canReadMoreBytes();
final int read = super.read(b, off, Math.min(len, Math.toIntExact(failAfterNBytesRead - bytesRead)));
if (read != -1) {
bytesRead += read;
}
return read;
}

@Override
public long skip(long n) throws IOException {
canReadMoreBytes();
final long skipped = super.skip(Math.min(n, Math.toIntExact(failAfterNBytesRead - bytesRead)));
if (skipped > 0L) {
bytesRead += skipped;
}
return skipped;
}

@Override
public synchronized void reset() throws IOException {
super.reset();
bytesRead = mark;
}

@Override
public synchronized void mark(int readlimit) {
super.mark(readlimit);
mark = bytesRead;
}
};

final int byteSize = randomIntBetween(0, DEFAULT_BYTE_ARRAY_SIZE);
try (InputStream stream = new CopyOnReadInputStream(erroneousStream, bigArrays.newByteArray(byteSize), listener)) {
IOException exception = expectThrows(IOException.class, () -> randomReads(stream, Math.max(1, blobContent.length)));
assertThat(exception.getMessage(), containsString("Cannot read more bytes"));
}

if (failAfterNBytesRead < byteSize) {
final Exception failure = onFailure.get();
assertThat(failure, notNullValue());
assertThat(failure.getMessage(), containsString("Cannot read more bytes"));
assertThat(onSuccess.get(), nullValue());

} else {
final ReleasableBytesReference releasable = onSuccess.get();
assertThat(releasable, notNullValue());
assertArrayEquals(Arrays.copyOfRange(blobContent, 0, byteSize), BytesReference.toBytes(releasable));
assertThat(onFailure.get(), nullValue());
releasable.close();
}
}

private static byte[] randomByteArray() {
return randomByteArrayOfLength(randomIntBetween(0, frequently() ? DEFAULT_BYTE_ARRAY_SIZE : 1 << 20)); // rarely up to 1mb;
}

private void randomReads(final InputStream stream, final int maxBytesToRead) throws IOException {
int remaining = maxBytesToRead;
while (remaining > 0) {
int read;
switch (randomInt(3)) {
case 0: // single byte read
read = stream.read();
if (read != -1) {
remaining--;
}
break;
case 1: // buffered read with fixed buffer offset/length
read = stream.read(new byte[randomIntBetween(1, remaining)]);
if (read != -1) {
remaining -= read;
}
break;
case 2: // buffered read with random buffer offset/length
final byte[] tmp = new byte[randomIntBetween(1, remaining)];
final int off = randomIntBetween(0, tmp.length - 1);
read = stream.read(tmp, off, randomIntBetween(1, Math.min(1, tmp.length - off)));
if (read != -1) {
remaining -= read;
}
break;

case 3: // mark & reset with intermediate skip()
final int toSkip = randomIntBetween(1, remaining);
stream.mark(toSkip);
stream.skip(toSkip);
stream.reset();
break;
default:
fail("Unsupported test condition in " + getTestName());
}
}
}
}