-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unit tests for streams manager #45090
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,15 +15,17 @@ import io.github.oshai.kotlinlogging.KotlinLogging | |
import io.micronaut.context.annotation.Factory | ||
import jakarta.inject.Singleton | ||
import java.util.concurrent.ConcurrentHashMap | ||
import java.util.concurrent.CountDownLatch | ||
import java.util.concurrent.atomic.AtomicBoolean | ||
import java.util.concurrent.atomic.AtomicLong | ||
import kotlinx.coroutines.Dispatchers | ||
import kotlinx.coroutines.withContext | ||
import kotlinx.coroutines.channels.Channel | ||
|
||
/** Manages the state of all streams in the destination. */ | ||
interface StreamsManager { | ||
/** Get the manager for the given stream. Throws an exception if the stream is not found. */ | ||
fun getManager(stream: DestinationStream): StreamManager | ||
suspend fun awaitAllStreamsComplete() | ||
|
||
/** Suspend until all streams are closed. */ | ||
suspend fun awaitAllStreamsClosed() | ||
} | ||
|
||
class DefaultStreamsManager( | ||
|
@@ -33,68 +35,98 @@ class DefaultStreamsManager( | |
return streamManagers[stream] ?: throw IllegalArgumentException("Stream not found: $stream") | ||
} | ||
|
||
override suspend fun awaitAllStreamsComplete() { | ||
override suspend fun awaitAllStreamsClosed() { | ||
streamManagers.forEach { (_, manager) -> manager.awaitStreamClosed() } | ||
} | ||
} | ||
|
||
/** Manages the state of a single stream. */ | ||
interface StreamManager { | ||
fun countRecordIn(sizeBytes: Long): Long | ||
/** Count incoming record and return the record's *index*. */ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks to me like most of the functions here are setting state on the stream itself. Maybe this is just part of a |
||
fun countRecordIn(): Long | ||
|
||
/** | ||
* Count the end-of-stream. Expect this exactly once. Expect no further `countRecordIn`, and | ||
* expect that `markClosed` will always occur after this. | ||
*/ | ||
fun countEndOfStream(): Long | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a weird name. I'm not really sure what "counting the end of stream" means. Total number of records in the stream? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand when reading further. |
||
|
||
/** | ||
* Mark a checkpoint in the stream and return the current index and the number of records since | ||
* the last one. | ||
* | ||
* NOTE: Single-writer. If in the future multiple threads set checkpoints, this method should be | ||
* synchronized. | ||
*/ | ||
fun markCheckpoint(): Pair<Long, Long> | ||
|
||
/** Record that the given batch's state has been reached for the associated range(s). */ | ||
fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>) | ||
|
||
/** | ||
* True if all are true: | ||
* * all records have been seen (ie, we've counted an end-of-stream) | ||
* * a [Batch.State.COMPLETE] batch range has been seen covering every record | ||
* | ||
* Does NOT require that the stream be closed. | ||
*/ | ||
fun isBatchProcessingComplete(): Boolean | ||
|
||
/** | ||
* True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is | ||
* implicitly true if they have all reached [Batch.State.COMPLETE]. | ||
*/ | ||
fun areRecordsPersistedUntil(index: Long): Boolean | ||
|
||
/** Mark the stream as closed. This should only be called after all records have been read. */ | ||
fun markClosed() | ||
|
||
/** True if the stream has been marked as closed. */ | ||
fun streamIsClosed(): Boolean | ||
|
||
/** Suspend until the stream is closed. */ | ||
suspend fun awaitStreamClosed() | ||
} | ||
|
||
/** | ||
* Maintains a map of stream -> status metadata, and a map of batch state -> record ranges for which | ||
* that state has been reached. | ||
* | ||
* TODO: Log a detailed report of the stream status on a regular cadence. | ||
*/ | ||
class DefaultStreamManager( | ||
val stream: DestinationStream, | ||
) : StreamManager { | ||
private val log = KotlinLogging.logger {} | ||
|
||
data class StreamStatus( | ||
val recordCount: AtomicLong = AtomicLong(0), | ||
val totalBytes: AtomicLong = AtomicLong(0), | ||
val enqueuedSize: AtomicLong = AtomicLong(0), | ||
val lastCheckpoint: AtomicLong = AtomicLong(0L), | ||
val closedLatch: CountDownLatch = CountDownLatch(1), | ||
) | ||
private val recordCount = AtomicLong(0) | ||
private val lastCheckpoint = AtomicLong(0L) | ||
private val readIsClosed = AtomicBoolean(false) | ||
private val streamIsClosed = AtomicBoolean(false) | ||
private val closedLock = Channel<Unit>() | ||
|
||
private val streamStatus: StreamStatus = StreamStatus() | ||
private val rangesState: ConcurrentHashMap<Batch.State, RangeSet<Long>> = ConcurrentHashMap() | ||
|
||
init { | ||
Batch.State.entries.forEach { rangesState[it] = TreeRangeSet.create() } | ||
} | ||
|
||
override fun countRecordIn(sizeBytes: Long): Long { | ||
val index = streamStatus.recordCount.getAndIncrement() | ||
streamStatus.totalBytes.addAndGet(sizeBytes) | ||
streamStatus.enqueuedSize.addAndGet(sizeBytes) | ||
return index | ||
override fun countRecordIn(): Long { | ||
if (readIsClosed.get()) { | ||
throw IllegalStateException("Stream is closed for reading") | ||
} | ||
|
||
return recordCount.getAndIncrement() | ||
} | ||
|
||
override fun countEndOfStream(): Long { | ||
if (readIsClosed.getAndSet(true)) { | ||
throw IllegalStateException("Stream is closed for reading") | ||
} | ||
|
||
return recordCount.get() | ||
} | ||
|
||
/** | ||
* Mark a checkpoint in the stream and return the current index and the number of records since | ||
* the last one. | ||
*/ | ||
override fun markCheckpoint(): Pair<Long, Long> { | ||
val index = streamStatus.recordCount.get() | ||
val lastCheckpoint = streamStatus.lastCheckpoint.getAndSet(index) | ||
val index = recordCount.get() | ||
val lastCheckpoint = lastCheckpoint.getAndSet(index) | ||
return Pair(index, index - lastCheckpoint) | ||
} | ||
|
||
/** Record that the given batch's state has been reached for the associated range(s). */ | ||
override fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>) { | ||
val stateRanges = | ||
rangesState[batch.batch.state] | ||
|
@@ -112,37 +144,44 @@ class DefaultStreamManager( | |
log.info { "Updated ranges for $stream[${batch.batch.state}]: $stateRanges" } | ||
} | ||
|
||
/** True if all records in [0, index] have reached the given state. */ | ||
/** True if all records in `[0, index)` have reached the given state. */ | ||
private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean { | ||
|
||
val completeRanges = rangesState[state]!! | ||
return completeRanges.encloses(Range.closedOpen(0L, index)) | ||
} | ||
|
||
/** True if all records have associated [Batch.State.COMPLETE] batches. */ | ||
override fun isBatchProcessingComplete(): Boolean { | ||
return isProcessingCompleteForState(streamStatus.recordCount.get(), Batch.State.COMPLETE) | ||
/* If the stream hasn't been fully read, it can't be done. */ | ||
if (!readIsClosed.get()) { | ||
return false | ||
} | ||
|
||
return isProcessingCompleteForState(recordCount.get(), Batch.State.COMPLETE) | ||
} | ||
|
||
/** | ||
* True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is | ||
* implicitly true if they have all reached [Batch.State.COMPLETE]. | ||
*/ | ||
override fun areRecordsPersistedUntil(index: Long): Boolean { | ||
return isProcessingCompleteForState(index, Batch.State.PERSISTED) || | ||
isProcessingCompleteForState(index, Batch.State.COMPLETE) // complete => persisted | ||
} | ||
|
||
override fun markClosed() { | ||
streamStatus.closedLatch.countDown() | ||
if (!readIsClosed.get()) { | ||
throw IllegalStateException("Stream must be fully read before it can be closed") | ||
} | ||
|
||
if (streamIsClosed.compareAndSet(false, true)) { | ||
closedLock.trySend(Unit) | ||
} | ||
} | ||
|
||
override fun streamIsClosed(): Boolean { | ||
return streamStatus.closedLatch.count == 0L | ||
return streamIsClosed.get() | ||
} | ||
|
||
override suspend fun awaitStreamClosed() { | ||
withContext(Dispatchers.IO) { streamStatus.closedLatch.await() } | ||
if (!streamIsClosed.get()) { | ||
closedLock.receive() | ||
} | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* | ||
* Copyright (c) 2024 Airbyte, Inc., all rights reserved. | ||
*/ | ||
|
||
package io.airbyte.cdk.command | ||
|
||
import io.micronaut.context.annotation.Factory | ||
import io.micronaut.context.annotation.Replaces | ||
import io.micronaut.context.annotation.Requires | ||
import jakarta.inject.Named | ||
import jakarta.inject.Singleton | ||
|
||
@Factory | ||
@Replaces(factory = DestinationCatalogFactory::class) | ||
@Requires(env = ["test"]) | ||
class MockCatalogFactory : DestinationCatalogFactory { | ||
companion object { | ||
val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1")) | ||
val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2")) | ||
} | ||
|
||
@Singleton | ||
@Named("mockCatalog") | ||
override fun make(): DestinationCatalog { | ||
return DestinationCatalog(streams = listOf(stream1, stream2)) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: error handling