diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index 137060595d..200b601135 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -85,6 +85,7 @@ public final class GrpcCallContext implements ApiCallContext { public static final CallOptions.Key TRACER_KEY = CallOptions.Key.create("gax.tracer"); private final Channel channel; + private final boolean isCallCredentialAttachedToChannel; @Nullable private final Credentials credentials; private final CallOptions callOptions; @Nullable private final java.time.Duration timeout; @@ -101,6 +102,7 @@ public final class GrpcCallContext implements ApiCallContext { public static GrpcCallContext createDefault() { return new GrpcCallContext( null, + false, null, CallOptions.DEFAULT, null, @@ -118,6 +120,7 @@ public static GrpcCallContext createDefault() { public static GrpcCallContext of(Channel channel, CallOptions callOptions) { return new GrpcCallContext( channel, + false, null, callOptions, null, @@ -133,6 +136,7 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) { private GrpcCallContext( Channel channel, + boolean isCallCredentialAttached, @Nullable Credentials credentials, CallOptions callOptions, @Nullable java.time.Duration timeout, @@ -145,6 +149,7 @@ private GrpcCallContext( @Nullable Set retryableCodes, @Nullable EndpointContext endpointContext) { this.channel = channel; + this.isCallCredentialAttachedToChannel = isCallCredentialAttached; this.credentials = credentials; this.callOptions = Preconditions.checkNotNull(callOptions); this.timeout = timeout; @@ -186,11 +191,15 @@ public GrpcCallContext nullToSelf(ApiCallContext inputContext) { @Override public GrpcCallContext withCredentials(Credentials newCredentials) { Preconditions.checkNotNull(newCredentials); - CallCredentials callCredentials = MoreCallCredentials.from(newCredentials); + CallOptions newCallOptions = callOptions; + if (!isCallCredentialAttachedToChannel) { + newCallOptions = callOptions.withCallCredentials(MoreCallCredentials.from(newCredentials)); + } return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, newCredentials, - callOptions.withCallCredentials(callCredentials), + newCallOptions, timeout, streamWaitTimeout, streamIdleTimeout, @@ -210,7 +219,20 @@ public GrpcCallContext withTransportChannel(TransportChannel inputChannel) { "Expected GrpcTransportChannel, got " + inputChannel.getClass().getName()); } GrpcTransportChannel transportChannel = (GrpcTransportChannel) inputChannel; - return withChannel(transportChannel.getChannel()); + return new GrpcCallContext( + transportChannel.getChannel(), + transportChannel.getIsCallCredentialAttachedToChannel(), + credentials, + callOptions, + timeout, + streamWaitTimeout, + streamIdleTimeout, + channelAffinity, + extraHeaders, + options, + retrySettings, + retryableCodes, + endpointContext); } @Override @@ -218,6 +240,7 @@ public GrpcCallContext withEndpointContext(EndpointContext endpointContext) { Preconditions.checkNotNull(endpointContext); return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -252,6 +275,7 @@ public GrpcCallContext withTimeoutDuration(@Nullable java.time.Duration timeout) return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -300,6 +324,7 @@ public GrpcCallContext withStreamWaitTimeoutDuration( return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -334,6 +359,7 @@ public GrpcCallContext withStreamIdleTimeoutDuration( return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -351,6 +377,7 @@ public GrpcCallContext withStreamIdleTimeoutDuration( public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -372,6 +399,7 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) Headers.mergeHeaders(this.extraHeaders, extraHeaders); return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -394,6 +422,7 @@ public RetrySettings getRetrySettings() { public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -416,6 +445,7 @@ public Set getRetryableCodes() { public GrpcCallContext withRetryableCodes(Set retryableCodes) { return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -441,6 +471,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { } GrpcCallContext grpcCallContext = (GrpcCallContext) inputCallContext; + boolean isCallCredentialAttached = grpcCallContext.isCallCredentialAttachedToChannel; + Credentials newCredentials = grpcCallContext.credentials; if (newCredentials == null) { newCredentials = credentials; @@ -515,6 +547,7 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { // to update this. return new GrpcCallContext( newChannel, + isCallCredentialAttached, newCredentials, newCallOptions, newTimeout, @@ -592,6 +625,7 @@ public Map> getExtraHeaders() { public GrpcCallContext withChannel(Channel newChannel) { return new GrpcCallContext( newChannel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, @@ -609,6 +643,7 @@ public GrpcCallContext withChannel(Channel newChannel) { public GrpcCallContext withCallOptions(CallOptions newCallOptions) { return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, newCallOptions, timeout, @@ -653,6 +688,7 @@ public GrpcCallContext withOption(Key key, T value) { ApiCallContextOptions newOptions = options.withOption(key, value); return new GrpcCallContext( channel, + isCallCredentialAttachedToChannel, credentials, callOptions, timeout, diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java index 2fa0908f17..b98ba54577 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java @@ -62,6 +62,8 @@ public GrpcCallContext getEmptyCallContext() { public abstract boolean isDirectPath(); + abstract boolean getIsCallCredentialAttachedToChannel(); + public Channel getChannel() { return getManagedChannel(); } @@ -102,11 +104,16 @@ public void close() { } public static Builder newBuilder() { - return new AutoValue_GrpcTransportChannel.Builder().setDirectPath(false); + return new AutoValue_GrpcTransportChannel.Builder() + .setDirectPath(false) + .setIsCallCredentialAttachedToChannel(false); } public static GrpcTransportChannel create(ManagedChannel channel) { - return newBuilder().setManagedChannel(channel).build(); + return newBuilder() + .setManagedChannel(channel) + .setIsCallCredentialAttachedToChannel(false) + .build(); } @AutoValue.Builder @@ -115,6 +122,9 @@ public abstract static class Builder { abstract Builder setDirectPath(boolean value); + abstract Builder setIsCallCredentialAttachedToChannel( + boolean isCallCredentialAttachedToChannel); + public abstract GrpcTransportChannel build(); } } diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 88bfd91601..5af1df1816 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -156,6 +156,12 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP @Nullable private final ApiFunction channelConfigurator; + // This is an internal flag to determine if a CallCredential has been passed to a gRPC Channel + // This is intended only for internal use cases and determines if the client library should + // attach the Credentials to the CallOptions. If this flag is true (i.e. DirectPaht and MTLS_S2A), + // the client library will skip attaching the Credentials to the CallOptions + private boolean isCallCredentialAttachedToChannel = false; + /* * Experimental feature * @@ -325,6 +331,12 @@ private TransportChannel createChannel() throws IOException { ChannelPool.create( channelPoolSettings, InstantiatingGrpcChannelProvider.this::createSingleChannel)) .setDirectPath(this.canUseDirectPath()) + // `createSingleChannel` must be invoked first as `isCallCredentialAttachedToChannel` + // is based off the logic in there. The initial channel count for a ChannelPool must be + // greater than 0, which means that it will be invoked at least once. Multiple invocations + // of `createSingleChannel` does not change the value of + // `isCallCredentialAttachedToChannel`. + .setIsCallCredentialAttachedToChannel(this.isCallCredentialAttachedToChannel) .build(); } @@ -626,6 +638,7 @@ private ManagedChannel createSingleChannel() throws IOException { .callCredentials(callCreds) .altsCallCredentials(altsCallCredentials) .build(); + isCallCredentialAttachedToChannel = true; useDirectPathXds = isDirectPathXdsEnabled(); if (useDirectPathXds) { // google-c2p: CloudToProd(C2P) Directpath. This scheme is defined in @@ -665,6 +678,7 @@ private ManagedChannel createSingleChannel() throws IOException { // which will be used to fetch MTLS_S2A hard bound tokens from the metdata server. channelCredentials = CompositeChannelCredentials.create(channelCredentials, mtlsS2ACallCredentials); + isCallCredentialAttachedToChannel = true; } builder = Grpc.newChannelBuilder(endpoint, channelCredentials); } else {