Skip to content

Commit

Permalink
Bulk Load CDK: Add a dedicated flush checkpoints task (#46318)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Oct 3, 2024
1 parent b38057b commit d8f6799
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ interface Batch {
COMPLETE
}

fun isPersisted(): Boolean =
when (state) {
State.PERSISTED,
State.COMPLETE -> true
else -> false
}

val state: State
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicReference
import java.util.function.Consumer
import kotlinx.coroutines.sync.Mutex

/**
* Interface for checkpoint management. Should accept stream and global checkpoints, as well as
Expand All @@ -43,6 +44,7 @@ interface CheckpointManager<K, T> {
abstract class StreamsCheckpointManager<T, U>() :
CheckpointManager<DestinationStream.Descriptor, T> {
private val log = KotlinLogging.logger {}
private val flushLock = Mutex()

abstract val catalog: DestinationCatalog
abstract val syncManager: SyncManager
Expand Down Expand Up @@ -123,16 +125,24 @@ abstract class StreamsCheckpointManager<T, U>() :
}

override suspend fun flushReadyCheckpointMessages() {
if (!flushLock.tryLock()) {
log.info { "Flush already in progress, skipping" }
return
}
/*
Iterate over the checkpoints in order, evicting each that passes
the persistence check. If a checkpoint is not persisted, then
we can break the loop since the checkpoints are ordered. For global
checkpoints, all streams must be persisted up to the checkpoint.
*/
when (checkpointsAreGlobal.get()) {
null -> log.info { "No checkpoints to flush" }
true -> flushGlobalCheckpoints()
false -> flushStreamCheckpoints()
try {
when (checkpointsAreGlobal.get()) {
null -> log.info { "No checkpoints to flush" }
true -> flushGlobalCheckpoints()
false -> flushStreamCheckpoints()
}
} finally {
flushLock.unlock()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class DefaultStreamManager(
/** True if all records in `[0, index)` have reached the given state. */
private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean {
val completeRanges = rangesState[state]!!

return completeRanges.encloses(Range.closedOpen(0L, index))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class DefaultDestinationTaskLauncher(
private val processBatchTaskFactory: ProcessBatchTaskFactory,
private val closeStreamTaskFactory: CloseStreamTaskFactory,
private val teardownTaskFactory: TeardownTaskFactory,
private val flushCheckpointsTaskFactory: FlushCheckpointsTaskFactory,
private val exceptionHandler: TaskLauncherExceptionHandler<DestinationWriteTask>
) : DestinationTaskLauncher {
private val log = KotlinLogging.logger {}
Expand Down Expand Up @@ -155,6 +156,10 @@ class DefaultDestinationTaskLauncher(
val streamManager = syncManager.getStreamManager(stream.descriptor)
streamManager.updateBatchState(wrapped)

if (wrapped.batch.isPersisted()) {
enqueue(flushCheckpointsTaskFactory.make())
}

if (wrapped.batch.state != Batch.State.COMPLETE) {
log.info {
"Batch not complete: Starting process batch task for ${stream.descriptor}, batch $wrapped"
Expand Down Expand Up @@ -223,6 +228,7 @@ class DefaultDestinationTaskLauncherExceptionHandler(
)
}
log.info { "Sync terminated, skipping task $innerTask." }

return
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.task

import io.airbyte.cdk.state.CheckpointManager
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton

interface FlushCheckpointsTask : SyncTask

class DefaultFlushCheckpointsTask(
private val checkpointManager: CheckpointManager<*, *>,
) : FlushCheckpointsTask {
override suspend fun execute() {
checkpointManager.flushReadyCheckpointMessages()
}
}

interface FlushCheckpointsTaskFactory {
fun make(): FlushCheckpointsTask
}

@Singleton
@Secondary
class DefaultFlushCheckpointsTaskFactory(
private val checkpointManager: CheckpointManager<*, *>,
) : FlushCheckpointsTaskFactory {
override fun make(): FlushCheckpointsTask {
return DefaultFlushCheckpointsTask(checkpointManager)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import io.airbyte.cdk.message.Batch
import io.airbyte.cdk.message.BatchEnvelope
import io.airbyte.cdk.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.state.SyncManager
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Replaces
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
Expand Down Expand Up @@ -51,6 +52,7 @@ class DestinationTaskLauncherTest {
@Inject lateinit var processBatchTaskFactory: MockProcessBatchTaskFactory
@Inject lateinit var closeStreamTaskFactory: MockCloseStreamTaskFactory
@Inject lateinit var teardownTaskFactory: MockTeardownTaskFactory
@Inject lateinit var flushCheckpointsTaskFactory: MockFlushCheckpointsTaskFactory

@Singleton
@Replaces(DefaultSetupTaskFactory::class)
Expand Down Expand Up @@ -189,6 +191,21 @@ class DestinationTaskLauncherTest {
}
}

@Singleton
@Primary
@Requires(env = ["DestinationTaskLauncherTest"])
class MockFlushCheckpointsTaskFactory : FlushCheckpointsTaskFactory {
val hasRun: Channel<Boolean> = Channel(Channel.UNLIMITED)

override fun make(): FlushCheckpointsTask {
return object : FlushCheckpointsTask {
override suspend fun execute() {
hasRun.send(true)
}
}
}
}

class MockBatch(override val state: Batch.State) : Batch

@Singleton
Expand Down Expand Up @@ -297,6 +314,7 @@ class DestinationTaskLauncherTest {
val range = TreeRangeSet.create(listOf(Range.closed(0L, 100L)))
val streamManager = syncManager.getStreamManager(stream1.descriptor)
repeat(100) { streamManager.countRecordIn() }

streamManager.markEndOfStream()

// Verify incomplete batch triggers process batch
Expand All @@ -305,6 +323,7 @@ class DestinationTaskLauncherTest {
Assertions.assertTrue(streamManager.areRecordsPersistedUntil(100L))
val batchReceived = processBatchTaskFactory.hasRun.receive()
Assertions.assertEquals(incompleteBatch, batchReceived)
Assertions.assertTrue(flushCheckpointsTaskFactory.hasRun.receive())

// Verify complete batch w/o batch processing complete does nothing
val halfRange = TreeRangeSet.create(listOf(Range.closed(0L, 50L)))
Expand Down

0 comments on commit d8f6799

Please sign in to comment.