From 001737c71982c58589524d0dca8aa547132691b4 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Tue, 14 Dec 2021 06:02:04 -0800 Subject: [PATCH] Remote: Don't blocking-get when acquiring gRPC connections. With recent change to limit the max number of gRPC connections by default, acquiring a connection could suspend a thread if there is no available connection. gRPC calls are scheduled to a dedicated background thread pool. Workers in the thread pool are responsible to acquire the connection before starting the RPC call. There could be a race condition that a worker thread handles some gRPC calls and then switches to a new call which will acquire new connections. If the number of connections reaches the max, the worker thread is suspended and doesn't have a chance to switch to previous calls. The connections held by previous calls are, hence, never released. This PR changes to not use blocking get when acquiring gRPC connections. Fixes #14363. Closes #14416. PiperOrigin-RevId: 416282883 --- .../google/devtools/build/lib/remote/BUILD | 3 +- .../build/lib/remote/ByteStreamUploader.java | 30 ++-- .../ExperimentalGrpcRemoteExecutor.java | 44 ++++-- .../build/lib/remote/GrpcCacheClient.java | 63 ++++++--- .../build/lib/remote/GrpcRemoteExecutor.java | 15 +- .../lib/remote/ReferenceCountedChannel.java | 129 +++++------------- .../build/lib/remote/RemoteModule.java | 11 +- .../lib/remote/RemoteServerCapabilities.java | 9 +- .../build/lib/remote/UploadManifest.java | 7 +- .../downloader/GrpcRemoteDownloader.java | 11 +- .../build/lib/remote/util/RxFutures.java | 32 ++--- 11 files changed, 188 insertions(+), 166 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD index 1eaa3fdf618efe..a5745bf9b616b8 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD @@ -138,9 +138,10 @@ java_library( ], deps = [ "//src/main/java/com/google/devtools/build/lib/remote/grpc", + "//src/main/java/com/google/devtools/build/lib/remote/util", "//third_party:guava", - "//third_party:jsr305", "//third_party:netty", + "//third_party:rxjava3", "//third_party/grpc:grpc-jar", ], ) diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java index c488f14f397d07..cc31b5bf070705 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java +++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java @@ -24,6 +24,7 @@ import com.google.bytestream.ByteStreamGrpc; import com.google.bytestream.ByteStreamGrpc.ByteStreamFutureStub; import com.google.bytestream.ByteStreamProto.QueryWriteStatusRequest; +import com.google.bytestream.ByteStreamProto.QueryWriteStatusResponse; import com.google.bytestream.ByteStreamProto.WriteRequest; import com.google.bytestream.ByteStreamProto.WriteResponse; import com.google.common.annotations.VisibleForTesting; @@ -374,7 +375,7 @@ public ReferenceCounted touch(Object o) { private static class AsyncUpload { private final RemoteActionExecutionContext context; - private final Channel channel; + private final ReferenceCountedChannel channel; private final CallCredentialsProvider callCredentialsProvider; private final long callTimeoutSecs; private final Retrier retrier; @@ -385,7 +386,7 @@ private static class AsyncUpload { AsyncUpload( RemoteActionExecutionContext context, - Channel channel, + ReferenceCountedChannel channel, CallCredentialsProvider callCredentialsProvider, long callTimeoutSecs, Retrier retrier, @@ -452,7 +453,7 @@ ListenableFuture start() { MoreExecutors.directExecutor()); } - private ByteStreamFutureStub bsFutureStub() { + private ByteStreamFutureStub bsFutureStub(Channel channel) { return ByteStreamGrpc.newFutureStub(channel) .withInterceptors( TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata())) @@ -463,7 +464,10 @@ private ByteStreamFutureStub bsFutureStub() { private ListenableFuture callAndQueryOnFailure( AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) { return Futures.catchingAsync( - call(committedOffset), + Futures.transform( + channel.withChannelFuture(channel -> call(committedOffset, channel)), + written -> null, + MoreExecutors.directExecutor()), Exception.class, (e) -> guardQueryWithSuppression(e, committedOffset, progressiveBackoff), MoreExecutors.directExecutor()); @@ -500,10 +504,14 @@ private ListenableFuture query( AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) { ListenableFuture committedSizeFuture = Futures.transform( - bsFutureStub() - .queryWriteStatus( - QueryWriteStatusRequest.newBuilder().setResourceName(resourceName).build()), - (response) -> response.getCommittedSize(), + channel.withChannelFuture( + channel -> + bsFutureStub(channel) + .queryWriteStatus( + QueryWriteStatusRequest.newBuilder() + .setResourceName(resourceName) + .build())), + QueryWriteStatusResponse::getCommittedSize, MoreExecutors.directExecutor()); ListenableFuture guardedCommittedSizeFuture = Futures.catchingAsync( @@ -533,14 +541,14 @@ private ListenableFuture query( MoreExecutors.directExecutor()); } - private ListenableFuture call(AtomicLong committedOffset) { + private ListenableFuture call(AtomicLong committedOffset, Channel channel) { CallOptions callOptions = CallOptions.DEFAULT .withCallCredentials(callCredentialsProvider.getCallCredentials()) .withDeadlineAfter(callTimeoutSecs, SECONDS); call = channel.newCall(ByteStreamGrpc.getWriteMethod(), callOptions); - SettableFuture uploadResult = SettableFuture.create(); + SettableFuture uploadResult = SettableFuture.create(); ClientCall.Listener callListener = new ClientCall.Listener() { @@ -568,7 +576,7 @@ public void onMessage(WriteResponse response) { @Override public void onClose(Status status, Metadata trailers) { if (status.isOk()) { - uploadResult.set(null); + uploadResult.set(committedOffset.get()); } else { uploadResult.setException(status.asRuntimeException()); } diff --git a/src/main/java/com/google/devtools/build/lib/remote/ExperimentalGrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/ExperimentalGrpcRemoteExecutor.java index 41f5306624d29e..d50a77cd1c32b2 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/ExperimentalGrpcRemoteExecutor.java +++ b/src/main/java/com/google/devtools/build/lib/remote/ExperimentalGrpcRemoteExecutor.java @@ -35,12 +35,13 @@ import com.google.longrunning.Operation; import com.google.longrunning.Operation.ResultCase; import com.google.rpc.Status; +import io.grpc.Channel; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; +import io.reactivex.rxjava3.functions.Function; import java.io.IOException; import java.util.Iterator; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; import javax.annotation.Nullable; /** @@ -73,7 +74,7 @@ public ExperimentalGrpcRemoteExecutor( this.retrier = retrier; } - private ExecutionBlockingStub executionBlockingStub(RequestMetadata metadata) { + private ExecutionBlockingStub executionBlockingStub(RequestMetadata metadata, Channel channel) { return ExecutionGrpc.newBlockingStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataInterceptor(metadata)) .withCallCredentials(callCredentialsProvider.getCallCredentials()) @@ -90,7 +91,8 @@ private static class Execution { // Count retry times for WaitExecution() calls and is reset when we receive any response from // the server that is not an error. private final ProgressiveBackoff waitExecutionBackoff; - private final Supplier executionBlockingStubSupplier; + private final Function> executeFunction; + private final Function> waitExecutionFunction; // Last response (without error) we received from server. private Operation lastOperation; @@ -100,14 +102,16 @@ private static class Execution { OperationObserver observer, RemoteRetrier retrier, CallCredentialsProvider callCredentialsProvider, - Supplier executionBlockingStubSupplier) { + Function> executeFunction, + Function> waitExecutionFunction) { this.request = request; this.observer = observer; this.retrier = retrier; this.callCredentialsProvider = callCredentialsProvider; this.executeBackoff = this.retrier.newBackoff(); this.waitExecutionBackoff = new ProgressiveBackoff(this.retrier::newBackoff); - this.executionBlockingStubSupplier = executionBlockingStubSupplier; + this.executeFunction = executeFunction; + this.waitExecutionFunction = waitExecutionFunction; } ExecuteResponse start() throws IOException, InterruptedException { @@ -168,9 +172,9 @@ ExecuteResponse execute() throws IOException { Preconditions.checkState(lastOperation == null); try { - Iterator operationStream = executionBlockingStubSupplier.get().execute(request); + Iterator operationStream = executeFunction.apply(request); return handleOperationStream(operationStream); - } catch (StatusRuntimeException e) { + } catch (Throwable e) { // If lastOperation is not null, we know the execution request is accepted by the server. In // this case, we will fallback to WaitExecution() loop when the stream is broken. if (lastOperation != null) { @@ -188,17 +192,20 @@ ExecuteResponse waitExecution() throws IOException { WaitExecutionRequest request = WaitExecutionRequest.newBuilder().setName(lastOperation.getName()).build(); try { - Iterator operationStream = - executionBlockingStubSupplier.get().waitExecution(request); + Iterator operationStream = waitExecutionFunction.apply(request); return handleOperationStream(operationStream); - } catch (StatusRuntimeException e) { + } catch (Throwable e) { // A NOT_FOUND error means Operation was lost on the server, retry Execute(). // // However, we only retry Execute() if executeBackoff should retry. Also increase the retry // counter at the same time (done by nextDelayMillis()). - if (e.getStatus().getCode() == Code.NOT_FOUND && executeBackoff.nextDelayMillis(e) >= 0) { - lastOperation = null; - return null; + if (e instanceof StatusRuntimeException) { + StatusRuntimeException sre = (StatusRuntimeException) e; + if (sre.getStatus().getCode() == Code.NOT_FOUND + && executeBackoff.nextDelayMillis(sre) >= 0) { + lastOperation = null; + return null; + } } throw new IOException(e); } @@ -321,7 +328,16 @@ public ExecuteResponse executeRemotely( observer, retrier, callCredentialsProvider, - () -> this.executionBlockingStub(context.getRequestMetadata())); + (req) -> + channel.withChannelBlocking( + channel -> + this.executionBlockingStub(context.getRequestMetadata(), channel) + .execute(req)), + (req) -> + channel.withChannelBlocking( + channel -> + this.executionBlockingStub(context.getRequestMetadata(), channel) + .waitExecution(req))); return execution.start(); } diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java index e35d4c6f32ce94..717504be39f401 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java @@ -56,6 +56,7 @@ import com.google.devtools.build.lib.remote.zstd.ZstdDecompressingOutputStream; import com.google.devtools.build.lib.vfs.Path; import com.google.protobuf.ByteString; +import io.grpc.Channel; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; @@ -122,7 +123,8 @@ private int computeMaxMissingBlobsDigestsPerMessage() { return (options.maxOutboundMessageSize - overhead) / digestSize; } - private ContentAddressableStorageFutureStub casFutureStub(RemoteActionExecutionContext context) { + private ContentAddressableStorageFutureStub casFutureStub( + RemoteActionExecutionContext context, Channel channel) { return ContentAddressableStorageGrpc.newFutureStub(channel) .withInterceptors( TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()), @@ -131,7 +133,7 @@ private ContentAddressableStorageFutureStub casFutureStub(RemoteActionExecutionC .withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS); } - private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context) { + private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context, Channel channel) { return ByteStreamGrpc.newStub(channel) .withInterceptors( TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()), @@ -140,7 +142,8 @@ private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context) { .withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS); } - private ActionCacheFutureStub acFutureStub(RemoteActionExecutionContext context) { + private ActionCacheFutureStub acFutureStub( + RemoteActionExecutionContext context, Channel channel) { return ActionCacheGrpc.newFutureStub(channel) .withInterceptors( TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()), @@ -222,7 +225,11 @@ public ListenableFuture> findMissingDigests( private ListenableFuture getMissingDigests( RemoteActionExecutionContext context, FindMissingBlobsRequest request) { return Utils.refreshIfUnauthenticatedAsync( - () -> retrier.executeAsync(() -> casFutureStub(context).findMissingBlobs(request)), + () -> + retrier.executeAsync( + () -> + channel.withChannelFuture( + channel -> casFutureStub(context, channel).findMissingBlobs(request))), callCredentialsProvider); } @@ -254,7 +261,10 @@ public ListenableFuture downloadActionResult( return Utils.refreshIfUnauthenticatedAsync( () -> retrier.executeAsync( - () -> handleStatus(acFutureStub(context).getActionResult(request))), + () -> + handleStatus( + channel.withChannelFuture( + channel -> acFutureStub(context, channel).getActionResult(request)))), callCredentialsProvider); } @@ -267,13 +277,15 @@ public ListenableFuture uploadActionResult( retrier.executeAsync( () -> Futures.catchingAsync( - acFutureStub(context) - .updateActionResult( - UpdateActionResultRequest.newBuilder() - .setInstanceName(options.remoteInstanceName) - .setActionDigest(actionKey.getDigest()) - .setActionResult(actionResult) - .build()), + channel.withChannelFuture( + channel -> + acFutureStub(context, channel) + .updateActionResult( + UpdateActionResultRequest.newBuilder() + .setInstanceName(options.remoteInstanceName) + .setActionDigest(actionKey.getDigest()) + .setActionResult(actionResult) + .build())), StatusRuntimeException.class, (sre) -> Futures.immediateFailedFuture(new IOException(sre)), MoreExecutors.directExecutor())), @@ -317,18 +329,26 @@ private ListenableFuture downloadBlob( @Nullable Supplier digestSupplier) { AtomicLong offset = new AtomicLong(0); ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff); - ListenableFuture downloadFuture = + ListenableFuture downloadFuture = Utils.refreshIfUnauthenticatedAsync( () -> retrier.executeAsync( () -> - requestRead( - context, offset, progressiveBackoff, digest, out, digestSupplier), + channel.withChannelFuture( + channel -> + requestRead( + context, + offset, + progressiveBackoff, + digest, + out, + digestSupplier, + channel)), progressiveBackoff), callCredentialsProvider); return Futures.catchingAsync( - downloadFuture, + Futures.transform(downloadFuture, bytesWritten -> null, MoreExecutors.directExecutor()), StatusRuntimeException.class, (e) -> Futures.immediateFailedFuture(new IOException(e)), MoreExecutors.directExecutor()); @@ -343,17 +363,18 @@ public static String getResourceName(String instanceName, Digest digest, boolean return resourceName + DigestUtil.toString(digest); } - private ListenableFuture requestRead( + private ListenableFuture requestRead( RemoteActionExecutionContext context, AtomicLong offset, ProgressiveBackoff progressiveBackoff, Digest digest, CountingOutputStream out, - @Nullable Supplier digestSupplier) { + @Nullable Supplier digestSupplier, + Channel channel) { String resourceName = getResourceName(options.remoteInstanceName, digest, options.cacheCompression); - SettableFuture future = SettableFuture.create(); - bsAsyncStub(context) + SettableFuture future = SettableFuture.create(); + bsAsyncStub(context, channel) .read( ReadRequest.newBuilder() .setResourceName(resourceName) @@ -400,7 +421,7 @@ public void onCompleted() { Utils.verifyBlobContents(digest, digestSupplier.get()); } out.flush(); - future.set(null); + future.set(offset.get()); } catch (IOException e) { future.setException(e); } catch (RuntimeException e) { diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java index 0b8c3fa312585e..df3872ebfaeb46 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java @@ -30,6 +30,7 @@ import com.google.devtools.build.lib.remote.util.Utils; import com.google.longrunning.Operation; import com.google.rpc.Status; +import io.grpc.Channel; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import java.io.IOException; @@ -57,7 +58,7 @@ public GrpcRemoteExecutor( this.retrier = retrier; } - private ExecutionBlockingStub execBlockingStub(RequestMetadata metadata) { + private ExecutionBlockingStub execBlockingStub(RequestMetadata metadata, Channel channel) { return ExecutionGrpc.newBlockingStub(channel) .withInterceptors(TracingMetadataUtils.attachMetadataInterceptor(metadata)) .withCallCredentials(callCredentialsProvider.getCallCredentials()); @@ -152,9 +153,17 @@ public ExecuteResponse executeRemotely( WaitExecutionRequest.newBuilder() .setName(operation.get().getName()) .build(); - replies = execBlockingStub(context.getRequestMetadata()).waitExecution(wr); + replies = + channel.withChannelBlocking( + channel -> + execBlockingStub(context.getRequestMetadata(), channel) + .waitExecution(wr)); } else { - replies = execBlockingStub(context.getRequestMetadata()).execute(request); + replies = + channel.withChannelBlocking( + channel -> + execBlockingStub(context.getRequestMetadata(), channel) + .execute(request)); } try { while (replies.hasNext()) { diff --git a/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java b/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java index 36df5e77a1a78f..ee67160c372d54 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java +++ b/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java @@ -13,26 +13,23 @@ // limitations under the License. package com.google.devtools.build.lib.remote; -import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import com.google.common.base.Throwables; +import com.google.common.util.concurrent.ListenableFuture; import com.google.devtools.build.lib.remote.grpc.ChannelConnectionFactory; import com.google.devtools.build.lib.remote.grpc.ChannelConnectionFactory.ChannelConnection; import com.google.devtools.build.lib.remote.grpc.DynamicConnectionPool; import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection; -import io.grpc.CallOptions; +import com.google.devtools.build.lib.remote.util.RxFutures; import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ForwardingClientCall; -import io.grpc.ForwardingClientCallListener; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCounted; +import io.reactivex.rxjava3.annotations.CheckReturnValue; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleSource; +import io.reactivex.rxjava3.functions.Function; import java.io.IOException; -import java.util.concurrent.atomic.AtomicReference; -import javax.annotation.Nullable; /** * A wrapper around a {@link DynamicConnectionPool} exposing {@link Channel} and a reference count. @@ -41,7 +38,7 @@ * *

See {@link ReferenceCounted} for more information about reference counting. */ -public class ReferenceCountedChannel extends Channel implements ReferenceCounted { +public class ReferenceCountedChannel implements ReferenceCounted { private final DynamicConnectionPool dynamicConnectionPool; private final AbstractReferenceCounted referenceCounted = new AbstractReferenceCounted() { @@ -59,7 +56,6 @@ public ReferenceCounted touch(Object o) { return this; } }; - private final AtomicReference authorityRef = new AtomicReference<>(); public ReferenceCountedChannel(ChannelConnectionFactory connectionFactory) { this(connectionFactory, /*maxConnections=*/ 0); @@ -75,93 +71,42 @@ public boolean isShutdown() { return dynamicConnectionPool.isClosed(); } - /** A {@link ClientCall} which call {@link SharedConnection#close()} after the RPC is closed. */ - static class ConnectionCleanupCall - extends ForwardingClientCall.SimpleForwardingClientCall { - private final SharedConnection connection; - - protected ConnectionCleanupCall(ClientCall delegate, SharedConnection connection) { - super(delegate); - this.connection = connection; - } - - @Override - public void start(Listener responseListener, Metadata headers) { - super.start( - new ForwardingClientCallListener.SimpleForwardingClientCallListener( - responseListener) { - @Override - public void onClose(Status status, Metadata trailers) { - try { - connection.close(); - } catch (IOException e) { - throw new AssertionError(e.getMessage(), e); - } finally { - super.onClose(status, trailers); - } - } - }, - headers); - } - } - - private static class CloseOnStartClientCall extends ClientCall { - private final Status status; - - CloseOnStartClientCall(Status status) { - this.status = status; - } - - @Override - public void start(Listener responseListener, Metadata headers) { - responseListener.onClose(status, new Metadata()); - } - - @Override - public void request(int numMessages) {} - - @Override - public void cancel(@Nullable String message, @Nullable Throwable cause) {} - - @Override - public void halfClose() {} - - @Override - public void sendMessage(ReqT message) {} + @CheckReturnValue + public ListenableFuture withChannelFuture( + Function> source) { + return RxFutures.toListenableFuture( + withChannel(channel -> RxFutures.toSingle(() -> source.apply(channel), directExecutor()))); } - private SharedConnection acquireSharedConnection() throws IOException, InterruptedException { + public T withChannelBlocking(Function source) + throws IOException, InterruptedException { try { - SharedConnection sharedConnection = dynamicConnectionPool.create().blockingGet(); - ChannelConnection connection = (ChannelConnection) sharedConnection.getUnderlyingConnection(); - authorityRef.compareAndSet(null, connection.getChannel().authority()); - return sharedConnection; + return withChannel(channel -> Single.just(source.apply(channel))).blockingGet(); } catch (RuntimeException e) { - Throwables.throwIfInstanceOf(e.getCause(), IOException.class); - Throwables.throwIfInstanceOf(e.getCause(), InterruptedException.class); + Throwable cause = e.getCause(); + if (cause != null) { + throwIfInstanceOf(cause, IOException.class); + throwIfInstanceOf(cause, InterruptedException.class); + } throw e; } } - @Override - public ClientCall newCall( - MethodDescriptor methodDescriptor, CallOptions callOptions) { - try { - SharedConnection sharedConnection = acquireSharedConnection(); - return new ConnectionCleanupCall<>( - sharedConnection.call(methodDescriptor, callOptions), sharedConnection); - } catch (IOException e) { - return new CloseOnStartClientCall<>(Status.UNKNOWN.withCause(e)); - } catch (InterruptedException e) { - return new CloseOnStartClientCall<>(Status.CANCELLED.withCause(e)); - } - } - - @Override - public String authority() { - String authority = authorityRef.get(); - checkNotNull(authority, "create a connection first to get the authority"); - return authority; + @CheckReturnValue + public Single withChannel(Function> source) { + return dynamicConnectionPool + .create() + .flatMap( + sharedConnection -> + Single.using( + () -> sharedConnection, + conn -> { + ChannelConnection connection = + (ChannelConnection) sharedConnection.getUnderlyingConnection(); + Channel channel = connection.getChannel(); + return source.apply(channel); + }, + SharedConnection::close)); } @Override diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java index 252abc94622141..df4e38cf5df2de 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java @@ -96,6 +96,7 @@ import com.google.devtools.common.options.OptionsBase; import com.google.devtools.common.options.OptionsParsingResult; import io.grpc.CallCredentials; +import io.grpc.Channel; import io.grpc.ClientInterceptor; import io.grpc.ManagedChannel; import io.reactivex.rxjava3.plugins.RxJavaPlugins; @@ -516,7 +517,15 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException { String remoteBytestreamUriPrefix = remoteOptions.remoteBytestreamUriPrefix; if (Strings.isNullOrEmpty(remoteBytestreamUriPrefix)) { - remoteBytestreamUriPrefix = cacheChannel.authority(); + try { + remoteBytestreamUriPrefix = cacheChannel.withChannelBlocking(Channel::authority); + } catch (IOException e) { + handleInitFailure(env, e, Code.CACHE_INIT_FAILURE); + return; + } catch (InterruptedException e) { + handleInitFailure(env, new IOException(e), Code.CACHE_INIT_FAILURE); + return; + } if (!Strings.isNullOrEmpty(remoteOptions.remoteInstanceName)) { remoteBytestreamUriPrefix += "/" + remoteOptions.remoteInstanceName; } diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java index 6eb03ceb559b87..6d486480e30b5d 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java @@ -31,6 +31,7 @@ import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; import io.grpc.CallCredentials; +import io.grpc.Channel; import io.grpc.StatusRuntimeException; import java.io.IOException; import java.util.List; @@ -59,7 +60,8 @@ public RemoteServerCapabilities( this.retrier = retrier; } - private CapabilitiesBlockingStub capabilitiesBlockingStub(RemoteActionExecutionContext context) { + private CapabilitiesBlockingStub capabilitiesBlockingStub( + RemoteActionExecutionContext context, Channel channel) { return CapabilitiesGrpc.newBlockingStub(channel) .withInterceptors( TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata())) @@ -77,7 +79,10 @@ public ServerCapabilities get(String buildRequestId, String commandId) instanceName == null ? GetCapabilitiesRequest.getDefaultInstance() : GetCapabilitiesRequest.newBuilder().setInstanceName(instanceName).build(); - return retrier.execute(() -> capabilitiesBlockingStub(context).getCapabilities(request)); + return retrier.execute( + () -> + channel.withChannelBlocking( + channel -> capabilitiesBlockingStub(context, channel).getCapabilities(request))); } catch (StatusRuntimeException e) { if (e.getCause() instanceof IOException) { throw (IOException) e.getCause(); diff --git a/src/main/java/com/google/devtools/build/lib/remote/UploadManifest.java b/src/main/java/com/google/devtools/build/lib/remote/UploadManifest.java index 5dbbb0721c1dd2..b9b391227306d4 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/UploadManifest.java +++ b/src/main/java/com/google/devtools/build/lib/remote/UploadManifest.java @@ -354,8 +354,11 @@ public ActionResult upload( try { return uploadAsync(context, remoteCache, reporter).blockingGet(); } catch (RuntimeException e) { - throwIfInstanceOf(e.getCause(), InterruptedException.class); - throwIfInstanceOf(e.getCause(), IOException.class); + Throwable cause = e.getCause(); + if (cause != null) { + throwIfInstanceOf(cause, InterruptedException.class); + throwIfInstanceOf(cause, IOException.class); + } throw e; } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java b/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java index a0bc56b0b12d6e..c3456eb687968c 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java +++ b/src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java @@ -38,6 +38,7 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import io.grpc.CallCredentials; +import io.grpc.Channel; import io.grpc.StatusRuntimeException; import java.io.IOException; import java.io.OutputStream; @@ -122,7 +123,12 @@ public void download( newFetchBlobRequest(options.remoteInstanceName, urls, authHeaders, checksum, canonicalId); try { FetchBlobResponse response = - retrier.execute(() -> fetchBlockingStub(remoteActionExecutionContext).fetchBlob(request)); + retrier.execute( + () -> + channel.withChannelBlocking( + channel -> + fetchBlockingStub(remoteActionExecutionContext, channel) + .fetchBlob(request))); final Digest blobDigest = response.getBlobDigest(); retrier.execute( @@ -172,7 +178,8 @@ static FetchBlobRequest newFetchBlobRequest( return requestBuilder.build(); } - private FetchBlockingStub fetchBlockingStub(RemoteActionExecutionContext context) { + private FetchBlockingStub fetchBlockingStub( + RemoteActionExecutionContext context, Channel channel) { return FetchGrpc.newBlockingStub(channel) .withInterceptors( TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata())) diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java b/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java index 7eb07d4d95e05d..d86cfd8bcfdd8a 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java @@ -13,7 +13,6 @@ // limitations under the License. package com.google.devtools.build.lib.remote.util; -import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.util.concurrent.AbstractFuture; @@ -31,7 +30,7 @@ import io.reactivex.rxjava3.core.SingleOnSubscribe; import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.exceptions.Exceptions; -import java.util.concurrent.Callable; +import io.reactivex.rxjava3.functions.Supplier; import java.util.concurrent.CancellationException; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; @@ -48,7 +47,7 @@ private RxFutures() {} * completed. * *

A {@link ListenableFuture} represents some computation that is already in progress. We use - * {@link Callable} here to defer the execution of the thing that produces ListenableFuture until + * {@link Supplier} here to defer the execution of the thing that produces ListenableFuture until * there is subscriber. * *

Errors are also propagated except for certain "fatal" exceptions defined by rxjava. Multiple @@ -57,19 +56,19 @@ private RxFutures() {} *

Disposes the Completable to cancel the underlying ListenableFuture. */ public static Completable toCompletable( - Callable> callable, Executor executor) { - return Completable.create(new OnceCompletableOnSubscribe(callable, executor)); + Supplier> supplier, Executor executor) { + return Completable.create(new OnceCompletableOnSubscribe(supplier, executor)); } private static class OnceCompletableOnSubscribe implements CompletableOnSubscribe { private final AtomicBoolean subscribed = new AtomicBoolean(false); - private final Callable> callable; + private final Supplier> supplier; private final Executor executor; private OnceCompletableOnSubscribe( - Callable> callable, Executor executor) { - this.callable = callable; + Supplier> supplier, Executor executor) { + this.supplier = supplier; this.executor = executor; } @@ -77,7 +76,7 @@ private OnceCompletableOnSubscribe( public void subscribe(@NonNull CompletableEmitter emitter) throws Throwable { try { checkState(!subscribed.getAndSet(true), "This completable cannot be subscribed to twice"); - ListenableFuture future = callable.call(); + ListenableFuture future = supplier.get(); Futures.addCallback( future, new FutureCallback() { @@ -120,7 +119,7 @@ public void onFailure(Throwable throwable) { * completed. * *

A {@link ListenableFuture} represents some computation that is already in progress. We use - * {@link Callable} here to defer the execution of the thing that produces ListenableFuture until + * {@link Supplier} here to defer the execution of the thing that produces ListenableFuture until * there is subscriber. * *

Errors are also propagated except for certain "fatal" exceptions defined by rxjava. Multiple @@ -128,18 +127,18 @@ public void onFailure(Throwable throwable) { * *

Disposes the Single to cancel the underlying ListenableFuture. */ - public static Single toSingle(Callable> callable, Executor executor) { - return Single.create(new OnceSingleOnSubscribe<>(callable, executor)); + public static Single toSingle(Supplier> supplier, Executor executor) { + return Single.create(new OnceSingleOnSubscribe<>(supplier, executor)); } private static class OnceSingleOnSubscribe implements SingleOnSubscribe { private final AtomicBoolean subscribed = new AtomicBoolean(false); - private final Callable> callable; + private final Supplier> supplier; private final Executor executor; - private OnceSingleOnSubscribe(Callable> callable, Executor executor) { - this.callable = callable; + private OnceSingleOnSubscribe(Supplier> supplier, Executor executor) { + this.supplier = supplier; this.executor = executor; } @@ -147,13 +146,12 @@ private OnceSingleOnSubscribe(Callable> callable, Executor e public void subscribe(@NonNull SingleEmitter emitter) throws Throwable { try { checkState(!subscribed.getAndSet(true), "This single cannot be subscribed to twice"); - ListenableFuture future = callable.call(); + ListenableFuture future = supplier.get(); Futures.addCallback( future, new FutureCallback() { @Override public void onSuccess(@Nullable T t) { - checkNotNull(t, "value in future onSuccess callback is null"); emitter.onSuccess(t); }