diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt index 3b106c440d..54e88677e1 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt @@ -1,10 +1,13 @@ package kotlinx.coroutines +import kotlinx.coroutines.flow.* import kotlinx.coroutines.testing.* import org.junit.Test +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors import kotlin.coroutines.* import kotlin.test.* -import kotlinx.coroutines.flow.* class ThreadContextElementTest : TestBase() { @@ -155,39 +158,81 @@ class ThreadContextElementTest : TestBase() { } } - class JobCaptor(val capturees: ArrayList = ArrayList()) : ThreadContextElement { + class JobCaptor(val capturees: MutableList = CopyOnWriteArrayList()) : ThreadContextElement { companion object Key : CoroutineContext.Key override val key: CoroutineContext.Key<*> get() = Key override fun updateThreadContext(context: CoroutineContext) { - capturees.add(context.job) + capturees.add("Update: ${context.job}") } override fun restoreThreadContext(context: CoroutineContext, oldState: Unit) { + capturees.add("Restore: ${context.job}") } } + /** + * For stability of the test, it is important to make sure that + * the parent job actually suspends when calling + * `withContext(dispatcher2 + CoroutineName("dispatched"))`. + * + * Here this requirement is fulfilled by forcing execution on a single thread. + * However, dispatching is performed with two non-equal dispatchers to force dispatching. + * + * Suspend of the parent coroutine [kotlinx.coroutines.DispatchedCoroutine.trySuspend] is out of the control of the test, + * while being executed concurrently with resume of the child coroutine [kotlinx.coroutines.DispatchedCoroutine.tryResume]. + */ @Test fun testWithContextJobAccess() = runTest { + val executor = Executors.newSingleThreadExecutor() + // Emulate non-equal dispatchers + val executor1 = object : ExecutorService by executor {} + val executor2 = object : ExecutorService by executor {} + val dispatcher1 = executor1.asCoroutineDispatcher() + val dispatcher2 = executor2.asCoroutineDispatcher() val captor = JobCaptor() - val manuallyCaptured = ArrayList() - runBlocking(captor) { - manuallyCaptured += coroutineContext.job + val manuallyCaptured = mutableListOf() + + fun registerUpdate(job: Job?) = manuallyCaptured.add("Update: $job") + fun registerRestore(job: Job?) = manuallyCaptured.add("Restore: $job") + + var rootJob: Job? = null + runBlocking(captor + dispatcher1) { + rootJob = coroutineContext.job + registerUpdate(rootJob) + var undispatchedJob: Job? = null withContext(CoroutineName("undispatched")) { - manuallyCaptured += coroutineContext.job - withContext(Dispatchers.IO) { - manuallyCaptured += coroutineContext.job + undispatchedJob = coroutineContext.job + registerUpdate(undispatchedJob) + // These 2 restores and the corresponding next 2 updates happen only if the following `withContext` + // call actually suspends. + registerRestore(undispatchedJob) + registerRestore(rootJob) + // Without forcing of single backing thread the code inside `withContext` + // may already complete at the moment when the parent coroutine decides + // whether it needs to suspend or not. + var dispatchedJob: Job? = null + withContext(dispatcher2 + CoroutineName("dispatched")) { + dispatchedJob = coroutineContext.job + registerUpdate(dispatchedJob) } + registerRestore(dispatchedJob) // Context restored, captured again - manuallyCaptured += coroutineContext.job + registerUpdate(undispatchedJob) } + registerRestore(undispatchedJob) // Context restored, captured again - manuallyCaptured += coroutineContext.job + registerUpdate(rootJob) } + registerRestore(rootJob) - assertEquals(manuallyCaptured, captor.capturees) + // Restores may be called concurrently to the update calls in other threads, so their order is not checked. + val expected = manuallyCaptured.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") + val actual = captor.capturees.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") + assertEquals(expected, actual) + executor.shutdownNow() } @Test