Skip to content

Commit

Permalink
remote: Proactively close the ZstdInputStream in ZstdDecompressingOut…
Browse files Browse the repository at this point in the history
…putStream.

ZstdInputStream hangs onto some native memory, which should be released as soon as ZstdDecompressingOutputStream is done being used rather than when the finalizer runs.

Closes #15061.

PiperOrigin-RevId: 438521302
  • Loading branch information
benjaminp authored and copybara-github committed Mar 31, 2022
1 parent 5b95286 commit 299022c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 60 deletions.
1 change: 0 additions & 1 deletion src/main/java/com/google/devtools/build/lib/remote/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/vfs:pathfragment",
"//src/main/java/com/google/devtools/common/options",
"//src/main/protobuf:failure_details_java_proto",
"//third_party:apache_commons_compress",
"//third_party:auth",
"//third_party:caffeine",
"//third_party:flogger",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.flogger.GoogleLogger;
import com.google.common.io.CountingOutputStream;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
Expand Down Expand Up @@ -67,10 +68,8 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.compress.utils.CountingOutputStream;

/** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
@ThreadSafe
Expand Down Expand Up @@ -303,7 +302,7 @@ public ListenableFuture<Void> uploadActionResult(
public ListenableFuture<Void> downloadBlob(
RemoteActionExecutionContext context, Digest digest, OutputStream out) {
if (digest.getSizeBytes() == 0) {
return Futures.immediateFuture(null);
return Futures.immediateVoidFuture();
}

@Nullable Supplier<Digest> digestSupplier = null;
Expand All @@ -313,26 +312,14 @@ public ListenableFuture<Void> downloadBlob(
out = digestOut;
}

CountingOutputStream outputStream;
if (options.cacheCompression) {
try {
outputStream = new ZstdDecompressingOutputStream(out);
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
} else {
outputStream = new CountingOutputStream(out);
}

return downloadBlob(context, digest, outputStream, digestSupplier);
return downloadBlob(context, digest, new CountingOutputStream(out), digestSupplier);
}

private ListenableFuture<Void> downloadBlob(
RemoteActionExecutionContext context,
Digest digest,
CountingOutputStream out,
@Nullable Supplier<Digest> digestSupplier) {
AtomicLong offset = new AtomicLong(0);
ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
ListenableFuture<Long> downloadFuture =
Utils.refreshIfUnauthenticatedAsync(
Expand All @@ -343,7 +330,6 @@ private ListenableFuture<Void> downloadBlob(
channel ->
requestRead(
context,
offset,
progressiveBackoff,
digest,
out,
Expand All @@ -370,20 +356,25 @@ public static String getResourceName(String instanceName, Digest digest, boolean

private ListenableFuture<Long> requestRead(
RemoteActionExecutionContext context,
AtomicLong offset,
ProgressiveBackoff progressiveBackoff,
Digest digest,
CountingOutputStream out,
CountingOutputStream rawOut,
@Nullable Supplier<Digest> digestSupplier,
Channel channel) {
String resourceName =
getResourceName(options.remoteInstanceName, digest, options.cacheCompression);
SettableFuture<Long> future = SettableFuture.create();
OutputStream out;
try {
out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut;
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
.setReadOffset(offset.get())
.setReadOffset(rawOut.getCount())
.build(),
new StreamObserver<ReadResponse>() {

Expand All @@ -392,7 +383,6 @@ public void onNext(ReadResponse readResponse) {
ByteString data = readResponse.getData();
try {
data.writeTo(out);
offset.set(out.getBytesWritten());
} catch (IOException e) {
// Cancel the call.
throw new RuntimeException(e);
Expand All @@ -403,14 +393,15 @@ public void onNext(ReadResponse readResponse) {

@Override
public void onError(Throwable t) {
if (offset.get() == digest.getSizeBytes()) {
if (rawOut.getCount() == digest.getSizeBytes()) {
// If the file was fully downloaded, it doesn't matter if there was an error at
// the end of the stream.
logger.atInfo().withCause(t).log(
"ignoring error because file was fully received");
onCompleted();
return;
}
releaseOut();
Status status = Status.fromThrowable(t);
if (status.getCode() == Status.Code.NOT_FOUND) {
future.setException(new CacheNotFoundException(digest));
Expand All @@ -426,12 +417,24 @@ public void onCompleted() {
Utils.verifyBlobContents(digest, digestSupplier.get());
}
out.flush();
future.set(offset.get());
future.set(rawOut.getCount());
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
logger.atWarning().withCause(e).log("Unexpected exception");
future.setException(e);
} finally {
releaseOut();
}
}

private void releaseOut() {
if (out instanceof ZstdDecompressingOutputStream) {
try {
((ZstdDecompressingOutputStream) out).closeShallow();
} catch (IOException e) {
logger.atWarning().withCause(e).log("failed to cleanly close output stream");
}
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ java_library(
name = "zstd",
srcs = glob(["*.java"]),
deps = [
"//third_party:apache_commons_compress",
"//third_party:guava",
"//third_party/protobuf:protobuf_java",
"@zstd-jni",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,35 @@
// limitations under the License.
package com.google.devtools.build.lib.remote.zstd;

import com.github.luben.zstd.ZstdInputStream;
import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
import com.google.protobuf.ByteString;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.apache.commons.compress.utils.CountingOutputStream;

/** A {@link CountingOutputStream} that use zstd to decompress the content. */
public class ZstdDecompressingOutputStream extends CountingOutputStream {
/** An {@link OutputStream} that use zstd to decompress the content. */
public final class ZstdDecompressingOutputStream extends OutputStream {
private final OutputStream out;
private ByteArrayInputStream inner;
private final ZstdInputStream zis;
private final ZstdInputStreamNoFinalizer zis;

public ZstdDecompressingOutputStream(OutputStream out) throws IOException {
super(out);
this.out = out;
zis =
new ZstdInputStream(
new InputStream() {
@Override
public int read() {
return inner.read();
}

@Override
public int read(byte[] b, int off, int len) {
return inner.read(b, off, len);
}
});
zis.setContinuous(true);
new ZstdInputStreamNoFinalizer(
new InputStream() {
@Override
public int read() {
return inner.read();
}

@Override
public int read(byte[] b, int off, int len) {
return inner.read(b, off, len);
}
})
.setContinuous(true);
}

@Override
Expand All @@ -58,6 +58,19 @@ public void write(byte[] b) throws IOException {
public void write(byte[] b, int off, int len) throws IOException {
inner = new ByteArrayInputStream(b, off, len);
byte[] data = ByteString.readFrom(zis).toByteArray();
super.write(data, 0, data.length);
out.write(data, 0, data.length);
}

@Override
public void close() throws IOException {
closeShallow();
out.close();
}

/**
* Free resources related to decompression without closing the underlying {@link OutputStream}.
*/
public void closeShallow() throws IOException {
zis.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.mockito.ArgumentMatchers.any;

import build.bazel.remote.execution.v2.Digest;
import com.github.luben.zstd.Zstd;
import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase;
import com.google.bytestream.ByteStreamProto.ReadRequest;
import com.google.bytestream.ByteStreamProto.ReadResponse;
import com.google.devtools.build.lib.remote.Retrier.Backoff;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.common.options.Options;
import com.google.protobuf.ByteString;
Expand All @@ -31,38 +29,50 @@
import java.io.IOException;
import java.util.Arrays;
import org.junit.Test;
import org.mockito.Mockito;

/** Extra tests for {@link GrpcCacheClient} that are not tested internally. */
public class GrpcCacheClientTestExtra extends GrpcCacheClientTest {

@Test
public void compressedDownloadBlobIsRetriedWithProgress()
throws IOException, InterruptedException {
Backoff mockBackoff = Mockito.mock(Backoff.class);
RemoteOptions options = Options.getDefaults(RemoteOptions.class);
options.cacheCompression = true;
final GrpcCacheClient client = newClient(options, () -> mockBackoff);
final GrpcCacheClient client = newClient(options);
final Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg");
ByteString blob = ByteString.copyFrom(Zstd.compress("abcdefg".getBytes(UTF_8)));
ByteString chunk1 = ByteString.copyFrom(Zstd.compress("abc".getBytes(UTF_8)));
ByteString chunk2 = ByteString.copyFrom(Zstd.compress("def".getBytes(UTF_8)));
ByteString chunk3 = ByteString.copyFrom(Zstd.compress("g".getBytes(UTF_8)));
serviceRegistry.addService(
new ByteStreamImplBase() {
private boolean first = true;

@Override
public void read(ReadRequest request, StreamObserver<ReadResponse> responseObserver) {
assertThat(request.getResourceName().contains(digest.getHash())).isTrue();
int off = (int) request.getReadOffset();
// Zstd header size is 9 bytes
ByteString data = off == 0 ? blob.substring(0, 9 + 1) : blob.substring(9 + off);
responseObserver.onNext(ReadResponse.newBuilder().setData(data).build());
if (off == 0) {
if (first) {
first = false;
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
} else {
responseObserver.onCompleted();
return;
}
switch (Math.toIntExact(request.getReadOffset())) {
case 0:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk1).build());
break;
case 3:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk2).build());
break;
case 6:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk3).build());
responseObserver.onCompleted();
return;
default:
throw new IllegalStateException("unexpected offset " + request.getReadOffset());
}
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
}
});
assertThat(new String(downloadBlob(context, client, digest), UTF_8)).isEqualTo("abcdefg");
Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public void bytesWrittenMatchesDecompressedBytes() throws IOException {
for (byte b : compressed.toByteArray()) {
zdos.write(b);
zdos.flush();
assertThat(zdos.getBytesWritten()).isEqualTo(decompressed.toByteArray().length);
}
assertThat(decompressed.toByteArray()).isEqualTo(data);
}
Expand Down

0 comments on commit 299022c

Please sign in to comment.