Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk Load CDK: State -> Checkpoint & flush at end #45377

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,36 @@ data class DestinationStreamComplete(
) : DestinationRecordMessage()

/** State. */
sealed class DestinationStateMessage : DestinationMessage() {
sealed class CheckpointMessage : DestinationMessage() {
data class Stats(val recordCount: Long)
data class StreamState(
data class StreamCheckpoint(
val stream: DestinationStream,
val state: JsonNode,
)

abstract val sourceStats: Stats
abstract val destinationStats: Stats?

abstract fun withDestinationStats(stats: Stats): DestinationStateMessage
abstract fun withDestinationStats(stats: Stats): CheckpointMessage
}

data class DestinationStreamState(
val streamState: StreamState,
data class StreamCheckpoint(
val streamCheckpoint: StreamCheckpoint,
override val sourceStats: Stats,
override val destinationStats: Stats? = null
) : DestinationStateMessage() {
) : CheckpointMessage() {
override fun withDestinationStats(stats: Stats) =
DestinationStreamState(streamState, sourceStats, stats)
StreamCheckpoint(streamCheckpoint, sourceStats, stats)
}

data class DestinationGlobalState(
data class GlobalCheckpoint(
val state: JsonNode,
override val sourceStats: Stats,
override val destinationStats: Stats? = null,
val streamStates: List<StreamState> = emptyList()
) : DestinationStateMessage() {
val streamCheckpoints: List<StreamCheckpoint> = emptyList()
) : CheckpointMessage() {
override fun withDestinationStats(stats: Stats) =
DestinationGlobalState(state, sourceStats, stats, streamStates)
GlobalCheckpoint(state, sourceStats, stats, streamCheckpoints)
}

/** Catchall for anything unimplemented. */
Expand Down Expand Up @@ -108,21 +108,21 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {
AirbyteMessage.Type.STATE -> {
when (message.state.type) {
AirbyteStateMessage.AirbyteStateType.STREAM ->
DestinationStreamState(
streamState = fromAirbyteStreamState(message.state.stream),
StreamCheckpoint(
streamCheckpoint = fromAirbyteStreamState(message.state.stream),
sourceStats =
DestinationStateMessage.Stats(
CheckpointMessage.Stats(
recordCount = message.state.sourceStats.recordCount.toLong()
)
)
AirbyteStateMessage.AirbyteStateType.GLOBAL ->
DestinationGlobalState(
GlobalCheckpoint(
sourceStats =
DestinationStateMessage.Stats(
CheckpointMessage.Stats(
recordCount = message.state.sourceStats.recordCount.toLong()
),
state = message.state.global.sharedState,
streamStates =
streamCheckpoints =
message.state.global.streamStates.map { fromAirbyteStreamState(it) }
)
else -> // TODO: Do we still need to handle LEGACY?
Expand All @@ -135,9 +135,9 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {

private fun fromAirbyteStreamState(
streamState: AirbyteStreamState
): DestinationStateMessage.StreamState {
): CheckpointMessage.StreamCheckpoint {
val descriptor = streamState.streamDescriptor
return DestinationStateMessage.StreamState(
return CheckpointMessage.StreamCheckpoint(
stream = catalog.getStream(namespace = descriptor.namespace, name = descriptor.name),
state = streamState.streamState
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ import io.airbyte.protocol.models.v0.StreamDescriptor
import jakarta.inject.Singleton

/**
* Converts the internal @[DestinationStateMessage] case class to the Protocol state messages
* required by @[io.airbyte.cdk.output.OutputConsumer]
* Converts the internal @[CheckpointMessage] case class to the Protocol state messages required by
* @[io.airbyte.cdk.output.OutputConsumer]
*/
interface MessageConverter<T, U> {
fun from(message: T): U
}

@Singleton
class DefaultMessageConverter : MessageConverter<DestinationStateMessage, AirbyteMessage> {
override fun from(message: DestinationStateMessage): AirbyteMessage {
class DefaultMessageConverter : MessageConverter<CheckpointMessage, AirbyteMessage> {
override fun from(message: CheckpointMessage): AirbyteMessage {
val state =
when (message) {
is DestinationStreamState ->
is StreamCheckpoint ->
AirbyteStateMessage()
.withSourceStats(
AirbyteStateStats()
Expand All @@ -40,8 +40,8 @@ class DefaultMessageConverter : MessageConverter<DestinationStateMessage, Airbyt
)
)
.withType(AirbyteStateMessage.AirbyteStateType.STREAM)
.withStream(fromStreamState(message.streamState))
is DestinationGlobalState ->
.withStream(fromStreamState(message.streamCheckpoint))
is GlobalCheckpoint ->
AirbyteStateMessage()
.withSourceStats(
AirbyteStateStats()
Expand All @@ -56,21 +56,23 @@ class DefaultMessageConverter : MessageConverter<DestinationStateMessage, Airbyt
.withGlobal(
AirbyteGlobalState()
.withSharedState(message.state)
.withStreamStates(message.streamStates.map { fromStreamState(it) })
.withStreamStates(
message.streamCheckpoints.map { fromStreamState(it) }
)
)
}
return AirbyteMessage().withState(state)
return AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(state)
}

private fun fromStreamState(
streamState: DestinationStateMessage.StreamState
streamCheckpoint: CheckpointMessage.StreamCheckpoint
): AirbyteStreamState {
return AirbyteStreamState()
.withStreamDescriptor(
StreamDescriptor()
.withNamespace(streamState.stream.descriptor.namespace)
.withName(streamState.stream.descriptor.name)
.withNamespace(streamCheckpoint.stream.descriptor.namespace)
.withName(streamCheckpoint.stream.descriptor.name)
)
.withStreamState(streamState.state)
.withStreamState(streamCheckpoint.state)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package io.airbyte.cdk.message
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.command.DestinationCatalog
import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.state.StateManager
import io.airbyte.cdk.state.CheckpointManager
import io.airbyte.cdk.state.StreamsManager
import jakarta.inject.Singleton

Expand All @@ -18,7 +18,7 @@ interface MessageQueueWriter<T : Any> {

/**
* Routes @[DestinationRecordMessage]s by stream to the appropriate channel and @
* [DestinationStateMessage]s to the state manager.
* [CheckpointMessage]s to the state manager.
*
* TODO: Handle other message types.
*/
Expand All @@ -31,7 +31,7 @@ class DestinationMessageQueueWriter(
private val catalog: DestinationCatalog,
private val messageQueue: MessageQueue<DestinationStream, DestinationRecordWrapped>,
private val streamsManager: StreamsManager,
private val stateManager: StateManager<DestinationStream, DestinationStateMessage>
private val checkpointManager: CheckpointManager<DestinationStream, CheckpointMessage>
) : MessageQueueWriter<DestinationMessage> {
/**
* Deserialize and route the message to the appropriate channel.
Expand Down Expand Up @@ -62,28 +62,30 @@ class DestinationMessageQueueWriter(
}
}
}
is DestinationStateMessage -> {
is CheckpointMessage -> {
when (message) {
/**
* For a stream state message, mark the checkpoint and add the message with
* index and count to the state manager. Also, add the count to the destination
* stats.
*/
is DestinationStreamState -> {
val stream = message.streamState.stream
is StreamCheckpoint -> {
val stream = message.streamCheckpoint.stream
val manager = streamsManager.getManager(stream)
val (currentIndex, countSinceLast) = manager.markCheckpoint()
val messageWithCount =
message.withDestinationStats(
DestinationStateMessage.Stats(countSinceLast)
)
stateManager.addStreamState(stream, currentIndex, messageWithCount)
message.withDestinationStats(CheckpointMessage.Stats(countSinceLast))
checkpointManager.addStreamCheckpoint(
stream,
currentIndex,
messageWithCount
)
}
/**
* For a global state message, collect the index per stream, but add the total
* count to the destination stats.
*/
is DestinationGlobalState -> {
is GlobalCheckpoint -> {
val streamWithIndexAndCount =
catalog.streams.map { stream ->
val manager = streamsManager.getManager(stream)
Expand All @@ -92,9 +94,9 @@ class DestinationMessageQueueWriter(
}
val totalCount = streamWithIndexAndCount.sumOf { it.third }
val messageWithCount =
message.withDestinationStats(DestinationStateMessage.Stats(totalCount))
message.withDestinationStats(CheckpointMessage.Stats(totalCount))
val streamIndexes = streamWithIndexAndCount.map { it.first to it.second }
stateManager.addGlobalState(streamIndexes, messageWithCount)
checkpointManager.addGlobalCheckpoint(streamIndexes, messageWithCount)
}
}
}
Expand Down
Loading
Loading