From af58faa63ff211c225c4348cb490c70a3456ab1b Mon Sep 17 00:00:00 2001 From: Johnny Schmidt Date: Tue, 3 Sep 2024 16:28:37 -0700 Subject: [PATCH] Unit tests for streams manager (#45090) --- .../airbyte/cdk/command/DestinationCatalog.kt | 6 +- .../airbyte/cdk/message/MessageQueueWriter.kt | 5 +- .../{StreamManager.kt => StreamsManager.kt} | 125 +++++--- .../io/airbyte/cdk/task/TeardownTask.kt | 2 +- .../airbyte/cdk/command/MockCatalogFactory.kt | 27 ++ .../io/airbyte/cdk/state/StateManagerTest.kt | 36 +-- .../airbyte/cdk/state/StreamsManagerTest.kt | 282 ++++++++++++++++++ 7 files changed, 412 insertions(+), 71 deletions(-) rename airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/{StreamManager.kt => StreamsManager.kt} (61%) create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockCatalogFactory.kt create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StreamsManagerTest.kt diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt index c476c049d840..e5ac1ddf92c0 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt @@ -25,8 +25,12 @@ data class DestinationCatalog( } } +interface DestinationCatalogFactory { + fun make(): DestinationCatalog +} + @Factory -class DestinationCatalogFactory( +class DefaultDestinationCatalogFactory( private val catalog: ConfiguredAirbyteCatalog, private val streamFactory: DestinationStreamFactory ) { diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt index 50c9f637ef88..a09979a2d7d1 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt @@ -43,13 +43,12 @@ class DestinationMessageQueueWriter( /* If the input message represents a record. */ is DestinationRecordMessage -> { val manager = streamsManager.getManager(message.stream) - val index = manager.countRecordIn(sizeBytes) when (message) { /* If a data record */ is DestinationRecord -> { val wrapped = StreamRecordWrapped( - index = index, + index = manager.countRecordIn(), sizeBytes = sizeBytes, record = message ) @@ -58,7 +57,7 @@ class DestinationMessageQueueWriter( /* If an end-of-stream marker. */ is DestinationStreamComplete -> { - val wrapped = StreamCompleteWrapped(index) + val wrapped = StreamCompleteWrapped(index = manager.countEndOfStream()) messageQueue.getChannel(message.stream).send(wrapped) } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamsManager.kt similarity index 61% rename from airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt rename to airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamsManager.kt index 4b74df5bf3ba..c58e126a9f8c 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamsManager.kt @@ -15,15 +15,17 @@ import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Factory import jakarta.inject.Singleton import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicLong -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext +import kotlinx.coroutines.channels.Channel /** Manages the state of all streams in the destination. */ interface StreamsManager { + /** Get the manager for the given stream. Throws an exception if the stream is not found. */ fun getManager(stream: DestinationStream): StreamManager - suspend fun awaitAllStreamsComplete() + + /** Suspend until all streams are closed. */ + suspend fun awaitAllStreamsClosed() } class DefaultStreamsManager( @@ -33,68 +35,98 @@ class DefaultStreamsManager( return streamManagers[stream] ?: throw IllegalArgumentException("Stream not found: $stream") } - override suspend fun awaitAllStreamsComplete() { + override suspend fun awaitAllStreamsClosed() { streamManagers.forEach { (_, manager) -> manager.awaitStreamClosed() } } } /** Manages the state of a single stream. */ interface StreamManager { - fun countRecordIn(sizeBytes: Long): Long + /** Count incoming record and return the record's *index*. */ + fun countRecordIn(): Long + + /** + * Count the end-of-stream. Expect this exactly once. Expect no further `countRecordIn`, and + * expect that `markClosed` will always occur after this. + */ + fun countEndOfStream(): Long + + /** + * Mark a checkpoint in the stream and return the current index and the number of records since + * the last one. + * + * NOTE: Single-writer. If in the future multiple threads set checkpoints, this method should be + * synchronized. + */ fun markCheckpoint(): Pair + + /** Record that the given batch's state has been reached for the associated range(s). */ fun updateBatchState(batch: BatchEnvelope) + + /** + * True if all are true: + * * all records have been seen (ie, we've counted an end-of-stream) + * * a [Batch.State.COMPLETE] batch range has been seen covering every record + * + * Does NOT require that the stream be closed. + */ fun isBatchProcessingComplete(): Boolean + + /** + * True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is + * implicitly true if they have all reached [Batch.State.COMPLETE]. + */ fun areRecordsPersistedUntil(index: Long): Boolean + /** Mark the stream as closed. This should only be called after all records have been read. */ fun markClosed() + + /** True if the stream has been marked as closed. */ fun streamIsClosed(): Boolean + + /** Suspend until the stream is closed. */ suspend fun awaitStreamClosed() } -/** - * Maintains a map of stream -> status metadata, and a map of batch state -> record ranges for which - * that state has been reached. - * - * TODO: Log a detailed report of the stream status on a regular cadence. - */ class DefaultStreamManager( val stream: DestinationStream, ) : StreamManager { private val log = KotlinLogging.logger {} - data class StreamStatus( - val recordCount: AtomicLong = AtomicLong(0), - val totalBytes: AtomicLong = AtomicLong(0), - val enqueuedSize: AtomicLong = AtomicLong(0), - val lastCheckpoint: AtomicLong = AtomicLong(0L), - val closedLatch: CountDownLatch = CountDownLatch(1), - ) + private val recordCount = AtomicLong(0) + private val lastCheckpoint = AtomicLong(0L) + private val readIsClosed = AtomicBoolean(false) + private val streamIsClosed = AtomicBoolean(false) + private val closedLock = Channel() - private val streamStatus: StreamStatus = StreamStatus() private val rangesState: ConcurrentHashMap> = ConcurrentHashMap() init { Batch.State.entries.forEach { rangesState[it] = TreeRangeSet.create() } } - override fun countRecordIn(sizeBytes: Long): Long { - val index = streamStatus.recordCount.getAndIncrement() - streamStatus.totalBytes.addAndGet(sizeBytes) - streamStatus.enqueuedSize.addAndGet(sizeBytes) - return index + override fun countRecordIn(): Long { + if (readIsClosed.get()) { + throw IllegalStateException("Stream is closed for reading") + } + + return recordCount.getAndIncrement() + } + + override fun countEndOfStream(): Long { + if (readIsClosed.getAndSet(true)) { + throw IllegalStateException("Stream is closed for reading") + } + + return recordCount.get() } - /** - * Mark a checkpoint in the stream and return the current index and the number of records since - * the last one. - */ override fun markCheckpoint(): Pair { - val index = streamStatus.recordCount.get() - val lastCheckpoint = streamStatus.lastCheckpoint.getAndSet(index) + val index = recordCount.get() + val lastCheckpoint = lastCheckpoint.getAndSet(index) return Pair(index, index - lastCheckpoint) } - /** Record that the given batch's state has been reached for the associated range(s). */ override fun updateBatchState(batch: BatchEnvelope) { val stateRanges = rangesState[batch.batch.state] @@ -112,37 +144,44 @@ class DefaultStreamManager( log.info { "Updated ranges for $stream[${batch.batch.state}]: $stateRanges" } } - /** True if all records in [0, index] have reached the given state. */ + /** True if all records in `[0, index)` have reached the given state. */ private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean { - val completeRanges = rangesState[state]!! return completeRanges.encloses(Range.closedOpen(0L, index)) } - /** True if all records have associated [Batch.State.COMPLETE] batches. */ override fun isBatchProcessingComplete(): Boolean { - return isProcessingCompleteForState(streamStatus.recordCount.get(), Batch.State.COMPLETE) + /* If the stream hasn't been fully read, it can't be done. */ + if (!readIsClosed.get()) { + return false + } + + return isProcessingCompleteForState(recordCount.get(), Batch.State.COMPLETE) } - /** - * True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is - * implicitly true if they have all reached [Batch.State.COMPLETE]. - */ override fun areRecordsPersistedUntil(index: Long): Boolean { return isProcessingCompleteForState(index, Batch.State.PERSISTED) || isProcessingCompleteForState(index, Batch.State.COMPLETE) // complete => persisted } override fun markClosed() { - streamStatus.closedLatch.countDown() + if (!readIsClosed.get()) { + throw IllegalStateException("Stream must be fully read before it can be closed") + } + + if (streamIsClosed.compareAndSet(false, true)) { + closedLock.trySend(Unit) + } } override fun streamIsClosed(): Boolean { - return streamStatus.closedLatch.count == 0L + return streamIsClosed.get() } override suspend fun awaitStreamClosed() { - withContext(Dispatchers.IO) { streamStatus.closedLatch.await() } + if (!streamIsClosed.get()) { + closedLock.receive() + } } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TeardownTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TeardownTask.kt index 52fec0acaf3d..5d76c2b260f9 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TeardownTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TeardownTask.kt @@ -34,7 +34,7 @@ class TeardownTask( } /** Ensure we don't run until all streams have completed */ - streamsManager.awaitAllStreamsComplete() + streamsManager.awaitAllStreamsClosed() destination.teardown() taskLauncher.stop() diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockCatalogFactory.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockCatalogFactory.kt new file mode 100644 index 000000000000..dc19a9e28452 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockCatalogFactory.kt @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.command + +import io.micronaut.context.annotation.Factory +import io.micronaut.context.annotation.Replaces +import io.micronaut.context.annotation.Requires +import jakarta.inject.Named +import jakarta.inject.Singleton + +@Factory +@Replaces(factory = DestinationCatalogFactory::class) +@Requires(env = ["test"]) +class MockCatalogFactory : DestinationCatalogFactory { + companion object { + val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1")) + val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2")) + } + + @Singleton + @Named("mockCatalog") + override fun make(): DestinationCatalog { + return DestinationCatalog(streams = listOf(stream1, stream2)) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt index 5c34cd8446e8..edbdaad56100 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt @@ -8,17 +8,17 @@ import com.google.common.collect.Range import com.google.common.collect.RangeSet import com.google.common.collect.TreeRangeSet import io.airbyte.cdk.command.DestinationCatalog -import io.airbyte.cdk.command.DestinationCatalogFactory import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1 +import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream2 import io.airbyte.cdk.message.Batch import io.airbyte.cdk.message.BatchEnvelope import io.airbyte.cdk.message.MessageConverter -import io.micronaut.context.annotation.Factory import io.micronaut.context.annotation.Prototype -import io.micronaut.context.annotation.Replaces import io.micronaut.context.annotation.Requires import io.micronaut.test.extensions.junit5.annotation.MicronautTest import jakarta.inject.Inject +import jakarta.inject.Named import jakarta.inject.Singleton import java.util.function.Consumer import java.util.stream.Stream @@ -29,25 +29,10 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -@MicronautTest +@MicronautTest(environments = ["StateManagerTest"]) class StateManagerTest { @Inject lateinit var stateManager: TestStateManager - companion object { - val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1")) - val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2")) - } - - @Factory - @Replaces(factory = DestinationCatalogFactory::class) - class MockCatalogFactory { - @Singleton - @Requires(env = ["test"]) - fun make(): DestinationCatalog { - return DestinationCatalog(streams = listOf(stream1, stream2)) - } - } - /** * Test state messages. * @@ -95,7 +80,11 @@ class StateManagerTest { class MockStreamManager : StreamManager { var persistedRanges: RangeSet = TreeRangeSet.create() - override fun countRecordIn(sizeBytes: Long): Long { + override fun countRecordIn(): Long { + throw NotImplementedError() + } + + override fun countEndOfStream(): Long { throw NotImplementedError() } @@ -129,7 +118,8 @@ class StateManagerTest { } @Prototype - class MockStreamsManager(catalog: DestinationCatalog) : StreamsManager { + @Requires(env = ["StateManagerTest"]) + class MockStreamsManager(@Named("mockCatalog") catalog: DestinationCatalog) : StreamsManager { private val mockManagers = catalog.streams.associateWith { MockStreamManager() } fun addPersistedRanges(stream: DestinationStream, ranges: List>) { @@ -141,14 +131,14 @@ class StateManagerTest { ?: throw IllegalArgumentException("Stream not found: $stream") } - override suspend fun awaitAllStreamsComplete() { + override suspend fun awaitAllStreamsClosed() { throw NotImplementedError() } } @Prototype class TestStateManager( - override val catalog: DestinationCatalog, + @Named("mockCatalog") override val catalog: DestinationCatalog, override val streamsManager: MockStreamsManager, override val outputFactory: MessageConverter, override val outputConsumer: MockOutputConsumer diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StreamsManagerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StreamsManagerTest.kt new file mode 100644 index 000000000000..5ba013f4fab4 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StreamsManagerTest.kt @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import com.google.common.collect.Range +import io.airbyte.cdk.command.DestinationCatalog +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1 +import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream2 +import io.airbyte.cdk.message.Batch +import io.airbyte.cdk.message.BatchEnvelope +import io.airbyte.cdk.message.SimpleBatch +import io.micronaut.test.extensions.junit5.annotation.MicronautTest +import jakarta.inject.Inject +import jakarta.inject.Named +import java.util.concurrent.atomic.AtomicBoolean +import java.util.stream.Stream +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource + +@MicronautTest +class StreamsManagerTest { + @Inject @Named("mockCatalog") lateinit var catalog: DestinationCatalog + + @Test + fun testCountRecordsAndCheckpoint() { + val streamsManager = StreamsManagerFactory(catalog).make() + val manager1 = streamsManager.getManager(stream1) + val manager2 = streamsManager.getManager(stream2) + + // Incrementing once yields (n, n) + repeat(10) { manager1.countRecordIn() } + val (index, count) = manager1.markCheckpoint() + + Assertions.assertEquals(10, index) + Assertions.assertEquals(10, count) + + // Incrementing a second time yields (n + m, m) + repeat(5) { manager1.countRecordIn() } + val (index2, count2) = manager1.markCheckpoint() + + Assertions.assertEquals(15, index2) + Assertions.assertEquals(5, count2) + + // Never incrementing yields (0, 0) + val (index3, count3) = manager2.markCheckpoint() + + Assertions.assertEquals(0, index3) + Assertions.assertEquals(0, count3) + + // Incrementing twice in a row yields (n + m + 0, 0) + val (index4, count4) = manager1.markCheckpoint() + + Assertions.assertEquals(15, index4) + Assertions.assertEquals(0, count4) + } + + @Test + fun testGettingNonexistentManagerFails() { + val streamsManager = StreamsManagerFactory(catalog).make() + Assertions.assertThrows(IllegalArgumentException::class.java) { + streamsManager.getManager( + DestinationStream(DestinationStream.Descriptor("test", "non-existent")) + ) + } + } + + sealed class TestEvent + data class SetRecordCount(val count: Long) : TestEvent() + data object SetEndOfStream : TestEvent() + data class AddPersisted(val firstIndex: Long, val lastIndex: Long) : TestEvent() + data class AddComplete(val firstIndex: Long, val lastIndex: Long) : TestEvent() + data class ExpectPersistedUntil(val end: Long, val expectation: Boolean = true) : TestEvent() + data class ExpectComplete(val expectation: Boolean = true) : TestEvent() + + data class TestCase( + val name: String, + val events: List>, + ) + + class TestUpdateBatchStateProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext): Stream { + return listOf( + TestCase( + "Single stream, single batch", + listOf( + Pair(stream1, SetRecordCount(10)), + Pair(stream1, AddPersisted(0, 9)), + Pair(stream1, ExpectPersistedUntil(9)), + Pair(stream1, ExpectPersistedUntil(10)), + Pair(stream1, ExpectComplete(false)), + Pair(stream1, ExpectPersistedUntil(11, false)), + Pair(stream2, ExpectPersistedUntil(10, false)), + ) + ), + TestCase( + "Single stream, multiple batches", + listOf( + Pair(stream1, SetRecordCount(10)), + Pair(stream1, AddPersisted(0, 4)), + Pair(stream1, ExpectPersistedUntil(4)), + Pair(stream1, AddPersisted(5, 9)), + Pair(stream1, ExpectPersistedUntil(9)), + Pair(stream1, ExpectPersistedUntil(10)), + Pair(stream1, ExpectComplete(false)), + Pair(stream1, AddComplete(0, 9)), + Pair(stream1, ExpectComplete(false)), + Pair(stream1, SetEndOfStream), + Pair(stream1, ExpectComplete(true)), + Pair(stream1, ExpectPersistedUntil(11, false)), + Pair(stream2, ExpectPersistedUntil(10, false)), + ) + ), + TestCase( + "Single stream, multiple batches, out of order", + listOf( + Pair(stream1, SetRecordCount(10)), + Pair(stream1, AddPersisted(5, 9)), + Pair(stream1, ExpectPersistedUntil(10, false)), + Pair(stream1, AddPersisted(0, 4)), + Pair(stream1, ExpectPersistedUntil(10)), + Pair(stream1, ExpectComplete(false)), + Pair(stream1, SetEndOfStream), + Pair(stream1, AddComplete(5, 9)), + Pair(stream1, ExpectComplete(false)), + Pair(stream1, AddComplete(0, 4)), + Pair(stream1, ExpectComplete(true)), + ) + ), + TestCase( + "multiple streams", + listOf( + Pair(stream1, SetRecordCount(10)), + Pair(stream2, SetRecordCount(20)), + Pair(stream2, AddPersisted(0, 9)), + Pair(stream2, ExpectPersistedUntil(10, true)), + Pair(stream1, ExpectPersistedUntil(10, false)), + Pair(stream2, SetEndOfStream), + Pair(stream2, ExpectComplete(false)), + Pair(stream1, AddPersisted(0, 9)), + Pair(stream1, ExpectPersistedUntil(10)), + Pair(stream1, ExpectComplete(false)), + Pair(stream2, AddComplete(10, 20)), + Pair(stream2, ExpectComplete(false)), + Pair(stream1, SetEndOfStream), + Pair(stream1, ExpectComplete(false)), + Pair(stream1, AddComplete(0, 9)), + Pair(stream1, ExpectComplete(true)), + Pair(stream2, AddComplete(0, 9)), + Pair(stream2, ExpectPersistedUntil(20, true)), + Pair(stream2, ExpectComplete(true)), + ) + ) + ) + .map { Arguments.of(it) } + .stream() + } + } + + @ParameterizedTest + @ArgumentsSource(TestUpdateBatchStateProvider::class) + fun testUpdateBatchState(testCase: TestCase) { + val streamsManager = StreamsManagerFactory(catalog).make() + testCase.events.forEach { (stream, event) -> + val manager = streamsManager.getManager(stream) + when (event) { + is SetRecordCount -> repeat(event.count.toInt()) { manager.countRecordIn() } + is SetEndOfStream -> manager.countEndOfStream() + is AddPersisted -> + manager.updateBatchState( + BatchEnvelope( + SimpleBatch(Batch.State.PERSISTED), + Range.closed(event.firstIndex, event.lastIndex) + ) + ) + is AddComplete -> + manager.updateBatchState( + BatchEnvelope( + SimpleBatch(Batch.State.COMPLETE), + Range.closed(event.firstIndex, event.lastIndex) + ) + ) + is ExpectPersistedUntil -> + Assertions.assertEquals( + event.expectation, + manager.areRecordsPersistedUntil(event.end), + "$stream: ${testCase.name}: ${event.end}" + ) + is ExpectComplete -> + Assertions.assertEquals( + event.expectation, + manager.isBatchProcessingComplete(), + "$stream: ${testCase.name}" + ) + } + } + } + + @Test + fun testCannotUpdateOrCloseReadClosedStream() { + val streamsManager = StreamsManagerFactory(catalog).make() + val manager = streamsManager.getManager(stream1) + + // Can't close before end-of-stream + Assertions.assertThrows(IllegalStateException::class.java) { manager.markClosed() } + + manager.countEndOfStream() + + // Can't update after end-of-stream + Assertions.assertThrows(IllegalStateException::class.java) { manager.countRecordIn() } + + Assertions.assertThrows(IllegalStateException::class.java) { manager.countEndOfStream() } + + // Can close now + Assertions.assertDoesNotThrow(manager::markClosed) + } + + @Test + fun testAwaitStreamClosed() = runTest { + val streamsManager = StreamsManagerFactory(catalog).make() + val manager = streamsManager.getManager(stream1) + val hasClosed = AtomicBoolean(false) + + val job = launch { + manager.awaitStreamClosed() + hasClosed.set(true) + } + + Assertions.assertFalse(hasClosed.get()) + manager.countEndOfStream() + manager.markClosed() + try { + withTimeout(5000) { job.join() } + } catch (e: Exception) { + Assertions.fail("Stream did not close in time") + } + Assertions.assertTrue(hasClosed.get()) + } + + @Test + fun testAwaitAllStreamsClosed() = runTest { + val streamsManager = StreamsManagerFactory(catalog).make() + val manager1 = streamsManager.getManager(stream1) + val manager2 = streamsManager.getManager(stream2) + val allHaveClosed = AtomicBoolean(false) + + val awaitStream1 = launch { manager1.awaitStreamClosed() } + + val awaitAllStreams = launch { + streamsManager.awaitAllStreamsClosed() + allHaveClosed.set(true) + } + + Assertions.assertFalse(allHaveClosed.get()) + manager1.countEndOfStream() + manager1.markClosed() + try { + withTimeout(5000) { awaitStream1.join() } + } catch (e: Exception) { + Assertions.fail("Stream1 did not close in time") + } + Assertions.assertFalse(allHaveClosed.get()) + manager2.countEndOfStream() + manager2.markClosed() + try { + withTimeout(5000) { awaitAllStreams.join() } + } catch (e: Exception) { + Assertions.fail("Streams did not close in time") + } + Assertions.assertTrue(allHaveClosed.get()) + } +}