diff --git a/gax-java/gax/clirr-ignored-differences.xml b/gax-java/gax/clirr-ignored-differences.xml
index cab9fe4f8a..cbaed47eb5 100644
--- a/gax-java/gax/clirr-ignored-differences.xml
+++ b/gax-java/gax/clirr-ignored-differences.xml
@@ -105,4 +105,10 @@
com/google/api/gax/tracing/MetricsTracer
*
+
+
+ 7012
+ com/google/api/gax/batching/Batcher
+ *
+
diff --git a/gax-java/gax/src/main/java/com/google/api/gax/batching/Batcher.java b/gax-java/gax/src/main/java/com/google/api/gax/batching/Batcher.java
index 1e069d53e0..6f9f878905 100644
--- a/gax-java/gax/src/main/java/com/google/api/gax/batching/Batcher.java
+++ b/gax-java/gax/src/main/java/com/google/api/gax/batching/Batcher.java
@@ -32,6 +32,9 @@
import com.google.api.core.ApiFuture;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.gax.rpc.ApiCallContext;
+import java.time.Duration;
+import java.util.concurrent.TimeoutException;
+import javax.annotation.Nullable;
/**
* Represents a batching context where individual elements will be accumulated and flushed in a
@@ -77,13 +80,25 @@ public interface Batcher extends AutoCloseable {
*/
void sendOutstanding();
+ /** Cancels all outstanding batch RPCs. */
+ void cancelOutstanding();
+
/**
- * Closes this Batcher by preventing new elements from being added, and then flushing the existing
- * elements.
+ * Closes this Batcher by preventing new elements from being added, then flushing the existing
+ * elements and waiting for all the outstanding work to be resolved.
*/
@Override
void close() throws InterruptedException;
+ /**
+ * Closes this Batcher by preventing new elements from being added, then flushing the existing
+ * elements and waiting for all the outstanding work to be resolved. If all of the outstanding
+ * work has not been resolved, then a {@link BatchingException} will be thrown with details of the
+ * remaining work. The batcher will remain in a closed state and will not allow additional
+ * elements to be added.
+ */
+ void close(@Nullable Duration timeout) throws InterruptedException, TimeoutException;
+
/**
* Closes this Batcher by preventing new elements from being added, and then sending outstanding
* elements. The returned future will be resolved when the last element completes
diff --git a/gax-java/gax/src/main/java/com/google/api/gax/batching/BatcherImpl.java b/gax-java/gax/src/main/java/com/google/api/gax/batching/BatcherImpl.java
index 8cb437a5e2..51549f70b3 100644
--- a/gax-java/gax/src/main/java/com/google/api/gax/batching/BatcherImpl.java
+++ b/gax-java/gax/src/main/java/com/google/api/gax/batching/BatcherImpl.java
@@ -42,6 +42,7 @@
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.util.concurrent.Futures;
@@ -49,17 +50,21 @@
import java.lang.ref.ReferenceQueue;
import java.lang.ref.SoftReference;
import java.lang.ref.WeakReference;
+import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
+import java.util.Optional;
+import java.util.StringJoiner;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
+import javax.annotation.Nonnull;
import javax.annotation.Nullable;
/**
@@ -86,7 +91,8 @@ public class BatcherImpl
private final BatcherReference currentBatcherReference;
private Batch currentOpenBatch;
- private final AtomicInteger numOfOutstandingBatches = new AtomicInteger(0);
+ private final ConcurrentMap, Boolean>
+ outstandingBatches = new ConcurrentHashMap<>();
private final Object flushLock = new Object();
private final Object elementLock = new Object();
private final Future> scheduledFuture;
@@ -297,8 +303,10 @@ public void sendOutstanding() {
} catch (Exception ex) {
batchResponse = ApiFutures.immediateFailedFuture(ex);
}
+ accumulatedBatch.setResponseFuture(batchResponse);
+
+ outstandingBatches.put(accumulatedBatch, Boolean.TRUE);
- numOfOutstandingBatches.incrementAndGet();
ApiFutures.addCallback(
batchResponse,
new ApiFutureCallback() {
@@ -310,7 +318,7 @@ public void onSuccess(ResponseT response) {
accumulatedBatch.resource.getByteCount());
accumulatedBatch.onBatchSuccess(response);
} finally {
- onBatchCompletion();
+ onBatchCompletion(accumulatedBatch);
}
}
@@ -322,18 +330,19 @@ public void onFailure(Throwable throwable) {
accumulatedBatch.resource.getByteCount());
accumulatedBatch.onBatchFailure(throwable);
} finally {
- onBatchCompletion();
+ onBatchCompletion(accumulatedBatch);
}
}
},
directExecutor());
}
- private void onBatchCompletion() {
+ private void onBatchCompletion(Batch batch) {
boolean shouldClose = false;
synchronized (flushLock) {
- if (numOfOutstandingBatches.decrementAndGet() == 0) {
+ outstandingBatches.remove(batch);
+ if (outstandingBatches.isEmpty()) {
flushLock.notifyAll();
shouldClose = closeFuture != null;
}
@@ -349,10 +358,10 @@ private void onBatchCompletion() {
}
private void awaitAllOutstandingBatches() throws InterruptedException {
- while (numOfOutstandingBatches.get() > 0) {
+ while (!outstandingBatches.isEmpty()) {
synchronized (flushLock) {
// Check again under lock to avoid racing with onBatchCompletion
- if (numOfOutstandingBatches.get() == 0) {
+ if (outstandingBatches.isEmpty()) {
break;
}
flushLock.wait();
@@ -360,11 +369,32 @@ private void awaitAllOutstandingBatches() throws InterruptedException {
}
}
+ @Override
+ public void cancelOutstanding() {
+ for (Batch, ?, ?, ?> batch : outstandingBatches.keySet()) {
+ batch.cancel();
+ }
+ }
/** {@inheritDoc} */
@Override
public void close() throws InterruptedException {
try {
- closeAsync().get();
+ close(null);
+ } catch (TimeoutException e) {
+ // should never happen with a null timeout
+ throw new IllegalStateException(
+ "Unexpected timeout exception when trying to close the batcher without a timeout", e);
+ }
+ }
+
+ @Override
+ public void close(@Nullable Duration timeout) throws InterruptedException, TimeoutException {
+ try {
+ if (timeout != null) {
+ closeAsync().get(timeout.toMillis(), TimeUnit.MILLISECONDS);
+ } else {
+ closeAsync().get();
+ }
} catch (ExecutionException e) {
// Original stacktrace of a batching exception is not useful, so rethrow the error with
// the caller stacktrace
@@ -374,6 +404,17 @@ public void close() throws InterruptedException {
} else {
throw new IllegalStateException("unexpected error closing the batcher", e.getCause());
}
+ } catch (TimeoutException e) {
+ StringJoiner batchesStr = new StringJoiner(",");
+ for (Batch batch :
+ outstandingBatches.keySet()) {
+ batchesStr.add(batch.toString());
+ }
+ String msg = "Timed out trying to close batcher after " + timeout + ".";
+ msg += " Batch request prototype: " + prototype + ".";
+ msg += " Outstanding batches: " + batchesStr;
+
+ throw new TimeoutException(msg);
}
}
@@ -393,7 +434,7 @@ public ApiFuture closeAsync() {
// prevent admission of new elements
closeFuture = SettableApiFuture.create();
// check if we can close immediately
- closeImmediately = numOfOutstandingBatches.get() == 0;
+ closeImmediately = outstandingBatches.isEmpty();
}
// Clean up accounting
@@ -435,6 +476,8 @@ private static class Batch {
private long totalThrottledTimeMs = 0;
private BatchResource resource;
+ private volatile ApiFuture responseFuture;
+
private Batch(
RequestT prototype,
BatchingDescriptor descriptor,
@@ -457,6 +500,17 @@ void add(
totalThrottledTimeMs += throttledTimeMs;
}
+ void setResponseFuture(@Nonnull ApiFuture responseFuture) {
+ Preconditions.checkNotNull(responseFuture);
+ this.responseFuture = responseFuture;
+ }
+
+ void cancel() {
+ if (this.responseFuture != null) {
+ this.responseFuture.cancel(true);
+ }
+ }
+
void onBatchSuccess(ResponseT response) {
try {
descriptor.splitResponse(response, entries);
@@ -480,6 +534,19 @@ void onBatchFailure(Throwable throwable) {
boolean isEmpty() {
return resource.getElementCount() == 0;
}
+
+ @Override
+ public String toString() {
+ StringJoiner elementsStr = new StringJoiner(",");
+ for (BatchEntry entry : entries) {
+ elementsStr.add(
+ Optional.ofNullable(entry.getElement()).map(Object::toString).orElse("null"));
+ }
+ return MoreObjects.toStringHelper(this)
+ .add("responseFuture", responseFuture)
+ .add("elements", elementsStr)
+ .toString();
+ }
}
/**
diff --git a/gax-java/gax/src/test/java/com/google/api/gax/batching/BatcherImplTest.java b/gax-java/gax/src/test/java/com/google/api/gax/batching/BatcherImplTest.java
index 3ebcc2c5d0..7f8957a4b2 100644
--- a/gax-java/gax/src/test/java/com/google/api/gax/batching/BatcherImplTest.java
+++ b/gax-java/gax/src/test/java/com/google/api/gax/batching/BatcherImplTest.java
@@ -40,6 +40,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+import com.google.api.core.AbstractApiFuture;
import com.google.api.core.ApiFuture;
import com.google.api.core.ApiFutures;
import com.google.api.core.SettableApiFuture;
@@ -55,11 +56,13 @@
import com.google.common.base.Stopwatch;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Queues;
+import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Callable;
+import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
@@ -249,6 +252,107 @@ public ApiFuture> futureCall(
closeFuture.get();
}
+ @Test
+ void testCloseTimeout() throws ExecutionException, InterruptedException {
+ final String futureToStringMsg = "some descriptive message about this future";
+ MySettableApiFuture> innerFuture = new MySettableApiFuture<>(futureToStringMsg);
+
+ UnaryCallable> unaryCallable =
+ new UnaryCallable>() {
+ @Override
+ public ApiFuture> futureCall(
+ LabeledIntList request, ApiCallContext context) {
+ return innerFuture;
+ }
+ };
+ underTest =
+ new BatcherImpl<>(
+ SQUARER_BATCHING_DESC_V2, unaryCallable, labeledIntList, batchingSettings, EXECUTOR);
+
+ underTest.add(1);
+
+ Stopwatch stopwatch = Stopwatch.createStarted();
+
+ TimeoutException closeException =
+ assertThrows(TimeoutException.class, () -> underTest.close(Duration.ofMillis(10)));
+
+ // resolve the future to allow batcher to close
+ innerFuture.set(ImmutableList.of(1));
+
+ assertThat(stopwatch.elapsed()).isAtMost(java.time.Duration.ofSeconds(1));
+ System.out.println();
+ assertThat(closeException)
+ .hasMessageThat()
+ .matches(".*Outstanding batches.*" + futureToStringMsg + ".*elements=1.*");
+ }
+
+ @Test
+ void testCloseTimeoutPreventsAdd() throws ExecutionException, InterruptedException {
+ final String futureToStringMsg = "some descriptive message about this future";
+ MySettableApiFuture> innerFuture = new MySettableApiFuture<>(futureToStringMsg);
+
+ UnaryCallable> unaryCallable =
+ new UnaryCallable>() {
+ @Override
+ public ApiFuture> futureCall(
+ LabeledIntList request, ApiCallContext context) {
+ return innerFuture;
+ }
+ };
+ underTest =
+ new BatcherImpl<>(
+ SQUARER_BATCHING_DESC_V2, unaryCallable, labeledIntList, batchingSettings, EXECUTOR);
+
+ underTest.add(1);
+
+ try {
+ underTest.close(Duration.ofMillis(10));
+ } catch (TimeoutException ignored) {
+ // ignored
+ }
+
+ // Even though the close operation timed out, the batcher should be in a closed state
+ // and reject new additions
+ assertThrows(IllegalStateException.class, () -> underTest.add(2));
+
+ // resolve the future to allow batcher to close
+ innerFuture.set(ImmutableList.of(1));
+ }
+
+ @Test
+ void testCancelOutstanding() throws ExecutionException, InterruptedException {
+ SettableApiFuture> innerFuture = SettableApiFuture.create();
+
+ UnaryCallable> unaryCallable =
+ new UnaryCallable>() {
+ @Override
+ public ApiFuture> futureCall(
+ LabeledIntList request, ApiCallContext context) {
+ return innerFuture;
+ }
+ };
+ underTest =
+ new BatcherImpl<>(
+ SQUARER_BATCHING_DESC_V2, unaryCallable, labeledIntList, batchingSettings, EXECUTOR);
+
+ ApiFuture elementF = underTest.add(1);
+
+ // Initial close will timeout
+ TimeoutException firstCloseException =
+ assertThrows(TimeoutException.class, () -> underTest.close(Duration.ofMillis(10)));
+ assertThat(firstCloseException).hasMessageThat().contains("Timed out");
+
+ underTest.cancelOutstanding();
+
+ BatchingException finalCloseException =
+ assertThrows(BatchingException.class, () -> underTest.close(Duration.ofSeconds(1)));
+ assertThat(finalCloseException).hasMessageThat().contains("Batching finished");
+
+ // element future should resolve to a cancelled future
+ ExecutionException elementException = assertThrows(ExecutionException.class, elementF::get);
+ assertThat(elementException).hasCauseThat().isInstanceOf(CancellationException.class);
+ }
+
/** Verifies exception occurred at RPC is propagated to element results */
@Test
void testResultFailureAfterRPCFailure() throws Exception {
@@ -1102,4 +1206,27 @@ private BatcherImpl> createDefau
EXECUTOR,
flowController);
}
+
+ private static class MySettableApiFuture extends AbstractApiFuture {
+ private final String desc;
+
+ MySettableApiFuture(String desc) {
+ this.desc = desc;
+ }
+
+ @Override
+ public boolean set(T value) {
+ return super.set(value);
+ }
+
+ @Override
+ public boolean setException(Throwable throwable) {
+ return super.setException(throwable);
+ }
+
+ @Override
+ public String toString() {
+ return desc;
+ }
+ }
}