Skip to content

Commit

Permalink
[source-mysql-v2] add cdc initial read partition (#46901)
Browse files Browse the repository at this point in the history
Co-authored-by: Rodi Reich Zilberman <867491+rodireich@users.noreply.github.com>
  • Loading branch information
xiaohansong and rodireich authored Oct 24, 2024
1 parent 4cf57cd commit b22f250
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ dependencies {

testImplementation platform('org.testcontainers:testcontainers-bom:1.20.2')
testImplementation 'org.testcontainers:mysql'
testImplementation("io.mockk:mockk:1.12.0")
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ data:
connectorSubtype: database
connectorType: source
definitionId: 561393ed-7e3a-4d0d-8b8b-90ded371754c
dockerImageTag: 0.0.26
dockerImageTag: 0.0.27
dockerRepository: airbyte/source-mysql-v2
documentationUrl: https://docs.airbyte.com/integrations/sources/mysql
githubIssueLabel: source-mysql-v2
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.source.mysql

import com.fasterxml.jackson.annotation.JsonProperty
import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.command.OpaqueStateValue
import io.airbyte.cdk.discover.Field
import io.airbyte.cdk.read.Stream
import io.airbyte.cdk.util.Jsons

data class MysqlCdcInitialSnapshotStateValue(
@JsonProperty("pk_val") val pkVal: String? = null,
@JsonProperty("pk_name") val pkName: String? = null,
@JsonProperty("version") val version: Int? = null,
@JsonProperty("state_type") val stateType: String? = null,
@JsonProperty("incremental_state") val incrementalState: JsonNode? = null,
@JsonProperty("stream_name") val streamName: String? = null,
@JsonProperty("cursor_field") val cursorField: List<String>? = null,
@JsonProperty("stream_namespace") val streamNamespace: String? = null,
) {
companion object {
/** Value representing the completion of a FULL_REFRESH snapshot. */
fun getSnapshotCompletedState(stream: Stream): OpaqueStateValue =
Jsons.valueToTree(
MysqlCdcInitialSnapshotStateValue(
streamName = stream.name,
cursorField = listOf(),
streamNamespace = stream.namespace
)
)

/** Value representing the progress of an ongoing snapshot. */
fun snapshotCheckpoint(
primaryKey: List<Field>,
primaryKeyCheckpoint: List<JsonNode>,
): OpaqueStateValue {
val primaryKeyField = primaryKey.first()
return Jsons.valueToTree(
MysqlCdcInitialSnapshotStateValue(
pkName = primaryKeyField.id,
pkVal = primaryKeyCheckpoint.first().asText(),
stateType = "primary_key",
)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,14 @@ class MysqlJdbcSnapshotPartition(
override val upperBound: List<JsonNode>?,
) : MysqlJdbcResumablePartition(selectQueryGenerator, streamState, primaryKey) {

// TODO: this needs to reflect lastRecord. Complete state needs to have last primary key value
// in RFR case.
override val completeState: OpaqueStateValue
get() =
when (upperBound) {
null -> MysqlJdbcStreamStateValue.snapshotCompleted
else ->
MysqlJdbcStreamStateValue.snapshotCheckpoint(
primaryKey = checkpointColumns,
primaryKeyCheckpoint = upperBound,
)
}
MysqlJdbcStreamStateValue.snapshotCheckpoint(
primaryKey = checkpointColumns,
primaryKeyCheckpoint = listOf(),
)

override fun incompleteState(lastRecord: ObjectNode): OpaqueStateValue =
MysqlJdbcStreamStateValue.snapshotCheckpoint(
Expand All @@ -203,6 +201,24 @@ class MysqlJdbcSnapshotPartition(
)
}

/** Implementation of a [JdbcPartition] for a CDC snapshot partition. */
class MysqlJdbcCdcSnapshotPartition(
selectQueryGenerator: SelectQueryGenerator,
override val streamState: DefaultJdbcStreamState,
primaryKey: List<Field>,
override val lowerBound: List<JsonNode>?
) : MysqlJdbcResumablePartition(selectQueryGenerator, streamState, primaryKey) {
override val upperBound: List<JsonNode>? = null
override val completeState: OpaqueStateValue
get() = MysqlCdcInitialSnapshotStateValue.getSnapshotCompletedState(stream)

override fun incompleteState(lastRecord: ObjectNode): OpaqueStateValue =
MysqlCdcInitialSnapshotStateValue.snapshotCheckpoint(
primaryKey = checkpointColumns,
primaryKeyCheckpoint = checkpointColumns.map { lastRecord[it.id] ?: Jsons.nullNode() },
)
}

/**
* Default implementation of a [JdbcPartition] for a splittable partition involving cursor columns.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ class MysqlJdbcPartitionFactory(
private fun coldStart(streamState: DefaultJdbcStreamState): MysqlJdbcPartition {
val stream: Stream = streamState.stream
val pkChosenFromCatalog: List<Field> = stream.configuredPrimaryKey ?: listOf()
if (
stream.configuredSyncMode == ConfiguredSyncMode.FULL_REFRESH ||
sharedState.configuration.global
) {

if (stream.configuredSyncMode == ConfiguredSyncMode.FULL_REFRESH) {
if (pkChosenFromCatalog.isEmpty()) {
return MysqlJdbcNonResumableSnapshotPartition(
selectQueryGenerator,
Expand All @@ -58,6 +56,15 @@ class MysqlJdbcPartitionFactory(
)
}

if (sharedState.configuration.global) {
return MysqlJdbcCdcSnapshotPartition(
selectQueryGenerator,
streamState,
pkChosenFromCatalog,
lowerBound = null,
)
}

val cursorChosenFromCatalog: Field =
stream.configuredCursor as? Field ?: throw ConfigErrorException("no cursor")

Expand Down Expand Up @@ -104,19 +111,37 @@ class MysqlJdbcPartitionFactory(
if (opaqueStateValue == null) {
return coldStart(streamState)
}
val sv: MysqlJdbcStreamStateValue =
Jsons.treeToValue(opaqueStateValue, MysqlJdbcStreamStateValue::class.java)

val isCursorBasedIncremental: Boolean =
stream.configuredSyncMode == ConfiguredSyncMode.INCREMENTAL &&
!sharedState.configuration.global

if (!isCursorBasedIncremental) {
// TODO: This should consider v1 state format for CDC initial read and return
// a MysqlJdbcSnapshotPartition, or a different partition if we can't reuse
// MysqlJdbcStreamStateValue.
return null
val sv: MysqlCdcInitialSnapshotStateValue =
Jsons.treeToValue(opaqueStateValue, MysqlCdcInitialSnapshotStateValue::class.java)

if (sv.pkName == null) {
// This indicates initial snapshot has been completed. CDC snapshot will be handled
// by CDCPartitionFactory.
// Nothing to do here.
return null
} else {
// This branch indicates snapshot is incomplete. We need to resume based on previous
// snapshot state.
val pkChosenFromCatalog: List<Field> = stream.configuredPrimaryKey!!
val pkField = pkChosenFromCatalog.first()
val pkLowerBound: JsonNode = stateValueToJsonNode(pkField, sv.pkVal)
return MysqlJdbcCdcSnapshotPartition(
selectQueryGenerator,
streamState,
pkChosenFromCatalog,
lowerBound = listOf(pkLowerBound),
)
}
} else {
val sv: MysqlJdbcStreamStateValue =
Jsons.treeToValue(opaqueStateValue, MysqlJdbcStreamStateValue::class.java)

if (sv.stateType != "cursor_based") {
// Loading value from catalog. Note there could be unexpected behaviors if user
// updates their schema but did not reset their state.
Expand All @@ -137,23 +162,8 @@ class MysqlJdbcPartitionFactory(
}
// resume back to cursor based increment.
val cursor: Field = stream.fields.find { it.id == sv.cursorField.first() } as Field
val cursorCheckpoint: JsonNode =
when (cursor.type.airbyteSchemaType) {
is LeafAirbyteSchemaType ->
when (cursor.type.airbyteSchemaType as LeafAirbyteSchemaType) {
LeafAirbyteSchemaType.INTEGER -> {
Jsons.valueToTree(sv.cursors.toInt())
}
LeafAirbyteSchemaType.NUMBER -> {
Jsons.valueToTree(sv.cursors.toDouble())
}
else -> Jsons.valueToTree(sv.cursors)
}
else ->
throw IllegalStateException(
"Cursor field must be leaf type but is ${cursor.type.airbyteSchemaType}."
)
}
val cursorCheckpoint: JsonNode = stateValueToJsonNode(cursor, sv.cursors)

// Compose a jsonnode of cursor label to cursor value to fit in
// DefaultJdbcCursorIncrementalPartition
if (cursorCheckpoint == streamState.cursorUpperBound) {
Expand All @@ -171,6 +181,25 @@ class MysqlJdbcPartitionFactory(
}
}

private fun stateValueToJsonNode(field: Field, stateValue: String?): JsonNode {
when (field.type.airbyteSchemaType) {
is LeafAirbyteSchemaType ->
return when (field.type.airbyteSchemaType as LeafAirbyteSchemaType) {
LeafAirbyteSchemaType.INTEGER -> {
Jsons.valueToTree(stateValue?.toInt())
}
LeafAirbyteSchemaType.NUMBER -> {
Jsons.valueToTree(stateValue?.toDouble())
}
else -> Jsons.valueToTree(stateValue)
}
else ->
throw IllegalStateException(
"PK field must be leaf type but is ${field.type.airbyteSchemaType}."
)
}
}

override fun split(
unsplitPartition: MysqlJdbcPartition,
opaqueStateValues: List<OpaqueStateValue>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ data class MysqlJdbcStreamStateValue(
)
}

/** Value representing the progress of a ongoing snapshot not involving cursor columns. */
/** Value representing the progress of an ongoing snapshot not involving cursor columns. */
fun snapshotCheckpoint(
primaryKey: List<Field>,
primaryKeyCheckpoint: List<JsonNode>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package io.airbyte.integrations.source.mysql

import io.airbyte.cdk.StreamIdentifier
import io.airbyte.cdk.command.CliRunner
import io.airbyte.cdk.discover.CommonMetaField
import io.airbyte.cdk.discover.DiscoveredStream
import io.airbyte.cdk.discover.Field
import io.airbyte.cdk.discover.JdbcAirbyteStreamFactory
Expand Down Expand Up @@ -114,7 +113,9 @@ class MysqlCdcIntegrationTest {
CatalogHelpers.toDefaultConfiguredStream(stream)
.withSyncMode(SyncMode.INCREMENTAL)
.withPrimaryKey(discoveredStream.primaryKeyColumnIDs)
.withCursorField(listOf(CommonMetaField.CDC_LSN.id))
.withCursorField(
listOf(MysqlJdbcStreamFactory.MysqlCDCMetaFields.CDC_CURSOR.id)
)
ConfiguredAirbyteCatalog().withStreams(listOf(configuredStream))
}

Expand Down
Loading

0 comments on commit b22f250

Please sign in to comment.