diff --git a/authentication/src/main/java/io/deephaven/auth/AnonymousAuthenticationHandler.java b/authentication/src/main/java/io/deephaven/auth/AnonymousAuthenticationHandler.java index 59e307e1371..448cbb91337 100644 --- a/authentication/src/main/java/io/deephaven/auth/AnonymousAuthenticationHandler.java +++ b/authentication/src/main/java/io/deephaven/auth/AnonymousAuthenticationHandler.java @@ -50,7 +50,7 @@ public Optional login(long protocolVersion, ByteBuffer payload, Han @Override public Optional login(String payload, MetadataResponseListener listener) { - if (payload.length() == 0) { + if (payload.isEmpty()) { return Optional.of(new AuthContext.Anonymous()); } return Optional.empty(); diff --git a/java-client/barrage-dagger/src/main/java/io/deephaven/client/impl/BarrageSessionModule.java b/java-client/barrage-dagger/src/main/java/io/deephaven/client/impl/BarrageSessionModule.java index e2762d7ce80..bae3be898c7 100644 --- a/java-client/barrage-dagger/src/main/java/io/deephaven/client/impl/BarrageSessionModule.java +++ b/java-client/barrage-dagger/src/main/java/io/deephaven/client/impl/BarrageSessionModule.java @@ -8,9 +8,6 @@ import io.grpc.ManagedChannel; import org.apache.arrow.memory.BufferAllocator; -import java.util.concurrent.CompletableFuture; -import java.util.function.Function; - @Module public class BarrageSessionModule { @Provides @@ -18,12 +15,4 @@ public static BarrageSession newDeephavenClientSession( SessionImpl session, BufferAllocator allocator, ManagedChannel managedChannel) { return BarrageSession.of(session, allocator, managedChannel); } - - @Provides - public static CompletableFuture newDeephavenClientSessionFuture( - CompletableFuture sessionFuture, BufferAllocator allocator, - ManagedChannel managedChannel) { - return sessionFuture.thenApply((Function) session -> BarrageSession - .of(session, allocator, managedChannel)); - } } diff --git a/java-client/barrage-dagger/src/main/java/io/deephaven/client/impl/BarrageSubcomponent.java b/java-client/barrage-dagger/src/main/java/io/deephaven/client/impl/BarrageSubcomponent.java index 023e40fa023..75ed86e9535 100644 --- a/java-client/barrage-dagger/src/main/java/io/deephaven/client/impl/BarrageSubcomponent.java +++ b/java-client/barrage-dagger/src/main/java/io/deephaven/client/impl/BarrageSubcomponent.java @@ -10,7 +10,8 @@ import io.grpc.ManagedChannel; import org.apache.arrow.memory.BufferAllocator; -import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; +import javax.inject.Named; import java.util.concurrent.ScheduledExecutorService; @Subcomponent(modules = {SessionImplModule.class, FlightSessionModule.class, BarrageSessionModule.class}) @@ -18,8 +19,6 @@ public interface BarrageSubcomponent extends BarrageSessionFactory { BarrageSession newBarrageSession(); - CompletableFuture newBarrageSessionFuture(); - @Module(subcomponents = {BarrageSubcomponent.class}) interface DeephavenClientSubcomponentModule { @@ -33,6 +32,9 @@ interface Builder extends BarrageSessionFactoryBuilder { Builder allocator(@BindsInstance BufferAllocator bufferAllocator); + Builder authenticationTypeAndValue( + @BindsInstance @Nullable @Named("authenticationTypeAndValue") String authenticationTypeAndValue); + BarrageSubcomponent build(); } } diff --git a/java-client/barrage-examples/src/main/java/io/deephaven/client/examples/BarrageClientExampleBase.java b/java-client/barrage-examples/src/main/java/io/deephaven/client/examples/BarrageClientExampleBase.java index 103d022a9f8..54bbe6c6190 100644 --- a/java-client/barrage-examples/src/main/java/io/deephaven/client/examples/BarrageClientExampleBase.java +++ b/java-client/barrage-examples/src/main/java/io/deephaven/client/examples/BarrageClientExampleBase.java @@ -5,6 +5,7 @@ import io.deephaven.client.impl.BarrageSession; import io.deephaven.client.impl.BarrageSessionFactory; +import io.deephaven.client.impl.BarrageSubcomponent.Builder; import io.deephaven.client.impl.DaggerDeephavenBarrageRoot; import io.grpc.ManagedChannel; import org.apache.arrow.memory.BufferAllocator; @@ -21,6 +22,9 @@ abstract class BarrageClientExampleBase implements Callable { @ArgGroup(exclusive = false) ConnectOptions connectOptions; + @ArgGroup(exclusive = true) + AuthenticationOptions authenticationOptions; + protected abstract void execute(BarrageSession session) throws Exception; @Override @@ -32,15 +36,15 @@ public final Void call() throws Exception { Runtime.getRuntime() .addShutdownHook(new Thread(() -> onShutdown(scheduler, managedChannel))); - final BarrageSessionFactory barrageFactory = - DaggerDeephavenBarrageRoot.create().factoryBuilder() - .managedChannel(managedChannel) - .scheduler(scheduler) - .allocator(bufferAllocator) - .build(); - + final Builder builder = DaggerDeephavenBarrageRoot.create().factoryBuilder() + .managedChannel(managedChannel) + .scheduler(scheduler) + .allocator(bufferAllocator); + if (authenticationOptions != null) { + authenticationOptions.ifPresent(builder::authenticationTypeAndValue); + } + final BarrageSessionFactory barrageFactory = builder.build(); final BarrageSession deephavenSession = barrageFactory.newBarrageSession(); - try { try { execute(deephavenSession); diff --git a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSession.java b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSession.java index 1e369147522..600ba2389a7 100644 --- a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSession.java +++ b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSession.java @@ -5,6 +5,7 @@ import io.deephaven.extensions.barrage.BarrageSnapshotOptions; import io.deephaven.extensions.barrage.BarrageSubscriptionOptions; +import io.deephaven.proto.DeephavenChannel; import io.deephaven.qst.table.TableSpec; import io.grpc.CallOptions; import io.grpc.Channel; @@ -30,12 +31,9 @@ public static BarrageSession of( return new BarrageSession(session, client, channel); } - private final Channel interceptedChannel; - protected BarrageSession( final SessionImpl session, final FlightClient client, final ManagedChannel channel) { super(session, client); - this.interceptedChannel = ClientInterceptors.intercept(channel, new AuthInterceptor()); } @Override @@ -64,25 +62,12 @@ public BarrageSnapshot snapshot(final TableHandle tableHandle, final BarrageSnap return new BarrageSnapshotImpl(this, session.executor(), tableHandle.newRef(), options); } - public Channel channel() { - return interceptedChannel; - } - - private class AuthInterceptor implements ClientInterceptor { - @Override - public ClientCall interceptCall( - final MethodDescriptor methodDescriptor, final CallOptions callOptions, - final Channel channel) { - return new ForwardingClientCall.SimpleForwardingClientCall( - channel.newCall(methodDescriptor, callOptions)) { - @Override - public void start(final Listener responseListener, final Metadata headers) { - final AuthenticationInfo localAuth = ((SessionImpl) session()).auth(); - headers.put(Metadata.Key.of(localAuth.sessionHeaderKey(), Metadata.ASCII_STRING_MARSHALLER), - localAuth.session()); - super.start(responseListener, headers); - } - }; - } + /** + * The authenticated channel. + * + * @return the authenticated channel + */ + public DeephavenChannel channel() { + return session.channel(); } } diff --git a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSessionFactory.java b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSessionFactory.java index 78bd292c0c3..54fe0c0e7dd 100644 --- a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSessionFactory.java +++ b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSessionFactory.java @@ -3,10 +3,6 @@ */ package io.deephaven.client.impl; -import java.util.concurrent.CompletableFuture; - public interface BarrageSessionFactory { BarrageSession newBarrageSession(); - - CompletableFuture newBarrageSessionFuture(); } diff --git a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSessionFactoryBuilder.java b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSessionFactoryBuilder.java index e34222288f6..3a2a0c70994 100644 --- a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSessionFactoryBuilder.java +++ b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSessionFactoryBuilder.java @@ -6,6 +6,7 @@ import io.grpc.ManagedChannel; import org.apache.arrow.memory.BufferAllocator; +import javax.annotation.Nullable; import java.util.concurrent.ScheduledExecutorService; public interface BarrageSessionFactoryBuilder { @@ -15,5 +16,7 @@ public interface BarrageSessionFactoryBuilder { BarrageSessionFactoryBuilder allocator(BufferAllocator bufferAllocator); + BarrageSessionFactoryBuilder authenticationTypeAndValue(@Nullable String authenticationTypeAndValue); + BarrageSessionFactory build(); } diff --git a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSnapshotImpl.java b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSnapshotImpl.java index b15055fda11..ccba5047f4e 100644 --- a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSnapshotImpl.java +++ b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSnapshotImpl.java @@ -94,7 +94,7 @@ public BarrageSnapshotImpl( final ClientCall call; final Context previous = Context.ROOT.attach(); try { - call = session.channel().newCall(snapshotDescriptor, CallOptions.DEFAULT); + call = session.channel().channel().newCall(snapshotDescriptor, CallOptions.DEFAULT); } finally { Context.ROOT.detach(previous); } diff --git a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSubscriptionImpl.java b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSubscriptionImpl.java index b1e6a535789..2e9341f914f 100644 --- a/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSubscriptionImpl.java +++ b/java-client/barrage/src/main/java/io/deephaven/client/impl/BarrageSubscriptionImpl.java @@ -97,7 +97,7 @@ public BarrageSubscriptionImpl( final ClientCall call; final Context previous = Context.ROOT.attach(); try { - call = session.channel().newCall(subscribeDescriptor, CallOptions.DEFAULT); + call = session.channel().channel().newCall(subscribeDescriptor, CallOptions.DEFAULT); } finally { Context.ROOT.detach(previous); } diff --git a/java-client/example-utilities/src/main/java/io/deephaven/client/examples/AuthenticationOptions.java b/java-client/example-utilities/src/main/java/io/deephaven/client/examples/AuthenticationOptions.java new file mode 100644 index 00000000000..0fd92cf3ed3 --- /dev/null +++ b/java-client/example-utilities/src/main/java/io/deephaven/client/examples/AuthenticationOptions.java @@ -0,0 +1,36 @@ +package io.deephaven.client.examples; + +import picocli.CommandLine.Option; + +import java.util.function.Consumer; + +public class AuthenticationOptions { + @Option(names = {"--mtls"}, description = "Use the connect mTLS") + Boolean mtls; + + @Option(names = {"--psk"}, description = "The pre-shared key") + String psk; + + @Option(names = {"--explicit"}, description = "The explicit authentication type and value") + String explicit; + + public String toAuthenticationTypeAndValue() { + if (mtls != null && mtls) { + return "io.deephaven.authentication.mtls.MTlsAuthenticationHandler"; + } + if (psk != null) { + return "psk " + psk; + } + if (explicit != null) { + return explicit; + } + return null; + } + + public void ifPresent(Consumer consumer) { + final String authenticationTypeAndValue = toAuthenticationTypeAndValue(); + if (authenticationTypeAndValue != null) { + consumer.accept(authenticationTypeAndValue); + } + } +} diff --git a/java-client/flight-dagger/src/main/java/io/deephaven/client/impl/FlightSessionModule.java b/java-client/flight-dagger/src/main/java/io/deephaven/client/impl/FlightSessionModule.java index 853afc19253..bab46124548 100644 --- a/java-client/flight-dagger/src/main/java/io/deephaven/client/impl/FlightSessionModule.java +++ b/java-client/flight-dagger/src/main/java/io/deephaven/client/impl/FlightSessionModule.java @@ -8,9 +8,6 @@ import io.grpc.ManagedChannel; import org.apache.arrow.memory.BufferAllocator; -import java.util.concurrent.CompletableFuture; -import java.util.function.Function; - @Module public class FlightSessionModule { @@ -19,13 +16,4 @@ public static FlightSession newFlightSession(SessionImpl session, BufferAllocato ManagedChannel managedChannel) { return FlightSession.of(session, allocator, managedChannel); } - - @Provides - public static CompletableFuture newFlightSessionFuture( - CompletableFuture sessionFuture, BufferAllocator allocator, - ManagedChannel managedChannel) { - return sessionFuture - .thenApply((Function) session -> FlightSession.of(session, - allocator, managedChannel)); - } } diff --git a/java-client/flight-dagger/src/main/java/io/deephaven/client/impl/FlightSubcomponent.java b/java-client/flight-dagger/src/main/java/io/deephaven/client/impl/FlightSubcomponent.java index 72ffe311dd5..d9aff3a0cbc 100644 --- a/java-client/flight-dagger/src/main/java/io/deephaven/client/impl/FlightSubcomponent.java +++ b/java-client/flight-dagger/src/main/java/io/deephaven/client/impl/FlightSubcomponent.java @@ -10,7 +10,8 @@ import io.grpc.ManagedChannel; import org.apache.arrow.memory.BufferAllocator; -import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; +import javax.inject.Named; import java.util.concurrent.ScheduledExecutorService; @Subcomponent(modules = {SessionImplModule.class, FlightSessionModule.class}) @@ -18,8 +19,6 @@ public interface FlightSubcomponent extends FlightSessionFactory { FlightSession newFlightSession(); - CompletableFuture newFlightSessionFuture(); - @Module(subcomponents = {FlightSubcomponent.class}) interface FlightSubcomponentModule { @@ -33,6 +32,9 @@ interface Builder { Builder allocator(@BindsInstance BufferAllocator bufferAllocator); + Builder authenticationTypeAndValue( + @BindsInstance @Nullable @Named("authenticationTypeAndValue") String authenticationTypeAndValue); + FlightSubcomponent build(); } } diff --git a/java-client/flight-examples/src/main/java/io/deephaven/client/examples/FlightExampleBase.java b/java-client/flight-examples/src/main/java/io/deephaven/client/examples/FlightExampleBase.java index 3c200d61656..8d07eb78258 100644 --- a/java-client/flight-examples/src/main/java/io/deephaven/client/examples/FlightExampleBase.java +++ b/java-client/flight-examples/src/main/java/io/deephaven/client/examples/FlightExampleBase.java @@ -6,6 +6,7 @@ import io.deephaven.client.impl.DaggerDeephavenFlightRoot; import io.deephaven.client.impl.FlightSession; import io.deephaven.client.impl.FlightSessionFactory; +import io.deephaven.client.impl.FlightSubcomponent.Builder; import io.grpc.ManagedChannel; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -21,6 +22,9 @@ abstract class FlightExampleBase implements Callable { @ArgGroup(exclusive = false) ConnectOptions connectOptions; + @ArgGroup(exclusive = true) + AuthenticationOptions authenticationOptions; + BufferAllocator bufferAllocator = new RootAllocator(); protected abstract void execute(FlightSession flight) throws Exception; @@ -32,15 +36,15 @@ public final Void call() throws Exception { Runtime.getRuntime() .addShutdownHook(new Thread(() -> onShutdown(scheduler, managedChannel))); - FlightSessionFactory flightSessionFactory = - DaggerDeephavenFlightRoot.create().factoryBuilder() - .managedChannel(managedChannel) - .scheduler(scheduler) - .allocator(bufferAllocator) - .build(); - + final Builder builder = DaggerDeephavenFlightRoot.create().factoryBuilder() + .managedChannel(managedChannel) + .scheduler(scheduler) + .allocator(bufferAllocator); + if (authenticationOptions != null) { + authenticationOptions.ifPresent(builder::authenticationTypeAndValue); + } + FlightSessionFactory flightSessionFactory = builder.build(); FlightSession flightSession = flightSessionFactory.newFlightSession(); - try { try { execute(flightSession); diff --git a/java-client/flight/src/main/java/io/deephaven/client/impl/AuthenticationMiddleware.java b/java-client/flight/src/main/java/io/deephaven/client/impl/AuthenticationMiddleware.java deleted file mode 100644 index c12775df391..00000000000 --- a/java-client/flight/src/main/java/io/deephaven/client/impl/AuthenticationMiddleware.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright (c) 2016-2022 Deephaven Data Labs and Patent Pending - */ -package io.deephaven.client.impl; - -import org.apache.arrow.flight.CallHeaders; -import org.apache.arrow.flight.CallStatus; -import org.apache.arrow.flight.FlightClientMiddleware; - -import java.util.Objects; - -public class AuthenticationMiddleware implements FlightClientMiddleware { - private final AuthenticationInfo auth; - - public AuthenticationMiddleware(AuthenticationInfo auth) { - this.auth = Objects.requireNonNull(auth); - } - - @Override - public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { - outgoingHeaders.insert(auth.sessionHeaderKey(), auth.session()); - } - - @Override - public void onHeadersReceived(CallHeaders incomingHeaders) { - - } - - @Override - public void onCallCompleted(CallStatus status) { - - } -} diff --git a/java-client/flight/src/main/java/io/deephaven/client/impl/BearerMiddlewear.java b/java-client/flight/src/main/java/io/deephaven/client/impl/BearerMiddlewear.java new file mode 100644 index 00000000000..6d0b33b17b4 --- /dev/null +++ b/java-client/flight/src/main/java/io/deephaven/client/impl/BearerMiddlewear.java @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2016-2022 Deephaven Data Labs and Patent Pending + */ +package io.deephaven.client.impl; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClientMiddleware; + +import java.util.Objects; + +class BearerMiddlewear implements FlightClientMiddleware { + private final BearerHandler bearerHandler; + + BearerMiddlewear(BearerHandler bearerHandler) { + this.bearerHandler = Objects.requireNonNull(bearerHandler); + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + outgoingHeaders.insert(Authentication.AUTHORIZATION_HEADER.name(), bearerHandler.authenticationValue()); + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + String lastBearerValue = null; + for (String authenticationValue : incomingHeaders.getAll(Authentication.AUTHORIZATION_HEADER.name())) { + if (authenticationValue.startsWith(BearerHandler.BEARER_PREFIX)) { + lastBearerValue = authenticationValue; + } + } + if (lastBearerValue != null) { + bearerHandler.setBearerToken(lastBearerValue.substring(BearerHandler.BEARER_PREFIX.length())); + } + } + + @Override + public void onCallCompleted(CallStatus status) { + + } +} diff --git a/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSession.java b/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSession.java index 5d08b920b49..459dbf08373 100644 --- a/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSession.java +++ b/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSession.java @@ -20,6 +20,10 @@ public class FlightSession implements AutoCloseable { public static FlightSession of(SessionImpl session, BufferAllocator incomingAllocator, ManagedChannel channel) { + // Note: this pattern of FlightClient owning the ManagedChannel does not mesh well with the idea that some + // other entity may be managing the authentication lifecycle. We'd prefer to pass in the stubs or "intercepted" + // channel directly, but that's not supported. So, we need to create the specific middleware interfaces so + // flight can do its own shims. final FlightClient client = FlightGrpcUtilsExtension.createFlightClientWithSharedChannel( incomingAllocator, channel, Collections.singletonList(new SessionMiddleware(session))); return new FlightSession(session, client); diff --git a/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSessionFactory.java b/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSessionFactory.java index 5a75ddb71c7..938e2367ad9 100644 --- a/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSessionFactory.java +++ b/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSessionFactory.java @@ -3,10 +3,6 @@ */ package io.deephaven.client.impl; -import java.util.concurrent.CompletableFuture; - public interface FlightSessionFactory { FlightSession newFlightSession(); - - CompletableFuture newFlightSessionFuture(); } diff --git a/java-client/flight/src/main/java/io/deephaven/client/impl/SessionMiddleware.java b/java-client/flight/src/main/java/io/deephaven/client/impl/SessionMiddleware.java index 79fa9eda1ad..441e661e96e 100644 --- a/java-client/flight/src/main/java/io/deephaven/client/impl/SessionMiddleware.java +++ b/java-client/flight/src/main/java/io/deephaven/client/impl/SessionMiddleware.java @@ -18,6 +18,6 @@ public SessionMiddleware(SessionImpl session) { @Override public final FlightClientMiddleware onCallStarted(CallInfo info) { - return new AuthenticationMiddleware(session.auth()); + return new BearerMiddlewear(session._hackBearerHandler()); } } diff --git a/java-client/session-dagger/src/main/java/io/deephaven/client/SessionImplModule.java b/java-client/session-dagger/src/main/java/io/deephaven/client/SessionImplModule.java index 033f3bef2a8..8c2f3a2dc7a 100644 --- a/java-client/session-dagger/src/main/java/io/deephaven/client/SessionImplModule.java +++ b/java-client/session-dagger/src/main/java/io/deephaven/client/SessionImplModule.java @@ -8,12 +8,14 @@ import dagger.Provides; import io.deephaven.client.impl.SessionImpl; import io.deephaven.client.impl.SessionImplConfig; +import io.deephaven.client.impl.SessionImplConfig.Builder; import io.deephaven.proto.DeephavenChannel; import io.deephaven.proto.DeephavenChannelImpl; import io.grpc.Channel; import io.grpc.ManagedChannel; -import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; +import javax.inject.Named; import java.util.concurrent.ScheduledExecutorService; @Module @@ -26,21 +28,19 @@ public interface SessionImplModule { DeephavenChannel bindsDeephavenChannelImpl(DeephavenChannelImpl deephavenChannelImpl); @Provides - static SessionImpl session(DeephavenChannel channel, ScheduledExecutorService scheduler) { - return SessionImplConfig.builder() + static SessionImpl session(DeephavenChannel channel, ScheduledExecutorService scheduler, + @Nullable @Named("authenticationTypeAndValue") String authenticationTypeAndValue) { + final Builder builder = SessionImplConfig.builder() .executor(scheduler) - .channel(channel) - .build() - .createSession(); - } - - @Provides - static CompletableFuture sessionFuture(DeephavenChannel channel, - ScheduledExecutorService scheduler) { - return SessionImplConfig.builder() - .executor(scheduler) - .channel(channel) - .build() - .createSessionFuture(); + .channel(channel); + if (authenticationTypeAndValue != null) { + builder.authenticationTypeAndValue(authenticationTypeAndValue); + } + final SessionImplConfig config = builder.build(); + try { + return config.createSession(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } } } diff --git a/java-client/session-dagger/src/main/java/io/deephaven/client/SessionSubcomponent.java b/java-client/session-dagger/src/main/java/io/deephaven/client/SessionSubcomponent.java index 7ce6d21b856..369d3146725 100644 --- a/java-client/session-dagger/src/main/java/io/deephaven/client/SessionSubcomponent.java +++ b/java-client/session-dagger/src/main/java/io/deephaven/client/SessionSubcomponent.java @@ -10,6 +10,8 @@ import io.deephaven.client.impl.SessionImpl; import io.grpc.ManagedChannel; +import javax.annotation.Nullable; +import javax.inject.Named; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; @@ -18,8 +20,6 @@ public interface SessionSubcomponent extends SessionFactory { SessionImpl newSession(); - CompletableFuture newSessionFuture(); - @Module(subcomponents = SessionSubcomponent.class) interface SessionFactorySubcomponentModule { @@ -31,6 +31,9 @@ interface Builder { Builder scheduler(@BindsInstance ScheduledExecutorService scheduler); + Builder authenticationTypeAndValue( + @BindsInstance @Nullable @Named("authenticationTypeAndValue") String authenticationTypeAndValue); + // TODO(deephaven-core#1157): Plumb SessionImplConfig.Builder options through dagger SessionSubcomponent build(); diff --git a/java-client/session-examples/build.gradle b/java-client/session-examples/build.gradle index 9e0aa0306de..a0b8d5c36c7 100644 --- a/java-client/session-examples/build.gradle +++ b/java-client/session-examples/build.gradle @@ -36,7 +36,6 @@ def createApplication = { String name, String mainClass -> } applicationDistribution.into('bin') { - from(createApplication('hammer-sessions', 'io.deephaven.client.examples.HammerSessions')) from(createApplication('publish-tables', 'io.deephaven.client.examples.PublishTables')) from(createApplication('write-qsts', 'io.deephaven.client.examples.WriteExampleQsts')) from(createApplication('table-manager', 'io.deephaven.client.examples.TableManagerExample')) diff --git a/java-client/session-examples/src/main/java/io/deephaven/client/examples/HammerSessions.java b/java-client/session-examples/src/main/java/io/deephaven/client/examples/HammerSessions.java deleted file mode 100644 index caef327cceb..00000000000 --- a/java-client/session-examples/src/main/java/io/deephaven/client/examples/HammerSessions.java +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright (c) 2016-2022 Deephaven Data Labs and Patent Pending - */ -package io.deephaven.client.examples; - -import io.deephaven.client.impl.Session; -import io.deephaven.client.impl.SessionFactory; -import picocli.CommandLine; -import picocli.CommandLine.Command; -import picocli.CommandLine.Parameters; - -import java.time.Duration; -import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.LongAdder; -import java.util.function.BiConsumer; - -@Command(name = "hammer-sessions", mixinStandardHelpOptions = true, - description = "Create a lot of sessions as fast as possible", version = "0.1.0") -public class HammerSessions extends SessionExampleBase { - - @Parameters(arity = "1", paramLabel = "OUTSTANDING", - description = "Maximum number of outstanding sessions.") - int outstanding; - - @Parameters(arity = "1", paramLabel = "TOTAL", - description = "The total number of sessions to create.") - int total; - - @Override - protected void execute(SessionFactory factory) throws Exception { - final LongAdder failed = new LongAdder(); - Semaphore semaphore = new Semaphore(outstanding, false); - - final long start = System.nanoTime(); - for (long i = 0; i < total; ++i) { - if (i % (outstanding / 3) == 0) { - System.out.printf("Started %d, outstanding %d%n", i, semaphore.availablePermits()); - } - semaphore.acquire(); - factory.newSessionFuture() - .whenComplete((BiConsumer) (session, throwable) -> { - if (throwable != null) { - throwable.printStackTrace(); - failed.add(1); - } - if (session != null) { - session.closeFuture().whenComplete((x, y) -> semaphore.release()); - } else { - semaphore.release(); - } - }); - } - - semaphore.acquire(outstanding); - final long end = System.nanoTime(); - - long f = failed.longValue(); - System.out.printf("%d succeeded, %d failed, avg duration %s%n", total - f, f, - Duration.ofNanos(end - start).dividedBy(total)); - } - - public static void main(String[] args) { - int execute = new CommandLine(new HammerSessions()).execute(args); - System.exit(execute); - } -} diff --git a/java-client/session-examples/src/main/java/io/deephaven/client/examples/SessionExampleBase.java b/java-client/session-examples/src/main/java/io/deephaven/client/examples/SessionExampleBase.java index dbd52103d39..bfc3970bbcd 100644 --- a/java-client/session-examples/src/main/java/io/deephaven/client/examples/SessionExampleBase.java +++ b/java-client/session-examples/src/main/java/io/deephaven/client/examples/SessionExampleBase.java @@ -4,6 +4,7 @@ package io.deephaven.client.examples; import io.deephaven.client.DaggerDeephavenSessionRoot; +import io.deephaven.client.SessionSubcomponent.Builder; import io.deephaven.client.impl.SessionFactory; import io.grpc.ManagedChannel; import picocli.CommandLine.ArgGroup; @@ -18,6 +19,9 @@ abstract class SessionExampleBase implements Callable { @ArgGroup(exclusive = false) ConnectOptions connectOptions; + @ArgGroup(exclusive = true) + AuthenticationOptions authenticationOptions; + protected abstract void execute(SessionFactory sessionFactory) throws Exception; @Override @@ -29,8 +33,13 @@ public final Void call() throws Exception { Runtime.getRuntime() .addShutdownHook(new Thread(() -> onShutdown(scheduler, managedChannel))); - SessionFactory factory = DaggerDeephavenSessionRoot.create().factoryBuilder() - .managedChannel(managedChannel).scheduler(scheduler).build(); + final Builder builder = DaggerDeephavenSessionRoot.create().factoryBuilder() + .managedChannel(managedChannel) + .scheduler(scheduler); + if (authenticationOptions != null) { + authenticationOptions.ifPresent(builder::authenticationTypeAndValue); + } + SessionFactory factory = builder.build(); execute(factory); scheduler.shutdownNow(); diff --git a/java-client/session/src/main/java/io/deephaven/client/impl/Authentication.java b/java-client/session/src/main/java/io/deephaven/client/impl/Authentication.java new file mode 100644 index 00000000000..2c6dec097fb --- /dev/null +++ b/java-client/session/src/main/java/io/deephaven/client/impl/Authentication.java @@ -0,0 +1,263 @@ +package io.deephaven.client.impl; + +import io.deephaven.proto.DeephavenChannel; +import io.deephaven.proto.backplane.grpc.ConfigurationConstantsRequest; +import io.deephaven.proto.backplane.grpc.ConfigurationConstantsResponse; +import io.grpc.CallCredentials; +import io.grpc.ClientInterceptor; +import io.grpc.Metadata; +import io.grpc.Metadata.Key; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientResponseObserver; + +import java.time.Duration; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public final class Authentication { + + /** + * The "authorization" header. This is a misnomer in the specification, this is really an "authentication" header. + */ + public static final Key AUTHORIZATION_HEADER = Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER); + + /** + * Starts an authentication request. + * + * @param channel the channel + * @param authenticationTypeAndValue the authentication type and optional value + * @return the authentication + */ + public static Authentication authenticate(DeephavenChannel channel, String authenticationTypeAndValue) { + final Authentication authentication = new Authentication(channel, authenticationTypeAndValue); + authentication.start(); + return authentication; + } + + private final DeephavenChannel channel; + private final String authenticationTypeAndValue; + private final BearerHandler bearerHandler = new BearerHandler(); + private final CountDownLatch done = new CountDownLatch(1); + private final CompletableFuture future = new CompletableFuture<>(); + private ClientCallStreamObserver requestStream; + private ConfigurationConstantsResponse response; + private Throwable error; + + private Authentication(DeephavenChannel channel, String authenticationTypeAndValue) { + this.channel = Objects.requireNonNull(channel); + this.authenticationTypeAndValue = Objects.requireNonNull(authenticationTypeAndValue); + } + + /** + * Causes the current thread to wait until the authentication request is finished, unless the thread is interrupted. + * + * @throws InterruptedException if the current thread is interrupted while waiting + */ + public void await() throws InterruptedException { + done.await(); + } + + /** + * Causes the current thread to wait for up to {@code duration} until the authentication request is finished, unless + * the thread is interrupted, or the specified waiting time elapses. + * + * @param duration the duration to wait + * @return true if the authentication request is done and false if the waiting time elapsed before the request is + * done + * @throws InterruptedException if the current thread is interrupted while waiting + */ + public boolean await(Duration duration) throws InterruptedException { + return done.await(duration.toNanos(), TimeUnit.NANOSECONDS); + } + + /** + * Waits for the request to finish. On interrupted, will cancel the request, and re-throw the interrupted exception. + * + * @throws InterruptedException if the current thread is interrupted while waiting + */ + public void awaitOrCancel() throws InterruptedException { + try { + done.await(); + } catch (InterruptedException e) { + cancel("Thread interrupted", e); + throw e; + } + } + + /** + * Causes the current thread to wait for up to {@code duration} until the authentication request is finished, unless + * the thread is interrupted, or the specified waiting time elapses. On interrupted, will cancel the request, and + * re-throw the interrupted exception. On timed-out, will cancel the request. + * + * @param duration the duration to wait + * @return true if the authentication request is done and false if the waiting time elapsed before the request is + * done + * @throws InterruptedException if the current thread is interrupted while waiting + */ + public boolean awaitOrCancel(Duration duration) throws InterruptedException { + final boolean finished; + try { + finished = done.await(duration.toNanos(), TimeUnit.NANOSECONDS); + } catch (InterruptedException e) { + cancel("Thread interrupted", e); + throw e; + } + if (!finished) { + cancel("Timed out", null); + } + return finished; + } + + /** + * The future. Is always completed successfully with {@code this} when done. The caller is still responsible for + * checking {@link #error()} as necessary. Presented as an alternative to the await methods. + * + * @return the future + */ + public CompletableFuture future() { + return future.whenComplete((r, t) -> { + if (future.isCancelled()) { + requestStream.cancel("User cancelled", null); + } + }); + } + + /** + * Cancels the request. + * + * @param message the message + * @param cause the cause + */ + public void cancel(String message, Throwable cause) { + requestStream.cancel(message, cause); + } + + /** + * Upon success, will return a channel that handles setting the Bearer token when messages are sent, and handles + * updating the Bearer token when messages are received. The request must already be finished. + * + *

+ * Note: the caller is responsible for ensuring at least some messages are sent as appropriate during the token + * timeout schedule. See {@link #configurationConstants()}. + */ + public Optional bearerChannel() { + if (done.getCount() != 0) { + throw new IllegalStateException("Must await response"); + } + if (response == null) { + return Optional.empty(); + } + return Optional.of(credsAndInterceptor(channel, bearerHandler, bearerHandler)); + } + + /** + * The configuration constants. The request must already be finished. + * + * @return the configuration constants + */ + public Optional configurationConstants() { + if (done.getCount() != 0) { + throw new IllegalStateException("Must await response"); + } + return Optional.ofNullable(response); + } + + /** + * The error. The request must already be finished. + * + * @return the error + */ + public Optional error() { + if (done.getCount() != 0) { + throw new IllegalStateException("Must await response"); + } + return Optional.ofNullable(error); + } + + /** + * Throws if an error has been returned. The request must already be finished. + * + * @throws RuntimeException if an error has been returned + */ + public void throwOnError() { + if (done.getCount() != 0) { + throw new IllegalStateException("Must await response"); + } + if (error != null) { + throw toRuntimeException(error); + } + } + + BearerHandler bearerHandler() { + return bearerHandler; + } + + private void start() { + final DeephavenChannel initialChannel = credsAndInterceptor(channel, + new AuthenticationCallCredentials(authenticationTypeAndValue), bearerHandler); + initialChannel.config().getConfigurationConstants(ConfigurationConstantsRequest.getDefaultInstance(), + new Observer()); + } + + private class Observer + implements ClientResponseObserver { + + @Override + public void beforeStart(ClientCallStreamObserver stream) { + requestStream = stream; + } + + @Override + public void onNext(ConfigurationConstantsResponse response) { + Authentication.this.response = response; + } + + @Override + public void onError(Throwable t) { + error = t; + done.countDown(); + future.complete(Authentication.this); + } + + @Override + public void onCompleted() { + if (Authentication.this.response == null) { + error = new IllegalStateException("Completed without response"); + } + done.countDown(); + future.complete(Authentication.this); + } + } + + private static DeephavenChannel credsAndInterceptor(DeephavenChannel channel, CallCredentials callCredentials, + ClientInterceptor clientInterceptor) { + return DeephavenChannel.withClientInterceptors(DeephavenChannel.withCallCredentials(channel, callCredentials), + clientInterceptor); + } + + // Similar to io.grpc.stub.ClientCalls.toStatusRuntimeException + private static RuntimeException toRuntimeException(Throwable t) { + Throwable cause = Objects.requireNonNull(t); + while (cause != null) { + // If we have an embedded status, use it and replace the cause + if (cause instanceof StatusException) { + StatusException se = (StatusException) cause; + return new StatusRuntimeException(se.getStatus(), se.getTrailers()); + } else if (cause instanceof StatusRuntimeException) { + StatusRuntimeException se = (StatusRuntimeException) cause; + return new StatusRuntimeException(se.getStatus(), se.getTrailers()); + } else if (cause instanceof RuntimeException) { + return (RuntimeException) cause; + } + cause = cause.getCause(); + } + return Status.UNKNOWN.withDescription("unexpected exception").withCause(t) + .asRuntimeException(); + } +} diff --git a/java-client/session/src/main/java/io/deephaven/client/impl/AuthenticationCallCredentials.java b/java-client/session/src/main/java/io/deephaven/client/impl/AuthenticationCallCredentials.java new file mode 100644 index 00000000000..864f2f1244c --- /dev/null +++ b/java-client/session/src/main/java/io/deephaven/client/impl/AuthenticationCallCredentials.java @@ -0,0 +1,29 @@ +package io.deephaven.client.impl; + +import io.grpc.CallCredentials; +import io.grpc.Metadata; + +import java.util.Objects; +import java.util.concurrent.Executor; + +import static io.deephaven.client.impl.Authentication.AUTHORIZATION_HEADER; + +class AuthenticationCallCredentials extends CallCredentials { + private final String authenticationTypeAndValue; + + public AuthenticationCallCredentials(String authenticationTypeAndValue) { + this.authenticationTypeAndValue = Objects.requireNonNull(authenticationTypeAndValue); + } + + @Override + public void applyRequestMetadata(RequestInfo requestInfo, Executor appExecutor, MetadataApplier applier) { + final Metadata headers = new Metadata(); + headers.put(AUTHORIZATION_HEADER, authenticationTypeAndValue); + applier.apply(headers); + } + + @Override + public void thisUsesUnstableApi() { + + } +} diff --git a/java-client/session/src/main/java/io/deephaven/client/impl/BearerHandler.java b/java-client/session/src/main/java/io/deephaven/client/impl/BearerHandler.java new file mode 100644 index 00000000000..6ae782f64cd --- /dev/null +++ b/java-client/session/src/main/java/io/deephaven/client/impl/BearerHandler.java @@ -0,0 +1,125 @@ +package io.deephaven.client.impl; + +import io.grpc.CallCredentials; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientCall.Listener; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; + +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.Executor; + +import static io.deephaven.client.impl.Authentication.AUTHORIZATION_HEADER; + +/** + * As a {@link ClientInterceptor}, this parser the responses for the bearer token. + * + *

+ * As a {@link CallCredentials}, this sets the (previously attained) bearer token on requests. + */ +final class BearerHandler extends CallCredentials implements ClientInterceptor { + + // this is really about "authentication" not "authorization" + public static final String BEARER_PREFIX = "Bearer "; + private volatile String bearerToken; + + private static Optional parseBearerToken(Metadata metadata) { + final Iterable authenticationValues = metadata.getAll(AUTHORIZATION_HEADER); + if (authenticationValues == null) { + return Optional.empty(); + } + String lastBearerValue = null; + for (String authenticationValue : authenticationValues) { + if (authenticationValue.startsWith(BEARER_PREFIX)) { + lastBearerValue = authenticationValue; + } + } + if (lastBearerValue == null) { + return Optional.empty(); + } + return Optional.of(lastBearerValue.substring(BEARER_PREFIX.length())); + } + + // exposed for flight + void setBearerToken(String bearerToken) { + String localBearerToken = this.bearerToken; + // Only follow through with the volatile write if it's a different value. + if (!Objects.equals(localBearerToken, bearerToken)) { + this.bearerToken = Objects.requireNonNull(bearerToken); + } + } + + private void handleMetadata(Metadata metadata) { + parseBearerToken(metadata).ifPresent(BearerHandler.this::setBearerToken); + } + + // exposed for flight + String authenticationValue() { + return BEARER_PREFIX + bearerToken; + } + + @Override + public void applyRequestMetadata(RequestInfo requestInfo, Executor appExecutor, MetadataApplier applier) { + final String bearerToken = this.bearerToken; + if (bearerToken == null) { + applier.fail(Status.UNAUTHENTICATED); + return; + } + final Metadata headers = new Metadata(); + headers.put(AUTHORIZATION_HEADER, authenticationValue()); + applier.apply(headers); + } + + @Override + public void thisUsesUnstableApi() { + + } + + @Override + public ClientCall interceptCall(MethodDescriptor method, + CallOptions callOptions, Channel next) { + return new BearerCall<>(next.newCall(method, callOptions)); + } + + private final class BearerCall extends SimpleForwardingClientCall { + public BearerCall(ClientCall delegate) { + super(delegate); + } + + @Override + public void start(Listener responseListener, Metadata headers) { + super.start(new BearerListener<>(responseListener), headers); + } + } + + private final class BearerListener extends SimpleForwardingClientCallListener { + public BearerListener(Listener delegate) { + super(delegate); + } + + @Override + public void onHeaders(Metadata headers) { + try { + handleMetadata(headers); + } finally { + super.onHeaders(headers); + } + } + + @Override + public void onClose(Status status, Metadata trailers) { + try { + handleMetadata(trailers); + } finally { + super.onClose(status, trailers); + } + } + } +} diff --git a/java-client/session/src/main/java/io/deephaven/client/impl/SessionFactory.java b/java-client/session/src/main/java/io/deephaven/client/impl/SessionFactory.java index 58f06426f61..9c81ebc4905 100644 --- a/java-client/session/src/main/java/io/deephaven/client/impl/SessionFactory.java +++ b/java-client/session/src/main/java/io/deephaven/client/impl/SessionFactory.java @@ -3,11 +3,7 @@ */ package io.deephaven.client.impl; -import java.util.concurrent.CompletableFuture; - public interface SessionFactory { Session newSession(); - - CompletableFuture newSessionFuture(); } diff --git a/java-client/session/src/main/java/io/deephaven/client/impl/SessionImpl.java b/java-client/session/src/main/java/io/deephaven/client/impl/SessionImpl.java index 2a6f2142539..15810e208f1 100644 --- a/java-client/session/src/main/java/io/deephaven/client/impl/SessionImpl.java +++ b/java-client/session/src/main/java/io/deephaven/client/impl/SessionImpl.java @@ -6,38 +6,30 @@ import com.google.protobuf.ByteString; import io.deephaven.client.impl.script.Changes; import io.deephaven.proto.DeephavenChannel; -import io.deephaven.proto.DeephavenChannelMixin; import io.deephaven.proto.backplane.grpc.AddTableRequest; import io.deephaven.proto.backplane.grpc.AddTableResponse; -import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceStub; import io.deephaven.proto.backplane.grpc.CloseSessionResponse; +import io.deephaven.proto.backplane.grpc.ConfigValue; +import io.deephaven.proto.backplane.grpc.ConfigurationConstantsRequest; +import io.deephaven.proto.backplane.grpc.ConfigurationConstantsResponse; import io.deephaven.proto.backplane.grpc.DeleteTableRequest; import io.deephaven.proto.backplane.grpc.DeleteTableResponse; import io.deephaven.proto.backplane.grpc.FetchObjectRequest; import io.deephaven.proto.backplane.grpc.FetchObjectResponse; import io.deephaven.proto.backplane.grpc.FieldsChangeUpdate; import io.deephaven.proto.backplane.grpc.HandshakeRequest; -import io.deephaven.proto.backplane.grpc.HandshakeResponse; -import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceStub; import io.deephaven.proto.backplane.grpc.ListFieldsRequest; -import io.deephaven.proto.backplane.grpc.ObjectServiceGrpc.ObjectServiceStub; import io.deephaven.proto.backplane.grpc.ReleaseRequest; import io.deephaven.proto.backplane.grpc.ReleaseResponse; -import io.deephaven.proto.backplane.grpc.SessionServiceGrpc.SessionServiceStub; import io.deephaven.proto.backplane.grpc.Ticket; import io.deephaven.proto.backplane.grpc.TypedTicket; import io.deephaven.proto.backplane.script.grpc.BindTableToVariableRequest; import io.deephaven.proto.backplane.script.grpc.BindTableToVariableResponse; -import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc.ConsoleServiceStub; import io.deephaven.proto.backplane.script.grpc.ExecuteCommandRequest; import io.deephaven.proto.backplane.script.grpc.ExecuteCommandResponse; import io.deephaven.proto.backplane.script.grpc.StartConsoleRequest; import io.deephaven.proto.backplane.script.grpc.StartConsoleResponse; import io.deephaven.proto.util.ExportTicketHelper; -import io.grpc.CallCredentials; -import io.grpc.Metadata; -import io.grpc.Metadata.Key; -import io.grpc.stub.AbstractStub; import io.grpc.stub.ClientCallStreamObserver; import io.grpc.stub.ClientResponseObserver; import io.grpc.stub.StreamObserver; @@ -49,12 +41,14 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; +import java.time.Duration; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -69,156 +63,78 @@ public final class SessionImpl extends SessionBase { private static final Logger log = LoggerFactory.getLogger(SessionImpl.class); - private static final int REFRESH_RETRIES = 5; - - public interface Handler { - void onRefreshSuccess(); - - void onRefreshTokenError(Throwable t, Runnable invokeForRetry); - - void onCloseSessionError(Throwable t); - - void onClosed(); + public static SessionImpl create(SessionImplConfig config) throws InterruptedException { + final Authentication authentication = + Authentication.authenticate(config.channel(), config.authenticationTypeAndValue()); + authentication.awaitOrCancel(); + return create(config, authentication); } - private static class Retrying implements Handler { - private static final Logger log = LoggerFactory.getLogger(Retrying.class); - - private final int maxRefreshes; - private int remainingRefreshes; - - Retrying(int maxRefreshes) { - this.maxRefreshes = maxRefreshes; - } - - @Override - public void onRefreshSuccess() { - remainingRefreshes = maxRefreshes; - } - - @Override - public void onRefreshTokenError(Throwable t, Runnable invokeForRetry) { - if (remainingRefreshes > 0) { - remainingRefreshes--; - log.warn("Error refreshing token, trying again", t); - invokeForRetry.run(); - return; - } - log.error("Error refreshing token, giving up", t); - } - - @Override - public void onCloseSessionError(Throwable t) { - log.error("onCloseSessionError", t); - } - - @Override - public void onClosed() { - + public static SessionImpl create(SessionImplConfig config, Authentication authentication) { + authentication.throwOnError(); + final DeephavenChannel bearerChannel = authentication.bearerChannel().orElseThrow(IllegalStateException::new); + final ConfigurationConstantsResponse response = + authentication.configurationConstants().orElseThrow(IllegalStateException::new); + final Optional httpSessionDuration = parseHttpSessionDuration(response); + if (!httpSessionDuration.isPresent()) { + log.warn( + "Server did not return an 'http.session.durationMs', defaulting to pinging the server every minute."); } + final Duration pingFrequency = httpSessionDuration.map(d -> d.dividedBy(3)).orElse(Duration.ofMinutes(1)); + return new SessionImpl(config, bearerChannel, pingFrequency, authentication.bearerHandler()); } - public static SessionImpl create(SessionImplConfig config) { - final HandshakeRequest request = initialHandshake(); - final HandshakeResponse response = config.channel().sessionBlocking().newSession(request); - final AuthenticationInfo initialAuth = AuthenticationInfo.of(response); - final SessionImpl session = - new SessionImpl(config, new Retrying(REFRESH_RETRIES), initialAuth); - - - session.scheduleRefreshSessionToken(response); - return session; + private static Optional parseHttpSessionDuration(ConfigurationConstantsResponse response) { + return getHttpSessionDurationMs(response).map(SessionImpl::stringValue).flatMap(SessionImpl::parseMillis); } - public static CompletableFuture createFuture(SessionImplConfig config) { - final HandshakeRequest request = initialHandshake(); - final SessionObserver sessionObserver = new SessionObserver(config); - config.channel().session().newSession(request, sessionObserver); - return sessionObserver.future; + private static String stringValue(ConfigValue value) { + if (!value.hasStringValue()) { + throw new IllegalArgumentException("Expected string value"); + } + return value.getStringValue(); } - private static HandshakeRequest initialHandshake() { - return HandshakeRequest.newBuilder().setAuthProtocol(1).build(); + private static Optional getHttpSessionDurationMs(ConfigurationConstantsResponse response) { + return Optional.ofNullable(response.getConfigValuesMap().get("http.session.durationMs")); } - private static class SessionObserver - implements ClientResponseObserver { - - private final SessionImplConfig config; - private final CompletableFuture future = new CompletableFuture<>(); - - SessionObserver(SessionImplConfig config) { - this.config = Objects.requireNonNull(config); - } - - @Override - public void beforeStart(ClientCallStreamObserver requestStream) { - future.whenComplete((session, throwable) -> { - if (future.isCancelled()) { - requestStream.cancel("User cancelled", null); - } - }); - } - - @Override - public void onNext(HandshakeResponse response) { - AuthenticationInfo initialAuth = AuthenticationInfo.of(response); - SessionImpl session = - new SessionImpl(config, new Retrying(REFRESH_RETRIES), initialAuth); - if (future.complete(session)) { - session.scheduleRefreshSessionToken(response); - } else { - // Make sure we don't leak a session if we aren't able to pass it off to the user - session.close(); - } - } - - @Override - public void onError(Throwable t) { - future.completeExceptionally(t); - } - - @Override - public void onCompleted() { - if (!future.isDone()) { - future.completeExceptionally( - new IllegalStateException("Observer completed without response")); - } + private static Optional parseMillis(String x) { + try { + return Optional.of(Duration.ofMillis(Long.parseLong(x))); + } catch (NumberFormatException e) { + return Optional.empty(); } } private final SessionImplConfig config; - private final SessionServiceStub sessionService; - private final ConsoleServiceStub consoleService; - private final ObjectServiceStub objectService; - private final InputTableServiceStub inputTableService; - private final ApplicationServiceStub applicationServiceStub; - private final Handler handler; + private final DeephavenChannel bearerChannel; + // Needed for downstream flight workarounds + private final BearerHandler bearerHandler; private final ExportTicketCreator exportTicketCreator; private final ExportStates states; - private volatile AuthenticationInfo auth; private final TableHandleManagerSerial serialManager; private final TableHandleManagerBatch batchManager; + private final ScheduledFuture pingJob; - private SessionImpl(SessionImplConfig config, Handler handler, AuthenticationInfo auth) { - SessionCallCredentials credentials = new SessionCallCredentials(); - this.auth = Objects.requireNonNull(auth); - this.handler = Objects.requireNonNull(handler); + private SessionImpl(SessionImplConfig config, DeephavenChannel bearerChannel, Duration pingFrequency, + BearerHandler bearerHandler) { this.config = Objects.requireNonNull(config); - this.sessionService = config.channel().session().withCallCredentials(credentials); - this.consoleService = config.channel().console().withCallCredentials(credentials); - this.objectService = config.channel().object().withCallCredentials(credentials); - this.inputTableService = config.channel().inputTable().withCallCredentials(credentials); - this.applicationServiceStub = config.channel().application().withCallCredentials(credentials); + this.bearerChannel = Objects.requireNonNull(bearerChannel); + this.bearerHandler = Objects.requireNonNull(bearerHandler); this.exportTicketCreator = new ExportTicketCreator(); - this.states = new ExportStates(this, sessionService, config.channel().table().withCallCredentials(credentials), - exportTicketCreator); + this.states = new ExportStates(this, bearerChannel.session(), bearerChannel.table(), exportTicketCreator); this.serialManager = TableHandleManagerSerial.of(this); this.batchManager = TableHandleManagerBatch.of(this, config.mixinStacktrace()); + this.pingJob = config.executor().scheduleAtFixedRate( + () -> bearerChannel.config().getConfigurationConstants( + ConfigurationConstantsRequest.getDefaultInstance(), PingObserverNoOp.INSTANCE), + pingFrequency.toNanos(), pingFrequency.toNanos(), TimeUnit.NANOSECONDS); } - public AuthenticationInfo auth() { - return auth; + // exposed for Flight + BearerHandler _hackBearerHandler() { + return bearerHandler; } @Override @@ -232,7 +148,7 @@ public CompletableFuture console(String type) { final StartConsoleRequest request = StartConsoleRequest.newBuilder().setSessionType(type) .setResultId(consoleId.ticketId().ticket()).build(); final ConsoleHandler handler = new ConsoleHandler(request); - consoleService.startConsole(request, handler); + bearerChannel.console().startConsole(request, handler); return handler.future(); } @@ -242,7 +158,7 @@ public CompletableFuture publish(String name, HasTicketId ticketId) { throw new IllegalArgumentException("Invalid name"); } PublishObserver observer = new PublishObserver(); - consoleService.bindTableToVariable(BindTableToVariableRequest.newBuilder() + bearerChannel.console().bindTableToVariable(BindTableToVariableRequest.newBuilder() .setVariableName(name).setTableId(ticketId.ticketId().ticket()).build(), observer); return observer.future; } @@ -256,7 +172,7 @@ public CompletableFuture fetchObject(String type, HasTicketId tic .build()) .build(); final FetchObserver observer = new FetchObserver(); - objectService.fetchObject(request, observer); + bearerChannel.object().fetchObject(request, observer); return observer.future; } @@ -276,10 +192,10 @@ public void close() { @Override public CompletableFuture closeFuture() { - HandshakeRequest handshakeRequest = HandshakeRequest.newBuilder().setAuthProtocol(0) - .setPayload(ByteString.copyFromUtf8(auth.session())).build(); + pingJob.cancel(false); + HandshakeRequest handshakeRequest = HandshakeRequest.getDefaultInstance(); CloseSessionHandler handler = new CloseSessionHandler(); - sessionService.closeSession(handshakeRequest, handler); + bearerChannel.session().closeSession(handshakeRequest, handler); return handler.future; } @@ -314,14 +230,14 @@ public ExportId newExportId() { @Override public CompletableFuture release(ExportId exportId) { final ReleaseTicketObserver observer = new ReleaseTicketObserver(); - sessionService.release( + bearerChannel.session().release( ReleaseRequest.newBuilder().setId(exportId.ticketId().ticket()).build(), observer); return observer.future; } @Override public DeephavenChannel channel() { - return new DeephavenChannelWithCredentials(); + return bearerChannel; } @Override @@ -331,7 +247,7 @@ public CompletableFuture addToInputTable(HasTicketId destination, HasTicke .setTableToAdd(source.ticketId().ticket()) .build(); final AddToInputTableObserver observer = new AddToInputTableObserver(); - inputTableService.addTableToInputTable(request, observer); + bearerChannel.inputTable().addTableToInputTable(request, observer); return observer.future; } @@ -342,7 +258,7 @@ public CompletableFuture deleteFromInputTable(HasTicketId destination, Has .setTableToRemove(source.ticketId().ticket()) .build(); final DeleteFromInputTableObserver observer = new DeleteFromInputTableObserver(); - inputTableService.deleteTableFromInputTable(request, observer); + bearerChannel.inputTable().deleteTableFromInputTable(request, observer); return observer.future; } @@ -350,7 +266,7 @@ public CompletableFuture deleteFromInputTable(HasTicketId destination, Has public Cancel subscribeToFields(Listener listener) { final ListFieldsRequest request = ListFieldsRequest.newBuilder().build(); final ListFieldsObserver observer = new ListFieldsObserver(listener); - applicationServiceStub.listFields(request, observer); + bearerChannel.application().listFields(request, observer); return observer; } @@ -366,26 +282,6 @@ public long releaseCount() { return states.releaseCount(); } - private void scheduleRefreshSessionToken(HandshakeResponse response) { - final long now = System.currentTimeMillis(); - final long targetRefreshTime = Math.min( - now + response.getTokenExpirationDelayMillis() / 3, - response.getTokenDeadlineTimeMillis() - response.getTokenExpirationDelayMillis() / 10); - final long refreshDelayMs = Math.max(targetRefreshTime - now, 0); - config.executor().schedule(SessionImpl.this::refreshSessionToken, refreshDelayMs, TimeUnit.MILLISECONDS); - } - - private void scheduleRefreshSessionTokenNow() { - config.executor().schedule(SessionImpl.this::refreshSessionToken, 0, TimeUnit.MILLISECONDS); - } - - private void refreshSessionToken() { - HandshakeRequest handshakeRequest = HandshakeRequest.newBuilder().setAuthProtocol(0) - .setPayload(ByteString.copyFromUtf8(auth.session())).build(); - HandshakeHandler handler = new HandshakeHandler(); - sessionService.refreshSessionToken(handshakeRequest, handler); - } - private static class PublishObserver implements ClientResponseObserver { private final CompletableFuture future = new CompletableFuture<>(); @@ -467,56 +363,17 @@ public void onCompleted() { } } - private class SessionCallCredentials extends CallCredentials { - - @Override - public void applyRequestMetadata(RequestInfo requestInfo, Executor appExecutor, - MetadataApplier applier) { - AuthenticationInfo localAuth = auth; - Metadata metadata = new Metadata(); - metadata.put(Key.of(localAuth.sessionHeaderKey(), Metadata.ASCII_STRING_MARSHALLER), - localAuth.session()); - applier.apply(metadata); - } - - @Override - public void thisUsesUnstableApi() { - - } - } - - private class HandshakeHandler implements StreamObserver { - - @Override - public void onNext(HandshakeResponse value) { - auth = AuthenticationInfo.of(value); - scheduleRefreshSessionToken(value); - handler.onRefreshSuccess(); - } - - @Override - public void onError(Throwable t) { - handler.onRefreshTokenError(t, SessionImpl.this::scheduleRefreshSessionTokenNow); - } - - @Override - public void onCompleted() { - // ignore - } - } - - private class CloseSessionHandler implements StreamObserver { + private static class CloseSessionHandler implements StreamObserver { private final CompletableFuture future = new CompletableFuture<>(); @Override public void onNext(CloseSessionResponse value) { - handler.onClosed(); + } @Override public void onError(Throwable t) { - handler.onCloseSessionError(t); future.completeExceptionally(t); } @@ -623,7 +480,7 @@ public CompletableFuture executeCodeFuture(String code) { final ExecuteCommandRequest request = ExecuteCommandRequest.newBuilder().setConsoleId(ticket()).setCode(code).build(); final ExecuteCommandHandler handler = new ExecuteCommandHandler(); - consoleService.executeCommand(request, handler); + bearerChannel.console().executeCommand(request, handler); return handler.future; } @@ -636,7 +493,7 @@ public CompletableFuture executeScriptFuture(Path path) throws IOExcept @Override public CompletableFuture closeFuture() { final ConsoleCloseHandler handler = new ConsoleCloseHandler(); - sessionService.release(ReleaseRequest.newBuilder().setId(request.getResultId()).build(), handler); + bearerChannel.session().release(ReleaseRequest.newBuilder().setId(request.getResultId()).build(), handler); return handler.future(); } @@ -815,14 +672,22 @@ public void onCompleted() { } } - private final class DeephavenChannelWithCredentials extends DeephavenChannelMixin { - public DeephavenChannelWithCredentials() { - super(config.channel()); + private enum PingObserverNoOp implements StreamObserver { + INSTANCE; + + @Override + public void onNext(ConfigurationConstantsResponse value) { + } @Override - protected > S mixin(S stub) { - return stub.withCallCredentials(new SessionCallCredentials()); + public void onError(Throwable t) { + + } + + @Override + public void onCompleted() { + } } } diff --git a/java-client/session/src/main/java/io/deephaven/client/impl/SessionImplConfig.java b/java-client/session/src/main/java/io/deephaven/client/impl/SessionImplConfig.java index 3012430becd..901176a237b 100644 --- a/java-client/session/src/main/java/io/deephaven/client/impl/SessionImplConfig.java +++ b/java-client/session/src/main/java/io/deephaven/client/impl/SessionImplConfig.java @@ -9,7 +9,6 @@ import org.immutables.value.Value.Immutable; import java.time.Duration; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; @Immutable @@ -24,6 +23,11 @@ public static Builder builder() { public abstract DeephavenChannel channel(); + @Default + public String authenticationTypeAndValue() { + return "Anonymous"; + } + /** * Whether the {@link Session} implementation will implement a batch {@link TableHandleManager}. By default, is * {@code false}. The default can be overridden via the system property {@code deephaven.session.batch}. @@ -68,20 +72,18 @@ public Duration closeTimeout() { return Duration.parse(System.getProperty("deephaven.session.closeTimeout", "PT5s")); } - public final SessionImpl createSession() { + public final SessionImpl createSession() throws InterruptedException { return SessionImpl.create(this); } - public final CompletableFuture createSessionFuture() { - return SessionImpl.createFuture(this); - } - public interface Builder { Builder executor(ScheduledExecutorService executor); Builder channel(DeephavenChannel channel); + Builder authenticationTypeAndValue(String authenticationTypeAndValue); + Builder delegateToBatch(boolean delegateToBatch); Builder mixinStacktrace(boolean mixinStacktrace); diff --git a/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannel.java b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannel.java index a42712f6c83..5f1454d1632 100644 --- a/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannel.java +++ b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannel.java @@ -6,6 +6,9 @@ import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceBlockingStub; import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceFutureStub; import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceBlockingStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceFutureStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceStub; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceBlockingStub; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceFutureStub; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceStub; @@ -21,8 +24,21 @@ import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc.ConsoleServiceBlockingStub; import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc.ConsoleServiceFutureStub; import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc.ConsoleServiceStub; +import io.grpc.CallCredentials; +import io.grpc.Channel; +import io.grpc.ClientInterceptor; public interface DeephavenChannel { + static DeephavenChannel withCallCredentials(DeephavenChannel channel, CallCredentials callCredentials) { + return new DeephavenChannelWithCallCredentials(channel, callCredentials); + } + + static DeephavenChannel withClientInterceptors(DeephavenChannel channel, ClientInterceptor... clientInterceptors) { + return new DeephavenChannelWithClientInterceptors(channel, clientInterceptors); + } + + Channel channel(); + SessionServiceStub session(); TableServiceStub table(); @@ -35,6 +51,8 @@ public interface DeephavenChannel { InputTableServiceStub inputTable(); + ConfigServiceStub config(); + SessionServiceBlockingStub sessionBlocking(); TableServiceBlockingStub tableBlocking(); @@ -47,6 +65,8 @@ public interface DeephavenChannel { InputTableServiceBlockingStub inputTableBlocking(); + ConfigServiceBlockingStub configBlocking(); + SessionServiceFutureStub sessionFuture(); TableServiceFutureStub tableFuture(); @@ -58,4 +78,6 @@ public interface DeephavenChannel { ApplicationServiceFutureStub applicationFuture(); InputTableServiceFutureStub inputTableFuture(); + + ConfigServiceFutureStub configFuture(); } diff --git a/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelImpl.java b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelImpl.java index 0d07d1437e4..f88c29790cc 100644 --- a/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelImpl.java +++ b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelImpl.java @@ -7,6 +7,10 @@ import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceBlockingStub; import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceFutureStub; import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceBlockingStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceFutureStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceStub; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceBlockingStub; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceFutureStub; @@ -43,6 +47,7 @@ public DeephavenChannelImpl(Channel channel) { this.channel = Objects.requireNonNull(channel); } + @Override public Channel channel() { return channel; } @@ -77,6 +82,11 @@ public InputTableServiceStub inputTable() { return InputTableServiceGrpc.newStub(channel); } + @Override + public ConfigServiceStub config() { + return ConfigServiceGrpc.newStub(channel); + } + @Override public SessionServiceBlockingStub sessionBlocking() { return SessionServiceGrpc.newBlockingStub(channel); @@ -107,6 +117,11 @@ public InputTableServiceBlockingStub inputTableBlocking() { return InputTableServiceGrpc.newBlockingStub(channel); } + @Override + public ConfigServiceBlockingStub configBlocking() { + return ConfigServiceGrpc.newBlockingStub(channel); + } + @Override public SessionServiceFutureStub sessionFuture() { return SessionServiceGrpc.newFutureStub(channel); @@ -136,4 +151,9 @@ public ApplicationServiceFutureStub applicationFuture() { public InputTableServiceFutureStub inputTableFuture() { return InputTableServiceGrpc.newFutureStub(channel); } + + @Override + public ConfigServiceFutureStub configFuture() { + return ConfigServiceGrpc.newFutureStub(channel); + } } diff --git a/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelMixin.java b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelMixin.java index f09e6377b34..f0dfda4422a 100644 --- a/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelMixin.java +++ b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelMixin.java @@ -6,6 +6,9 @@ import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceBlockingStub; import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceFutureStub; import io.deephaven.proto.backplane.grpc.ApplicationServiceGrpc.ApplicationServiceStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceBlockingStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceFutureStub; +import io.deephaven.proto.backplane.grpc.ConfigServiceGrpc.ConfigServiceStub; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceBlockingStub; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceFutureStub; import io.deephaven.proto.backplane.grpc.InputTableServiceGrpc.InputTableServiceStub; @@ -21,6 +24,7 @@ import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc.ConsoleServiceBlockingStub; import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc.ConsoleServiceFutureStub; import io.deephaven.proto.backplane.script.grpc.ConsoleServiceGrpc.ConsoleServiceStub; +import io.grpc.Channel; import io.grpc.stub.AbstractStub; import java.util.Objects; @@ -34,6 +38,13 @@ public DeephavenChannelMixin(DeephavenChannel delegate) { protected abstract > S mixin(S stub); + protected abstract Channel mixinChannel(Channel channel); + + @Override + public Channel channel() { + return mixinChannel(delegate.channel()); + } + @Override public final SessionServiceStub session() { return mixin(delegate.session()); @@ -64,6 +75,11 @@ public final InputTableServiceStub inputTable() { return mixin(delegate.inputTable()); } + @Override + public ConfigServiceStub config() { + return mixin(delegate.config()); + } + @Override public final SessionServiceBlockingStub sessionBlocking() { return mixin(delegate.sessionBlocking()); @@ -94,6 +110,11 @@ public final InputTableServiceBlockingStub inputTableBlocking() { return mixin(delegate.inputTableBlocking()); } + @Override + public ConfigServiceBlockingStub configBlocking() { + return mixin(delegate.configBlocking()); + } + @Override public final SessionServiceFutureStub sessionFuture() { return mixin(delegate.sessionFuture()); @@ -123,4 +144,9 @@ public final ApplicationServiceFutureStub applicationFuture() { public final InputTableServiceFutureStub inputTableFuture() { return mixin(delegate.inputTableFuture()); } + + @Override + public ConfigServiceFutureStub configFuture() { + return mixin(delegate.configFuture()); + } } diff --git a/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelWithCallCredentials.java b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelWithCallCredentials.java new file mode 100644 index 00000000000..ec57c31b563 --- /dev/null +++ b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelWithCallCredentials.java @@ -0,0 +1,48 @@ +package io.deephaven.proto; + +import io.grpc.CallCredentials; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.MethodDescriptor; +import io.grpc.stub.AbstractStub; + +import java.util.Objects; + +final class DeephavenChannelWithCallCredentials extends DeephavenChannelMixin { + private final CallCredentials callCredentials; + + public DeephavenChannelWithCallCredentials(DeephavenChannel delegate, CallCredentials callCredentials) { + super(delegate); + this.callCredentials = Objects.requireNonNull(callCredentials); + } + + @Override + protected > S mixin(S stub) { + return stub.withCallCredentials(callCredentials); + } + + @Override + protected Channel mixinChannel(Channel channel) { + return new CallCredentialsChannel(channel); + } + + private final class CallCredentialsChannel extends Channel { + private final Channel delegate; + + public CallCredentialsChannel(Channel delegate) { + this.delegate = Objects.requireNonNull(delegate); + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + return delegate.newCall(methodDescriptor, callOptions.withCallCredentials(callCredentials)); + } + + @Override + public String authority() { + return delegate.authority(); + } + } +} diff --git a/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelWithClientInterceptors.java b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelWithClientInterceptors.java new file mode 100644 index 00000000000..615828b5aab --- /dev/null +++ b/proto/proto-backplane-grpc/src/main/java/io/deephaven/proto/DeephavenChannelWithClientInterceptors.java @@ -0,0 +1,27 @@ +package io.deephaven.proto; + +import io.grpc.Channel; +import io.grpc.ClientInterceptor; +import io.grpc.ClientInterceptors; +import io.grpc.stub.AbstractStub; + +import java.util.Objects; + +final class DeephavenChannelWithClientInterceptors extends DeephavenChannelMixin { + private final ClientInterceptor[] clientInterceptors; + + public DeephavenChannelWithClientInterceptors(DeephavenChannel delegate, ClientInterceptor... clientInterceptors) { + super(delegate); + this.clientInterceptors = Objects.requireNonNull(clientInterceptors); + } + + @Override + protected > S mixin(S stub) { + return stub.withInterceptors(clientInterceptors); + } + + @Override + protected Channel mixinChannel(Channel channel) { + return ClientInterceptors.intercept(channel, clientInterceptors); + } +} diff --git a/server/jetty/src/test/java/io/deephaven/server/jetty/JettyFlightRoundTripTest.java b/server/jetty/src/test/java/io/deephaven/server/jetty/JettyFlightRoundTripTest.java index 96d64bb1b83..0ca04297cfd 100644 --- a/server/jetty/src/test/java/io/deephaven/server/jetty/JettyFlightRoundTripTest.java +++ b/server/jetty/src/test/java/io/deephaven/server/jetty/JettyFlightRoundTripTest.java @@ -7,6 +7,7 @@ import dagger.Module; import dagger.Provides; import io.deephaven.server.arrow.ArrowModule; +import io.deephaven.server.config.ConfigServiceModule; import io.deephaven.server.console.ConsoleModule; import io.deephaven.server.log.LogModule; import io.deephaven.server.runner.ExecutionContextUnitTestModule; @@ -35,6 +36,7 @@ static JettyConfig providesJettyConfig() { @Singleton @Component(modules = { ArrowModule.class, + ConfigServiceModule.class, ConsoleModule.class, ExecutionContextUnitTestModule.class, FlightTestModule.class, diff --git a/server/netty/src/test/java/io/deephaven/server/netty/NettyFlightRoundTripTest.java b/server/netty/src/test/java/io/deephaven/server/netty/NettyFlightRoundTripTest.java index 69534875ff9..e7a0c94edae 100644 --- a/server/netty/src/test/java/io/deephaven/server/netty/NettyFlightRoundTripTest.java +++ b/server/netty/src/test/java/io/deephaven/server/netty/NettyFlightRoundTripTest.java @@ -7,6 +7,7 @@ import dagger.Module; import dagger.Provides; import io.deephaven.server.arrow.ArrowModule; +import io.deephaven.server.config.ConfigServiceModule; import io.deephaven.server.console.ConsoleModule; import io.deephaven.server.log.LogModule; import io.deephaven.server.runner.ExecutionContextUnitTestModule; @@ -35,6 +36,7 @@ static NettyConfig providesNettyConfig() { @Singleton @Component(modules = { ArrowModule.class, + ConfigServiceModule.class, ConsoleModule.class, ExecutionContextUnitTestModule.class, FlightTestModule.class, diff --git a/server/src/main/java/io/deephaven/server/session/SessionState.java b/server/src/main/java/io/deephaven/server/session/SessionState.java index 00c1771beae..ae4198ac999 100644 --- a/server/src/main/java/io/deephaven/server/session/SessionState.java +++ b/server/src/main/java/io/deephaven/server/session/SessionState.java @@ -204,10 +204,8 @@ protected void updateExpiration(@NotNull final SessionService.TokenExpiration ex throw GrpcUtil.statusRuntimeException(Code.UNAUTHENTICATED, "session has expired"); } - log.info().append(logPrefix) - .append("token rotating to '").append(expiration.token.toString()) - .append("' which expires at ").append(MILLIS_FROM_EPOCH_FORMATTER, expiration.deadlineMillis) - .append(".").endl(); + log.info().append(logPrefix).append("token, expires at ") + .append(MILLIS_FROM_EPOCH_FORMATTER, expiration.deadlineMillis).append(".").endl(); } /** diff --git a/server/src/main/java/io/deephaven/server/uri/BarrageTableResolver.java b/server/src/main/java/io/deephaven/server/uri/BarrageTableResolver.java index 3135b19fbf6..24114c02be2 100644 --- a/server/src/main/java/io/deephaven/server/uri/BarrageTableResolver.java +++ b/server/src/main/java/io/deephaven/server/uri/BarrageTableResolver.java @@ -308,6 +308,7 @@ private BarrageSession newSession(ClientConfig config) { } private BarrageSession newSession(ManagedChannel channel) { + // TODO(deephaven-core#3421): DH URI / BarrageTableResolver authentication support return builder .allocator(allocator) .managedChannel(channel) diff --git a/server/src/test/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java b/server/src/test/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java index 5f561c5775a..db78fc8ed6a 100644 --- a/server/src/test/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java +++ b/server/src/test/java/io/deephaven/server/runner/DeephavenApiServerSingleAuthenticatedBase.java @@ -6,19 +6,12 @@ import io.deephaven.UncheckedDeephavenException; import io.deephaven.auth.AuthenticationException; import io.deephaven.proto.DeephavenChannel; -import io.deephaven.proto.DeephavenChannelMixin; import io.deephaven.proto.backplane.grpc.HandshakeRequest; import io.deephaven.proto.backplane.grpc.HandshakeResponse; import io.deephaven.server.session.SessionState; -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ClientInterceptor; -import io.grpc.ForwardingClientCall; import io.grpc.Metadata; import io.grpc.Metadata.Key; -import io.grpc.MethodDescriptor; -import io.grpc.stub.AbstractStub; +import io.grpc.stub.MetadataUtils; import org.junit.Before; public abstract class DeephavenApiServerSingleAuthenticatedBase extends DeephavenApiServerTestBase { @@ -39,26 +32,10 @@ public void setUp() throws Exception { sessionToken = result.getSessionToken().toStringUtf8(); final String sessionHeader = result.getMetadataHeader().toStringUtf8(); final Key sessionHeaderKey = Metadata.Key.of(sessionHeader, Metadata.ASCII_STRING_MARSHALLER); - this.channel = new DeephavenChannelMixin(channel) { - @Override - protected > S mixin(S stub) { - return stub.withInterceptors(new ClientInterceptor() { - @Override - public ClientCall interceptCall( - final MethodDescriptor methodDescriptor, final CallOptions callOptions, - final Channel channel) { - return new ForwardingClientCall.SimpleForwardingClientCall<>( - channel.newCall(methodDescriptor, callOptions)) { - @Override - public void start(final Listener responseListener, final Metadata headers) { - headers.put(sessionHeaderKey, sessionToken); - super.start(responseListener, headers); - } - }; - } - }); - } - }; + final Metadata extraHeaders = new Metadata(); + extraHeaders.put(sessionHeaderKey, sessionToken); + this.channel = DeephavenChannel.withClientInterceptors(channel, + MetadataUtils.newAttachHeadersInterceptor(extraHeaders)); } public SessionState authenticatedSessionState() { diff --git a/server/test/src/main/java/io/deephaven/server/test/FlightMessageRoundTripTest.java b/server/test/src/main/java/io/deephaven/server/test/FlightMessageRoundTripTest.java index 46d2d344616..c125f140785 100644 --- a/server/test/src/main/java/io/deephaven/server/test/FlightMessageRoundTripTest.java +++ b/server/test/src/main/java/io/deephaven/server/test/FlightMessageRoundTripTest.java @@ -49,6 +49,7 @@ import io.deephaven.server.session.SessionState; import io.deephaven.server.session.TicketResolver; import io.deephaven.server.session.TicketResolverBase; +import io.deephaven.server.test.TestAuthModule.FakeBearer; import io.deephaven.server.util.Scheduler; import io.deephaven.util.SafeCloseable; import io.deephaven.auth.AuthContext; @@ -372,38 +373,11 @@ public void testLoginHeaderCustomBearer() { scriptSession.setVariable("test", TableTools.emptyTable(10).update("I=i")); // add the bearer token override - final String bearerToken = UUID.randomUUID().toString(); - component.authRequestHandlers().put(Auth2Constants.BEARER_PREFIX.trim(), new AuthenticationRequestHandler() { - @Override - public String getAuthType() { - return Auth2Constants.BEARER_PREFIX.trim(); - } - - @Override - public Optional login(long protocolVersion, ByteBuffer payload, - HandshakeResponseListener listener) { - return Optional.empty(); - } - - @Override - public Optional login(String payload, MetadataResponseListener listener) { - if (payload.equals(bearerToken)) { - return Optional.of(new AuthContext.SuperUser()); - } - return Optional.empty(); - } - - @Override - public void initialize(String targetUrl) { - // do nothing - } - }); - final MutableBoolean tokenChanged = new MutableBoolean(); flightClient = FlightClient.builder().location(serverLocation) .allocator(new RootAllocator()) .intercept(info -> new FlightClientMiddleware() { - String currToken = Auth2Constants.BEARER_PREFIX + bearerToken; + String currToken = Auth2Constants.BEARER_PREFIX + FakeBearer.TOKEN; @Override public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { @@ -442,9 +416,6 @@ public void testLoginHandshakeAnonymous() { final Ticket ticket = new Ticket("s/test".getBytes(StandardCharsets.UTF_8)); fullyReadStream(ticket, false); - // install the auth handler - component.authRequestHandlers().put(ANONYMOUS, new AnonymousRequestHandler()); - flightClient.authenticate(new ClientAuthHandler() { byte[] callToken = new byte[0]; @@ -476,9 +447,6 @@ public void testLoginHeaderAnonymous() { closeClient(); scriptSession.setVariable("test", TableTools.emptyTable(10).update("I=i")); - // install the auth handler - component.authRequestHandlers().put(ANONYMOUS, new AnonymousRequestHandler()); - final MutableBoolean tokenChanged = new MutableBoolean(); flightClient = FlightClient.builder().location(serverLocation) .allocator(new RootAllocator()) @@ -1014,34 +982,4 @@ private void assertRoundTripDataEqual(Table deephavenTable) throws Exception { assertEquals(0, (long) TableTools .diffPair(deephavenTable, uploadedTable, 0, EnumSet.noneOf(TableDiff.DiffItems.class)).getSecond()); } - - private static class AnonymousRequestHandler implements AuthenticationRequestHandler { - @Override - public String getAuthType() { - return ANONYMOUS; - } - - @Override - public Optional login(long protocolVersion, ByteBuffer payload, - AuthenticationRequestHandler.HandshakeResponseListener listener) { - if (!payload.hasRemaining()) { - return Optional.of(new AuthContext.Anonymous()); - } - return Optional.empty(); - } - - @Override - public Optional login(String payload, - AuthenticationRequestHandler.MetadataResponseListener listener) { - if (payload.isEmpty()) { - return Optional.of(new AuthContext.Anonymous()); - } - return Optional.empty(); - } - - @Override - public void initialize(String targetUrl) { - // do nothing - } - } } diff --git a/server/test/src/main/java/io/deephaven/server/test/TestAuthModule.java b/server/test/src/main/java/io/deephaven/server/test/TestAuthModule.java index 7d427705f4f..4c51d836359 100644 --- a/server/test/src/main/java/io/deephaven/server/test/TestAuthModule.java +++ b/server/test/src/main/java/io/deephaven/server/test/TestAuthModule.java @@ -5,15 +5,20 @@ import dagger.Module; import dagger.Provides; +import io.deephaven.auth.AnonymousAuthenticationHandler; import io.deephaven.auth.AuthenticationRequestHandler; import io.deephaven.auth.BasicAuthMarshaller; import io.deephaven.auth.AuthContext; +import org.apache.arrow.flight.auth2.Auth2Constants; import javax.inject.Singleton; +import java.nio.ByteBuffer; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.UUID; @Module public class TestAuthModule { @@ -45,9 +50,41 @@ public BasicAuthMarshaller bindBasicAuth(BasicAuthTestImpl handler) { @Provides @Singleton public Map bindAuthHandlerMap(BasicAuthMarshaller basicAuthMarshaller) { - // note this is mutable - HashMap map = new HashMap<>(); + Map map = new HashMap<>(); map.put(basicAuthMarshaller.getAuthType(), basicAuthMarshaller); - return map; + AnonymousAuthenticationHandler anonymous = new AnonymousAuthenticationHandler(); + map.put(anonymous.getAuthType(), anonymous); + map.put(FakeBearer.INSTANCE.getAuthType(), FakeBearer.INSTANCE); + return Collections.unmodifiableMap(map); + } + + public enum FakeBearer implements AuthenticationRequestHandler { + INSTANCE; + + public static final String TOKEN = UUID.randomUUID().toString(); + + @Override + public String getAuthType() { + return Auth2Constants.BEARER_PREFIX.trim(); + } + + @Override + public Optional login(long protocolVersion, ByteBuffer payload, + HandshakeResponseListener listener) { + return Optional.empty(); + } + + @Override + public Optional login(String payload, MetadataResponseListener listener) { + if (payload.equals(TOKEN)) { + return Optional.of(new AuthContext.SuperUser()); + } + return Optional.empty(); + } + + @Override + public void initialize(String targetUrl) { + // do nothing + } } }