diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index 762b37859b948..0b614cf3013f2 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -98,7 +98,9 @@ public class FlightClient implements AutoCloseable { this.middleware = middleware; final ClientInterceptor[] interceptors; - interceptors = new ClientInterceptor[]{authInterceptor, new ClientInterceptorAdapter(middleware)}; + interceptors = new ClientInterceptor[]{authInterceptor, new RefreshClientInterceptor(), new ClientInterceptorAdapter(middleware)}; + +// interceptors = new ClientInterceptor[]{authInterceptor, new ClientInterceptorAdapter(middleware)}; // Create a channel with interceptors pre-applied for DoGet and DoPut this.interceptedChannel = ClientInterceptors.intercept(channel, interceptors); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RefreshClientInterceptor.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RefreshClientInterceptor.java new file mode 100644 index 0000000000000..5bb777d6725a1 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RefreshClientInterceptor.java @@ -0,0 +1,130 @@ +package org.apache.arrow.flight; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; + +import javax.annotation.Nullable; + +public class RefreshClientInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall(MethodDescriptor method, + io.grpc.CallOptions callOptions, Channel next) { + System.out.println("Intercept the call"); + return new RetryClientCall<>(callOptions, next, method); + } + + static class RetryClientCall extends ClientCall { + + Listener listener; + Metadata metadata; + CallOptions callOptions; + Channel next; + int req; + ReqT msg; + ClientCall call; + MethodDescriptor method; + + public RetryClientCall(CallOptions callOptions, Channel next, MethodDescriptor method) { + this.callOptions = callOptions; + this.next = next; + this.method = method; + } + + @Override + public void start(Listener responseListener, Metadata headers) { + this.listener = responseListener; + this.metadata = headers; + + startCall(new CheckingListener()); + + + System.out.println("Run start method from interceptor"); + } + + + + @Override + public void request(int numMessages) { + System.out.println("Run request method from interceptor"); + req += numMessages; + call.request(numMessages); + + } + + @Override + public void cancel(@Nullable String message, @Nullable Throwable cause) { + System.out.println("Run cancel method from interceptor"); + } + + @Override + public void halfClose() { + System.out.println("Run halfClose method from interceptor"); + call.halfClose(); + } + + private void startCall(Listener listener) { + System.out.println("Run startCall method from interceptor"); + + call = next.newCall(method, callOptions); + Metadata headers = new Metadata(); + headers.merge(metadata); + call.start(listener, headers); + } + + @Override + public void sendMessage(ReqT message) { + assert this.msg == null; + this.msg = message; + call.sendMessage(msg); + } + + class CheckingListener extends ForwardingClientCallListener { + Listener delegate; + + @Override + protected Listener delegate() { + if (delegate == null) { + throw new IllegalStateException(); + } + return delegate; + } + + @Override + public void onReady() { + listener.onReady(); + } + + @Override + public void onHeaders(Metadata headers) { + delegate = listener; + super.onHeaders(headers); + } + + @Override + public void onClose(Status status, Metadata trailers) { + System.out.println("Run close method from listener interceptor"); + if (delegate != null) { + super.onClose(status, trailers); + return; + } + if (!needToRetry(status, trailers)) { // YOUR CODE HERE + delegate = listener; + super.onClose(status, trailers); + return; + } + start(listener, trailers); // to allow multiple retries + } + } + + private boolean needToRetry(Status status, Metadata trailers) { + return status.getCode().toStatus() == Status.UNAUTHENTICATED; + } + } +}