diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java index 5be92558ee0..7203d3a4945 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -18,6 +18,8 @@ import com.mongodb.lang.Nullable; +import java.util.concurrent.atomic.AtomicBoolean; + /** * See {@link AsyncRunnable} *

@@ -33,4 +35,28 @@ public interface AsyncFunction { * @param callback the callback */ void unsafeFinish(T value, SingleResultCallback callback); + + /** + * Must be invoked at end of async chain or when executing a callback handler supplied by the caller. + * + * @param callback the callback provided by the method the chain is used in. + */ + default void finish(final T value, final SingleResultCallback callback) { + final AtomicBoolean callbackInvoked = new AtomicBoolean(false); + try { + this.unsafeFinish(value, (v, e) -> { + if (!callbackInvoked.compareAndSet(false, true)) { + throw new AssertionError(String.format("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: %s", v), e); + } + callback.onResult(v, e); + }); + } catch (Throwable t) { + if (!callbackInvoked.compareAndSet(false, true)) { + throw t; + } else { + callback.completeExceptionally(t); + } + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index a81b2fdd12c..d4ead3c5b96 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -171,7 +171,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) { return (c) -> { this.unsafeFinish((r, e) -> { if (e == null) { - runnable.unsafeFinish(c); + /* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()', + then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */ + runnable.finish(c); } else { c.completeExceptionally(e); } @@ -236,7 +238,7 @@ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRu return; } if (matched) { - runnable.unsafeFinish(callback); + runnable.finish(callback); } else { callback.complete(callback); } @@ -253,7 +255,7 @@ default AsyncSupplier thenSupply(final AsyncSupplier supplier) { return (c) -> { this.unsafeFinish((r, e) -> { if (e == null) { - supplier.unsafeFinish(c); + supplier.finish(c); } else { c.completeExceptionally(e); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index b7d24dd3df5..77c289c8723 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -18,6 +18,7 @@ import com.mongodb.lang.Nullable; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Predicate; @@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback } /** - * Must be invoked at end of async chain. + * Must be invoked at end of async chain or when executing a callback handler supplied by the caller. + * + * @see #thenApply(AsyncFunction) + * @see #thenConsume(AsyncConsumer) + * @see #onErrorIf(Predicate, AsyncFunction) * @param callback the callback provided by the method the chain is used in */ default void finish(final SingleResultCallback callback) { - final boolean[] callbackInvoked = {false}; + final AtomicBoolean callbackInvoked = new AtomicBoolean(false); try { this.unsafeFinish((v, e) -> { - callbackInvoked[0] = true; + if (!callbackInvoked.compareAndSet(false, true)) { + throw new AssertionError(String.format("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: %s", v), e); + } callback.onResult(v, e); }); } catch (Throwable t) { - if (callbackInvoked[0]) { + if (!callbackInvoked.compareAndSet(false, true)) { throw t; } else { callback.completeExceptionally(t); @@ -80,9 +88,9 @@ default void finish(final SingleResultCallback callback) { */ default AsyncSupplier thenApply(final AsyncFunction function) { return (c) -> { - this.unsafeFinish((v, e) -> { + this.finish((v, e) -> { if (e == null) { - function.unsafeFinish(v, c); + function.finish(v, c); } else { c.completeExceptionally(e); } @@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer consumer) { return (c) -> { this.unsafeFinish((v, e) -> { if (e == null) { - consumer.unsafeFinish(v, c); + consumer.finish(v, c); } else { c.completeExceptionally(e); } @@ -131,7 +139,7 @@ default AsyncSupplier onErrorIf( return; } if (errorMatched) { - errorFunction.unsafeFinish(e, callback); + errorFunction.finish(e, callback); } else { callback.completeExceptionally(e); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 98e43fe5fbe..de12e5f092f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -610,6 +610,7 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d return; } assertNotNull(responseBuffers); + T commandResult; try { updateSessionContext(operationContext.getSessionContext(), responseBuffers); boolean commandOk = @@ -624,13 +625,14 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d } commandEventSender.sendSucceededEvent(responseBuffers); - T result1 = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext()); - callback.onResult(result1, null); + commandResult = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext()); } catch (Throwable localThrowable) { callback.onResult(null, localThrowable); + return; } finally { responseBuffers.close(); } + callback.onResult(commandResult, null); })); } }); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index ee509873e40..6fca357b080 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -101,7 +101,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection, fin callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t); } else { setSpeculativeAuthenticateResponse(helloResult); - callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null); + InternalConnectionInitializationDescription initializationDescription; + try { + initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime); + } catch (Throwable localThrowable) { + callback.onResult(null, localThrowable); + return; + } + callback.onResult(initializationDescription, null); } }); } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java similarity index 97% rename from driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java rename to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java index 20553fe881a..611b90fc675 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java @@ -25,10 +25,10 @@ import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; -import static org.junit.jupiter.api.Assertions.assertThrows; -final class AsyncFunctionsTest extends AsyncFunctionsTestAbstract { +abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase { private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0)); + @Test void test1Method() { // the number of expected variations is often: 1 + N methods invoked @@ -760,25 +760,6 @@ void testVariables() { }); } - @Test - void testInvalid() { - setIsTestingAbruptCompletion(false); - setAsyncStep(true); - assertThrows(IllegalStateException.class, () -> { - beginAsync().thenRun(c -> { - async(3, c); - throw new IllegalStateException("must not cause second callback invocation"); - }).finish((v, e) -> {}); - }); - assertThrows(IllegalStateException.class, () -> { - beginAsync().thenRun(c -> { - async(3, c); - }).finish((v, e) -> { - throw new IllegalStateException("must not cause second callback invocation"); - }); - }); - } - @Test void testDerivation() { // Demonstrates the progression from nested async to the API. @@ -866,5 +847,4 @@ void testDerivation() { }).finish(callback); }); } - } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestAbstract.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java similarity index 80% rename from driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestAbstract.java rename to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java index 7cc8b456f1c..1229dbcfcad 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestAbstract.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java @@ -17,11 +17,17 @@ package com.mongodb.internal.async; import com.mongodb.client.TestListener; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.opentest4j.AssertionFailedError; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; @@ -31,11 +37,12 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -public class AsyncFunctionsTestAbstract { +public abstract class AsyncFunctionsTestBase { private final TestListener listener = new TestListener(); private final InvocationTracker invocationTracker = new InvocationTracker(); private boolean isTestingAbruptCompletion = false; + private ExecutorService asyncExecutor; void setIsTestingAbruptCompletion(final boolean b) { isTestingAbruptCompletion = b; @@ -53,6 +60,23 @@ public void listenerAdd(final String s) { listener.add(s); } + /** + * Create an executor service for async operations before each test. + * + * @return the executor service. + */ + public abstract ExecutorService createAsyncExecutor(); + + @BeforeEach + public void setUp() { + asyncExecutor = createAsyncExecutor(); + } + + @AfterEach + public void shutDown() { + asyncExecutor.shutdownNow(); + } + void plain(final int i) { int cur = invocationTracker.getNextOption(2); if (cur == 0) { @@ -98,32 +122,47 @@ Integer syncReturns(final int i) { return affectedReturns(i); } + + public void submit(final Runnable task) { + asyncExecutor.execute(task); + } void async(final int i, final SingleResultCallback callback) { assertTrue(invocationTracker.isAsyncStep); if (isTestingAbruptCompletion) { + /* We should not test for abrupt completion in a separate thread. Once a callback is registered for an async operation, + the Async Framework does not handle exceptions thrown outside of callbacks by the executing thread. Such exception management + should be the responsibility of the thread conducting the asynchronous operations. */ affected(i); - callback.complete(callback); - - } else { - try { - affected(i); + submit(() -> { callback.complete(callback); - } catch (Throwable t) { - callback.onResult(null, t); - } + }); + } else { + submit(() -> { + try { + affected(i); + callback.complete(callback); + } catch (Throwable t) { + callback.onResult(null, t); + } + }); } } void asyncReturns(final int i, final SingleResultCallback callback) { assertTrue(invocationTracker.isAsyncStep); if (isTestingAbruptCompletion) { - callback.complete(affectedReturns(i)); + int result = affectedReturns(i); + submit(() -> { + callback.complete(result); + }); } else { - try { - callback.complete(affectedReturns(i)); - } catch (Throwable t) { - callback.onResult(null, t); - } + submit(() -> { + try { + callback.complete(affectedReturns(i)); + } catch (Throwable t) { + callback.onResult(null, t); + } + }); } } @@ -200,24 +239,26 @@ private void assertBehavesSame(final Supplier sync, final Runnable betwee AtomicReference actualValue = new AtomicReference<>(); AtomicReference actualException = new AtomicReference<>(); - AtomicBoolean wasCalled = new AtomicBoolean(false); + CompletableFuture wasCalledFuture = new CompletableFuture<>(); try { async.accept((v, e) -> { actualValue.set(v); actualException.set(e); - if (wasCalled.get()) { + if (wasCalledFuture.isDone()) { fail(); } - wasCalled.set(true); + wasCalledFuture.complete(null); }); } catch (Throwable e) { fail("async threw instead of using callback"); } + await(wasCalledFuture, "Callback should have been called"); + // The following code can be used to debug variations: // System.out.println("===VARIATION START"); // System.out.println("sync: " + expectedEvents); -// System.out.println("callback called?: " + wasCalled.get()); +// System.out.println("callback called?: " + wasCalledFuture.isDone()); // System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get()); // System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get()); // System.out.println("exception mode: " + (isTestingAbruptCompletion @@ -229,7 +270,7 @@ private void assertBehavesSame(final Supplier sync, final Runnable betwee throw (AssertionFailedError) actualException.get(); } - assertTrue(wasCalled.get(), "callback should have been called"); + assertTrue(wasCalledFuture.isDone(), "callback should have been called"); assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched"); assertEquals(expectedValue, actualValue.get()); assertEquals(expectedException == null, actualException.get() == null, @@ -242,6 +283,14 @@ private void assertBehavesSame(final Supplier sync, final Runnable betwee listener.clear(); } + protected T await(final CompletableFuture voidCompletableFuture, final String errorMessage) { + try { + return voidCompletableFuture.get(1, TimeUnit.MINUTES); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new AssertionError(errorMessage); + } + } + /** * Tracks invocations: allows testing of all variations of a method calls */ diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java new file mode 100644 index 00000000000..04b9290af55 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@DisplayName("The same thread async functions") +public class SameThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest { + @Override + public ExecutorService createAsyncExecutor() { + return new SameThreadExecutorService(); + } + + @Test + void testInvalid() { + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + throw illegalStateException; + }).finish((v, e) -> { + assertNotEquals(e, illegalStateException); + }); + }); + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw illegalStateException; + }); + }); + } + + private static class SameThreadExecutorService extends AbstractExecutorService { + @Override + public void execute(@NotNull final Runnable command) { + command.run(); + } + + @Override + public void shutdown() { + } + + @NotNull + @Override + public List shutdownNow() { + return Collections.emptyList(); + } + + @Override + public boolean isShutdown() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isTerminated() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean awaitTermination(final long timeout, @NotNull final TimeUnit unit) { + return true; + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java new file mode 100644 index 00000000000..401c4d2c18e --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@DisplayName("Separate thread async functions") +public class SeparateThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest { + + private UncaughtExceptionHandler uncaughtExceptionHandler; + + @Override + public ExecutorService createAsyncExecutor() { + uncaughtExceptionHandler = new UncaughtExceptionHandler(); + return Executors.newFixedThreadPool(1, r -> { + Thread thread = new Thread(r); + thread.setUncaughtExceptionHandler(uncaughtExceptionHandler); + return thread; + }); + } + + /** + * This test covers the scenario where a callback is erroneously invoked after a callback had been completed. + * Such behavior is considered a bug and is not expected. An AssertionError should be thrown if an asynchronous invocation + * attempts to use a callback that has already been marked as completed. + */ + @Test + void shouldPropagateAssertionErrorIfCallbackHasBeenCompletedAfterAsyncInvocation() { + //given + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + AtomicBoolean callbackInvoked = new AtomicBoolean(false); + + //when + beginAsync().thenRun(c -> { + async(3, c); + throw illegalStateException; + }).thenRun(c -> { + assertInvokedOnce(callbackInvoked); + c.complete(c); + }) + .finish((v, e) -> { + assertEquals(illegalStateException, e); + } + ); + + //then + Throwable exception = uncaughtExceptionHandler.getException(); + assertNotNull(exception); + assertEquals(AssertionError.class, exception.getClass()); + assertEquals("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: null", exception.getMessage()); + } + + @Test + void shouldPropagateUnexpectedExceptionFromFinishCallback() { + //given + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + + //when + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw illegalStateException; + }); + + //then + Throwable exception = uncaughtExceptionHandler.getException(); + assertNotNull(exception); + assertEquals(illegalStateException, exception); + } + + private static void assertInvokedOnce(final AtomicBoolean callbackInvoked1) { + assertTrue(callbackInvoked1.compareAndSet(false, true)); + } + + private final class UncaughtExceptionHandler implements Thread.UncaughtExceptionHandler { + + private final CompletableFuture completable = new CompletableFuture<>(); + + @Override + public void uncaughtException(final Thread t, final Throwable e) { + completable.complete(e); + } + + public Throwable getException() { + return await(completable, "No exception was thrown"); + } + } +}