Skip to content

Commit

Permalink
Bulk Load CDK: State -> Checkpoint & flush at end (#45377)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Sep 11, 2024
1 parent b00dac8 commit 7056428
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 254 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,36 @@ data class DestinationStreamComplete(
) : DestinationRecordMessage()

/** State. */
sealed class DestinationStateMessage : DestinationMessage() {
sealed class CheckpointMessage : DestinationMessage() {
data class Stats(val recordCount: Long)
data class StreamState(
data class StreamCheckpoint(
val stream: DestinationStream,
val state: JsonNode,
)

abstract val sourceStats: Stats
abstract val destinationStats: Stats?

abstract fun withDestinationStats(stats: Stats): DestinationStateMessage
abstract fun withDestinationStats(stats: Stats): CheckpointMessage
}

data class DestinationStreamState(
val streamState: StreamState,
data class StreamCheckpoint(
val streamCheckpoint: StreamCheckpoint,
override val sourceStats: Stats,
override val destinationStats: Stats? = null
) : DestinationStateMessage() {
) : CheckpointMessage() {
override fun withDestinationStats(stats: Stats) =
DestinationStreamState(streamState, sourceStats, stats)
StreamCheckpoint(streamCheckpoint, sourceStats, stats)
}

data class DestinationGlobalState(
data class GlobalCheckpoint(
val state: JsonNode,
override val sourceStats: Stats,
override val destinationStats: Stats? = null,
val streamStates: List<StreamState> = emptyList()
) : DestinationStateMessage() {
val streamCheckpoints: List<StreamCheckpoint> = emptyList()
) : CheckpointMessage() {
override fun withDestinationStats(stats: Stats) =
DestinationGlobalState(state, sourceStats, stats, streamStates)
GlobalCheckpoint(state, sourceStats, stats, streamCheckpoints)
}

/** Catchall for anything unimplemented. */
Expand Down Expand Up @@ -108,21 +108,21 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {
AirbyteMessage.Type.STATE -> {
when (message.state.type) {
AirbyteStateMessage.AirbyteStateType.STREAM ->
DestinationStreamState(
streamState = fromAirbyteStreamState(message.state.stream),
StreamCheckpoint(
streamCheckpoint = fromAirbyteStreamState(message.state.stream),
sourceStats =
DestinationStateMessage.Stats(
CheckpointMessage.Stats(
recordCount = message.state.sourceStats.recordCount.toLong()
)
)
AirbyteStateMessage.AirbyteStateType.GLOBAL ->
DestinationGlobalState(
GlobalCheckpoint(
sourceStats =
DestinationStateMessage.Stats(
CheckpointMessage.Stats(
recordCount = message.state.sourceStats.recordCount.toLong()
),
state = message.state.global.sharedState,
streamStates =
streamCheckpoints =
message.state.global.streamStates.map { fromAirbyteStreamState(it) }
)
else -> // TODO: Do we still need to handle LEGACY?
Expand All @@ -135,9 +135,9 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {

private fun fromAirbyteStreamState(
streamState: AirbyteStreamState
): DestinationStateMessage.StreamState {
): CheckpointMessage.StreamCheckpoint {
val descriptor = streamState.streamDescriptor
return DestinationStateMessage.StreamState(
return CheckpointMessage.StreamCheckpoint(
stream = catalog.getStream(namespace = descriptor.namespace, name = descriptor.name),
state = streamState.streamState
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ import io.airbyte.protocol.models.v0.StreamDescriptor
import jakarta.inject.Singleton

/**
* Converts the internal @[DestinationStateMessage] case class to the Protocol state messages
* required by @[io.airbyte.cdk.output.OutputConsumer]
* Converts the internal @[CheckpointMessage] case class to the Protocol state messages required by
* @[io.airbyte.cdk.output.OutputConsumer]
*/
interface MessageConverter<T, U> {
fun from(message: T): U
}

@Singleton
class DefaultMessageConverter : MessageConverter<DestinationStateMessage, AirbyteMessage> {
override fun from(message: DestinationStateMessage): AirbyteMessage {
class DefaultMessageConverter : MessageConverter<CheckpointMessage, AirbyteMessage> {
override fun from(message: CheckpointMessage): AirbyteMessage {
val state =
when (message) {
is DestinationStreamState ->
is StreamCheckpoint ->
AirbyteStateMessage()
.withSourceStats(
AirbyteStateStats()
Expand All @@ -40,8 +40,8 @@ class DefaultMessageConverter : MessageConverter<DestinationStateMessage, Airbyt
)
)
.withType(AirbyteStateMessage.AirbyteStateType.STREAM)
.withStream(fromStreamState(message.streamState))
is DestinationGlobalState ->
.withStream(fromStreamState(message.streamCheckpoint))
is GlobalCheckpoint ->
AirbyteStateMessage()
.withSourceStats(
AirbyteStateStats()
Expand All @@ -56,21 +56,23 @@ class DefaultMessageConverter : MessageConverter<DestinationStateMessage, Airbyt
.withGlobal(
AirbyteGlobalState()
.withSharedState(message.state)
.withStreamStates(message.streamStates.map { fromStreamState(it) })
.withStreamStates(
message.streamCheckpoints.map { fromStreamState(it) }
)
)
}
return AirbyteMessage().withState(state)
return AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(state)
}

private fun fromStreamState(
streamState: DestinationStateMessage.StreamState
streamCheckpoint: CheckpointMessage.StreamCheckpoint
): AirbyteStreamState {
return AirbyteStreamState()
.withStreamDescriptor(
StreamDescriptor()
.withNamespace(streamState.stream.descriptor.namespace)
.withName(streamState.stream.descriptor.name)
.withNamespace(streamCheckpoint.stream.descriptor.namespace)
.withName(streamCheckpoint.stream.descriptor.name)
)
.withStreamState(streamState.state)
.withStreamState(streamCheckpoint.state)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package io.airbyte.cdk.message
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.command.DestinationCatalog
import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.state.StateManager
import io.airbyte.cdk.state.CheckpointManager
import io.airbyte.cdk.state.StreamsManager
import jakarta.inject.Singleton

Expand All @@ -18,7 +18,7 @@ interface MessageQueueWriter<T : Any> {

/**
* Routes @[DestinationRecordMessage]s by stream to the appropriate channel and @
* [DestinationStateMessage]s to the state manager.
* [CheckpointMessage]s to the state manager.
*
* TODO: Handle other message types.
*/
Expand All @@ -31,7 +31,7 @@ class DestinationMessageQueueWriter(
private val catalog: DestinationCatalog,
private val messageQueue: MessageQueue<DestinationStream, DestinationRecordWrapped>,
private val streamsManager: StreamsManager,
private val stateManager: StateManager<DestinationStream, DestinationStateMessage>
private val checkpointManager: CheckpointManager<DestinationStream, CheckpointMessage>
) : MessageQueueWriter<DestinationMessage> {
/**
* Deserialize and route the message to the appropriate channel.
Expand Down Expand Up @@ -62,28 +62,30 @@ class DestinationMessageQueueWriter(
}
}
}
is DestinationStateMessage -> {
is CheckpointMessage -> {
when (message) {
/**
* For a stream state message, mark the checkpoint and add the message with
* index and count to the state manager. Also, add the count to the destination
* stats.
*/
is DestinationStreamState -> {
val stream = message.streamState.stream
is StreamCheckpoint -> {
val stream = message.streamCheckpoint.stream
val manager = streamsManager.getManager(stream)
val (currentIndex, countSinceLast) = manager.markCheckpoint()
val messageWithCount =
message.withDestinationStats(
DestinationStateMessage.Stats(countSinceLast)
)
stateManager.addStreamState(stream, currentIndex, messageWithCount)
message.withDestinationStats(CheckpointMessage.Stats(countSinceLast))
checkpointManager.addStreamCheckpoint(
stream,
currentIndex,
messageWithCount
)
}
/**
* For a global state message, collect the index per stream, but add the total
* count to the destination stats.
*/
is DestinationGlobalState -> {
is GlobalCheckpoint -> {
val streamWithIndexAndCount =
catalog.streams.map { stream ->
val manager = streamsManager.getManager(stream)
Expand All @@ -92,9 +94,9 @@ class DestinationMessageQueueWriter(
}
val totalCount = streamWithIndexAndCount.sumOf { it.third }
val messageWithCount =
message.withDestinationStats(DestinationStateMessage.Stats(totalCount))
message.withDestinationStats(CheckpointMessage.Stats(totalCount))
val streamIndexes = streamWithIndexAndCount.map { it.first to it.second }
stateManager.addGlobalState(streamIndexes, messageWithCount)
checkpointManager.addGlobalCheckpoint(streamIndexes, messageWithCount)
}
}
}
Expand Down
Loading

0 comments on commit 7056428

Please sign in to comment.