From 2b91a27afae91c3eaf83c82eae6b6bcaa1c67bb1 Mon Sep 17 00:00:00 2001 From: Johnny Schmidt Date: Fri, 20 Sep 2024 18:28:23 -0700 Subject: [PATCH] Bulk Load CDK: QueueWriter Unit Tests (#45107) --- .../airbyte/cdk/message/MessageQueueWriter.kt | 2 +- .../DestinationMessageQueueWriterTest.kt | 253 ++++++++++++++++++ .../cdk/message/MockStreamsManagerFactory.kt | 86 ++++++ .../cdk/state/CheckpointManagerTest.kt | 73 +---- 4 files changed, 342 insertions(+), 72 deletions(-) create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageQueueWriterTest.kt create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/MockStreamsManagerFactory.kt diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt index 9ea0862f6ae3..292b9db1faa5 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt @@ -22,11 +22,11 @@ interface MessageQueueWriter { * * TODO: Handle other message types. */ -@Singleton @SuppressFBWarnings( "NP_NONNULL_PARAM_VIOLATION", justification = "message is guaranteed to be non-null by Kotlin's type system" ) +@Singleton class DestinationMessageQueueWriter( private val catalog: DestinationCatalog, private val messageQueue: MessageQueue, diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageQueueWriterTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageQueueWriterTest.kt new file mode 100644 index 000000000000..5bf7a362a5e2 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageQueueWriterTest.kt @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.message + +import com.fasterxml.jackson.databind.node.JsonNodeFactory +import io.airbyte.cdk.command.DestinationCatalog +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1 +import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream2 +import io.airbyte.cdk.data.NullValue +import io.airbyte.cdk.state.CheckpointManager +import io.micronaut.context.annotation.Prototype +import io.micronaut.test.extensions.junit5.annotation.MicronautTest +import jakarta.inject.Inject +import jakarta.inject.Named +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +@MicronautTest(environments = ["MockStreamsManager"]) +class DestinationMessageQueueWriterTest { + @Inject lateinit var queueWriterFactory: TestDestinationMessageQueueWriterFactory + + @Prototype + class TestDestinationMessageQueueWriterFactory( + @Named("mockCatalog") private val catalog: DestinationCatalog, + val messageQueue: MockMessageQueue, + val streamsManager: MockStreamsManager, + val checkpointManager: MockCheckpointManager + ) { + fun make(): DestinationMessageQueueWriter { + return DestinationMessageQueueWriter( + catalog, + messageQueue, + streamsManager, + checkpointManager + ) + } + } + + class MockQueueChannel : QueueChannel { + val messages = mutableListOf() + var closed = false + + override suspend fun close() { + closed = true + } + + override suspend fun isClosed(): Boolean { + return closed + } + + override suspend fun send(message: DestinationRecordWrapped) { + messages.add(message) + } + + override suspend fun receive(): DestinationRecordWrapped { + return messages.removeAt(0) + } + } + + @Prototype + class MockMessageQueue : MessageQueue { + private val channels = + mutableMapOf>() + + override suspend fun getChannel( + key: DestinationStream + ): QueueChannel { + return channels.getOrPut(key) { MockQueueChannel() } + } + + override suspend fun acquireQueueBytesBlocking(bytes: Long) { + TODO("Not yet implemented") + } + + override suspend fun releaseQueueBytes(bytes: Long) { + TODO("Not yet implemented") + } + } + + @Prototype + class MockCheckpointManager : CheckpointManager { + val streamStates = + mutableMapOf>>() + val globalStates = + mutableListOf>, CheckpointMessage>>() + + override fun addStreamCheckpoint( + key: DestinationStream, + index: Long, + checkpointMessage: CheckpointMessage + ) { + streamStates.getOrPut(key) { mutableListOf() }.add(index to checkpointMessage) + } + + override fun addGlobalCheckpoint( + keyIndexes: List>, + checkpointMessage: CheckpointMessage + ) { + globalStates.add(keyIndexes to checkpointMessage) + } + + override fun flushReadyCheckpointMessages() { + TODO("Not yet implemented") + } + } + + private fun makeRecord(stream: DestinationStream, record: String): DestinationRecord { + return DestinationRecord( + stream = stream, + data = NullValue, + emittedAtMs = 0, + meta = null, + serialized = record + ) + } + + private fun makeStreamComplete(stream: DestinationStream): DestinationStreamComplete { + return DestinationStreamComplete(stream = stream, emittedAtMs = 0) + } + + private fun makeStreamState(stream: DestinationStream, recordCount: Long): CheckpointMessage { + return StreamCheckpoint( + checkpoint = + CheckpointMessage.Checkpoint(stream, JsonNodeFactory.instance.objectNode()), + sourceStats = CheckpointMessage.Stats(recordCount) + ) + } + + private fun makeGlobalState(recordCount: Long): CheckpointMessage { + return GlobalCheckpoint( + state = JsonNodeFactory.instance.objectNode(), + sourceStats = CheckpointMessage.Stats(recordCount), + checkpoints = emptyList() + ) + } + + @Test + fun testSendRecords() = runTest { + val writer = queueWriterFactory.make() + + val channel1 = queueWriterFactory.messageQueue.getChannel(stream1) as MockQueueChannel + val channel2 = queueWriterFactory.messageQueue.getChannel(stream2) as MockQueueChannel + + val manager1 = queueWriterFactory.streamsManager.getManager(stream1) as MockStreamManager + val manager2 = queueWriterFactory.streamsManager.getManager(stream2) as MockStreamManager + + (0 until 10).forEach { writer.publish(makeRecord(stream1, "test${it}"), it * 2L) } + Assertions.assertEquals(10, channel1.messages.size) + val expectedRecords = + (0 until 10).map { + StreamRecordWrapped(it.toLong(), it * 2L, makeRecord(stream1, "test${it}")) + } + + Assertions.assertEquals(expectedRecords, channel1.messages) + Assertions.assertEquals(10, manager1.countedRecords) + + Assertions.assertEquals(emptyList(), channel2.messages) + Assertions.assertEquals(0, manager2.countedRecords) + + writer.publish(makeRecord(stream2, "test"), 1L) + writer.publish(makeStreamComplete(stream1), 0L) + Assertions.assertEquals( + listOf(StreamRecordWrapped(0, 1L, makeRecord(stream2, "test"))), + channel2.messages + ) + Assertions.assertEquals(1, manager2.countedRecords) + + Assertions.assertFalse(manager2.countedEndOfStream) + Assertions.assertTrue(manager1.countedEndOfStream) + Assertions.assertEquals(11, channel1.messages.size) + Assertions.assertEquals(channel1.messages[10], StreamCompleteWrapped(10)) + } + + @Test + fun testSendStreamState() = runTest { + val writer = queueWriterFactory.make() + + data class TestEvent( + val stream: DestinationStream, + val count: Int, + val stateLookupIndex: Int, + val expectedStateIndex: Long + ) + + val batches = + listOf( + TestEvent(stream1, 10, 0, 10), + TestEvent(stream1, 5, 1, 15), + TestEvent(stream2, 4, 0, 4), + TestEvent(stream1, 3, 2, 18), + ) + + batches.forEach { (stream, count, stateLookupIndex, expectedCount) -> + repeat(count) { writer.publish(makeRecord(stream, "test"), 1L) } + writer.publish(makeStreamState(stream, count.toLong()), 0L) + val state = + queueWriterFactory.checkpointManager.streamStates[stream]!![stateLookupIndex] + Assertions.assertEquals(expectedCount, state.first) + Assertions.assertEquals(count.toLong(), state.second.destinationStats?.recordCount) + } + } + + @Test + fun testSendGlobalState() = runTest { + val writer = queueWriterFactory.make() + + open class TestEvent + data class AddRecords(val stream: DestinationStream, val count: Int) : TestEvent() + data class SendState( + val stateLookupIndex: Int, + val expectedStream1Count: Long, + val expectedStream2Count: Long, + val expectedStats: Long = 0 + ) : TestEvent() + + val batches = + listOf( + AddRecords(stream1, 10), + SendState(0, 10, 0, 10), + AddRecords(stream2, 5), + AddRecords(stream1, 4), + SendState(1, 14, 5, 9), + AddRecords(stream2, 3), + SendState(2, 14, 8, 3), + SendState(3, 14, 8, 0), + ) + + batches.forEach { event -> + when (event) { + is AddRecords -> { + repeat(event.count) { writer.publish(makeRecord(event.stream, "test"), 1L) } + } + is SendState -> { + writer.publish(makeGlobalState(event.expectedStream1Count), 0L) + val state = + queueWriterFactory.checkpointManager.globalStates[event.stateLookupIndex] + val stream1State = state.first.find { it.first == stream1 }!! + val stream2State = state.first.find { it.first == stream2 }!! + Assertions.assertEquals(event.expectedStream1Count, stream1State.second) + Assertions.assertEquals(event.expectedStream2Count, stream2State.second) + Assertions.assertEquals( + event.expectedStats, + state.second.destinationStats?.recordCount + ) + } + } + } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/MockStreamsManagerFactory.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/MockStreamsManagerFactory.kt new file mode 100644 index 000000000000..68058babc9e7 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/MockStreamsManagerFactory.kt @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.message + +import com.google.common.collect.Range +import com.google.common.collect.RangeSet +import com.google.common.collect.TreeRangeSet +import io.airbyte.cdk.command.DestinationCatalog +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.state.StreamManager +import io.airbyte.cdk.state.StreamsManager +import io.micronaut.context.annotation.Prototype +import io.micronaut.context.annotation.Requires +import jakarta.inject.Named + +class MockStreamManager : StreamManager { + var persistedRanges: RangeSet = TreeRangeSet.create() + var countedRecords: Long = 0 + var countedEndOfStream: Boolean = false + var lastCheckpoint: Long = 0 + + override fun countRecordIn(): Long { + return countedRecords++ + } + + override fun countEndOfStream(): Long { + return if (countedEndOfStream) { + throw IllegalStateException("End-of-stream already counted") + } else { + countedEndOfStream = true + countedRecords + } + } + + override fun markCheckpoint(): Pair { + val checkpoint = countedRecords + val count = checkpoint - lastCheckpoint + lastCheckpoint = checkpoint + + return Pair(checkpoint, count) + } + + override fun updateBatchState(batch: BatchEnvelope) { + throw NotImplementedError() + } + + override fun isBatchProcessingComplete(): Boolean { + throw NotImplementedError() + } + + override fun areRecordsPersistedUntil(index: Long): Boolean { + return persistedRanges.encloses(Range.closedOpen(0, index)) + } + + override fun markClosed() { + throw NotImplementedError() + } + + override fun streamIsClosed(): Boolean { + throw NotImplementedError() + } + + override suspend fun awaitStreamClosed() { + throw NotImplementedError() + } +} + +@Prototype +@Requires(env = ["MockStreamsManager"]) +class MockStreamsManager(@Named("mockCatalog") catalog: DestinationCatalog) : StreamsManager { + private val mockManagers = catalog.streams.associateWith { MockStreamManager() } + + fun addPersistedRanges(stream: DestinationStream, ranges: List>) { + mockManagers[stream]!!.persistedRanges.addAll(ranges) + } + + override fun getManager(stream: DestinationStream): StreamManager { + return mockManagers[stream] ?: throw IllegalArgumentException("Stream not found: $stream") + } + + override suspend fun awaitAllStreamsClosed() { + throw NotImplementedError() + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt index a8abba9ab742..514d2cb5eff0 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt @@ -5,17 +5,13 @@ package io.airbyte.cdk.state import com.google.common.collect.Range -import com.google.common.collect.RangeSet -import com.google.common.collect.TreeRangeSet import io.airbyte.cdk.command.DestinationCatalog import io.airbyte.cdk.command.DestinationStream import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1 import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream2 -import io.airbyte.cdk.message.Batch -import io.airbyte.cdk.message.BatchEnvelope import io.airbyte.cdk.message.MessageConverter +import io.airbyte.cdk.message.MockStreamsManager import io.micronaut.context.annotation.Prototype -import io.micronaut.context.annotation.Requires import io.micronaut.test.extensions.junit5.annotation.MicronautTest import jakarta.inject.Inject import jakarta.inject.Named @@ -29,10 +25,9 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -@MicronautTest(environments = ["StateManagerTest"]) +@MicronautTest(environments = ["MockStreamsManager"]) class CheckpointManagerTest { @Inject lateinit var checkpointManager: TestCheckpointManager - /** * Test state messages. * @@ -73,70 +68,6 @@ class CheckpointManagerTest { } } - /** - * The only thing we really need is `areRecordsPersistedUntil`. (Technically we're emulating the - * @[StreamManager] behavior here, since the state manager doesn't actually know what ranges are - * closed, but less than that would make the test unrealistic.) - */ - class MockStreamManager : StreamManager { - var persistedRanges: RangeSet = TreeRangeSet.create() - - override fun countRecordIn(): Long { - throw NotImplementedError() - } - - override fun countEndOfStream(): Long { - throw NotImplementedError() - } - - override fun markCheckpoint(): Pair { - throw NotImplementedError() - } - - override fun updateBatchState(batch: BatchEnvelope) { - throw NotImplementedError() - } - - override fun isBatchProcessingComplete(): Boolean { - throw NotImplementedError() - } - - override fun areRecordsPersistedUntil(index: Long): Boolean { - return persistedRanges.encloses(Range.closedOpen(0, index)) - } - - override fun markClosed() { - throw NotImplementedError() - } - - override fun streamIsClosed(): Boolean { - throw NotImplementedError() - } - - override suspend fun awaitStreamClosed() { - throw NotImplementedError() - } - } - - @Prototype - @Requires(env = ["StateManagerTest"]) - class MockStreamsManager(@Named("mockCatalog") catalog: DestinationCatalog) : StreamsManager { - private val mockManagers = catalog.streams.associateWith { MockStreamManager() } - - fun addPersistedRanges(stream: DestinationStream, ranges: List>) { - mockManagers[stream]!!.persistedRanges.addAll(ranges) - } - - override fun getManager(stream: DestinationStream): StreamManager { - return mockManagers[stream] - ?: throw IllegalArgumentException("Stream not found: $stream") - } - - override suspend fun awaitAllStreamsClosed() { - throw NotImplementedError() - } - } - @Prototype class TestCheckpointManager( @Named("mockCatalog") override val catalog: DestinationCatalog,