diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
index 5be92558ee0..7203d3a4945 100644
--- a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
+++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
@@ -18,6 +18,8 @@
import com.mongodb.lang.Nullable;
+import java.util.concurrent.atomic.AtomicBoolean;
+
/**
* See {@link AsyncRunnable}
*
@@ -33,4 +35,28 @@ public interface AsyncFunction {
* @param callback the callback
*/
void unsafeFinish(T value, SingleResultCallback callback);
+
+ /**
+ * Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
+ *
+ * @param callback the callback provided by the method the chain is used in.
+ */
+ default void finish(final T value, final SingleResultCallback callback) {
+ final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
+ try {
+ this.unsafeFinish(value, (v, e) -> {
+ if (!callbackInvoked.compareAndSet(false, true)) {
+ throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ + "if code throws an exception after invoking an async method. Value: %s", v), e);
+ }
+ callback.onResult(v, e);
+ });
+ } catch (Throwable t) {
+ if (!callbackInvoked.compareAndSet(false, true)) {
+ throw t;
+ } else {
+ callback.completeExceptionally(t);
+ }
+ }
+ }
}
diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java
index a81b2fdd12c..d4ead3c5b96 100644
--- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java
+++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java
@@ -171,7 +171,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
- runnable.unsafeFinish(c);
+ /* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()',
+ then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */
+ runnable.finish(c);
} else {
c.completeExceptionally(e);
}
@@ -236,7 +238,7 @@ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRu
return;
}
if (matched) {
- runnable.unsafeFinish(callback);
+ runnable.finish(callback);
} else {
callback.complete(callback);
}
@@ -253,7 +255,7 @@ default AsyncSupplier thenSupply(final AsyncSupplier supplier) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
- supplier.unsafeFinish(c);
+ supplier.finish(c);
} else {
c.completeExceptionally(e);
}
diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
index b7d24dd3df5..77c289c8723 100644
--- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
+++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
@@ -18,6 +18,7 @@
import com.mongodb.lang.Nullable;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
@@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback
}
/**
- * Must be invoked at end of async chain.
+ * Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
+ *
+ * @see #thenApply(AsyncFunction)
+ * @see #thenConsume(AsyncConsumer)
+ * @see #onErrorIf(Predicate, AsyncFunction)
* @param callback the callback provided by the method the chain is used in
*/
default void finish(final SingleResultCallback callback) {
- final boolean[] callbackInvoked = {false};
+ final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish((v, e) -> {
- callbackInvoked[0] = true;
+ if (!callbackInvoked.compareAndSet(false, true)) {
+ throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ + "if code throws an exception after invoking an async method. Value: %s", v), e);
+ }
callback.onResult(v, e);
});
} catch (Throwable t) {
- if (callbackInvoked[0]) {
+ if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
@@ -80,9 +88,9 @@ default void finish(final SingleResultCallback callback) {
*/
default AsyncSupplier thenApply(final AsyncFunction function) {
return (c) -> {
- this.unsafeFinish((v, e) -> {
+ this.finish((v, e) -> {
if (e == null) {
- function.unsafeFinish(v, c);
+ function.finish(v, c);
} else {
c.completeExceptionally(e);
}
@@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer consumer) {
return (c) -> {
this.unsafeFinish((v, e) -> {
if (e == null) {
- consumer.unsafeFinish(v, c);
+ consumer.finish(v, c);
} else {
c.completeExceptionally(e);
}
@@ -131,7 +139,7 @@ default AsyncSupplier onErrorIf(
return;
}
if (errorMatched) {
- errorFunction.unsafeFinish(e, callback);
+ errorFunction.finish(e, callback);
} else {
callback.completeExceptionally(e);
}
diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java
index 98e43fe5fbe..de12e5f092f 100644
--- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java
+++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java
@@ -610,6 +610,7 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d
return;
}
assertNotNull(responseBuffers);
+ T commandResult;
try {
updateSessionContext(operationContext.getSessionContext(), responseBuffers);
boolean commandOk =
@@ -624,13 +625,14 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d
}
commandEventSender.sendSucceededEvent(responseBuffers);
- T result1 = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext());
- callback.onResult(result1, null);
+ commandResult = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext());
} catch (Throwable localThrowable) {
callback.onResult(null, localThrowable);
+ return;
} finally {
responseBuffers.close();
}
+ callback.onResult(commandResult, null);
}));
}
});
diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java
index ee509873e40..6fca357b080 100644
--- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java
+++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java
@@ -101,7 +101,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection, fin
callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t);
} else {
setSpeculativeAuthenticateResponse(helloResult);
- callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null);
+ InternalConnectionInitializationDescription initializationDescription;
+ try {
+ initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime);
+ } catch (Throwable localThrowable) {
+ callback.onResult(null, localThrowable);
+ return;
+ }
+ callback.onResult(initializationDescription, null);
}
});
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java
similarity index 97%
rename from driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java
rename to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java
index 20553fe881a..611b90fc675 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java
@@ -25,10 +25,10 @@
import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-final class AsyncFunctionsTest extends AsyncFunctionsTestAbstract {
+abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));
+
@Test
void test1Method() {
// the number of expected variations is often: 1 + N methods invoked
@@ -760,25 +760,6 @@ void testVariables() {
});
}
- @Test
- void testInvalid() {
- setIsTestingAbruptCompletion(false);
- setAsyncStep(true);
- assertThrows(IllegalStateException.class, () -> {
- beginAsync().thenRun(c -> {
- async(3, c);
- throw new IllegalStateException("must not cause second callback invocation");
- }).finish((v, e) -> {});
- });
- assertThrows(IllegalStateException.class, () -> {
- beginAsync().thenRun(c -> {
- async(3, c);
- }).finish((v, e) -> {
- throw new IllegalStateException("must not cause second callback invocation");
- });
- });
- }
-
@Test
void testDerivation() {
// Demonstrates the progression from nested async to the API.
@@ -866,5 +847,4 @@ void testDerivation() {
}).finish(callback);
});
}
-
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestAbstract.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java
similarity index 80%
rename from driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestAbstract.java
rename to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java
index 7cc8b456f1c..1229dbcfcad 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestAbstract.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java
@@ -17,11 +17,17 @@
package com.mongodb.internal.async;
import com.mongodb.client.TestListener;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
import org.opentest4j.AssertionFailedError;
import java.util.ArrayList;
import java.util.List;
-import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
@@ -31,11 +37,12 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
-public class AsyncFunctionsTestAbstract {
+public abstract class AsyncFunctionsTestBase {
private final TestListener listener = new TestListener();
private final InvocationTracker invocationTracker = new InvocationTracker();
private boolean isTestingAbruptCompletion = false;
+ private ExecutorService asyncExecutor;
void setIsTestingAbruptCompletion(final boolean b) {
isTestingAbruptCompletion = b;
@@ -53,6 +60,23 @@ public void listenerAdd(final String s) {
listener.add(s);
}
+ /**
+ * Create an executor service for async operations before each test.
+ *
+ * @return the executor service.
+ */
+ public abstract ExecutorService createAsyncExecutor();
+
+ @BeforeEach
+ public void setUp() {
+ asyncExecutor = createAsyncExecutor();
+ }
+
+ @AfterEach
+ public void shutDown() {
+ asyncExecutor.shutdownNow();
+ }
+
void plain(final int i) {
int cur = invocationTracker.getNextOption(2);
if (cur == 0) {
@@ -98,32 +122,47 @@ Integer syncReturns(final int i) {
return affectedReturns(i);
}
+
+ public void submit(final Runnable task) {
+ asyncExecutor.execute(task);
+ }
void async(final int i, final SingleResultCallback callback) {
assertTrue(invocationTracker.isAsyncStep);
if (isTestingAbruptCompletion) {
+ /* We should not test for abrupt completion in a separate thread. Once a callback is registered for an async operation,
+ the Async Framework does not handle exceptions thrown outside of callbacks by the executing thread. Such exception management
+ should be the responsibility of the thread conducting the asynchronous operations. */
affected(i);
- callback.complete(callback);
-
- } else {
- try {
- affected(i);
+ submit(() -> {
callback.complete(callback);
- } catch (Throwable t) {
- callback.onResult(null, t);
- }
+ });
+ } else {
+ submit(() -> {
+ try {
+ affected(i);
+ callback.complete(callback);
+ } catch (Throwable t) {
+ callback.onResult(null, t);
+ }
+ });
}
}
void asyncReturns(final int i, final SingleResultCallback callback) {
assertTrue(invocationTracker.isAsyncStep);
if (isTestingAbruptCompletion) {
- callback.complete(affectedReturns(i));
+ int result = affectedReturns(i);
+ submit(() -> {
+ callback.complete(result);
+ });
} else {
- try {
- callback.complete(affectedReturns(i));
- } catch (Throwable t) {
- callback.onResult(null, t);
- }
+ submit(() -> {
+ try {
+ callback.complete(affectedReturns(i));
+ } catch (Throwable t) {
+ callback.onResult(null, t);
+ }
+ });
}
}
@@ -200,24 +239,26 @@ private void assertBehavesSame(final Supplier sync, final Runnable betwee
AtomicReference actualValue = new AtomicReference<>();
AtomicReference actualException = new AtomicReference<>();
- AtomicBoolean wasCalled = new AtomicBoolean(false);
+ CompletableFuture wasCalledFuture = new CompletableFuture<>();
try {
async.accept((v, e) -> {
actualValue.set(v);
actualException.set(e);
- if (wasCalled.get()) {
+ if (wasCalledFuture.isDone()) {
fail();
}
- wasCalled.set(true);
+ wasCalledFuture.complete(null);
});
} catch (Throwable e) {
fail("async threw instead of using callback");
}
+ await(wasCalledFuture, "Callback should have been called");
+
// The following code can be used to debug variations:
// System.out.println("===VARIATION START");
// System.out.println("sync: " + expectedEvents);
-// System.out.println("callback called?: " + wasCalled.get());
+// System.out.println("callback called?: " + wasCalledFuture.isDone());
// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get());
// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get());
// System.out.println("exception mode: " + (isTestingAbruptCompletion
@@ -229,7 +270,7 @@ private void assertBehavesSame(final Supplier sync, final Runnable betwee
throw (AssertionFailedError) actualException.get();
}
- assertTrue(wasCalled.get(), "callback should have been called");
+ assertTrue(wasCalledFuture.isDone(), "callback should have been called");
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());
assertEquals(expectedException == null, actualException.get() == null,
@@ -242,6 +283,14 @@ private void assertBehavesSame(final Supplier sync, final Runnable betwee
listener.clear();
}
+ protected T await(final CompletableFuture voidCompletableFuture, final String errorMessage) {
+ try {
+ return voidCompletableFuture.get(1, TimeUnit.MINUTES);
+ } catch (InterruptedException | ExecutionException | TimeoutException e) {
+ throw new AssertionError(errorMessage);
+ }
+ }
+
/**
* Tracks invocations: allows testing of all variations of a method calls
*/
diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java
new file mode 100644
index 00000000000..04b9290af55
--- /dev/null
+++ b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.async;
+
+import org.jetbrains.annotations.NotNull;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.AbstractExecutorService;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+@DisplayName("The same thread async functions")
+public class SameThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest {
+ @Override
+ public ExecutorService createAsyncExecutor() {
+ return new SameThreadExecutorService();
+ }
+
+ @Test
+ void testInvalid() {
+ setIsTestingAbruptCompletion(false);
+ setAsyncStep(true);
+ IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation");
+
+ assertThrows(IllegalStateException.class, () -> {
+ beginAsync().thenRun(c -> {
+ async(3, c);
+ throw illegalStateException;
+ }).finish((v, e) -> {
+ assertNotEquals(e, illegalStateException);
+ });
+ });
+ assertThrows(IllegalStateException.class, () -> {
+ beginAsync().thenRun(c -> {
+ async(3, c);
+ }).finish((v, e) -> {
+ throw illegalStateException;
+ });
+ });
+ }
+
+ private static class SameThreadExecutorService extends AbstractExecutorService {
+ @Override
+ public void execute(@NotNull final Runnable command) {
+ command.run();
+ }
+
+ @Override
+ public void shutdown() {
+ }
+
+ @NotNull
+ @Override
+ public List shutdownNow() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public boolean isShutdown() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean isTerminated() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean awaitTermination(final long timeout, @NotNull final TimeUnit unit) {
+ return true;
+ }
+ }
+}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java
new file mode 100644
index 00000000000..401c4d2c18e
--- /dev/null
+++ b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.async;
+
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+@DisplayName("Separate thread async functions")
+public class SeparateThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest {
+
+ private UncaughtExceptionHandler uncaughtExceptionHandler;
+
+ @Override
+ public ExecutorService createAsyncExecutor() {
+ uncaughtExceptionHandler = new UncaughtExceptionHandler();
+ return Executors.newFixedThreadPool(1, r -> {
+ Thread thread = new Thread(r);
+ thread.setUncaughtExceptionHandler(uncaughtExceptionHandler);
+ return thread;
+ });
+ }
+
+ /**
+ * This test covers the scenario where a callback is erroneously invoked after a callback had been completed.
+ * Such behavior is considered a bug and is not expected. An AssertionError should be thrown if an asynchronous invocation
+ * attempts to use a callback that has already been marked as completed.
+ */
+ @Test
+ void shouldPropagateAssertionErrorIfCallbackHasBeenCompletedAfterAsyncInvocation() {
+ //given
+ setIsTestingAbruptCompletion(false);
+ setAsyncStep(true);
+ IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation");
+ AtomicBoolean callbackInvoked = new AtomicBoolean(false);
+
+ //when
+ beginAsync().thenRun(c -> {
+ async(3, c);
+ throw illegalStateException;
+ }).thenRun(c -> {
+ assertInvokedOnce(callbackInvoked);
+ c.complete(c);
+ })
+ .finish((v, e) -> {
+ assertEquals(illegalStateException, e);
+ }
+ );
+
+ //then
+ Throwable exception = uncaughtExceptionHandler.getException();
+ assertNotNull(exception);
+ assertEquals(AssertionError.class, exception.getClass());
+ assertEquals("Callback has been already completed. It could happen "
+ + "if code throws an exception after invoking an async method. Value: null", exception.getMessage());
+ }
+
+ @Test
+ void shouldPropagateUnexpectedExceptionFromFinishCallback() {
+ //given
+ setIsTestingAbruptCompletion(false);
+ setAsyncStep(true);
+ IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation");
+
+ //when
+ beginAsync().thenRun(c -> {
+ async(3, c);
+ }).finish((v, e) -> {
+ throw illegalStateException;
+ });
+
+ //then
+ Throwable exception = uncaughtExceptionHandler.getException();
+ assertNotNull(exception);
+ assertEquals(illegalStateException, exception);
+ }
+
+ private static void assertInvokedOnce(final AtomicBoolean callbackInvoked1) {
+ assertTrue(callbackInvoked1.compareAndSet(false, true));
+ }
+
+ private final class UncaughtExceptionHandler implements Thread.UncaughtExceptionHandler {
+
+ private final CompletableFuture completable = new CompletableFuture<>();
+
+ @Override
+ public void uncaughtException(final Thread t, final Throwable e) {
+ completable.complete(e);
+ }
+
+ public Throwable getException() {
+ return await(completable, "No exception was thrown");
+ }
+ }
+}