Skip to content

Commit

Permalink
bulk-cdk-core-extract: fix TRACE STATUS message emission (#46314)
Browse files Browse the repository at this point in the history
  • Loading branch information
postamar authored Oct 4, 2024
1 parent 4c84f58 commit 4b8f113
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ data class Stream(
override val label: String
get() = id.toString()
}

/** List of [Stream]s this [Feed] emits records for. */
val Feed.streams
get() =
when (this) {
is Global -> streams
is Stream -> listOf(this)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
package io.airbyte.cdk.read

import io.airbyte.cdk.SystemErrorException
import io.airbyte.cdk.asProtocolStreamDescriptor
import io.airbyte.cdk.command.OpaqueStateValue
import io.airbyte.cdk.util.ThreadRenamingCoroutineName
import io.airbyte.protocol.models.v0.AirbyteStateMessage
import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
Expand Down Expand Up @@ -46,7 +44,7 @@ class FeedReader(
// Publish a checkpoint if applicable.
maybeCheckpoint()
// Publish stream completion.
emitStreamStatus(AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE)
root.streamStatusManager.notifyComplete(feed)
break
}
// Launch coroutines which read from each partition.
Expand Down Expand Up @@ -85,7 +83,7 @@ class FeedReader(
acquirePartitionsCreatorResources(partitionsCreatorID, partitionsCreator)
}
if (1L == partitionsCreatorID) {
emitStreamStatus(AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED)
root.streamStatusManager.notifyStarting(feed)
}
return withContext(ctx("round-$partitionsCreatorID-create-partitions")) {
createPartitionsWithResources(partitionsCreatorID, partitionsCreator)
Expand Down Expand Up @@ -309,14 +307,4 @@ class FeedReader(
root.outputConsumer.accept(stateMessage)
}
}

private fun emitStreamStatus(status: AirbyteStreamStatusTraceMessage.AirbyteStreamStatus) {
if (feed is Stream) {
root.outputConsumer.accept(
AirbyteStreamStatusTraceMessage()
.withStreamDescriptor(feed.id.asProtocolStreamDescriptor())
.withStatus(status),
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import kotlin.coroutines.CoroutineContext
import kotlin.time.toKotlinDuration
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.collectLatest
import kotlinx.coroutines.flow.update
Expand Down Expand Up @@ -51,6 +50,8 @@ class RootReader(
}
}

val streamStatusManager = StreamStatusManager(stateManager.feeds, outputConsumer::accept)

/** Reads records from all [Feed]s. */
suspend fun read(listener: suspend (Map<Feed, Job>) -> Unit = {}) {
supervisorScope {
Expand All @@ -60,7 +61,7 @@ class RootReader(
val feedJobs: Map<Feed, Job> =
feeds.associateWith { feed: Feed ->
val coroutineName = ThreadRenamingCoroutineName(feed.label)
val handler = FeedExceptionHandler(feed, exceptions)
val handler = FeedExceptionHandler(feed, streamStatusManager, exceptions)
launch(coroutineName + handler) { FeedReader(this@RootReader, feed).read() }
}
// Call listener hook.
Expand All @@ -71,21 +72,6 @@ class RootReader(
feedJobs[it]?.join()
exceptions[it]
}
// Cancel any incomplete global feed job whose stream feed jobs have not all succeeded.
for ((global, globalJob) in feedJobs) {
if (global !is Global) continue
if (globalJob.isCompleted) continue
val globalStreamExceptions: List<Throwable> =
global.streams.mapNotNull { streamExceptions[it] }
if (globalStreamExceptions.isNotEmpty()) {
val cause: Throwable =
globalStreamExceptions.reduce { acc: Throwable, exception: Throwable ->
acc.addSuppressed(exception)
acc
}
globalJob.cancel("at least one stream did non complete", cause)
}
}
// Join on all global feeds and collect caught exceptions.
val globalExceptions: Map<Global, Throwable?> =
feeds.filterIsInstance<Global>().associateWith {
Expand All @@ -109,6 +95,7 @@ class RootReader(

class FeedExceptionHandler(
val feed: Feed,
val streamStatusManager: StreamStatusManager,
private val exceptions: ConcurrentHashMap<Feed, Throwable>,
) : CoroutineExceptionHandler {
private val log = KotlinLogging.logger {}
Expand All @@ -121,6 +108,7 @@ class RootReader(
exception: Throwable,
) {
log.warn(exception) { "canceled feed '${feed.label}' due to thrown exception" }
streamStatusManager.notifyFailure(feed)
exceptions[feed] = exception
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.read

import io.airbyte.cdk.StreamIdentifier
import io.airbyte.cdk.asProtocolStreamDescriptor
import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage
import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage.AirbyteStreamStatus
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import org.apache.mina.util.ConcurrentHashSet

/**
* [StreamStatusManager] emits [AirbyteStreamStatusTraceMessage]s in response to [Feed] activity
* events, via [notifyStarting], [notifyComplete] and [notifyFailure].
*/
class StreamStatusManager(
feeds: List<Feed>,
private val emit: (AirbyteStreamStatusTraceMessage) -> Unit,
) {
private val streamStates: Map<StreamIdentifier, StreamState> =
feeds
.flatMap { feed: Feed -> feed.streams.map { it.id to feed } }
.groupBy({ it.first }, { it.second })
.mapValues { (id: StreamIdentifier, feeds: List<Feed>) ->
StreamState(id, feeds.toSet())
}

/**
* Notify that the [feed] is about to start running.
*
* Emits Airbyte TRACE messages of type STATUS accordingly. Safe to call even if
* [notifyStarting], [notifyComplete] or [notifyFailure] have been called before.
*/
fun notifyStarting(feed: Feed) {
handle(feed) { it.onStarting() }
}

/**
* Notify that the [feed] has completed running.
*
* Emits Airbyte TRACE messages of type STATUS accordingly. Idempotent. Safe to call even if
* [notifyStarting] hasn't been called previously.
*/
fun notifyComplete(feed: Feed) {
handle(feed) { it.onComplete(feed) }
}

/**
* Notify that the [feed] has stopped running due to a failure.
*
* Emits Airbyte TRACE messages of type STATUS accordingly. Idempotent. Safe to call even if
* [notifyStarting] hasn't been called previously.
*/
fun notifyFailure(feed: Feed) {
handle(feed) { it.onFailure(feed) }
}

private fun handle(feed: Feed, notification: (StreamState) -> List<AirbyteStreamStatus>) {
for (stream in feed.streams) {
val streamState: StreamState = streamStates[stream.id] ?: continue
for (statusToEmit: AirbyteStreamStatus in notification(streamState)) {
emit(
AirbyteStreamStatusTraceMessage()
.withStreamDescriptor(stream.id.asProtocolStreamDescriptor())
.withStatus(statusToEmit)
)
}
}
}

data class StreamState(
val id: StreamIdentifier,
val feeds: Set<Feed>,
val state: AtomicReference<State> = AtomicReference(State.PENDING),
val stoppedFeeds: ConcurrentHashSet<Feed> = ConcurrentHashSet(),
val numStoppedFeeds: AtomicInteger = AtomicInteger()
) {
fun onStarting(): List<AirbyteStreamStatus> =
if (state.compareAndSet(State.PENDING, State.SUCCESS)) {
listOf(AirbyteStreamStatus.STARTED)
} else {
emptyList()
}

fun onComplete(feed: Feed): List<AirbyteStreamStatus> =
onStarting() + // ensure the state is not PENDING
run {
if (!finalStop(feed)) {
return@run emptyList()
}
// At this point, we just stopped the last feed for this stream.
// Transition to DONE.
if (state.compareAndSet(State.SUCCESS, State.DONE)) {
listOf(AirbyteStreamStatus.COMPLETE)
} else if (state.compareAndSet(State.FAILURE, State.DONE)) {
listOf(AirbyteStreamStatus.INCOMPLETE)
} else {
emptyList() // this should never happen
}
}

fun onFailure(feed: Feed): List<AirbyteStreamStatus> =
onStarting() + // ensure the state is not PENDING
run {
state.compareAndSet(State.SUCCESS, State.FAILURE)
if (!finalStop(feed)) {
return@run emptyList()
}
// At this point, we just stopped the last feed for this stream.
// Transition from FAILURE to DONE.
if (state.compareAndSet(State.FAILURE, State.DONE)) {
listOf(AirbyteStreamStatus.INCOMPLETE)
} else {
emptyList() // this should never happen
}
}

private fun finalStop(feed: Feed): Boolean {
if (!stoppedFeeds.add(feed)) {
// This feed was stopped before.
return false
}
// True if and only if this feed was stopped and all others were already stopped.
return numStoppedFeeds.incrementAndGet() == feeds.size
}
}

enum class State {
PENDING,
SUCCESS,
FAILURE,
DONE,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ data class TestCase(
fun verifyTraces(traceMessages: List<AirbyteTraceMessage>) {
var hasStarted = false
var hasCompleted = false
var hasIncompleted = false
for (trace in traceMessages) {
when (trace.type) {
AirbyteTraceMessage.Type.STREAM_STATUS -> {
Expand All @@ -282,14 +283,29 @@ data class TestCase(
hasStarted = true
Assertions.assertFalse(
hasCompleted,
"Case $name cannot emit a STARTED trace message because it already emitted a COMPLETE."
"Case $name cannot emit a STARTED trace " +
"message because it already emitted a COMPLETE."
)
Assertions.assertFalse(
hasIncompleted,
"Case $name cannot emit a STARTED trace " +
"message because it already emitted an INCOMPLETE."
)
}
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE -> {
hasCompleted = true
Assertions.assertTrue(
hasStarted,
"Case $name cannot emit a COMPLETE trace message because it hasn't emitted a STARTED yet."
"Case $name cannot emit a COMPLETE trace " +
"message because it hasn't emitted a STARTED yet."
)
}
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.INCOMPLETE -> {
hasIncompleted = true
Assertions.assertTrue(
hasStarted,
"Case $name cannot emit an INCOMPLETE trace " +
"message because it hasn't emitted a STARTED yet."
)
}
else ->
Expand All @@ -310,15 +326,25 @@ data class TestCase(
"Case $name should have emitted a STARTED trace message, but hasn't."
)
if (isSuccessful) {
Assertions.assertTrue(
hasCompleted,
"Case $name should have emitted a COMPLETE trace message, but hasn't."
if (!hasCompleted) {
Assertions.assertTrue(
hasCompleted,
"Case $name should have emitted a COMPLETE trace message, but hasn't."
)
}
Assertions.assertFalse(
hasIncompleted,
"Case $name should not have emitted an INCOMPLETE trace message, but did anyway."
)
} else {
Assertions.assertFalse(
hasCompleted,
"Case $name should not have emitted a COMPLETE trace message, but did anyway."
)
Assertions.assertTrue(
hasIncompleted,
"Case $name should have emitted an INCOMPLETE trace message, but hasn't."
)
}
}

Expand Down Expand Up @@ -556,21 +582,17 @@ class TestPartitionsCreatorFactory(
feed: Feed,
): PartitionsCreator {
if (feed is Global) {
// For a global feed, return a bogus PartitionsCreator which backs off forever.
// This tests that the corresponding coroutine gets canceled properly.
return object : PartitionsCreator {
override fun tryAcquireResources(): PartitionsCreator.TryAcquireResourcesStatus {
log.info { "failed to acquire resources for global feed, as always" }
return PartitionsCreator.TryAcquireResourcesStatus.RETRY_LATER
return PartitionsCreator.TryAcquireResourcesStatus.READY_TO_RUN
}

override suspend fun run(): List<PartitionReader> {
TODO("unreachable code")
// Do nothing.
return emptyList()
}

override fun releaseResources() {
TODO("unreachable code")
}
override fun releaseResources() {}
}
}
// For a stream feed, pick the CreatorCase in the corresponding TestCase
Expand Down
Loading

0 comments on commit 4b8f113

Please sign in to comment.