diff --git a/gax-java/gax/clirr-ignored-differences.xml b/gax-java/gax/clirr-ignored-differences.xml index cab9fe4f8a..cbaed47eb5 100644 --- a/gax-java/gax/clirr-ignored-differences.xml +++ b/gax-java/gax/clirr-ignored-differences.xml @@ -105,4 +105,10 @@ com/google/api/gax/tracing/MetricsTracer * + + + 7012 + com/google/api/gax/batching/Batcher + * + diff --git a/gax-java/gax/src/main/java/com/google/api/gax/batching/Batcher.java b/gax-java/gax/src/main/java/com/google/api/gax/batching/Batcher.java index 1e069d53e0..6f9f878905 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/batching/Batcher.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/batching/Batcher.java @@ -32,6 +32,9 @@ import com.google.api.core.ApiFuture; import com.google.api.core.InternalExtensionOnly; import com.google.api.gax.rpc.ApiCallContext; +import java.time.Duration; +import java.util.concurrent.TimeoutException; +import javax.annotation.Nullable; /** * Represents a batching context where individual elements will be accumulated and flushed in a @@ -77,13 +80,25 @@ public interface Batcher extends AutoCloseable { */ void sendOutstanding(); + /** Cancels all outstanding batch RPCs. */ + void cancelOutstanding(); + /** - * Closes this Batcher by preventing new elements from being added, and then flushing the existing - * elements. + * Closes this Batcher by preventing new elements from being added, then flushing the existing + * elements and waiting for all the outstanding work to be resolved. */ @Override void close() throws InterruptedException; + /** + * Closes this Batcher by preventing new elements from being added, then flushing the existing + * elements and waiting for all the outstanding work to be resolved. If all of the outstanding + * work has not been resolved, then a {@link BatchingException} will be thrown with details of the + * remaining work. The batcher will remain in a closed state and will not allow additional + * elements to be added. + */ + void close(@Nullable Duration timeout) throws InterruptedException, TimeoutException; + /** * Closes this Batcher by preventing new elements from being added, and then sending outstanding * elements. The returned future will be resolved when the last element completes diff --git a/gax-java/gax/src/main/java/com/google/api/gax/batching/BatcherImpl.java b/gax-java/gax/src/main/java/com/google/api/gax/batching/BatcherImpl.java index 8cb437a5e2..51549f70b3 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/batching/BatcherImpl.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/batching/BatcherImpl.java @@ -42,6 +42,7 @@ import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.UnaryCallable; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.util.concurrent.Futures; @@ -49,17 +50,21 @@ import java.lang.ref.ReferenceQueue; import java.lang.ref.SoftReference; import java.lang.ref.WeakReference; +import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.Optional; +import java.util.StringJoiner; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.TimeoutException; import java.util.logging.Level; import java.util.logging.Logger; +import javax.annotation.Nonnull; import javax.annotation.Nullable; /** @@ -86,7 +91,8 @@ public class BatcherImpl private final BatcherReference currentBatcherReference; private Batch currentOpenBatch; - private final AtomicInteger numOfOutstandingBatches = new AtomicInteger(0); + private final ConcurrentMap, Boolean> + outstandingBatches = new ConcurrentHashMap<>(); private final Object flushLock = new Object(); private final Object elementLock = new Object(); private final Future scheduledFuture; @@ -297,8 +303,10 @@ public void sendOutstanding() { } catch (Exception ex) { batchResponse = ApiFutures.immediateFailedFuture(ex); } + accumulatedBatch.setResponseFuture(batchResponse); + + outstandingBatches.put(accumulatedBatch, Boolean.TRUE); - numOfOutstandingBatches.incrementAndGet(); ApiFutures.addCallback( batchResponse, new ApiFutureCallback() { @@ -310,7 +318,7 @@ public void onSuccess(ResponseT response) { accumulatedBatch.resource.getByteCount()); accumulatedBatch.onBatchSuccess(response); } finally { - onBatchCompletion(); + onBatchCompletion(accumulatedBatch); } } @@ -322,18 +330,19 @@ public void onFailure(Throwable throwable) { accumulatedBatch.resource.getByteCount()); accumulatedBatch.onBatchFailure(throwable); } finally { - onBatchCompletion(); + onBatchCompletion(accumulatedBatch); } } }, directExecutor()); } - private void onBatchCompletion() { + private void onBatchCompletion(Batch batch) { boolean shouldClose = false; synchronized (flushLock) { - if (numOfOutstandingBatches.decrementAndGet() == 0) { + outstandingBatches.remove(batch); + if (outstandingBatches.isEmpty()) { flushLock.notifyAll(); shouldClose = closeFuture != null; } @@ -349,10 +358,10 @@ private void onBatchCompletion() { } private void awaitAllOutstandingBatches() throws InterruptedException { - while (numOfOutstandingBatches.get() > 0) { + while (!outstandingBatches.isEmpty()) { synchronized (flushLock) { // Check again under lock to avoid racing with onBatchCompletion - if (numOfOutstandingBatches.get() == 0) { + if (outstandingBatches.isEmpty()) { break; } flushLock.wait(); @@ -360,11 +369,32 @@ private void awaitAllOutstandingBatches() throws InterruptedException { } } + @Override + public void cancelOutstanding() { + for (Batch batch : outstandingBatches.keySet()) { + batch.cancel(); + } + } /** {@inheritDoc} */ @Override public void close() throws InterruptedException { try { - closeAsync().get(); + close(null); + } catch (TimeoutException e) { + // should never happen with a null timeout + throw new IllegalStateException( + "Unexpected timeout exception when trying to close the batcher without a timeout", e); + } + } + + @Override + public void close(@Nullable Duration timeout) throws InterruptedException, TimeoutException { + try { + if (timeout != null) { + closeAsync().get(timeout.toMillis(), TimeUnit.MILLISECONDS); + } else { + closeAsync().get(); + } } catch (ExecutionException e) { // Original stacktrace of a batching exception is not useful, so rethrow the error with // the caller stacktrace @@ -374,6 +404,17 @@ public void close() throws InterruptedException { } else { throw new IllegalStateException("unexpected error closing the batcher", e.getCause()); } + } catch (TimeoutException e) { + StringJoiner batchesStr = new StringJoiner(","); + for (Batch batch : + outstandingBatches.keySet()) { + batchesStr.add(batch.toString()); + } + String msg = "Timed out trying to close batcher after " + timeout + "."; + msg += " Batch request prototype: " + prototype + "."; + msg += " Outstanding batches: " + batchesStr; + + throw new TimeoutException(msg); } } @@ -393,7 +434,7 @@ public ApiFuture closeAsync() { // prevent admission of new elements closeFuture = SettableApiFuture.create(); // check if we can close immediately - closeImmediately = numOfOutstandingBatches.get() == 0; + closeImmediately = outstandingBatches.isEmpty(); } // Clean up accounting @@ -435,6 +476,8 @@ private static class Batch { private long totalThrottledTimeMs = 0; private BatchResource resource; + private volatile ApiFuture responseFuture; + private Batch( RequestT prototype, BatchingDescriptor descriptor, @@ -457,6 +500,17 @@ void add( totalThrottledTimeMs += throttledTimeMs; } + void setResponseFuture(@Nonnull ApiFuture responseFuture) { + Preconditions.checkNotNull(responseFuture); + this.responseFuture = responseFuture; + } + + void cancel() { + if (this.responseFuture != null) { + this.responseFuture.cancel(true); + } + } + void onBatchSuccess(ResponseT response) { try { descriptor.splitResponse(response, entries); @@ -480,6 +534,19 @@ void onBatchFailure(Throwable throwable) { boolean isEmpty() { return resource.getElementCount() == 0; } + + @Override + public String toString() { + StringJoiner elementsStr = new StringJoiner(","); + for (BatchEntry entry : entries) { + elementsStr.add( + Optional.ofNullable(entry.getElement()).map(Object::toString).orElse("null")); + } + return MoreObjects.toStringHelper(this) + .add("responseFuture", responseFuture) + .add("elements", elementsStr) + .toString(); + } } /** diff --git a/gax-java/gax/src/test/java/com/google/api/gax/batching/BatcherImplTest.java b/gax-java/gax/src/test/java/com/google/api/gax/batching/BatcherImplTest.java index 3ebcc2c5d0..7f8957a4b2 100644 --- a/gax-java/gax/src/test/java/com/google/api/gax/batching/BatcherImplTest.java +++ b/gax-java/gax/src/test/java/com/google/api/gax/batching/BatcherImplTest.java @@ -40,6 +40,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.api.core.AbstractApiFuture; import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; import com.google.api.core.SettableApiFuture; @@ -55,11 +56,13 @@ import com.google.common.base.Stopwatch; import com.google.common.collect.ImmutableList; import com.google.common.collect.Queues; +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Queue; import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; @@ -249,6 +252,107 @@ public ApiFuture> futureCall( closeFuture.get(); } + @Test + void testCloseTimeout() throws ExecutionException, InterruptedException { + final String futureToStringMsg = "some descriptive message about this future"; + MySettableApiFuture> innerFuture = new MySettableApiFuture<>(futureToStringMsg); + + UnaryCallable> unaryCallable = + new UnaryCallable>() { + @Override + public ApiFuture> futureCall( + LabeledIntList request, ApiCallContext context) { + return innerFuture; + } + }; + underTest = + new BatcherImpl<>( + SQUARER_BATCHING_DESC_V2, unaryCallable, labeledIntList, batchingSettings, EXECUTOR); + + underTest.add(1); + + Stopwatch stopwatch = Stopwatch.createStarted(); + + TimeoutException closeException = + assertThrows(TimeoutException.class, () -> underTest.close(Duration.ofMillis(10))); + + // resolve the future to allow batcher to close + innerFuture.set(ImmutableList.of(1)); + + assertThat(stopwatch.elapsed()).isAtMost(java.time.Duration.ofSeconds(1)); + System.out.println(); + assertThat(closeException) + .hasMessageThat() + .matches(".*Outstanding batches.*" + futureToStringMsg + ".*elements=1.*"); + } + + @Test + void testCloseTimeoutPreventsAdd() throws ExecutionException, InterruptedException { + final String futureToStringMsg = "some descriptive message about this future"; + MySettableApiFuture> innerFuture = new MySettableApiFuture<>(futureToStringMsg); + + UnaryCallable> unaryCallable = + new UnaryCallable>() { + @Override + public ApiFuture> futureCall( + LabeledIntList request, ApiCallContext context) { + return innerFuture; + } + }; + underTest = + new BatcherImpl<>( + SQUARER_BATCHING_DESC_V2, unaryCallable, labeledIntList, batchingSettings, EXECUTOR); + + underTest.add(1); + + try { + underTest.close(Duration.ofMillis(10)); + } catch (TimeoutException ignored) { + // ignored + } + + // Even though the close operation timed out, the batcher should be in a closed state + // and reject new additions + assertThrows(IllegalStateException.class, () -> underTest.add(2)); + + // resolve the future to allow batcher to close + innerFuture.set(ImmutableList.of(1)); + } + + @Test + void testCancelOutstanding() throws ExecutionException, InterruptedException { + SettableApiFuture> innerFuture = SettableApiFuture.create(); + + UnaryCallable> unaryCallable = + new UnaryCallable>() { + @Override + public ApiFuture> futureCall( + LabeledIntList request, ApiCallContext context) { + return innerFuture; + } + }; + underTest = + new BatcherImpl<>( + SQUARER_BATCHING_DESC_V2, unaryCallable, labeledIntList, batchingSettings, EXECUTOR); + + ApiFuture elementF = underTest.add(1); + + // Initial close will timeout + TimeoutException firstCloseException = + assertThrows(TimeoutException.class, () -> underTest.close(Duration.ofMillis(10))); + assertThat(firstCloseException).hasMessageThat().contains("Timed out"); + + underTest.cancelOutstanding(); + + BatchingException finalCloseException = + assertThrows(BatchingException.class, () -> underTest.close(Duration.ofSeconds(1))); + assertThat(finalCloseException).hasMessageThat().contains("Batching finished"); + + // element future should resolve to a cancelled future + ExecutionException elementException = assertThrows(ExecutionException.class, elementF::get); + assertThat(elementException).hasCauseThat().isInstanceOf(CancellationException.class); + } + /** Verifies exception occurred at RPC is propagated to element results */ @Test void testResultFailureAfterRPCFailure() throws Exception { @@ -1102,4 +1206,27 @@ private BatcherImpl> createDefau EXECUTOR, flowController); } + + private static class MySettableApiFuture extends AbstractApiFuture { + private final String desc; + + MySettableApiFuture(String desc) { + this.desc = desc; + } + + @Override + public boolean set(T value) { + return super.set(value); + } + + @Override + public boolean setException(Throwable throwable) { + return super.setException(throwable); + } + + @Override + public String toString() { + return desc; + } + } }