From 4df76ecb67d895f179bb974fa1209aa3c2c3ab01 Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Mon, 19 Aug 2024 11:10:49 -0700 Subject: [PATCH] Fix exception propagation in Async API methods (#1479) - Resolve an issue where exceptions thrown during thenRun, thenSupply, and related operations in the asynchronous API were not properly propagated to the completion callback. This issue was addressed by replacing `unsafeFinish` with `finish`, ensuring that exceptions are caught and correctly passed to the completion callback when executed on different threads. - Update existing Async API tests to ensure they simulate separate async thread execution. - Modify the async callback to catch and handle exceptions locally. Exceptions are now directly processed and passed as an error argument to the callback function, avoiding propagation to the parent callback. - Move `callback.onResult` outside the catch block to ensure it's not invoked twice when an exception occurs. JAVA-5562 --- .../mongodb/internal/async/AsyncFunction.java | 26 ++ .../mongodb/internal/async/AsyncRunnable.java | 8 +- .../mongodb/internal/async/AsyncSupplier.java | 24 +- .../connection/InternalStreamConnection.java | 6 +- .../InternalStreamConnectionInitializer.java | 9 +- ...t.java => AsyncFunctionsAbstractTest.java} | 310 +-------------- .../async/AsyncFunctionsTestBase.java | 373 ++++++++++++++++++ .../async/SameThreadAsyncFunctionsTest.java | 94 +++++ .../SeparateThreadAsyncFunctionsTest.java | 118 ++++++ 9 files changed, 647 insertions(+), 321 deletions(-) rename driver-core/src/test/unit/com/mongodb/internal/async/{AsyncFunctionsTest.java => AsyncFunctionsAbstractTest.java} (70%) create mode 100644 driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java create mode 100644 driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java create mode 100644 driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java index 5be92558ee0..7203d3a4945 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -18,6 +18,8 @@ import com.mongodb.lang.Nullable; +import java.util.concurrent.atomic.AtomicBoolean; + /** * See {@link AsyncRunnable} *

@@ -33,4 +35,28 @@ public interface AsyncFunction { * @param callback the callback */ void unsafeFinish(T value, SingleResultCallback callback); + + /** + * Must be invoked at end of async chain or when executing a callback handler supplied by the caller. + * + * @param callback the callback provided by the method the chain is used in. + */ + default void finish(final T value, final SingleResultCallback callback) { + final AtomicBoolean callbackInvoked = new AtomicBoolean(false); + try { + this.unsafeFinish(value, (v, e) -> { + if (!callbackInvoked.compareAndSet(false, true)) { + throw new AssertionError(String.format("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: %s", v), e); + } + callback.onResult(v, e); + }); + } catch (Throwable t) { + if (!callbackInvoked.compareAndSet(false, true)) { + throw t; + } else { + callback.completeExceptionally(t); + } + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index fcf8d61387d..7a872ded718 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -170,7 +170,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) { return (c) -> { this.unsafeFinish((r, e) -> { if (e == null) { - runnable.unsafeFinish(c); + /* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()', + then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */ + runnable.finish(c); } else { c.completeExceptionally(e); } @@ -199,7 +201,7 @@ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRu return; } if (matched) { - runnable.unsafeFinish(callback); + runnable.finish(callback); } else { callback.complete(callback); } @@ -216,7 +218,7 @@ default AsyncSupplier thenSupply(final AsyncSupplier supplier) { return (c) -> { this.unsafeFinish((r, e) -> { if (e == null) { - supplier.unsafeFinish(c); + supplier.finish(c); } else { c.completeExceptionally(e); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index b7d24dd3df5..77c289c8723 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -18,6 +18,7 @@ import com.mongodb.lang.Nullable; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Predicate; @@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback } /** - * Must be invoked at end of async chain. + * Must be invoked at end of async chain or when executing a callback handler supplied by the caller. + * + * @see #thenApply(AsyncFunction) + * @see #thenConsume(AsyncConsumer) + * @see #onErrorIf(Predicate, AsyncFunction) * @param callback the callback provided by the method the chain is used in */ default void finish(final SingleResultCallback callback) { - final boolean[] callbackInvoked = {false}; + final AtomicBoolean callbackInvoked = new AtomicBoolean(false); try { this.unsafeFinish((v, e) -> { - callbackInvoked[0] = true; + if (!callbackInvoked.compareAndSet(false, true)) { + throw new AssertionError(String.format("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: %s", v), e); + } callback.onResult(v, e); }); } catch (Throwable t) { - if (callbackInvoked[0]) { + if (!callbackInvoked.compareAndSet(false, true)) { throw t; } else { callback.completeExceptionally(t); @@ -80,9 +88,9 @@ default void finish(final SingleResultCallback callback) { */ default AsyncSupplier thenApply(final AsyncFunction function) { return (c) -> { - this.unsafeFinish((v, e) -> { + this.finish((v, e) -> { if (e == null) { - function.unsafeFinish(v, c); + function.finish(v, c); } else { c.completeExceptionally(e); } @@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer consumer) { return (c) -> { this.unsafeFinish((v, e) -> { if (e == null) { - consumer.unsafeFinish(v, c); + consumer.finish(v, c); } else { c.completeExceptionally(e); } @@ -131,7 +139,7 @@ default AsyncSupplier onErrorIf( return; } if (errorMatched) { - errorFunction.unsafeFinish(e, callback); + errorFunction.finish(e, callback); } else { callback.completeExceptionally(e); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 218835f083e..7751bcba86f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -595,6 +595,7 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d return; } assertNotNull(responseBuffers); + T commandResult; try { updateSessionContext(sessionContext, responseBuffers); boolean commandOk = @@ -609,13 +610,14 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d } commandEventSender.sendSucceededEvent(responseBuffers); - T result1 = getCommandResult(decoder, responseBuffers, messageId); - callback.onResult(result1, null); + commandResult = getCommandResult(decoder, responseBuffers, messageId); } catch (Throwable localThrowable) { callback.onResult(null, localThrowable); + return; } finally { responseBuffers.close(); } + callback.onResult(commandResult, null); })); } }); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index d4858f3d973..b8f85289a0b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -98,7 +98,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection, callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t); } else { setSpeculativeAuthenticateResponse(helloResult); - callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null); + InternalConnectionInitializationDescription initializationDescription; + try { + initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime); + } catch (Throwable localThrowable) { + callback.onResult(null, localThrowable); + return; + } + callback.onResult(initializationDescription, null); } }); } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java similarity index 70% rename from driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java rename to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java index b783b3de93b..16e4e978bf4 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java @@ -15,30 +15,16 @@ */ package com.mongodb.internal.async; -import com.mongodb.client.TestListener; import org.junit.jupiter.api.Test; -import org.opentest4j.AssertionFailedError; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -final class AsyncFunctionsTest { - private final TestListener listener = new TestListener(); - private final InvocationTracker invocationTracker = new InvocationTracker(); - private boolean isTestingAbruptCompletion = false; +abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase { @Test void test1Method() { @@ -720,25 +706,6 @@ void testVariables() { }); } - @Test - void testInvalid() { - isTestingAbruptCompletion = false; - invocationTracker.isAsyncStep = true; - assertThrows(IllegalStateException.class, () -> { - beginAsync().thenRun(c -> { - async(3, c); - throw new IllegalStateException("must not cause second callback invocation"); - }).finish((v, e) -> {}); - }); - assertThrows(IllegalStateException.class, () -> { - beginAsync().thenRun(c -> { - async(3, c); - }).finish((v, e) -> { - throw new IllegalStateException("must not cause second callback invocation"); - }); - }); - } - @Test void testDerivation() { // Demonstrates the progression from nested async to the API. @@ -746,8 +713,8 @@ void testDerivation() { // Stand-ins for sync-async methods; these "happily" do not throw // exceptions, to avoid complicating this demo async code. Consumer happySync = (i) -> { - invocationTracker.getNextOption(1); - listener.add("affected-success-" + i); + getNextOption(1); + listenerAdd("affected-success-" + i); }; BiConsumer> happyAsync = (i, c) -> { happySync.accept(i); @@ -827,275 +794,4 @@ void testDerivation() { }); } - // invoked methods: - - private void plain(final int i) { - int cur = invocationTracker.getNextOption(2); - if (cur == 0) { - listener.add("plain-exception-" + i); - throw new RuntimeException("affected method exception-" + i); - } else { - listener.add("plain-success-" + i); - } - } - - private int plainReturns(final int i) { - int cur = invocationTracker.getNextOption(2); - if (cur == 0) { - listener.add("plain-exception-" + i); - throw new RuntimeException("affected method exception-" + i); - } else { - listener.add("plain-success-" + i); - return i; - } - } - - private boolean plainTest(final int i) { - int cur = invocationTracker.getNextOption(3); - if (cur == 0) { - listener.add("plain-exception-" + i); - throw new RuntimeException("affected method exception-" + i); - } else if (cur == 1) { - listener.add("plain-false-" + i); - return false; - } else { - listener.add("plain-true-" + i); - return true; - } - } - - private void sync(final int i) { - assertFalse(invocationTracker.isAsyncStep); - affected(i); - } - - - private Integer syncReturns(final int i) { - assertFalse(invocationTracker.isAsyncStep); - return affectedReturns(i); - } - - private void async(final int i, final SingleResultCallback callback) { - assertTrue(invocationTracker.isAsyncStep); - if (isTestingAbruptCompletion) { - affected(i); - callback.complete(callback); - - } else { - try { - affected(i); - callback.complete(callback); - } catch (Throwable t) { - callback.onResult(null, t); - } - } - } - - private void asyncReturns(final int i, final SingleResultCallback callback) { - assertTrue(invocationTracker.isAsyncStep); - if (isTestingAbruptCompletion) { - callback.complete(affectedReturns(i)); - } else { - try { - callback.complete(affectedReturns(i)); - } catch (Throwable t) { - callback.onResult(null, t); - } - } - } - - private void affected(final int i) { - int cur = invocationTracker.getNextOption(2); - if (cur == 0) { - listener.add("affected-exception-" + i); - throw new RuntimeException("exception-" + i); - } else { - listener.add("affected-success-" + i); - } - } - - private int affectedReturns(final int i) { - int cur = invocationTracker.getNextOption(2); - if (cur == 0) { - listener.add("affected-exception-" + i); - throw new RuntimeException("exception-" + i); - } else { - listener.add("affected-success-" + i); - return i; - } - } - - // assert methods: - - private void assertBehavesSameVariations(final int expectedVariations, final Runnable sync, - final Consumer> async) { - assertBehavesSameVariations(expectedVariations, - () -> { - sync.run(); - return null; - }, - (c) -> { - async.accept((v, e) -> c.onResult(v, e)); - }); - } - - private void assertBehavesSameVariations(final int expectedVariations, final Supplier sync, - final Consumer> async) { - // run the variation-trying code twice, with direct/indirect exceptions - for (int i = 0; i < 2; i++) { - isTestingAbruptCompletion = i != 0; - - // the variation-trying code: - invocationTracker.reset(); - do { - invocationTracker.startInitialStep(); - assertBehavesSame( - sync, - () -> invocationTracker.startMatchStep(), - async); - } while (invocationTracker.countDown()); - assertEquals(expectedVariations, invocationTracker.getVariationCount(), - "number of variations did not match"); - } - - } - - private void assertBehavesSame(final Supplier sync, final Runnable between, - final Consumer> async) { - - T expectedValue = null; - Throwable expectedException = null; - try { - expectedValue = sync.get(); - } catch (Throwable e) { - expectedException = e; - } - List expectedEvents = listener.getEventStrings(); - - listener.clear(); - between.run(); - - AtomicReference actualValue = new AtomicReference<>(); - AtomicReference actualException = new AtomicReference<>(); - AtomicBoolean wasCalled = new AtomicBoolean(false); - try { - async.accept((v, e) -> { - actualValue.set(v); - actualException.set(e); - if (wasCalled.get()) { - fail(); - } - wasCalled.set(true); - }); - } catch (Throwable e) { - fail("async threw instead of using callback"); - } - - // The following code can be used to debug variations: -// System.out.println("===VARIATION START"); -// System.out.println("sync: " + expectedEvents); -// System.out.println("callback called?: " + wasCalled.get()); -// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get()); -// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get()); -// System.out.println("exception mode: " + (isTestingAbruptCompletion -// ? "exceptions thrown directly (abrupt completion)" : "exceptions into callbacks")); -// System.out.println("===VARIATION END"); - - // show assertion failures arising in async tests - if (actualException.get() != null && actualException.get() instanceof AssertionFailedError) { - throw (AssertionFailedError) actualException.get(); - } - - assertTrue(wasCalled.get(), "callback should have been called"); - assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched"); - assertEquals(expectedValue, actualValue.get()); - assertEquals(expectedException == null, actualException.get() == null, - "both or neither should have produced an exception"); - if (expectedException != null) { - assertEquals(expectedException.getMessage(), actualException.get().getMessage()); - assertEquals(expectedException.getClass(), actualException.get().getClass()); - } - - listener.clear(); - } - - /** - * Tracks invocations: allows testing of all variations of a method calls - */ - private static class InvocationTracker { - public static final int DEPTH_LIMIT = 50; - private final List invocationOptionSequence = new ArrayList<>(); - private boolean isAsyncStep; // async = matching, vs initial step = populating - private int currentInvocationIndex; - private int variationCount; - - public void reset() { - variationCount = 0; - } - - public void startInitialStep() { - variationCount++; - isAsyncStep = false; - currentInvocationIndex = -1; - } - - public int getNextOption(final int myOptionsSize) { - /* - This method creates (or gets) the next invocation's option. Each - invoker of this method has the "option" to behave in various ways, - usually just success (option 1) and exceptional failure (option 0), - though some callers might have more options. A sequence of method - outcomes (options) is one "variation". Tests automatically test - all possible variations (up to a limit, to prevent infinite loops). - - Methods generally have labels, to ensure that corresponding - sync/async methods are called in the right order, but these labels - are unrelated to the "variation" logic here. There are two "modes" - (whether completion is abrupt, or not), which are also unrelated. - */ - - currentInvocationIndex++; // which invocation result we are dealing with - - if (currentInvocationIndex >= invocationOptionSequence.size()) { - if (isAsyncStep) { - fail("result should have been pre-initialized: steps may not match"); - } - if (isWithinDepthLimit()) { - invocationOptionSequence.add(myOptionsSize - 1); - } else { - invocationOptionSequence.add(0); // choose "0" option, should always be an exception - } - } - return invocationOptionSequence.get(currentInvocationIndex); - } - - public void startMatchStep() { - isAsyncStep = true; - currentInvocationIndex = -1; - } - - private boolean countDown() { - while (!invocationOptionSequence.isEmpty()) { - int lastItemIndex = invocationOptionSequence.size() - 1; - int lastItem = invocationOptionSequence.get(lastItemIndex); - if (lastItem > 0) { - // count current digit down by 1, until 0 - invocationOptionSequence.set(lastItemIndex, lastItem - 1); - return true; - } else { - // current digit completed, remove (move left) - invocationOptionSequence.remove(lastItemIndex); - } - } - return false; - } - - public int getVariationCount() { - return variationCount; - } - - public boolean isWithinDepthLimit() { - return invocationOptionSequence.size() < DEPTH_LIMIT; - } - } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java new file mode 100644 index 00000000000..207e06b8a47 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java @@ -0,0 +1,373 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.client.TestListener; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.opentest4j.AssertionFailedError; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public abstract class AsyncFunctionsTestBase { + + private final TestListener listener = new TestListener(); + private final InvocationTracker invocationTracker = new InvocationTracker(); + private boolean isTestingAbruptCompletion = false; + private ExecutorService asyncExecutor; + + void setIsTestingAbruptCompletion(final boolean b) { + isTestingAbruptCompletion = b; + } + + public void setAsyncStep(final boolean isAsyncStep) { + invocationTracker.isAsyncStep = isAsyncStep; + } + + public void getNextOption(final int i) { + invocationTracker.getNextOption(i); + } + + public void listenerAdd(final String s) { + listener.add(s); + } + + /** + * Create an executor service for async operations before each test. + * + * @return the executor service. + */ + public abstract ExecutorService createAsyncExecutor(); + + @BeforeEach + public void setUp() { + asyncExecutor = createAsyncExecutor(); + } + + @AfterEach + public void shutDown() { + asyncExecutor.shutdownNow(); + } + + void plain(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else { + listener.add("plain-success-" + i); + } + } + + int plainReturns(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("plain-returns-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else { + listener.add("plain-returns-success-" + i); + return i; + } + } + + boolean plainTest(final int i) { + int cur = invocationTracker.getNextOption(3); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else if (cur == 1) { + listener.add("plain-false-" + i); + return false; + } else { + listener.add("plain-true-" + i); + return true; + } + } + + void sync(final int i) { + assertFalse(invocationTracker.isAsyncStep); + affected(i); + } + + Integer syncReturns(final int i) { + assertFalse(invocationTracker.isAsyncStep); + return affectedReturns(i); + } + + + public void submit(final Runnable task) { + asyncExecutor.execute(task); + } + void async(final int i, final SingleResultCallback callback) { + assertTrue(invocationTracker.isAsyncStep); + if (isTestingAbruptCompletion) { + /* We should not test for abrupt completion in a separate thread. Once a callback is registered for an async operation, + the Async Framework does not handle exceptions thrown outside of callbacks by the executing thread. Such exception management + should be the responsibility of the thread conducting the asynchronous operations. */ + affected(i); + submit(() -> { + callback.complete(callback); + }); + } else { + submit(() -> { + try { + affected(i); + callback.complete(callback); + } catch (Throwable t) { + callback.onResult(null, t); + } + }); + } + } + + void asyncReturns(final int i, final SingleResultCallback callback) { + assertTrue(invocationTracker.isAsyncStep); + if (isTestingAbruptCompletion) { + int result = affectedReturns(i); + submit(() -> { + callback.complete(result); + }); + } else { + submit(() -> { + try { + callback.complete(affectedReturns(i)); + } catch (Throwable t) { + callback.onResult(null, t); + } + }); + } + } + + private void affected(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-success-" + i); + } + } + + private int affectedReturns(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-returns-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-returns-success-" + i); + return i; + } + } + + // assert methods: + + void assertBehavesSameVariations(final int expectedVariations, final Runnable sync, + final Consumer> async) { + assertBehavesSameVariations(expectedVariations, + () -> { + sync.run(); + return null; + }, + (c) -> { + async.accept((v, e) -> c.onResult(v, e)); + }); + } + + void assertBehavesSameVariations(final int expectedVariations, final Supplier sync, + final Consumer> async) { + // run the variation-trying code twice, with direct/indirect exceptions + for (int i = 0; i < 2; i++) { + isTestingAbruptCompletion = i != 0; + + // the variation-trying code: + invocationTracker.reset(); + do { + invocationTracker.startInitialStep(); + assertBehavesSame( + sync, + () -> invocationTracker.startMatchStep(), + async); + } while (invocationTracker.countDown()); + assertEquals(expectedVariations, invocationTracker.getVariationCount(), + "number of variations did not match"); + } + + } + + private void assertBehavesSame(final Supplier sync, final Runnable between, + final Consumer> async) { + + T expectedValue = null; + Throwable expectedException = null; + try { + expectedValue = sync.get(); + } catch (Throwable e) { + expectedException = e; + } + List expectedEvents = listener.getEventStrings(); + + listener.clear(); + between.run(); + + AtomicReference actualValue = new AtomicReference<>(); + AtomicReference actualException = new AtomicReference<>(); + CompletableFuture wasCalledFuture = new CompletableFuture<>(); + try { + async.accept((v, e) -> { + actualValue.set(v); + actualException.set(e); + if (wasCalledFuture.isDone()) { + fail(); + } + wasCalledFuture.complete(null); + }); + } catch (Throwable e) { + fail("async threw instead of using callback"); + } + + await(wasCalledFuture, "Callback should have been called"); + + // The following code can be used to debug variations: +// System.out.println("===VARIATION START"); +// System.out.println("sync: " + expectedEvents); +// System.out.println("callback called?: " + wasCalledFuture.isDone()); +// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get()); +// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get()); +// System.out.println("exception mode: " + (isTestingAbruptCompletion +// ? "exceptions thrown directly (abrupt completion)" : "exceptions into callbacks")); +// System.out.println("===VARIATION END"); + + // show assertion failures arising in async tests + if (actualException.get() != null && actualException.get() instanceof AssertionFailedError) { + throw (AssertionFailedError) actualException.get(); + } + + assertTrue(wasCalledFuture.isDone(), "callback should have been called"); + assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched"); + assertEquals(expectedValue, actualValue.get()); + assertEquals(expectedException == null, actualException.get() == null, + "both or neither should have produced an exception"); + if (expectedException != null) { + assertEquals(expectedException.getMessage(), actualException.get().getMessage()); + assertEquals(expectedException.getClass(), actualException.get().getClass()); + } + + listener.clear(); + } + + protected T await(final CompletableFuture voidCompletableFuture, final String errorMessage) { + try { + return voidCompletableFuture.get(1, TimeUnit.MINUTES); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new AssertionError(errorMessage); + } + } + + /** + * Tracks invocations: allows testing of all variations of a method calls + */ + static class InvocationTracker { + public static final int DEPTH_LIMIT = 50; + private final List invocationOptionSequence = new ArrayList<>(); + private boolean isAsyncStep; // async = matching, vs initial step = populating + private int currentInvocationIndex; + private int variationCount; + + public void reset() { + variationCount = 0; + } + + public void startInitialStep() { + variationCount++; + isAsyncStep = false; + currentInvocationIndex = -1; + } + + public int getNextOption(final int myOptionsSize) { + /* + This method creates (or gets) the next invocation's option. Each + invoker of this method has the "option" to behave in various ways, + usually just success (option 1) and exceptional failure (option 0), + though some callers might have more options. A sequence of method + outcomes (options) is one "variation". Tests automatically test + all possible variations (up to a limit, to prevent infinite loops). + + Methods generally have labels, to ensure that corresponding + sync/async methods are called in the right order, but these labels + are unrelated to the "variation" logic here. There are two "modes" + (whether completion is abrupt, or not), which are also unrelated. + */ + + currentInvocationIndex++; // which invocation result we are dealing with + + if (currentInvocationIndex >= invocationOptionSequence.size()) { + if (isAsyncStep) { + fail("result should have been pre-initialized: steps may not match"); + } + if (isWithinDepthLimit()) { + invocationOptionSequence.add(myOptionsSize - 1); + } else { + invocationOptionSequence.add(0); // choose "0" option, should always be an exception + } + } + return invocationOptionSequence.get(currentInvocationIndex); + } + + public void startMatchStep() { + isAsyncStep = true; + currentInvocationIndex = -1; + } + + private boolean countDown() { + while (!invocationOptionSequence.isEmpty()) { + int lastItemIndex = invocationOptionSequence.size() - 1; + int lastItem = invocationOptionSequence.get(lastItemIndex); + if (lastItem > 0) { + // count current digit down by 1, until 0 + invocationOptionSequence.set(lastItemIndex, lastItem - 1); + return true; + } else { + // current digit completed, remove (move left) + invocationOptionSequence.remove(lastItemIndex); + } + } + return false; + } + + public int getVariationCount() { + return variationCount; + } + + public boolean isWithinDepthLimit() { + return invocationOptionSequence.size() < DEPTH_LIMIT; + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java new file mode 100644 index 00000000000..04b9290af55 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@DisplayName("The same thread async functions") +public class SameThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest { + @Override + public ExecutorService createAsyncExecutor() { + return new SameThreadExecutorService(); + } + + @Test + void testInvalid() { + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + throw illegalStateException; + }).finish((v, e) -> { + assertNotEquals(e, illegalStateException); + }); + }); + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw illegalStateException; + }); + }); + } + + private static class SameThreadExecutorService extends AbstractExecutorService { + @Override + public void execute(@NotNull final Runnable command) { + command.run(); + } + + @Override + public void shutdown() { + } + + @NotNull + @Override + public List shutdownNow() { + return Collections.emptyList(); + } + + @Override + public boolean isShutdown() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isTerminated() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean awaitTermination(final long timeout, @NotNull final TimeUnit unit) { + return true; + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java new file mode 100644 index 00000000000..401c4d2c18e --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@DisplayName("Separate thread async functions") +public class SeparateThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest { + + private UncaughtExceptionHandler uncaughtExceptionHandler; + + @Override + public ExecutorService createAsyncExecutor() { + uncaughtExceptionHandler = new UncaughtExceptionHandler(); + return Executors.newFixedThreadPool(1, r -> { + Thread thread = new Thread(r); + thread.setUncaughtExceptionHandler(uncaughtExceptionHandler); + return thread; + }); + } + + /** + * This test covers the scenario where a callback is erroneously invoked after a callback had been completed. + * Such behavior is considered a bug and is not expected. An AssertionError should be thrown if an asynchronous invocation + * attempts to use a callback that has already been marked as completed. + */ + @Test + void shouldPropagateAssertionErrorIfCallbackHasBeenCompletedAfterAsyncInvocation() { + //given + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + AtomicBoolean callbackInvoked = new AtomicBoolean(false); + + //when + beginAsync().thenRun(c -> { + async(3, c); + throw illegalStateException; + }).thenRun(c -> { + assertInvokedOnce(callbackInvoked); + c.complete(c); + }) + .finish((v, e) -> { + assertEquals(illegalStateException, e); + } + ); + + //then + Throwable exception = uncaughtExceptionHandler.getException(); + assertNotNull(exception); + assertEquals(AssertionError.class, exception.getClass()); + assertEquals("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: null", exception.getMessage()); + } + + @Test + void shouldPropagateUnexpectedExceptionFromFinishCallback() { + //given + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + + //when + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw illegalStateException; + }); + + //then + Throwable exception = uncaughtExceptionHandler.getException(); + assertNotNull(exception); + assertEquals(illegalStateException, exception); + } + + private static void assertInvokedOnce(final AtomicBoolean callbackInvoked1) { + assertTrue(callbackInvoked1.compareAndSet(false, true)); + } + + private final class UncaughtExceptionHandler implements Thread.UncaughtExceptionHandler { + + private final CompletableFuture completable = new CompletableFuture<>(); + + @Override + public void uncaughtException(final Thread t, final Throwable e) { + completable.complete(e); + } + + public Throwable getException() { + return await(completable, "No exception was thrown"); + } + } +}