From e8402c9f14498ded2e718091d64936b5d2a71313 Mon Sep 17 00:00:00 2001 From: Johnny Schmidt Date: Fri, 4 Oct 2024 15:45:05 -0700 Subject: [PATCH] Bulk Load CDK: Checkpoint flush every 15 minutes (#46382) --- .../cdk/command/DestinationConfiguration.kt | 6 + .../io/airbyte/cdk/file/TimeProvider.kt | 25 +++ .../io/airbyte/cdk/state/CheckpointManager.kt | 158 ++++++++++++------ .../io/airbyte/cdk/state/EventConsumer.kt | 22 +++ .../io/airbyte/cdk/state/EventProducer.kt | 34 ++++ .../io/airbyte/cdk/state/FlushStrategy.kt | 30 +++- .../cdk/task/DestinationTaskLauncher.kt | 10 ++ .../task/TimedForcedCheckpointFlushTask.kt | 92 ++++++++++ .../command/MockDestinationConfiguration.kt | 22 +++ .../io/airbyte/cdk/file/MockTimeProvider.kt | 28 ++++ .../DestinationMessageQueueWriterTest.kt | 38 +---- .../cdk/state/CheckpointManagerTest.kt | 155 ++++++++++++++++- .../cdk/state/DefaultFlushStrategyTest.kt | 102 +++++++++++ .../cdk/state/MockCheckpointManager.kt | 54 ++++++ .../io/airbyte/cdk/state/SyncManagerUtils.kt | 23 +++ .../cdk/task/DestinationTaskLauncherTest.kt | 35 ++++ .../io/airbyte/cdk/task/MockTaskLauncher.kt | 5 + .../airbyte/cdk/task/SpillToDiskTaskTest.kt | 14 +- .../TimedForcedCheckpointFlushTaskTest.kt | 101 +++++++++++ 19 files changed, 855 insertions(+), 99 deletions(-) create mode 100644 airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/file/TimeProvider.kt create mode 100644 airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/EventConsumer.kt create mode 100644 airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/EventProducer.kt create mode 100644 airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TimedForcedCheckpointFlushTask.kt create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockDestinationConfiguration.kt create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/file/MockTimeProvider.kt create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/DefaultFlushStrategyTest.kt create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/MockCheckpointManager.kt create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/SyncManagerUtils.kt create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/TimedForcedCheckpointFlushTaskTest.kt diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationConfiguration.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationConfiguration.kt index e0a0a0e4b987..c35cd0672044 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationConfiguration.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationConfiguration.kt @@ -19,6 +19,12 @@ abstract class DestinationConfiguration : Configuration { open val estimatedRecordMemoryOverheadRatio: Double = 0.1 // 0 => No overhead, 1.0 => 100% overhead + /** + * If we have not flushed state checkpoints in this amount of time, make a best-effort attempt + * to force a flush. + */ + open val maxCheckpointFlushTimeMs: Long = 15 * 60 * 1000L // 15 minutes + /** * Micronaut factory which glues [ConfigurationSpecificationSupplier] and * [DestinationConfigurationFactory] together to produce a [DestinationConfiguration] singleton. diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/file/TimeProvider.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/file/TimeProvider.kt new file mode 100644 index 000000000000..9090007b83b4 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/file/TimeProvider.kt @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.file + +import io.micronaut.context.annotation.Secondary +import jakarta.inject.Singleton + +interface TimeProvider { + fun currentTimeMillis(): Long + suspend fun delay(ms: Long) +} + +@Singleton +@Secondary +class DefaultTimeProvider : TimeProvider { + override fun currentTimeMillis(): Long { + return System.currentTimeMillis() + } + + override suspend fun delay(ms: Long) { + kotlinx.coroutines.delay(ms) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/CheckpointManager.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/CheckpointManager.kt index 80021c42b448..db2dc62657e8 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/CheckpointManager.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/CheckpointManager.kt @@ -6,6 +6,7 @@ package io.airbyte.cdk.state import io.airbyte.cdk.command.DestinationCatalog import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.file.TimeProvider import io.airbyte.cdk.message.CheckpointMessage import io.airbyte.cdk.message.MessageConverter import io.airbyte.protocol.models.v0.AirbyteMessage @@ -15,18 +16,24 @@ import io.micronaut.core.util.clhm.ConcurrentLinkedHashMap import jakarta.inject.Singleton import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicReference import java.util.function.Consumer +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext /** * Interface for checkpoint management. Should accept stream and global checkpoints, as well as * requests to flush all data-sufficient checkpoints. */ interface CheckpointManager { - fun addStreamCheckpoint(key: K, index: Long, checkpointMessage: T) - fun addGlobalCheckpoint(keyIndexes: List>, checkpointMessage: T) + suspend fun addStreamCheckpoint(key: K, index: Long, checkpointMessage: T) + suspend fun addGlobalCheckpoint(keyIndexes: List>, checkpointMessage: T) suspend fun flushReadyCheckpointMessages() + suspend fun getLastSuccessfulFlushTimeMs(): Long + suspend fun getNextCheckpointIndexes(): Map } /** @@ -41,15 +48,17 @@ interface CheckpointManager { * TODO: Ensure that checkpoint is flushed at the end, and require that all checkpoints be flushed * before the destination can succeed. */ -abstract class StreamsCheckpointManager() : - CheckpointManager { +abstract class StreamsCheckpointManager : CheckpointManager { + private val log = KotlinLogging.logger {} private val flushLock = Mutex() + protected val lastFlushTimeMs = AtomicLong(0L) abstract val catalog: DestinationCatalog abstract val syncManager: SyncManager abstract val outputFactory: MessageConverter abstract val outputConsumer: Consumer + abstract val timeProvider: TimeProvider data class GlobalCheckpoint( val streamIndexes: List>, @@ -63,65 +72,73 @@ abstract class StreamsCheckpointManager() : private val globalCheckpoints: ConcurrentLinkedQueue> = ConcurrentLinkedQueue() - override fun addStreamCheckpoint( + override suspend fun addStreamCheckpoint( key: DestinationStream.Descriptor, index: Long, checkpointMessage: T ) { - if (checkpointsAreGlobal.updateAndGet { it == true } != false) { - throw IllegalStateException( - "Global checkpoints cannot be mixed with non-global checkpoints" - ) - } + flushLock.withLock { + if (checkpointsAreGlobal.updateAndGet { it == true } != false) { + throw IllegalStateException( + "Global checkpoints cannot be mixed with non-global checkpoints" + ) + } - streamCheckpoints.compute(key) { _, indexToMessage -> - val map = - if (indexToMessage == null) { - // If the map doesn't exist yet, build it. - ConcurrentLinkedHashMap.Builder().maximumWeightedCapacity(1000).build() - } else { - if (indexToMessage.isNotEmpty()) { - // Make sure the messages are coming in order - val oldestIndex = indexToMessage.ascendingKeySet().first() - if (oldestIndex > index) { - throw IllegalStateException( - "Checkpoint message received out of order ($oldestIndex before $index)" - ) + streamCheckpoints.compute(key) { _, indexToMessage -> + val map = + if (indexToMessage == null) { + // If the map doesn't exist yet, build it. + ConcurrentLinkedHashMap.Builder() + .maximumWeightedCapacity(1000) + .build() + } else { + if (indexToMessage.isNotEmpty()) { + // Make sure the messages are coming in order + val oldestIndex = indexToMessage.ascendingKeySet().first() + if (oldestIndex > index) { + throw IllegalStateException( + "Checkpoint message received out of order ($oldestIndex before $index)" + ) + } } + indexToMessage } - indexToMessage - } - // Actually add the message - map[index] = checkpointMessage - map - } + // Actually add the message + map[index] = checkpointMessage + map + } - log.info { "Added checkpoint for stream: $key at index: $index" } + log.info { "Added checkpoint for stream: $key at index: $index" } + } } // TODO: Is it an error if we don't get all the streams every time? - override fun addGlobalCheckpoint( + override suspend fun addGlobalCheckpoint( keyIndexes: List>, checkpointMessage: T ) { - if (checkpointsAreGlobal.updateAndGet { it != false } != true) { - throw IllegalStateException( - "Global checkpoint cannot be mixed with non-global checkpoints" - ) - } + flushLock.withLock { + if (checkpointsAreGlobal.updateAndGet { it != false } != true) { + throw IllegalStateException( + "Global checkpoint cannot be mixed with non-global checkpoints" + ) + } - val head = globalCheckpoints.peek() - if (head != null) { - val keyIndexesByStream = keyIndexes.associate { it.first to it.second } - head.streamIndexes.forEach { - if (keyIndexesByStream[it.first]!! < it.second) { - throw IllegalStateException("Global checkpoint message received out of order") + val head = globalCheckpoints.peek() + if (head != null) { + val keyIndexesByStream = keyIndexes.associate { it.first to it.second } + head.streamIndexes.forEach { + if (keyIndexesByStream[it.first]!! < it.second) { + throw IllegalStateException( + "Global checkpoint message received out of order" + ) + } } } - } - globalCheckpoints.add(GlobalCheckpoint(keyIndexes, checkpointMessage)) - log.info { "Added global checkpoint with stream indexes: $keyIndexes" } + globalCheckpoints.add(GlobalCheckpoint(keyIndexes, checkpointMessage)) + log.info { "Added global checkpoint with stream indexes: $keyIndexes" } + } } override suspend fun flushReadyCheckpointMessages() { @@ -146,7 +163,7 @@ abstract class StreamsCheckpointManager() : } } - private fun flushGlobalCheckpoints() { + private suspend fun flushGlobalCheckpoints() { while (!globalCheckpoints.isEmpty()) { val head = globalCheckpoints.peek() val allStreamsPersisted = @@ -155,15 +172,14 @@ abstract class StreamsCheckpointManager() : } if (allStreamsPersisted) { globalCheckpoints.poll() - val outMessage = outputFactory.from(head.checkpointMessage) - outputConsumer.accept(outMessage) + sendMessage(head.checkpointMessage) } else { break } } } - private fun flushStreamCheckpoints() { + private suspend fun flushStreamCheckpoints() { for (stream in catalog.streams) { val manager = syncManager.getStreamManager(stream.descriptor) val streamCheckpoints = streamCheckpoints[stream.descriptor] ?: return @@ -173,14 +189,45 @@ abstract class StreamsCheckpointManager() : streamCheckpoints.remove(index) ?: throw IllegalStateException("Checkpoint not found for index: $index") log.info { "Flushing checkpoint for stream: $stream at index: $index" } - val outMessage = outputFactory.from(checkpointMessage) - outputConsumer.accept(outMessage) + sendMessage(checkpointMessage) } else { break } } } } + + private suspend fun sendMessage(checkpointMessage: T) = + withContext(Dispatchers.IO) { + lastFlushTimeMs.set(timeProvider.currentTimeMillis()) + val outMessage = outputFactory.from(checkpointMessage) + outputConsumer.accept(outMessage) + } + + override suspend fun getLastSuccessfulFlushTimeMs(): Long = + // Return inside the lock to ensure the value reflects flushes in progress + flushLock.withLock { lastFlushTimeMs.get() } + + override suspend fun getNextCheckpointIndexes(): Map { + flushLock.withLock { + return when (checkpointsAreGlobal.get()) { + null -> { + emptyMap() + } + true -> { + val head = globalCheckpoints.peek() + head?.streamIndexes?.associate { it } ?: emptyMap() + } + false -> { + println("streamCheckpoints: $streamCheckpoints") + streamCheckpoints + .mapValues { it.value.ascendingKeySet().firstOrNull() } + .filterValues { it != null } + .mapValues { it.value!! } + } + } + } + } } @Singleton @@ -189,5 +236,10 @@ class DefaultCheckpointManager( override val catalog: DestinationCatalog, override val syncManager: SyncManager, override val outputFactory: MessageConverter, - override val outputConsumer: Consumer -) : StreamsCheckpointManager() + override val outputConsumer: Consumer, + override val timeProvider: TimeProvider +) : StreamsCheckpointManager() { + init { + lastFlushTimeMs.set(timeProvider.currentTimeMillis()) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/EventConsumer.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/EventConsumer.kt new file mode 100644 index 000000000000..ee36f3180482 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/EventConsumer.kt @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +/** + * A multi-reader consumer of events produced by a single-writer [EventProducer]. + * + * To use: + * - set up an [EventProducer] with the same type parameter as described in the producer's + * documentation + * - declare a subclass of [EventConsumer] and mark it `@Prototype` (multi-reader) + * - inject the producer and consumers where needed + */ +abstract class EventConsumer(producer: EventProducer) { + val channel = producer.subscribe() + + suspend fun consumeMaybe(): T? { + return channel.tryReceive().getOrNull() + } +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/EventProducer.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/EventProducer.kt new file mode 100644 index 000000000000..5e41df959eee --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/EventProducer.kt @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import java.util.concurrent.ConcurrentLinkedQueue +import kotlinx.coroutines.channels.Channel + +/** + * A single-writer event producer for a multi-reader consumer. + * + * To use + * - declare a subclass of [EventProducer] with the type parameter of the events to produce + * - mark it `@Singleton` (single-writer!) + * - configure [EventConsumer]s as described in the consumer's documentation + * - inject the producer and consumers where needed + * + * TODO: If we need to support different paradigms (multi-writer, etc.), abstract this into an + * interface and provide abstract implementations for each type. + */ +abstract class EventProducer { + private val subscribers = ConcurrentLinkedQueue>() + + fun subscribe(): Channel { + val channel = Channel(Channel.UNLIMITED) + subscribers.add(channel) + return channel + } + + suspend fun produce(event: T) { + subscribers.forEach { it.send(event) } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/FlushStrategy.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/FlushStrategy.kt index c4749d39284e..56ff3a09062c 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/FlushStrategy.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/FlushStrategy.kt @@ -5,10 +5,13 @@ package io.airbyte.cdk.state import com.google.common.collect.Range +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings import io.airbyte.cdk.command.DestinationConfiguration import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.task.ForceFlushEvent import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton +import java.util.concurrent.ConcurrentHashMap interface FlushStrategy { suspend fun shouldFlush( @@ -18,17 +21,42 @@ interface FlushStrategy { ): Boolean } +/** + * Flush whenever + * - bytes consumed >= the configured batch size + * - the current range of indexes being consumed encloses a force flush index + */ +@SuppressFBWarnings( + "NP_NONNULL_PARAM_VIOLATION", + justification = "message is guaranteed to be non-null by Kotlin's type system" +) @Singleton @Secondary class DefaultFlushStrategy( private val config: DestinationConfiguration, + private val eventConsumer: EventConsumer ) : FlushStrategy { + private val forceFlushIndexes = ConcurrentHashMap() override suspend fun shouldFlush( stream: DestinationStream, rangeRead: Range, bytesProcessed: Long ): Boolean { - return bytesProcessed >= config.recordBatchSizeBytes + if (bytesProcessed >= config.recordBatchSizeBytes) { + return true + } + + // Listen to the event stream for a new force flush index + val nextFlushIndex = eventConsumer.consumeMaybe()?.indexes?.get(stream.descriptor) + + // Always update the index if the new one is not null + return when ( + val testIndex = + forceFlushIndexes.compute(stream.descriptor) { _, v -> nextFlushIndex ?: v } + ) { + null -> false + else -> rangeRead.contains(testIndex) + } } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/DestinationTaskLauncher.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/DestinationTaskLauncher.kt index 3227ae171c2d..78d48e9150f8 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/DestinationTaskLauncher.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/DestinationTaskLauncher.kt @@ -38,6 +38,7 @@ interface DestinationTaskLauncher : TaskLauncher { suspend fun handleNewBatch(stream: DestinationStream, wrapped: BatchEnvelope<*>) suspend fun handleStreamClosed(stream: DestinationStream) suspend fun handleTeardownComplete() + suspend fun scheduleNextForceFlushAttempt(msFromNow: Long) } interface DestinationTaskLauncherExceptionHandler : @@ -95,6 +96,7 @@ class DefaultDestinationTaskLauncher( private val closeStreamTaskFactory: CloseStreamTaskFactory, private val teardownTaskFactory: TeardownTaskFactory, private val flushCheckpointsTaskFactory: FlushCheckpointsTaskFactory, + private val timedFlushTaskFactory: TimedForcedCheckpointFlushTaskFactory, private val exceptionHandler: TaskLauncherExceptionHandler ) : DestinationTaskLauncher { private val log = KotlinLogging.logger {} @@ -114,6 +116,8 @@ class DefaultDestinationTaskLauncher( val spillTask = spillToDiskTaskFactory.make(this, stream) enqueue(spillTask) } + val forceFlushTask = timedFlushTaskFactory.make(this) + enqueue(forceFlushTask) } /** Called when the initial destination setup completes. */ @@ -187,6 +191,12 @@ class DefaultDestinationTaskLauncher( enqueue(teardownTaskFactory.make(this)) } + /** Called when a force flush is scheduled. */ + override suspend fun scheduleNextForceFlushAttempt(msFromNow: Long) { + val task = timedFlushTaskFactory.make(this, msFromNow) + enqueue(task) + } + /** Called exactly once when all streams are closed. */ override suspend fun handleTeardownComplete() { stop() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TimedForcedCheckpointFlushTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TimedForcedCheckpointFlushTask.kt new file mode 100644 index 000000000000..2eb14ebc712e --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TimedForcedCheckpointFlushTask.kt @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.task + +import io.airbyte.cdk.command.DestinationConfiguration +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.file.TimeProvider +import io.airbyte.cdk.state.CheckpointManager +import io.airbyte.cdk.state.EventConsumer +import io.airbyte.cdk.state.EventProducer +import io.micronaut.context.annotation.Prototype +import io.micronaut.context.annotation.Secondary +import jakarta.inject.Singleton +import kotlinx.coroutines.delay + +interface TimedForcedCheckpointFlushTask : SyncTask + +class DefaultTimedForcedCheckpointFlushTask( + private val delayMs: Long, + private val cadenceMs: Long, + private val checkpointManager: CheckpointManager, + private val eventProducer: EventProducer, + private val timeProvider: TimeProvider, + private val taskLauncher: DestinationTaskLauncher +) : TimedForcedCheckpointFlushTask { + + override suspend fun execute() { + // Wait for the configured time + timeProvider.delay(delayMs) + + // Flush whatever is handy + checkpointManager.flushReadyCheckpointMessages() + + // Compare the time since the last successful flush to the configured interval + val lastFlushTimeMs = checkpointManager.getLastSuccessfulFlushTimeMs() + val nowMs = timeProvider.currentTimeMillis() + val timeSinceLastFlushMs = nowMs - lastFlushTimeMs + + if (timeSinceLastFlushMs >= cadenceMs) { + // If the max time has elapsed, emit a force flush event with provided next checkpoint + // indexes + val nextIndexes = checkpointManager.getNextCheckpointIndexes() + eventProducer.produce(ForceFlushEvent(nextIndexes)) + taskLauncher.scheduleNextForceFlushAttempt(cadenceMs) + } else { + // Otherwise schedule the next attempt to run at {time of last flush + configured + // interval} + taskLauncher.scheduleNextForceFlushAttempt(cadenceMs - timeSinceLastFlushMs) + } + } +} + +interface TimedForcedCheckpointFlushTaskFactory { + fun make( + taskLauncher: DestinationTaskLauncher, + delayMs: Long? = null + ): TimedForcedCheckpointFlushTask +} + +@Singleton +@Secondary +class DefaultTimedForcedCheckpointFlushTaskFactory( + private val config: DestinationConfiguration, + private val checkpointManager: CheckpointManager, + private val eventProducer: EventProducer, + private val timeProvider: TimeProvider +) : TimedForcedCheckpointFlushTaskFactory { + override fun make( + taskLauncher: DestinationTaskLauncher, + delayMs: Long? + ): TimedForcedCheckpointFlushTask { + return DefaultTimedForcedCheckpointFlushTask( + delayMs ?: config.maxCheckpointFlushTimeMs, + config.maxCheckpointFlushTimeMs, + checkpointManager, + eventProducer, + timeProvider, + taskLauncher + ) + } +} + +data class ForceFlushEvent(val indexes: Map) + +@Singleton @Secondary class DefaultForceFlushEventProducer : EventProducer() + +@Prototype +@Secondary +class DefaultForceFlushEventConsumer(private val eventProducer: EventProducer) : + EventConsumer(eventProducer) diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockDestinationConfiguration.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockDestinationConfiguration.kt new file mode 100644 index 000000000000..fa0704b49503 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockDestinationConfiguration.kt @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.command + +import io.micronaut.context.annotation.Primary +import io.micronaut.context.annotation.Requires +import jakarta.inject.Singleton +import java.nio.file.Path + +@Singleton +@Primary +@Requires(env = ["MockDestinationConfiguration"]) +class MockDestinationConfiguration : DestinationConfiguration() { + override val recordBatchSizeBytes: Long = 1024L + override val tmpFileDirectory: Path = Path.of("/tmp-test") + override val firstStageTmpFilePrefix: String = "spilled" + override val firstStageTmpFileSuffix: String = ".jsonl" + + override val maxCheckpointFlushTimeMs: Long = 1000L +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/file/MockTimeProvider.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/file/MockTimeProvider.kt new file mode 100644 index 000000000000..330ac100f2ba --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/file/MockTimeProvider.kt @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.file + +import io.micronaut.context.annotation.Primary +import io.micronaut.context.annotation.Requires +import jakarta.inject.Singleton + +@Singleton +@Primary +@Requires(env = ["MockTimeProvider"]) +class MockTimeProvider : TimeProvider { + private var currentTime: Long = 0 + + override fun currentTimeMillis(): Long { + return currentTime + } + + fun setCurrentTime(currentTime: Long) { + this.currentTime = currentTime + } + + override suspend fun delay(ms: Long) { + currentTime += ms + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageQueueWriterTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageQueueWriterTest.kt index 09ad533422fa..e620d4a534de 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageQueueWriterTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageQueueWriterTest.kt @@ -10,7 +10,7 @@ import io.airbyte.cdk.command.DestinationStream import io.airbyte.cdk.command.MockDestinationCatalogFactory.Companion.stream1 import io.airbyte.cdk.command.MockDestinationCatalogFactory.Companion.stream2 import io.airbyte.cdk.data.NullValue -import io.airbyte.cdk.state.CheckpointManager +import io.airbyte.cdk.state.MockCheckpointManager import io.airbyte.cdk.state.SyncManager import io.micronaut.context.annotation.Prototype import io.micronaut.context.annotation.Requires @@ -23,7 +23,12 @@ import org.junit.jupiter.api.Test @MicronautTest( rebuildContext = true, - environments = ["DestinationMessageQueueWriterTest", "MockDestinationCatalog"] + environments = + [ + "DestinationMessageQueueWriterTest", + "MockDestinationCatalog", + "MockCheckpointManager", + ] ) class DestinationMessageQueueWriterTest { @Inject lateinit var queueWriterFactory: TestDestinationMessageQueueWriterFactory @@ -88,35 +93,6 @@ class DestinationMessageQueueWriterTest { } } - @Prototype - @Requires(env = ["DestinationMessageQueueWriterTest"]) - class MockCheckpointManager : - CheckpointManager { - val streamStates = - mutableMapOf>>() - val globalStates = - mutableListOf>, CheckpointMessage>>() - - override fun addStreamCheckpoint( - key: DestinationStream.Descriptor, - index: Long, - checkpointMessage: CheckpointMessage - ) { - streamStates.getOrPut(key) { mutableListOf() }.add(index to checkpointMessage) - } - - override fun addGlobalCheckpoint( - keyIndexes: List>, - checkpointMessage: CheckpointMessage - ) { - globalStates.add(keyIndexes to checkpointMessage) - } - - override suspend fun flushReadyCheckpointMessages() { - throw NotImplementedError() - } - } - private fun makeRecord(stream: DestinationStream, record: String): DestinationRecord { return DestinationRecord( stream = stream.descriptor, diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt index b2667f392022..5ea8bf86ec68 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/CheckpointManagerTest.kt @@ -10,10 +10,12 @@ import io.airbyte.cdk.command.DestinationCatalog import io.airbyte.cdk.command.DestinationStream import io.airbyte.cdk.command.MockDestinationCatalogFactory.Companion.stream1 import io.airbyte.cdk.command.MockDestinationCatalogFactory.Companion.stream2 +import io.airbyte.cdk.file.TimeProvider import io.airbyte.cdk.message.Batch import io.airbyte.cdk.message.BatchEnvelope import io.airbyte.cdk.message.MessageConverter import io.airbyte.cdk.message.SimpleBatch +import io.micronaut.context.annotation.Requires import io.micronaut.test.extensions.junit5.annotation.MicronautTest import jakarta.inject.Inject import jakarta.inject.Singleton @@ -21,6 +23,7 @@ import java.util.function.Consumer import java.util.stream.Stream import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments @@ -55,6 +58,7 @@ class CheckpointManagerTest { data class MockGlobalCheckpointOut(val payload: String) : MockCheckpointOut() @Singleton + @Requires(env = ["CheckpointManagerTest"]) class MockStateMessageFactory : MessageConverter { override fun from(message: MockCheckpointIn): MockCheckpointOut { return when (message) { @@ -66,6 +70,7 @@ class CheckpointManagerTest { } @Singleton + @Requires(env = ["CheckpointManagerTest"]) class MockOutputConsumer : Consumer { val collectedStreamOutput = mutableMapOf>() @@ -82,11 +87,13 @@ class CheckpointManagerTest { } @Singleton + @Requires(env = ["CheckpointManagerTest"]) class TestCheckpointManager( override val catalog: DestinationCatalog, override val syncManager: SyncManager, override val outputFactory: MessageConverter, - override val outputConsumer: MockOutputConsumer + override val outputConsumer: MockOutputConsumer, + override val timeProvider: TimeProvider ) : StreamsCheckpointManager() sealed class TestEvent @@ -469,4 +476,150 @@ class CheckpointManagerTest { } } } + + @Test + fun testGetLastFlushTimeMs() = runTest { + val startTime = System.currentTimeMillis() + checkpointManager.addStreamCheckpoint( + stream1.descriptor, + 1L, + MockStreamCheckpointIn(stream1, 1) + ) + syncManager.markPersisted(stream1, Range.closed(0L, 1L)) + Assertions.assertTrue(startTime >= checkpointManager.getLastSuccessfulFlushTimeMs()) + checkpointManager.flushReadyCheckpointMessages() + Assertions.assertTrue(startTime < checkpointManager.getLastSuccessfulFlushTimeMs()) + } + + @Test + fun testGetNextStreamCheckpoints() = runTest { + Assertions.assertEquals( + emptyMap(), + checkpointManager.getNextCheckpointIndexes() + ) + + checkpointManager.addStreamCheckpoint( + stream1.descriptor, + 1L, + MockStreamCheckpointIn(stream1, 1) + ) + Assertions.assertEquals( + mapOf(stream1.descriptor to 1L), + checkpointManager.getNextCheckpointIndexes() + ) + + checkpointManager.addStreamCheckpoint( + stream2.descriptor, + 10L, + MockStreamCheckpointIn(stream2, 10) + ) + Assertions.assertEquals( + mapOf(stream1.descriptor to 1L, stream2.descriptor to 10L), + checkpointManager.getNextCheckpointIndexes() + ) + + checkpointManager.addStreamCheckpoint( + stream1.descriptor, + 2L, + MockStreamCheckpointIn(stream1, 2) + ) + Assertions.assertEquals( + mapOf(stream1.descriptor to 1L, stream2.descriptor to 10L), + checkpointManager.getNextCheckpointIndexes(), + "only the first checkpoint is returned" + ) + + syncManager.markPersisted(stream1, Range.singleton(0)) + Assertions.assertEquals( + mapOf(stream1.descriptor to 1L, stream2.descriptor to 10L), + checkpointManager.getNextCheckpointIndexes(), + "marking persisted is not sufficient" + ) + + checkpointManager.flushReadyCheckpointMessages() + Assertions.assertEquals( + mapOf(stream1.descriptor to 2L, stream2.descriptor to 10L), + checkpointManager.getNextCheckpointIndexes(), + "flushing the first checkpoint reveals the second one" + ) + + checkpointManager.addStreamCheckpoint( + stream2.descriptor, + 20L, + MockStreamCheckpointIn(stream2, 20) + ) + checkpointManager.flushReadyCheckpointMessages() + Assertions.assertEquals( + mapOf(stream1.descriptor to 2L, stream2.descriptor to 10L), + checkpointManager.getNextCheckpointIndexes(), + "but only on the stream that was flushed" + ) + + syncManager.markPersisted(stream2, Range.closed(0L, 19L)) + checkpointManager.flushReadyCheckpointMessages() + Assertions.assertEquals( + mapOf(stream1.descriptor to 2L), + checkpointManager.getNextCheckpointIndexes(), + "flushing all the checkpoints clears the stream from the map" + ) + + syncManager.markPersisted(stream1, Range.singleton(1)) + checkpointManager.flushReadyCheckpointMessages() + Assertions.assertEquals( + emptyMap(), + checkpointManager.getNextCheckpointIndexes(), + "flushing all the checkpoints clears the map" + ) + } + + @Test + fun testGetNextGlobalCheckpoints() = runTest { + Assertions.assertEquals( + emptyMap(), + checkpointManager.getNextCheckpointIndexes() + ) + + checkpointManager.addGlobalCheckpoint( + listOf(stream1.descriptor to 1L, stream2.descriptor to 10L), + MockGlobalCheckpointIn(1) + ) + Assertions.assertEquals( + mapOf(stream1.descriptor to 1L, stream2.descriptor to 10L), + checkpointManager.getNextCheckpointIndexes() + ) + + checkpointManager.addGlobalCheckpoint( + listOf(stream1.descriptor to 2L, stream2.descriptor to 20L), + MockGlobalCheckpointIn(2) + ) + Assertions.assertEquals( + mapOf(stream1.descriptor to 1L, stream2.descriptor to 10L), + checkpointManager.getNextCheckpointIndexes(), + "only the first checkpoint is returned" + ) + + syncManager.markPersisted(stream1, Range.singleton(0)) + checkpointManager.flushReadyCheckpointMessages() + Assertions.assertEquals( + mapOf(stream1.descriptor to 1L, stream2.descriptor to 10L), + checkpointManager.getNextCheckpointIndexes(), + "if only 1 stream is persisted, neither are returned" + ) + + syncManager.markPersisted(stream2, Range.closed(0L, 19L)) + checkpointManager.flushReadyCheckpointMessages() + Assertions.assertEquals( + mapOf(stream1.descriptor to 2L, stream2.descriptor to 20L), + checkpointManager.getNextCheckpointIndexes(), + "persisting the second stream triggers both to flush, revealing the next pair" + ) + + syncManager.markPersisted(stream1, Range.singleton(1)) + checkpointManager.flushReadyCheckpointMessages() + Assertions.assertEquals( + emptyMap(), + checkpointManager.getNextCheckpointIndexes(), + "flushing all the checkpoints clears the map" + ) + } } diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/DefaultFlushStrategyTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/DefaultFlushStrategyTest.kt new file mode 100644 index 000000000000..58313422c039 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/DefaultFlushStrategyTest.kt @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import com.google.common.collect.Range +import io.airbyte.cdk.command.DestinationConfiguration +import io.airbyte.cdk.command.MockDestinationCatalogFactory +import io.airbyte.cdk.command.MockDestinationCatalogFactory.Companion.stream2 +import io.airbyte.cdk.task.ForceFlushEvent +import io.micronaut.context.annotation.Primary +import io.micronaut.context.annotation.Requires +import io.micronaut.test.extensions.junit5.annotation.MicronautTest +import jakarta.inject.Singleton +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +@MicronautTest( + environments = + [ + "FlushStrategyTest", + "MockDestinationConfiguration", + ] +) +class DefaultFlushStrategyTest { + val stream1 = MockDestinationCatalogFactory.stream1 + + @Singleton + @Primary + @Requires(env = ["FlushStrategyTest"]) + class MockForceFlushEventProducer : EventProducer() + + @Test + fun testFlushByByteSize(flushStrategy: DefaultFlushStrategy, config: DestinationConfiguration) = + runTest { + Assertions.assertFalse( + flushStrategy.shouldFlush(stream1, Range.all(), config.recordBatchSizeBytes - 1L) + ) + Assertions.assertTrue( + flushStrategy.shouldFlush(stream1, Range.all(), config.recordBatchSizeBytes) + ) + Assertions.assertTrue( + flushStrategy.shouldFlush(stream1, Range.all(), config.recordBatchSizeBytes * 1000L) + ) + } + + @Test + fun testFlushByIndex( + flushStrategy: DefaultFlushStrategy, + config: DestinationConfiguration, + forceFlushEventProducer: MockForceFlushEventProducer + ) = runTest { + // Ensure the size trigger is not a factor + val insufficientSize = config.recordBatchSizeBytes - 1L + + Assertions.assertFalse( + flushStrategy.shouldFlush(stream1, Range.all(), insufficientSize), + "Should not flush even with whole range if no event" + ) + + forceFlushEventProducer.produce(ForceFlushEvent(mapOf(stream1.descriptor to 42L))) + Assertions.assertFalse( + flushStrategy.shouldFlush(stream1, Range.closed(0, 41), insufficientSize), + "Should not flush if index is not in range" + ) + Assertions.assertTrue( + flushStrategy.shouldFlush(stream1, Range.closed(0, 42), insufficientSize), + "Should flush if index is in range" + ) + + Assertions.assertFalse( + flushStrategy.shouldFlush(stream2, Range.closed(0, 42), insufficientSize), + "Should not flush other streams" + ) + forceFlushEventProducer.produce(ForceFlushEvent(mapOf(stream2.descriptor to 200L))) + Assertions.assertTrue( + flushStrategy.shouldFlush(stream2, Range.closed(0, 200), insufficientSize), + "(Unless they also have flush points)" + ) + + Assertions.assertTrue( + flushStrategy.shouldFlush(stream1, Range.closed(42, 100), insufficientSize), + "Should flush even if barely in range" + ) + Assertions.assertFalse( + flushStrategy.shouldFlush(stream1, Range.closed(43, 100), insufficientSize), + "Should not flush if index has been passed" + ) + + forceFlushEventProducer.produce(ForceFlushEvent(mapOf(stream1.descriptor to 100L))) + Assertions.assertFalse( + flushStrategy.shouldFlush(stream1, Range.closed(0, 42), insufficientSize), + "New events indexes should invalidate old ones" + ) + Assertions.assertTrue( + flushStrategy.shouldFlush(stream1, Range.closed(43, 100), insufficientSize), + "New event indexes should be honored" + ) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/MockCheckpointManager.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/MockCheckpointManager.kt new file mode 100644 index 000000000000..6808061a82b4 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/MockCheckpointManager.kt @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.file.TimeProvider +import io.airbyte.cdk.message.CheckpointMessage +import io.micronaut.context.annotation.Requires +import jakarta.inject.Inject +import jakarta.inject.Singleton + +@Singleton +@Requires(env = ["MockCheckpointManager"]) +class MockCheckpointManager : CheckpointManager { + @Inject lateinit var timeProvider: TimeProvider + + val streamStates = + mutableMapOf>>() + val globalStates = + mutableListOf>, CheckpointMessage>>() + + val flushedAtMs = mutableListOf() + var mockCheckpointIndexes = mutableMapOf() + var mockLastFlushTimeMs = 0L + + override suspend fun addStreamCheckpoint( + key: DestinationStream.Descriptor, + index: Long, + checkpointMessage: CheckpointMessage + ) { + streamStates.getOrPut(key) { mutableListOf() }.add(index to checkpointMessage) + } + + override suspend fun addGlobalCheckpoint( + keyIndexes: List>, + checkpointMessage: CheckpointMessage + ) { + globalStates.add(keyIndexes to checkpointMessage) + } + + override suspend fun flushReadyCheckpointMessages() { + flushedAtMs.add(timeProvider.currentTimeMillis()) + } + + override suspend fun getLastSuccessfulFlushTimeMs(): Long { + return mockLastFlushTimeMs + } + + override suspend fun getNextCheckpointIndexes(): Map { + return mockCheckpointIndexes + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/SyncManagerUtils.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/SyncManagerUtils.kt new file mode 100644 index 000000000000..ea57af85f963 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/SyncManagerUtils.kt @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import com.google.common.collect.Range +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.message.Batch +import io.airbyte.cdk.message.BatchEnvelope +import io.airbyte.cdk.message.SimpleBatch + +/** + * Because [SyncManager] and [StreamManager] have thin interfaces with no side effects, mocking them + * is overkill (the mock implementation converges with the real one). Instead, we provide + * convenience extension functions to simplify mocking state for testing. + * + * TODO: add more of these and apply them throughout the tests to simplify the code. + */ +fun SyncManager.markPersisted(stream: DestinationStream, range: Range) { + this.getStreamManager(stream.descriptor) + .updateBatchState(BatchEnvelope(SimpleBatch(Batch.State.PERSISTED), range)) +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/DestinationTaskLauncherTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/DestinationTaskLauncherTest.kt index be850fbc6186..ed6f6097269c 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/DestinationTaskLauncherTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/DestinationTaskLauncherTest.kt @@ -36,6 +36,7 @@ import org.junit.jupiter.api.Test environments = [ "DestinationTaskLauncherTest", + "MockDestinationConfiguration", "MockDestinationCatalog", ] ) @@ -53,6 +54,7 @@ class DestinationTaskLauncherTest { @Inject lateinit var closeStreamTaskFactory: MockCloseStreamTaskFactory @Inject lateinit var teardownTaskFactory: MockTeardownTaskFactory @Inject lateinit var flushCheckpointsTaskFactory: MockFlushCheckpointsTaskFactory + @Inject lateinit var forceFlushTaskFactory: MockForceFlushTaskFactory @Singleton @Replaces(DefaultSetupTaskFactory::class) @@ -206,6 +208,24 @@ class DestinationTaskLauncherTest { } } + @Singleton + @Primary + @Requires(env = ["DestinationTaskLauncherTest"]) + class MockForceFlushTaskFactory : TimedForcedCheckpointFlushTaskFactory { + val ranWithDelay = Channel(Channel.UNLIMITED) + + override fun make( + taskLauncher: DestinationTaskLauncher, + delayMs: Long? + ): TimedForcedCheckpointFlushTask { + return object : TimedForcedCheckpointFlushTask { + override suspend fun execute() { + ranWithDelay.send(delayMs) + } + } + } + } + class MockBatch(override val state: Batch.State) : Batch @Singleton @@ -232,6 +252,9 @@ class DestinationTaskLauncherTest { // Verify that spill to disk ran for each stream mockSpillToDiskTaskFactory.streamHasRun.values.forEach { it.receive() } + // Verify that we kicked off the timed force flush w/o a specific delay + Assertions.assertNull(forceFlushTaskFactory.ranWithDelay.receive()) + // Collect the tasks wrapped by the exception handler: expect one Setup and [nStreams] // SpillToDisk mockExceptionHandler.wrappedTasks.close() @@ -321,6 +344,7 @@ class DestinationTaskLauncherTest { val incompleteBatch = BatchEnvelope(MockBatch(Batch.State.LOCAL), range) taskLauncher.handleNewBatch(stream1, incompleteBatch) Assertions.assertFalse(streamManager.areRecordsPersistedUntil(100L)) + val batchReceived = processBatchTaskFactory.hasRun.receive() Assertions.assertEquals(incompleteBatch, batchReceived) delay(500) @@ -358,4 +382,15 @@ class DestinationTaskLauncherTest { taskLauncher.stop() } + + @Test + fun testHandleScheduleForceFlush() = runTest { + launch { taskRunner.run() } + + // This should run force flush task with delay. + taskLauncher.scheduleNextForceFlushAttempt(1000) + Assertions.assertEquals(1000, forceFlushTaskFactory.ranWithDelay.receive()) + + taskLauncher.stop() + } } diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/MockTaskLauncher.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/MockTaskLauncher.kt index 220bde8e4389..d6b89cdf34fc 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/MockTaskLauncher.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/MockTaskLauncher.kt @@ -17,6 +17,7 @@ import jakarta.inject.Singleton class MockTaskLauncher(override val taskRunner: TaskRunner) : DestinationTaskLauncher { val spilledFiles = mutableListOf>() val batchEnvelopes = mutableListOf>() + val scheduledForcedFlushes = mutableListOf() override suspend fun handleSetupComplete() { throw NotImplementedError() @@ -49,4 +50,8 @@ class MockTaskLauncher(override val taskRunner: TaskRunner) : DestinationTaskLau override suspend fun start() { throw NotImplementedError() } + + override suspend fun scheduleNextForceFlushAttempt(msFromNow: Long) { + scheduledForcedFlushes.add(msFromNow) + } } diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/SpillToDiskTaskTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/SpillToDiskTaskTest.kt index 51f2ef88a6cd..ccf8cd784dee 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/SpillToDiskTaskTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/SpillToDiskTaskTest.kt @@ -5,7 +5,6 @@ package io.airbyte.cdk.task import com.google.common.collect.Range -import io.airbyte.cdk.command.DestinationConfiguration import io.airbyte.cdk.command.DestinationStream import io.airbyte.cdk.command.MockDestinationCatalogFactory.Companion.stream1 import io.airbyte.cdk.data.NullValue @@ -21,7 +20,6 @@ import io.micronaut.context.annotation.Requires import io.micronaut.test.extensions.junit5.annotation.MicronautTest import jakarta.inject.Inject import jakarta.inject.Singleton -import java.nio.file.Path import java.util.concurrent.atomic.AtomicLong import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow @@ -33,6 +31,7 @@ import org.junit.jupiter.api.Test environments = [ "SpillToDiskTaskTest", + "MockDestinationConfiguration", "MockTempFileProvider", "MockTaskLauncher", ] @@ -42,16 +41,6 @@ class SpillToDiskTaskTest { @Inject lateinit var spillToDiskTaskFactory: DefaultSpillToDiskTaskFactory @Inject lateinit var mockTempFileProvider: MockTempFileProvider - @Singleton - @Primary - @Requires(env = ["SpillToDiskTaskTest"]) - class MockWriteConfiguration : DestinationConfiguration() { - override val recordBatchSizeBytes: Long = 1024L - override val tmpFileDirectory: Path = Path.of("/tmp-test") - override val firstStageTmpFilePrefix: String = "spilled" - override val firstStageTmpFileSuffix: String = ".jsonl" - } - @Singleton @Requires(env = ["SpillToDiskTaskTest"]) class MockQueueReader : @@ -92,7 +81,6 @@ class SpillToDiskTaskTest { rangeRead: Range, bytesProcessed: Long ): Boolean { - println(bytesProcessed) return bytesProcessed >= 1024 } } diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/TimedForcedCheckpointFlushTaskTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/TimedForcedCheckpointFlushTaskTest.kt new file mode 100644 index 000000000000..bd729b000698 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/TimedForcedCheckpointFlushTaskTest.kt @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.task + +import io.airbyte.cdk.command.DestinationConfiguration +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.file.MockTimeProvider +import io.airbyte.cdk.state.EventConsumer +import io.airbyte.cdk.state.MockCheckpointManager +import io.micronaut.test.extensions.junit5.annotation.MicronautTest +import jakarta.inject.Inject +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +@MicronautTest( + rebuildContext = true, + environments = + [ + "TimedForcedCheckpointFlushTaskTest", + "MockDestinationConfiguration", + "MockCheckpointManager", + "MockTaskLauncher", + "MockTimeProvider" + ] +) +class TimedForcedCheckpointFlushTaskTest { + @Inject lateinit var flushTaskFactory: DefaultTimedForcedCheckpointFlushTaskFactory + @Inject lateinit var taskLauncher: MockTaskLauncher + @Inject lateinit var timeProvider: MockTimeProvider + @Inject lateinit var checkpointManager: MockCheckpointManager + @Inject lateinit var config: DestinationConfiguration + @Inject lateinit var eventConsumer: EventConsumer + + @Test + fun testTaskWillNotFlushIfTimeNotElapsed() = runTest { + val delayMs = 100L + val task = flushTaskFactory.make(taskLauncher, delayMs) + timeProvider.setCurrentTime(0L) + val mockLastFlushTime = delayMs + config.maxCheckpointFlushTimeMs - 1L + checkpointManager.mockLastFlushTimeMs = mockLastFlushTime + task.execute() + Assertions.assertEquals( + delayMs, + timeProvider.currentTimeMillis(), + "task delayed the specified time" + ) + Assertions.assertEquals( + mutableListOf(delayMs), + checkpointManager.flushedAtMs, + "task tried to flush" + ) + Assertions.assertNull( + eventConsumer.consumeMaybe(), + "task did not produce a force flush event" + ) + val mockTimeSinceLastFlush = timeProvider.currentTimeMillis() - mockLastFlushTime + val nextRun = config.maxCheckpointFlushTimeMs - mockTimeSinceLastFlush + Assertions.assertEquals( + listOf(nextRun), + taskLauncher.scheduledForcedFlushes, + "task scheduled next flush for remaining interval" + ) + } + + @Test + fun testTaskWillFlushIfTimeElapsed() = runTest { + val delayMs = + config.maxCheckpointFlushTimeMs // task uses flush interval as delay by default + val task = flushTaskFactory.make(taskLauncher) + timeProvider.setCurrentTime(0L) + checkpointManager.mockLastFlushTimeMs = 0L + val expectedMap = + mutableMapOf(DestinationStream.Descriptor(name = "test", namespace = "testing") to 999L) + checkpointManager.mockCheckpointIndexes = expectedMap + task.execute() + Assertions.assertEquals( + delayMs, + timeProvider.currentTimeMillis(), + "task delayed for the configured interval" + ) + Assertions.assertEquals( + listOf(delayMs), + checkpointManager.flushedAtMs, + "task tried to flush" + ) + val flushEvent = eventConsumer.consumeMaybe() + Assertions.assertEquals( + expectedMap, + flushEvent?.indexes, + "task produced a force flush event with indexes provided by the checkpoint manager" + ) + Assertions.assertEquals( + listOf(config.maxCheckpointFlushTimeMs), + taskLauncher.scheduledForcedFlushes, + "task scheduled next flush for full interval" + ) + } +}