Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit tests for streams manager #45090

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ data class DestinationCatalog(
}
}

interface DestinationCatalogFactory {
fun make(): DestinationCatalog
}

@Factory
class DestinationCatalogFactory(
class DefaultDestinationCatalogFactory(
private val catalog: ConfiguredAirbyteCatalog,
private val streamFactory: DestinationStreamFactory
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ class DestinationMessageQueueWriter(
/* If the input message represents a record. */
is DestinationRecordMessage -> {
val manager = streamsManager.getManager(message.stream)
val index = manager.countRecordIn(sizeBytes)
when (message) {
/* If a data record */
is DestinationRecord -> {
val wrapped =
StreamRecordWrapped(
index = index,
index = manager.countRecordIn(),
sizeBytes = sizeBytes,
record = message
)
Expand All @@ -58,7 +57,7 @@ class DestinationMessageQueueWriter(

/* If an end-of-stream marker. */
is DestinationStreamComplete -> {
val wrapped = StreamCompleteWrapped(index)
val wrapped = StreamCompleteWrapped(index = manager.countEndOfStream())
messageQueue.getChannel(message.stream).send(wrapped)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Factory
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.coroutines.channels.Channel

/** Manages the state of all streams in the destination. */
interface StreamsManager {
/** Get the manager for the given stream. Throws an exception if the stream is not found. */
fun getManager(stream: DestinationStream): StreamManager
suspend fun awaitAllStreamsComplete()

/** Suspend until all streams are closed. */
suspend fun awaitAllStreamsClosed()
}

class DefaultStreamsManager(
Expand All @@ -33,68 +35,98 @@ class DefaultStreamsManager(
return streamManagers[stream] ?: throw IllegalArgumentException("Stream not found: $stream")
}

override suspend fun awaitAllStreamsComplete() {
override suspend fun awaitAllStreamsClosed() {
streamManagers.forEach { (_, manager) -> manager.awaitStreamClosed() }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: error handling

}
}

/** Manages the state of a single stream. */
interface StreamManager {
fun countRecordIn(sizeBytes: Long): Long
/** Count incoming record and return the record's *index*. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me like most of the functions here are setting state on the stream itself. Maybe this is just part of a Stream?

fun countRecordIn(): Long

/**
* Count the end-of-stream. Expect this exactly once. Expect no further `countRecordIn`, and
* expect that `markClosed` will always occur after this.
*/
fun countEndOfStream(): Long
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a weird name. I'm not really sure what "counting the end of stream" means. Total number of records in the stream?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand when reading further.
I think the notion of "counting" is off here. I don't expect count to increment a counter. We can leave it as is for now


/**
* Mark a checkpoint in the stream and return the current index and the number of records since
* the last one.
*
* NOTE: Single-writer. If in the future multiple threads set checkpoints, this method should be
* synchronized.
*/
fun markCheckpoint(): Pair<Long, Long>

/** Record that the given batch's state has been reached for the associated range(s). */
fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>)

/**
* True if all are true:
* * all records have been seen (ie, we've counted an end-of-stream)
* * a [Batch.State.COMPLETE] batch range has been seen covering every record
*
* Does NOT require that the stream be closed.
*/
fun isBatchProcessingComplete(): Boolean

/**
* True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is
* implicitly true if they have all reached [Batch.State.COMPLETE].
*/
fun areRecordsPersistedUntil(index: Long): Boolean

/** Mark the stream as closed. This should only be called after all records have been read. */
fun markClosed()

/** True if the stream has been marked as closed. */
fun streamIsClosed(): Boolean

/** Suspend until the stream is closed. */
suspend fun awaitStreamClosed()
}

/**
* Maintains a map of stream -> status metadata, and a map of batch state -> record ranges for which
* that state has been reached.
*
* TODO: Log a detailed report of the stream status on a regular cadence.
*/
class DefaultStreamManager(
val stream: DestinationStream,
) : StreamManager {
private val log = KotlinLogging.logger {}

data class StreamStatus(
val recordCount: AtomicLong = AtomicLong(0),
val totalBytes: AtomicLong = AtomicLong(0),
val enqueuedSize: AtomicLong = AtomicLong(0),
val lastCheckpoint: AtomicLong = AtomicLong(0L),
val closedLatch: CountDownLatch = CountDownLatch(1),
)
private val recordCount = AtomicLong(0)
private val lastCheckpoint = AtomicLong(0L)
private val readIsClosed = AtomicBoolean(false)
private val streamIsClosed = AtomicBoolean(false)
private val closedLock = Channel<Unit>()

private val streamStatus: StreamStatus = StreamStatus()
private val rangesState: ConcurrentHashMap<Batch.State, RangeSet<Long>> = ConcurrentHashMap()

init {
Batch.State.entries.forEach { rangesState[it] = TreeRangeSet.create() }
}

override fun countRecordIn(sizeBytes: Long): Long {
val index = streamStatus.recordCount.getAndIncrement()
streamStatus.totalBytes.addAndGet(sizeBytes)
streamStatus.enqueuedSize.addAndGet(sizeBytes)
return index
override fun countRecordIn(): Long {
if (readIsClosed.get()) {
throw IllegalStateException("Stream is closed for reading")
}

return recordCount.getAndIncrement()
}

override fun countEndOfStream(): Long {
if (readIsClosed.getAndSet(true)) {
throw IllegalStateException("Stream is closed for reading")
}

return recordCount.get()
}

/**
* Mark a checkpoint in the stream and return the current index and the number of records since
* the last one.
*/
override fun markCheckpoint(): Pair<Long, Long> {
val index = streamStatus.recordCount.get()
val lastCheckpoint = streamStatus.lastCheckpoint.getAndSet(index)
val index = recordCount.get()
val lastCheckpoint = lastCheckpoint.getAndSet(index)
return Pair(index, index - lastCheckpoint)
}

/** Record that the given batch's state has been reached for the associated range(s). */
override fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>) {
val stateRanges =
rangesState[batch.batch.state]
Expand All @@ -112,37 +144,44 @@ class DefaultStreamManager(
log.info { "Updated ranges for $stream[${batch.batch.state}]: $stateRanges" }
}

/** True if all records in [0, index] have reached the given state. */
/** True if all records in `[0, index)` have reached the given state. */
private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean {

val completeRanges = rangesState[state]!!
return completeRanges.encloses(Range.closedOpen(0L, index))
}

/** True if all records have associated [Batch.State.COMPLETE] batches. */
override fun isBatchProcessingComplete(): Boolean {
return isProcessingCompleteForState(streamStatus.recordCount.get(), Batch.State.COMPLETE)
/* If the stream hasn't been fully read, it can't be done. */
if (!readIsClosed.get()) {
return false
}

return isProcessingCompleteForState(recordCount.get(), Batch.State.COMPLETE)
}

/**
* True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is
* implicitly true if they have all reached [Batch.State.COMPLETE].
*/
override fun areRecordsPersistedUntil(index: Long): Boolean {
return isProcessingCompleteForState(index, Batch.State.PERSISTED) ||
isProcessingCompleteForState(index, Batch.State.COMPLETE) // complete => persisted
}

override fun markClosed() {
streamStatus.closedLatch.countDown()
if (!readIsClosed.get()) {
throw IllegalStateException("Stream must be fully read before it can be closed")
}

if (streamIsClosed.compareAndSet(false, true)) {
closedLock.trySend(Unit)
}
}

override fun streamIsClosed(): Boolean {
return streamStatus.closedLatch.count == 0L
return streamIsClosed.get()
}

override suspend fun awaitStreamClosed() {
withContext(Dispatchers.IO) { streamStatus.closedLatch.await() }
if (!streamIsClosed.get()) {
closedLock.receive()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TeardownTask(
}

/** Ensure we don't run until all streams have completed */
streamsManager.awaitAllStreamsComplete()
streamsManager.awaitAllStreamsClosed()

destination.teardown()
taskLauncher.stop()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.command

import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Replaces
import io.micronaut.context.annotation.Requires
import jakarta.inject.Named
import jakarta.inject.Singleton

@Factory
@Replaces(factory = DestinationCatalogFactory::class)
@Requires(env = ["test"])
class MockCatalogFactory : DestinationCatalogFactory {
companion object {
val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1"))
val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2"))
}

@Singleton
@Named("mockCatalog")
override fun make(): DestinationCatalog {
return DestinationCatalog(streams = listOf(stream1, stream2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ import com.google.common.collect.Range
import com.google.common.collect.RangeSet
import com.google.common.collect.TreeRangeSet
import io.airbyte.cdk.command.DestinationCatalog
import io.airbyte.cdk.command.DestinationCatalogFactory
import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1
import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream2
import io.airbyte.cdk.message.Batch
import io.airbyte.cdk.message.BatchEnvelope
import io.airbyte.cdk.message.MessageConverter
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Prototype
import io.micronaut.context.annotation.Replaces
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Inject
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.util.function.Consumer
import java.util.stream.Stream
Expand All @@ -29,25 +29,10 @@ import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import org.junit.jupiter.params.provider.ArgumentsSource

@MicronautTest
@MicronautTest(environments = ["StateManagerTest"])
class StateManagerTest {
@Inject lateinit var stateManager: TestStateManager

companion object {
val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1"))
val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2"))
}

@Factory
@Replaces(factory = DestinationCatalogFactory::class)
class MockCatalogFactory {
@Singleton
@Requires(env = ["test"])
fun make(): DestinationCatalog {
return DestinationCatalog(streams = listOf(stream1, stream2))
}
}

/**
* Test state messages.
*
Expand Down Expand Up @@ -95,7 +80,11 @@ class StateManagerTest {
class MockStreamManager : StreamManager {
var persistedRanges: RangeSet<Long> = TreeRangeSet.create()

override fun countRecordIn(sizeBytes: Long): Long {
override fun countRecordIn(): Long {
throw NotImplementedError()
}

override fun countEndOfStream(): Long {
throw NotImplementedError()
}

Expand Down Expand Up @@ -129,7 +118,8 @@ class StateManagerTest {
}

@Prototype
class MockStreamsManager(catalog: DestinationCatalog) : StreamsManager {
@Requires(env = ["StateManagerTest"])
class MockStreamsManager(@Named("mockCatalog") catalog: DestinationCatalog) : StreamsManager {
private val mockManagers = catalog.streams.associateWith { MockStreamManager() }

fun addPersistedRanges(stream: DestinationStream, ranges: List<Range<Long>>) {
Expand All @@ -141,14 +131,14 @@ class StateManagerTest {
?: throw IllegalArgumentException("Stream not found: $stream")
}

override suspend fun awaitAllStreamsComplete() {
override suspend fun awaitAllStreamsClosed() {
throw NotImplementedError()
}
}

@Prototype
class TestStateManager(
override val catalog: DestinationCatalog,
@Named("mockCatalog") override val catalog: DestinationCatalog,
override val streamsManager: MockStreamsManager,
override val outputFactory: MessageConverter<MockStateIn, MockStateOut>,
override val outputConsumer: MockOutputConsumer
Expand Down
Loading
Loading