diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt index 1c7697579d05..ae388a275bb6 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt @@ -38,9 +38,9 @@ data class DestinationStreamComplete( ) : DestinationRecordMessage() /** State. */ -sealed class DestinationStateMessage : DestinationMessage() { +sealed class CheckpointMessage : DestinationMessage() { data class Stats(val recordCount: Long) - data class StreamState( + data class StreamCheckpoint( val stream: DestinationStream, val state: JsonNode, ) @@ -48,26 +48,26 @@ sealed class DestinationStateMessage : DestinationMessage() { abstract val sourceStats: Stats abstract val destinationStats: Stats? - abstract fun withDestinationStats(stats: Stats): DestinationStateMessage + abstract fun withDestinationStats(stats: Stats): CheckpointMessage } -data class DestinationStreamState( - val streamState: StreamState, +data class StreamCheckpoint( + val streamCheckpoint: StreamCheckpoint, override val sourceStats: Stats, override val destinationStats: Stats? = null -) : DestinationStateMessage() { +) : CheckpointMessage() { override fun withDestinationStats(stats: Stats) = - DestinationStreamState(streamState, sourceStats, stats) + StreamCheckpoint(streamCheckpoint, sourceStats, stats) } -data class DestinationGlobalState( +data class GlobalCheckpoint( val state: JsonNode, override val sourceStats: Stats, override val destinationStats: Stats? = null, - val streamStates: List = emptyList() -) : DestinationStateMessage() { + val streamCheckpoints: List = emptyList() +) : CheckpointMessage() { override fun withDestinationStats(stats: Stats) = - DestinationGlobalState(state, sourceStats, stats, streamStates) + GlobalCheckpoint(state, sourceStats, stats, streamCheckpoints) } /** Catchall for anything unimplemented. */ @@ -108,21 +108,21 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) { AirbyteMessage.Type.STATE -> { when (message.state.type) { AirbyteStateMessage.AirbyteStateType.STREAM -> - DestinationStreamState( - streamState = fromAirbyteStreamState(message.state.stream), + StreamCheckpoint( + streamCheckpoint = fromAirbyteStreamState(message.state.stream), sourceStats = - DestinationStateMessage.Stats( + CheckpointMessage.Stats( recordCount = message.state.sourceStats.recordCount.toLong() ) ) AirbyteStateMessage.AirbyteStateType.GLOBAL -> - DestinationGlobalState( + GlobalCheckpoint( sourceStats = - DestinationStateMessage.Stats( + CheckpointMessage.Stats( recordCount = message.state.sourceStats.recordCount.toLong() ), state = message.state.global.sharedState, - streamStates = + streamCheckpoints = message.state.global.streamStates.map { fromAirbyteStreamState(it) } ) else -> // TODO: Do we still need to handle LEGACY? @@ -135,9 +135,9 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) { private fun fromAirbyteStreamState( streamState: AirbyteStreamState - ): DestinationStateMessage.StreamState { + ): CheckpointMessage.StreamCheckpoint { val descriptor = streamState.streamDescriptor - return DestinationStateMessage.StreamState( + return CheckpointMessage.StreamCheckpoint( stream = catalog.getStream(namespace = descriptor.namespace, name = descriptor.name), state = streamState.streamState ) diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt index 97d40900b210..5a325115c9d0 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt @@ -13,19 +13,19 @@ import io.airbyte.protocol.models.v0.StreamDescriptor import jakarta.inject.Singleton /** - * Converts the internal @[DestinationStateMessage] case class to the Protocol state messages - * required by @[io.airbyte.cdk.output.OutputConsumer] + * Converts the internal @[CheckpointMessage] case class to the Protocol state messages required by + * @[io.airbyte.cdk.output.OutputConsumer] */ interface MessageConverter { fun from(message: T): U } @Singleton -class DefaultMessageConverter : MessageConverter { - override fun from(message: DestinationStateMessage): AirbyteMessage { +class DefaultMessageConverter : MessageConverter { + override fun from(message: CheckpointMessage): AirbyteMessage { val state = when (message) { - is DestinationStreamState -> + is StreamCheckpoint -> AirbyteStateMessage() .withSourceStats( AirbyteStateStats() @@ -40,8 +40,8 @@ class DefaultMessageConverter : MessageConverter + .withStream(fromStreamState(message.streamCheckpoint)) + is GlobalCheckpoint -> AirbyteStateMessage() .withSourceStats( AirbyteStateStats() @@ -56,21 +56,23 @@ class DefaultMessageConverter : MessageConverter { /** * Routes @[DestinationRecordMessage]s by stream to the appropriate channel and @ - * [DestinationStateMessage]s to the state manager. + * [CheckpointMessage]s to the state manager. * * TODO: Handle other message types. */ @@ -31,7 +31,7 @@ class DestinationMessageQueueWriter( private val catalog: DestinationCatalog, private val messageQueue: MessageQueue, private val streamsManager: StreamsManager, - private val stateManager: StateManager + private val checkpointManager: CheckpointManager ) : MessageQueueWriter { /** * Deserialize and route the message to the appropriate channel. @@ -62,28 +62,30 @@ class DestinationMessageQueueWriter( } } } - is DestinationStateMessage -> { + is CheckpointMessage -> { when (message) { /** * For a stream state message, mark the checkpoint and add the message with * index and count to the state manager. Also, add the count to the destination * stats. */ - is DestinationStreamState -> { - val stream = message.streamState.stream + is StreamCheckpoint -> { + val stream = message.streamCheckpoint.stream val manager = streamsManager.getManager(stream) val (currentIndex, countSinceLast) = manager.markCheckpoint() val messageWithCount = - message.withDestinationStats( - DestinationStateMessage.Stats(countSinceLast) - ) - stateManager.addStreamState(stream, currentIndex, messageWithCount) + message.withDestinationStats(CheckpointMessage.Stats(countSinceLast)) + checkpointManager.addStreamCheckpoint( + stream, + currentIndex, + messageWithCount + ) } /** * For a global state message, collect the index per stream, but add the total * count to the destination stats. */ - is DestinationGlobalState -> { + is GlobalCheckpoint -> { val streamWithIndexAndCount = catalog.streams.map { stream -> val manager = streamsManager.getManager(stream) @@ -92,9 +94,9 @@ class DestinationMessageQueueWriter( } val totalCount = streamWithIndexAndCount.sumOf { it.third } val messageWithCount = - message.withDestinationStats(DestinationStateMessage.Stats(totalCount)) + message.withDestinationStats(CheckpointMessage.Stats(totalCount)) val streamIndexes = streamWithIndexAndCount.map { it.first to it.second } - stateManager.addGlobalState(streamIndexes, messageWithCount) + checkpointManager.addGlobalCheckpoint(streamIndexes, messageWithCount) } } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/CheckpointManager.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/CheckpointManager.kt new file mode 100644 index 000000000000..4b162b5b14ae --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/CheckpointManager.kt @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import io.airbyte.cdk.command.DestinationCatalog +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.message.CheckpointMessage +import io.airbyte.cdk.message.MessageConverter +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.github.oshai.kotlinlogging.KotlinLogging +import io.micronaut.core.util.clhm.ConcurrentLinkedHashMap +import jakarta.inject.Singleton +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicReference +import java.util.function.Consumer + +/** + * Interface for checkpoint management. Should accept stream and global checkpoints, as well as + * requests to flush all data-sufficient checkpoints. + */ +interface CheckpointManager { + fun addStreamCheckpoint(key: K, index: Long, checkpointMessage: T) + fun addGlobalCheckpoint(keyIndexes: List>, checkpointMessage: T) + fun flushReadyCheckpointMessages() +} + +/** + * Message-type agnostic streams checkpoint manager. + * + * Accepts global and stream checkpoints, and enforces that stream and global checkpoints are not + * mixed. Determines ready checkpoints by querying the StreamsManager for the checkpoint of the + * record index range associated with each checkpoint message. + * + * TODO: Force flush on a configured schedule + * + * TODO: Ensure that checkpoint is flushed at the end, and require that all checkpoints be flushed + * before the destination can succeed. + */ +abstract class StreamsCheckpointManager() : CheckpointManager { + private val log = KotlinLogging.logger {} + + abstract val catalog: DestinationCatalog + abstract val streamsManager: StreamsManager + abstract val outputFactory: MessageConverter + abstract val outputConsumer: Consumer + + data class GlobalCheckpoint( + val streamIndexes: List>, + val checkpointMessage: T + ) + + private val checkpointsAreGlobal: AtomicReference = AtomicReference(null) + private val streamCheckpoints: + ConcurrentHashMap> = + ConcurrentHashMap() + private val globalCheckpoints: ConcurrentLinkedQueue> = + ConcurrentLinkedQueue() + + override fun addStreamCheckpoint(key: DestinationStream, index: Long, checkpointMessage: T) { + if (checkpointsAreGlobal.updateAndGet { it == true } != false) { + throw IllegalStateException( + "Global checkpoints cannot be mixed with non-global checkpoints" + ) + } + + streamCheckpoints.compute(key) { _, indexToMessage -> + val map = + if (indexToMessage == null) { + // If the map doesn't exist yet, build it. + ConcurrentLinkedHashMap.Builder().maximumWeightedCapacity(1000).build() + } else { + if (indexToMessage.isNotEmpty()) { + // Make sure the messages are coming in order + val oldestIndex = indexToMessage.ascendingKeySet().first() + if (oldestIndex > index) { + throw IllegalStateException( + "Checkpoint message received out of order ($oldestIndex before $index)" + ) + } + } + indexToMessage + } + // Actually add the message + map[index] = checkpointMessage + map + } + + log.info { "Added checkpoint for stream: $key at index: $index" } + } + + // TODO: Is it an error if we don't get all the streams every time? + override fun addGlobalCheckpoint( + keyIndexes: List>, + checkpointMessage: T + ) { + if (checkpointsAreGlobal.updateAndGet { it != false } != true) { + throw IllegalStateException( + "Global checkpoint cannot be mixed with non-global checkpoints" + ) + } + + val head = globalCheckpoints.peek() + if (head != null) { + val keyIndexesByStream = keyIndexes.associate { it.first to it.second } + head.streamIndexes.forEach { + if (keyIndexesByStream[it.first]!! < it.second) { + throw IllegalStateException("Global checkpoint message received out of order") + } + } + } + + globalCheckpoints.add(GlobalCheckpoint(keyIndexes, checkpointMessage)) + log.info { "Added global checkpoint with stream indexes: $keyIndexes" } + } + + override fun flushReadyCheckpointMessages() { + /* + Iterate over the checkpoints in order, evicting each that passes + the persistence check. If a checkpoint is not persisted, then + we can break the loop since the checkpoints are ordered. For global + checkpoints, all streams must be persisted up to the checkpoint. + */ + when (checkpointsAreGlobal.get()) { + null -> log.info { "No checkpoints to flush" } + true -> flushGlobalCheckpoints() + false -> flushStreamCheckpoints() + } + } + + private fun flushGlobalCheckpoints() { + while (!globalCheckpoints.isEmpty()) { + val head = globalCheckpoints.peek() + val allStreamsPersisted = + head.streamIndexes.all { (stream, index) -> + streamsManager.getManager(stream).areRecordsPersistedUntil(index) + } + if (allStreamsPersisted) { + globalCheckpoints.poll() + val outMessage = outputFactory.from(head.checkpointMessage) + outputConsumer.accept(outMessage) + } else { + break + } + } + } + + private fun flushStreamCheckpoints() { + for (stream in catalog.streams) { + val manager = streamsManager.getManager(stream) + val streamCheckpoints = streamCheckpoints[stream] ?: return + for (index in streamCheckpoints.keys) { + if (manager.areRecordsPersistedUntil(index)) { + val checkpointMessage = + streamCheckpoints.remove(index) + ?: throw IllegalStateException("Checkpoint not found for index: $index") + log.info { "Flushing checkpoint for stream: $stream at index: $index" } + val outMessage = outputFactory.from(checkpointMessage) + outputConsumer.accept(outMessage) + } else { + break + } + } + } + } +} + +@Singleton +class DefaultCheckpointManager( + override val catalog: DestinationCatalog, + override val streamsManager: StreamsManager, + override val outputFactory: MessageConverter, + override val outputConsumer: Consumer +) : StreamsCheckpointManager() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StateManager.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StateManager.kt deleted file mode 100644 index 9c900b4e4379..000000000000 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StateManager.kt +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright (c) 2024 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.state - -import io.airbyte.cdk.command.DestinationCatalog -import io.airbyte.cdk.command.DestinationStream -import io.airbyte.cdk.message.DestinationStateMessage -import io.airbyte.cdk.message.MessageConverter -import io.airbyte.protocol.models.v0.AirbyteMessage -import io.github.oshai.kotlinlogging.KotlinLogging -import io.micronaut.core.util.clhm.ConcurrentLinkedHashMap -import jakarta.inject.Singleton -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicReference -import java.util.function.Consumer - -/** - * Interface for state management. Should accept stream and global state, as well as requests to - * flush all data-sufficient states. - */ -interface StateManager { - fun addStreamState(key: K, index: Long, stateMessage: T) - fun addGlobalState(keyIndexes: List>, stateMessage: T) - fun flushStates() -} - -/** - * Message-type agnostic streams state manager. - * - * Accepts global and stream states, and enforces that stream and global state are not mixed. - * Determines ready states by querying the StreamsManager for the state of the record index range - * associated with each state message. - * - * TODO: Force flush on a configured schedule - * - * TODO: Ensure that state is flushed at the end, and require that all state be flushed before the - * destination can succeed. - */ -abstract class StreamsStateManager() : StateManager { - private val log = KotlinLogging.logger {} - - abstract val catalog: DestinationCatalog - abstract val streamsManager: StreamsManager - abstract val outputFactory: MessageConverter - abstract val outputConsumer: Consumer - - data class GlobalState( - val streamIndexes: List>, - val stateMessage: T - ) - - private val stateIsGlobal: AtomicReference = AtomicReference(null) - private val streamStates: - ConcurrentHashMap> = - ConcurrentHashMap() - private val globalStates: ConcurrentLinkedQueue> = ConcurrentLinkedQueue() - - override fun addStreamState(key: DestinationStream, index: Long, stateMessage: T) { - if (stateIsGlobal.updateAndGet { it == true } != false) { - throw IllegalStateException("Global state cannot be mixed with non-global state") - } - - streamStates.compute(key) { _, indexToMessage -> - val map = - if (indexToMessage == null) { - // If the map doesn't exist yet, build it. - ConcurrentLinkedHashMap.Builder().maximumWeightedCapacity(1000).build() - } else { - if (indexToMessage.isNotEmpty()) { - // Make sure the messages are coming in order - val oldestIndex = indexToMessage.ascendingKeySet().first() - if (oldestIndex > index) { - throw IllegalStateException( - "State message received out of order ($oldestIndex before $index)" - ) - } - } - indexToMessage - } - // Actually add the message - map[index] = stateMessage - map - } - - log.info { "Added state for stream: $key at index: $index" } - } - - // TODO: Is it an error if we don't get all the streams every time? - override fun addGlobalState(keyIndexes: List>, stateMessage: T) { - if (stateIsGlobal.updateAndGet { it != false } != true) { - throw IllegalStateException("Global state cannot be mixed with non-global state") - } - - val head = globalStates.peek() - if (head != null) { - val keyIndexesByStream = keyIndexes.associate { it.first to it.second } - head.streamIndexes.forEach { - if (keyIndexesByStream[it.first]!! < it.second) { - throw IllegalStateException("Global state message received out of order") - } - } - } - - globalStates.add(GlobalState(keyIndexes, stateMessage)) - log.info { "Added global state with stream indexes: $keyIndexes" } - } - - override fun flushStates() { - /* - Iterate over the states in order, evicting each that passes - the persistence check. If a state is not persisted, then - we can break the loop since the states are ordered. For global - states, all streams must be persisted up to the checkpoint. - */ - when (stateIsGlobal.get()) { - null -> log.info { "No states to flush" } - true -> flushGlobalStates() - false -> flushStreamStates() - } - } - - private fun flushGlobalStates() { - while (!globalStates.isEmpty()) { - val head = globalStates.peek() - val allStreamsPersisted = - head.streamIndexes.all { (stream, index) -> - streamsManager.getManager(stream).areRecordsPersistedUntil(index) - } - if (allStreamsPersisted) { - globalStates.poll() - val outMessage = outputFactory.from(head.stateMessage) - outputConsumer.accept(outMessage) - } else { - break - } - } - } - - private fun flushStreamStates() { - for (stream in catalog.streams) { - val manager = streamsManager.getManager(stream) - val streamStates = streamStates[stream] ?: return - for (index in streamStates.keys) { - if (manager.areRecordsPersistedUntil(index)) { - val stateMessage = - streamStates.remove(index) - ?: throw IllegalStateException("State not found for index: $index") - log.info { "Flushing state for stream: $stream at index: $index" } - val outMessage = outputFactory.from(stateMessage) - outputConsumer.accept(outMessage) - } else { - break - } - } - } - } -} - -@Singleton -class DefaultStateManager( - override val catalog: DestinationCatalog, - override val streamsManager: StreamsManager, - override val outputFactory: MessageConverter, - override val outputConsumer: Consumer -) : StreamsStateManager() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/DestinationTaskLauncher.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/DestinationTaskLauncher.kt index c713aa654db2..edec860efc76 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/DestinationTaskLauncher.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/DestinationTaskLauncher.kt @@ -5,8 +5,11 @@ package io.airbyte.cdk.task import io.airbyte.cdk.command.DestinationCatalog +import io.airbyte.cdk.command.DestinationStream import io.airbyte.cdk.message.BatchEnvelope +import io.airbyte.cdk.message.CheckpointMessage import io.airbyte.cdk.message.SpooledRawMessagesLocalFile +import io.airbyte.cdk.state.CheckpointManager import io.airbyte.cdk.write.StreamLoader import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Factory @@ -24,6 +27,7 @@ import jakarta.inject.Singleton class DestinationTaskLauncher( private val catalog: DestinationCatalog, override val taskRunner: TaskRunner, + private val checkpointManager: CheckpointManager, private val setupTaskFactory: SetupTaskFactory, private val openStreamTaskFactory: OpenStreamTaskFactory, private val spillToDiskTaskFactory: SpillToDiskTaskFactory, @@ -74,6 +78,7 @@ class DestinationTaskLauncher( suspend fun startTeardownTask() { log.info { "Starting teardown task" } + checkpointManager.flushReadyCheckpointMessages() taskRunner.enqueue(teardownTaskFactory.make(this)) } } @@ -82,6 +87,7 @@ class DestinationTaskLauncher( class DestinationTaskLauncherFactory( private val catalog: DestinationCatalog, private val taskRunner: TaskRunner, + private val checkpointManager: CheckpointManager, private val setupTaskFactory: SetupTaskFactory, private val openStreamTaskFactory: OpenStreamTaskFactory, private val spillToDiskTaskFactory: SpillToDiskTaskFactory, @@ -96,6 +102,7 @@ class DestinationTaskLauncherFactory( return DestinationTaskLauncher( catalog, taskRunner, + checkpointManager, setupTaskFactory, openStreamTaskFactory, spillToDiskTaskFactory, diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt similarity index 85% rename from airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt rename to airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt index edbdaad56100..a8abba9ab742 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt @@ -30,8 +30,8 @@ import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource @MicronautTest(environments = ["StateManagerTest"]) -class StateManagerTest { - @Inject lateinit var stateManager: TestStateManager +class CheckpointManagerTest { + @Inject lateinit var checkpointManager: TestCheckpointManager /** * Test state messages. @@ -39,35 +39,36 @@ class StateManagerTest { * StateIn: What is passed to the manager. StateOut: What is sent from the manager to the output * consumer. */ - sealed class MockStateIn - data class MockStreamStateIn(val stream: DestinationStream, val payload: Int) : MockStateIn() - data class MockGlobalStateIn(val payload: Int) : MockStateIn() + sealed class MockCheckpointIn + data class MockStreamCheckpointIn(val stream: DestinationStream, val payload: Int) : + MockCheckpointIn() + data class MockGlobalCheckpointIn(val payload: Int) : MockCheckpointIn() - sealed class MockStateOut - data class MockStreamStateOut(val stream: DestinationStream, val payload: String) : - MockStateOut() - data class MockGlobalStateOut(val payload: String) : MockStateOut() + sealed class MockCheckpointOut + data class MockStreamCheckpointOut(val stream: DestinationStream, val payload: String) : + MockCheckpointOut() + data class MockGlobalCheckpointOut(val payload: String) : MockCheckpointOut() @Singleton - class MockStateMessageFactory : MessageConverter { - override fun from(message: MockStateIn): MockStateOut { + class MockStateMessageFactory : MessageConverter { + override fun from(message: MockCheckpointIn): MockCheckpointOut { return when (message) { - is MockStreamStateIn -> - MockStreamStateOut(message.stream, message.payload.toString()) - is MockGlobalStateIn -> MockGlobalStateOut(message.payload.toString()) + is MockStreamCheckpointIn -> + MockStreamCheckpointOut(message.stream, message.payload.toString()) + is MockGlobalCheckpointIn -> MockGlobalCheckpointOut(message.payload.toString()) } } } @Prototype - class MockOutputConsumer : Consumer { + class MockOutputConsumer : Consumer { val collectedStreamOutput = mutableMapOf>() val collectedGlobalOutput = mutableListOf() - override fun accept(t: MockStateOut) { + override fun accept(t: MockCheckpointOut) { when (t) { - is MockStreamStateOut -> + is MockStreamCheckpointOut -> collectedStreamOutput.getOrPut(t.stream) { mutableListOf() }.add(t.payload) - is MockGlobalStateOut -> collectedGlobalOutput.add(t.payload) + is MockGlobalCheckpointOut -> collectedGlobalOutput.add(t.payload) } } } @@ -137,23 +138,23 @@ class StateManagerTest { } @Prototype - class TestStateManager( + class TestCheckpointManager( @Named("mockCatalog") override val catalog: DestinationCatalog, override val streamsManager: MockStreamsManager, - override val outputFactory: MessageConverter, + override val outputFactory: MessageConverter, override val outputConsumer: MockOutputConsumer - ) : StreamsStateManager() + ) : StreamsCheckpointManager() sealed class TestEvent data class TestStreamMessage(val stream: DestinationStream, val index: Long, val message: Int) : TestEvent() { - fun toMockStateIn() = MockStreamStateIn(stream, message) + fun toMockCheckpointIn() = MockStreamCheckpointIn(stream, message) } data class TestGlobalMessage( val streamIndexes: List>, val message: Int ) : TestEvent() { - fun toMockStateIn() = MockGlobalStateIn(message) + fun toMockCheckpointIn() = MockGlobalCheckpointIn(message) } data class FlushPoint( val persistedRanges: Map>> = mapOf() @@ -168,7 +169,7 @@ class StateManagerTest { val expectedException: Class? = null ) - class StateManagerTestArgumentsProvider : ArgumentsProvider { + class CheckpointManagerTestArgumentsProvider : ArgumentsProvider { override fun provideArguments(context: ExtensionContext?): Stream { return listOf( TestCase( @@ -244,7 +245,7 @@ class StateManagerTest { expectedException = IllegalStateException::class.java ), TestCase( - name = "Global state, two messages, flush all", + name = "Global checkpoint, two messages, flush all", events = listOf( TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), @@ -260,7 +261,7 @@ class StateManagerTest { expectedGlobalOutput = listOf("1", "2") ), TestCase( - name = "Global state, two messages, range only covers the first", + name = "Global checkpoint, two messages, range only covers the first", events = listOf( TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), @@ -277,7 +278,7 @@ class StateManagerTest { ), TestCase( name = - "Global state, two messages, where the range only covers *one stream*", + "Global checkpoint, two messages, where the range only covers *one stream*", events = listOf( TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), @@ -292,7 +293,7 @@ class StateManagerTest { expectedGlobalOutput = listOf("1") ), TestCase( - name = "Global state, out of order (should fail)", + name = "Global checkpoint, out of order (should fail)", events = listOf( TestGlobalMessage(listOf(stream1 to 20L, stream2 to 30L), 2), @@ -307,7 +308,7 @@ class StateManagerTest { expectedException = IllegalStateException::class.java ), TestCase( - name = "Mixed: first stream state, then global (should fail)", + name = "Mixed: first stream checkpoint, then global (should fail)", events = listOf( TestStreamMessage(stream1, 10L, 1), @@ -322,7 +323,7 @@ class StateManagerTest { expectedException = IllegalStateException::class.java ), TestCase( - name = "Mixed: first global, then stream state (should fail)", + name = "Mixed: first global, then stream checkpoint (should fail)", events = listOf( TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), @@ -354,7 +355,7 @@ class StateManagerTest { expectedStreamOutput = mapOf() ), TestCase( - name = "Stream state, multiple flush points", + name = "Stream checkpoint, multiple flush points", events = listOf( TestStreamMessage(stream1, 10L, 1), @@ -367,7 +368,7 @@ class StateManagerTest { expectedStreamOutput = mapOf(stream1 to listOf("1", "2", "3")) ), TestCase( - name = "Global state, multiple flush points, no output", + name = "Global checkpoint, multiple flush points, no output", events = listOf( TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), @@ -384,7 +385,7 @@ class StateManagerTest { expectedGlobalOutput = listOf() ), TestCase( - name = "Global state, multiple flush points, no output until end", + name = "Global checkpoint, multiple flush points, no output until end", events = listOf( TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), @@ -412,20 +413,20 @@ class StateManagerTest { } @ParameterizedTest - @ArgumentsSource(StateManagerTestArgumentsProvider::class) - fun testAddingAndFlushingState(testCase: TestCase) { + @ArgumentsSource(CheckpointManagerTestArgumentsProvider::class) + fun testAddingAndFlushingCheckpoints(testCase: TestCase) { if (testCase.expectedException != null) { Assertions.assertThrows(testCase.expectedException) { runTestCase(testCase) } } else { runTestCase(testCase) Assertions.assertEquals( testCase.expectedStreamOutput, - stateManager.outputConsumer.collectedStreamOutput, + checkpointManager.outputConsumer.collectedStreamOutput, testCase.name ) Assertions.assertEquals( testCase.expectedGlobalOutput, - stateManager.outputConsumer.collectedGlobalOutput, + checkpointManager.outputConsumer.collectedGlobalOutput, testCase.name ) } @@ -435,16 +436,20 @@ class StateManagerTest { testCase.events.forEach { when (it) { is TestStreamMessage -> { - stateManager.addStreamState(it.stream, it.index, it.toMockStateIn()) + checkpointManager.addStreamCheckpoint( + it.stream, + it.index, + it.toMockCheckpointIn() + ) } is TestGlobalMessage -> { - stateManager.addGlobalState(it.streamIndexes, it.toMockStateIn()) + checkpointManager.addGlobalCheckpoint(it.streamIndexes, it.toMockCheckpointIn()) } is FlushPoint -> { it.persistedRanges.forEach { (stream, ranges) -> - stateManager.streamsManager.addPersistedRanges(stream, ranges) + checkpointManager.streamsManager.addPersistedRanges(stream, ranges) } - stateManager.flushStates() + checkpointManager.flushReadyCheckpointMessages() } } }