diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index 137060595d..bbbdb2c989 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -58,7 +58,6 @@ import io.grpc.Deadline; import io.grpc.Metadata; import io.grpc.Status; -import io.grpc.auth.MoreCallCredentials; import java.io.IOException; import java.util.List; import java.util.Map; @@ -186,11 +185,10 @@ public GrpcCallContext nullToSelf(ApiCallContext inputContext) { @Override public GrpcCallContext withCredentials(Credentials newCredentials) { Preconditions.checkNotNull(newCredentials); - CallCredentials callCredentials = MoreCallCredentials.from(newCredentials); return new GrpcCallContext( channel, newCredentials, - callOptions.withCallCredentials(callCredentials), + callOptions, timeout, streamWaitTimeout, streamIdleTimeout, diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallCredentialsInterceptor.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallCredentialsInterceptor.java new file mode 100644 index 0000000000..b047ab34d9 --- /dev/null +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallCredentialsInterceptor.java @@ -0,0 +1,72 @@ +/* + * Copyright 2025 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.grpc; + +import com.google.auth.Credentials; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.auth.MoreCallCredentials; + +/** + * This interceptor is package-private and intended only for client library usage. It will be added + * to inject a CallCredentials to the RPC's CallOptions for two use cases: 1. Local Testing: + * NoCredentialsProvider or null Credentials is passed in 2. Non-Oauth2 Credentials (i.e. + * ApiKeyCredentials) which do not have an Access Token + */ +class GrpcCallCredentialsInterceptor implements ClientInterceptor { + + private final Credentials credentials; + + GrpcCallCredentialsInterceptor(Credentials credentials) { + this.credentials = credentials; + } + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + // If user manually sets a CallCredentials, do not override + if (callOptions.getCredentials() == null) { + callOptions = callOptions.withCallCredentials(MoreCallCredentials.from(credentials)); + } + + ClientCall call = next.newCall(method, callOptions); + return new ForwardingClientCall.SimpleForwardingClientCall(call) { + @Override + public void start(ClientCall.Listener responseListener, Metadata headers) { + super.start(responseListener, headers); + } + }; + } +} diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 88bfd91601..6544d10bab 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -614,9 +614,17 @@ private ManagedChannel createSingleChannel() throws IOException { ManagedChannelBuilder builder; + // CallCredentials will be attached to newly created channel if there are valid Credentials + // provided. Valid Credentials that can be used as CallCredentials are non-null instances + // of valid GoogleCredentials. There are certain Credentials that cannot be used as + // CallCredentials + // (i.e. ApiKeyCredentials). + boolean needsCallCredentialsInterceptor = credentials != null; + // Check DirectPath traffic. boolean useDirectPathXds = false; if (canUseDirectPath()) { + needsCallCredentialsInterceptor = false; CallCredentials callCreds = MoreCallCredentials.from(credentials); // altsCallCredentials may be null and GoogleDefaultChannelCredentials // will solely use callCreds. Otherwise it uses altsCallCredentials @@ -665,10 +673,10 @@ private ManagedChannel createSingleChannel() throws IOException { // which will be used to fetch MTLS_S2A hard bound tokens from the metdata server. channelCredentials = CompositeChannelCredentials.create(channelCredentials, mtlsS2ACallCredentials); + needsCallCredentialsInterceptor = false; } builder = Grpc.newChannelBuilder(endpoint, channelCredentials); } else { - // Use default if we cannot initialize channel credentials via DCA or S2A. builder = ManagedChannelBuilder.forAddress(serviceAddress, port); } } @@ -678,6 +686,11 @@ private ManagedChannel createSingleChannel() throws IOException { // See https://github.com/googleapis/gapic-generator/issues/2816 builder.disableServiceConfigLookUp(); } + + // This is intercepted first to ensure that CallCredentials is added to CallOptions + if (needsCallCredentialsInterceptor) { + builder = builder.intercept(new GrpcCallCredentialsInterceptor(credentials)); + } builder = builder .intercept(new GrpcChannelUUIDInterceptor()) diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java index 63de03a88a..eff74f55e4 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java @@ -78,7 +78,9 @@ void testWithCredentials() { GrpcCallContext emptyContext = GrpcCallContext.createDefault(); assertNull(emptyContext.getCallOptions().getCredentials()); GrpcCallContext context = emptyContext.withCredentials(credentials); - assertNotNull(context.getCallOptions().getCredentials()); + // The gRPC call credentials will be embedded into the channel credentials + // and not attached to call options. + assertNull(context.getCallOptions().getCredentials()); } @Test diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java index 86203ce47d..6f6b739b57 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java @@ -206,8 +206,11 @@ void testWithPoolSize() throws IOException { ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1); executor.shutdown(); + Credentials credentials = Mockito.mock(Credentials.class); + TransportChannelProvider provider = InstantiatingGrpcChannelProvider.newBuilder() + .setCredentials(credentials) .build() .withExecutor((Executor) executor) .withHeaders(Collections.emptyMap()) @@ -234,6 +237,7 @@ void testToBuilder() { new ArrayList<>(); hardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS); hardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A); + Credentials credentials = Mockito.mock(Credentials.class); InstantiatingGrpcChannelProvider provider = InstantiatingGrpcChannelProvider.newBuilder() @@ -244,6 +248,7 @@ void testToBuilder() { .setKeepAliveTimeDuration(keepaliveTime) .setKeepAliveTimeoutDuration(keepaliveTimeout) .setKeepAliveWithoutCalls(Boolean.TRUE) + .setCredentials(credentials) .setChannelConfigurator(channelConfigurator) .setChannelsPerCpu(2.5) .setDirectPathServiceConfig(directPathServiceConfig) @@ -274,6 +279,7 @@ void testWithInterceptorsAndMultipleChannels() throws Exception { private void testWithInterceptors(int numChannels) throws Exception { final GrpcInterceptorProvider interceptorProvider = Mockito.mock(GrpcInterceptorProvider.class); + Credentials credentials = Mockito.mock(Credentials.class); InstantiatingGrpcChannelProvider channelProvider = InstantiatingGrpcChannelProvider.newBuilder() @@ -282,6 +288,7 @@ private void testWithInterceptors(int numChannels) throws Exception { .setHeaderProvider(Mockito.mock(HeaderProvider.class)) .setExecutor(Mockito.mock(Executor.class)) .setInterceptorProvider(interceptorProvider) + .setCredentials(credentials) .build(); Mockito.verify(interceptorProvider, Mockito.never()).getInterceptors(); @@ -303,6 +310,7 @@ void testChannelConfigurator() throws IOException { ManagedChannelBuilder swappedBuilder = Mockito.mock(ManagedChannelBuilder.class); ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class); + Credentials credentials = Mockito.mock(Credentials.class); Mockito.when(swappedBuilder.build()).thenReturn(fakeChannel); Mockito.when(channelConfigurator.apply(channelBuilderCaptor.capture())) @@ -315,6 +323,7 @@ void testChannelConfigurator() throws IOException { .setExecutor(Mockito.mock(Executor.class)) .setChannelConfigurator(channelConfigurator) .setPoolSize(numChannels) + .setCredentials(credentials) .build() .getTransportChannel(); @@ -486,8 +495,11 @@ void testWithIPv6Address() throws IOException { ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1); executor.shutdown(); + Credentials credentials = Mockito.mock(Credentials.class); + TransportChannelProvider provider = InstantiatingGrpcChannelProvider.newBuilder() + .setCredentials(credentials) .build() .withExecutor((Executor) executor) .withHeaders(Collections.emptyMap()) @@ -501,6 +513,8 @@ void testWithIPv6Address() throws IOException { // Test that if ChannelPrimer is provided, it is called during creation @Test void testWithPrimeChannel() throws IOException { + Credentials credentials = Mockito.mock(Credentials.class); + // create channelProvider with different pool sizes to verify ChannelPrimer is called the // correct number of times for (int poolSize = 1; poolSize < 5; poolSize++) { @@ -512,6 +526,7 @@ void testWithPrimeChannel() throws IOException { .setPoolSize(poolSize) .setHeaderProvider(Mockito.mock(HeaderProvider.class)) .setExecutor(Mockito.mock(Executor.class)) + .setCredentials(credentials) .setChannelPrimer(mockChannelPrimer) .build(); @@ -600,10 +615,14 @@ void testWithCustomDirectPathServiceConfig() { @Override protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider) throws IOException, GeneralSecurityException { + + Credentials credentials = Mockito.mock(Credentials.class); + InstantiatingGrpcChannelProvider channelProvider = InstantiatingGrpcChannelProvider.newBuilder() .setEndpoint("localhost:8080") .setMtlsProvider(provider) + .setCredentials(credentials) .setHeaderProvider(Mockito.mock(HeaderProvider.class)) .setExecutor(Mockito.mock(Executor.class)) .build(); @@ -630,9 +649,14 @@ private void createAndCloseTransportChannel(InstantiatingGrpcChannelProvider pro testLogDirectPathMisconfig_AttemptDirectPathNotSetAndAttemptDirectPathXdsSetViaBuilder_warns() throws Exception { FakeLogHandler logHandler = new FakeLogHandler(); + Credentials credentials = Mockito.mock(Credentials.class); + InstantiatingGrpcChannelProvider.LOG.addHandler(logHandler); InstantiatingGrpcChannelProvider provider = - createChannelProviderBuilderForDirectPathLogTests().setAttemptDirectPathXds().build(); + createChannelProviderBuilderForDirectPathLogTests() + .setAttemptDirectPathXds() + .setCredentials(credentials) + .build(); createAndCloseTransportChannel(provider); assertThat(logHandler.getAllMessages()) .contains( @@ -645,9 +669,10 @@ void testLogDirectPathMisconfig_AttemptDirectPathNotSetAndAttemptDirectPathXdsSe throws Exception { FakeLogHandler logHandler = new FakeLogHandler(); InstantiatingGrpcChannelProvider.LOG.addHandler(logHandler); + Credentials credentials = Mockito.mock(Credentials.class); InstantiatingGrpcChannelProvider provider = - createChannelProviderBuilderForDirectPathLogTests().build(); + createChannelProviderBuilderForDirectPathLogTests().setCredentials(credentials).build(); createAndCloseTransportChannel(provider); assertThat(logHandler.getAllMessages()) .contains( @@ -672,11 +697,14 @@ void testLogDirectPathMisconfig_shouldNotLogInTheBuilder() { @Test void testLogDirectPathMisconfigWrongCredential() throws Exception { FakeLogHandler logHandler = new FakeLogHandler(); + Credentials credentials = Mockito.mock(Credentials.class); + InstantiatingGrpcChannelProvider.LOG.addHandler(logHandler); InstantiatingGrpcChannelProvider provider = InstantiatingGrpcChannelProvider.newBuilder() .setAttemptDirectPathXds() .setAttemptDirectPath(true) + .setCredentials(credentials) .setHeaderProvider(Mockito.mock(HeaderProvider.class)) .setExecutor(Mockito.mock(Executor.class)) .setEndpoint(DEFAULT_ENDPOINT) @@ -697,11 +725,14 @@ void testLogDirectPathMisconfigWrongCredential() throws Exception { @Test void testLogDirectPathMisconfigNotOnGCE() throws Exception { FakeLogHandler logHandler = new FakeLogHandler(); + Credentials credentials = Mockito.mock(Credentials.class); + InstantiatingGrpcChannelProvider.LOG.addHandler(logHandler); InstantiatingGrpcChannelProvider provider = InstantiatingGrpcChannelProvider.newBuilder() .setAttemptDirectPathXds() .setAttemptDirectPath(true) + .setCredentials(credentials) .setAllowNonDefaultServiceAccount(true) .setHeaderProvider(Mockito.mock(HeaderProvider.class)) .setExecutor(Mockito.mock(Executor.class))