Skip to content

Commit

Permalink
[Source-mysql] : Implement WASS algo (#38240)
Browse files Browse the repository at this point in the history
Co-authored-by: Evan Tahler <evan@airbyte.io>
  • Loading branch information
akashkulk and evantahler authored Jul 11, 2024
1 parent 68903b4 commit 310e6bd
Show file tree
Hide file tree
Showing 20 changed files with 371 additions and 114 deletions.
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ corresponds to that version.

| Version | Date | Pull Request | Subject |
|:-----------|:-----------| :--------------------------------------------------------- |:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.41.0 | 2024-07-11 | [\#38240](https://github.com/airbytehq/airbyte/pull/38240) | Sources : Changes in CDC interfaces to support WASS algorithm |
| 0.40.11 | 2024-07-08 | [\#41041](https://github.com/airbytehq/airbyte/pull/41041) | Destinations: Fix truncate refreshes incorrectly discarding data if successful attempt had 0 records |
| 0.40.10 | 2024-07-05 | [\#40719](https://github.com/airbytehq/airbyte/pull/40719) | Update test to refrlect isResumable field in catalog |
| 0.40.9 | 2024-07-01 | [\#39473](https://github.com/airbytehq/airbyte/pull/39473) | minor changes around error logging and testing |
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.40.11
version=0.41.0
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ abstract class AbstractJdbcSource<Datatype>(
)
return augmentWithStreamStatus(
airbyteStream,
initialLoadHandler.getIteratorForStream(airbyteStream, table, Instant.now())
initialLoadHandler.getIteratorForStream(
airbyteStream,
table,
Instant.now(),
Optional.empty()
)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ import io.airbyte.commons.util.AutoCloseableIterator
import io.airbyte.protocol.models.CommonField
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream
import java.time.Duration
import java.time.Instant
import java.util.Optional

interface InitialLoadHandler<T> {
fun getIteratorForStream(
airbyteStream: ConfiguredAirbyteStream,
table: TableInfo<CommonField<T>>,
emittedAt: Instant
emittedAt: Instant,
cdcInitialLoadTimeout: Optional<Duration>,
): AutoCloseableIterator<AirbyteMessage>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.integrations.source.relationaldb

import com.fasterxml.jackson.databind.JsonNode
import io.github.oshai.kotlinlogging.KotlinLogging
import java.time.Duration
import java.util.*

private val LOGGER = KotlinLogging.logger {}

object InitialLoadTimeoutUtil {

val MIN_INITIAL_LOAD_TIMEOUT: Duration = Duration.ofHours(4)
val MAX_INITIAL_LOAD_TIMEOUT: Duration = Duration.ofHours(24)
val DEFAULT_INITIAL_LOAD_TIMEOUT: Duration = Duration.ofHours(8)

@JvmStatic
fun getInitialLoadTimeout(config: JsonNode): Duration {
val isTest = config.has("is_test") && config["is_test"].asBoolean()
var initialLoadTimeout = DEFAULT_INITIAL_LOAD_TIMEOUT

val initialLoadTimeoutHours = getInitialLoadTimeoutHours(config)

if (initialLoadTimeoutHours.isPresent) {
initialLoadTimeout = Duration.ofHours(initialLoadTimeoutHours.get().toLong())
if (!isTest && initialLoadTimeout.compareTo(MIN_INITIAL_LOAD_TIMEOUT) < 0) {
LOGGER.warn {
"Initial Load timeout is overridden to ${MIN_INITIAL_LOAD_TIMEOUT.toHours()} hours, " +
"which is the min time allowed for safety."
}
initialLoadTimeout = MIN_INITIAL_LOAD_TIMEOUT
} else if (!isTest && initialLoadTimeout.compareTo(MAX_INITIAL_LOAD_TIMEOUT) > 0) {
LOGGER.warn {
"Initial Load timeout is overridden to ${MAX_INITIAL_LOAD_TIMEOUT.toHours()} hours, " +
"which is the max time allowed for safety."
}
initialLoadTimeout = MAX_INITIAL_LOAD_TIMEOUT
}
}

LOGGER.info { "Initial Load timeout: ${initialLoadTimeout.seconds} seconds" }
return initialLoadTimeout
}

fun getInitialLoadTimeoutHours(config: JsonNode): Optional<Int> {
val replicationMethod = config["replication_method"]
if (replicationMethod != null && replicationMethod.has("initial_load_timeout_hours")) {
val seconds = config["replication_method"]["initial_load_timeout_hours"].asInt()
return Optional.of(seconds)
}
return Optional.empty()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,17 @@ class GlobalStateManager(
// Populate global state
val globalState = AirbyteGlobalState()
globalState.sharedState = Jsons.jsonNode(cdcStateManager.cdcState)
globalState.streamStates = StateGeneratorUtils.generateStreamStateList(pairToCursorInfoMap)
// If stream state exists in the global manager, it should be used to reflect the partial
// states of initial loads.
if (
cdcStateManager.rawStateMessage?.global?.streamStates != null &&
cdcStateManager.rawStateMessage.global?.streamStates?.size != 0
) {
globalState.streamStates = cdcStateManager.rawStateMessage.global.streamStates
} else {
globalState.streamStates =
StateGeneratorUtils.generateStreamStateList(pairToCursorInfoMap)
}

// Generate the legacy state for backwards compatibility
val dbState =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
Assertions.assertEquals(
AirbyteMessage.Type.TRACE,
actualMessage.type,
"[Debug] all Message: $allMessages"
"[Debug] all Message: $allMessages",
)
var traceMessage = actualMessage.trace
Assertions.assertNotNull(traceMessage.streamStatus)
Expand Down Expand Up @@ -305,7 +305,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val recordsPerStream = extractRecordMessagesStreamWise(messages)
val consolidatedRecords: MutableSet<AirbyteRecordMessage> = HashSet()
recordsPerStream.values.forEach(
Consumer { c: Set<AirbyteRecordMessage> -> consolidatedRecords.addAll(c) }
Consumer { c: Set<AirbyteRecordMessage> -> consolidatedRecords.addAll(c) },
)
return consolidatedRecords
}
Expand Down Expand Up @@ -415,17 +415,17 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED,
),
)
assertStreamStatusTraceMessageIndex(
actualRecords.size - 1,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE,
),
)

Assertions.assertNotNull(targetPosition)
Expand Down Expand Up @@ -495,17 +495,17 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED,
),
)
assertStreamStatusTraceMessageIndex(
actualRecords1.size - 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE,
),
)

updateCommand(MODELS_STREAM_NAME, COL_MODEL, updatedModel, COL_ID, 11)
Expand Down Expand Up @@ -589,17 +589,17 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED,
),
)
assertStreamStatusTraceMessageIndex(
dataFromSecondBatch.size - 1,
dataFromSecondBatch,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE,
),
)

val stateAfterSecondBatch = extractStateMessages(dataFromSecondBatch)
Expand Down Expand Up @@ -711,17 +711,17 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED,
),
)
assertStreamStatusTraceMessageIndex(
actualMessages1.size - 1,
actualMessages1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE,
),
)

val recordMessages1 = extractRecordMessages(actualMessages1)
Expand Down Expand Up @@ -753,17 +753,17 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE,
),
)
assertStreamStatusTraceMessageIndex(
MODEL_RECORDS_2.size + 1,
actualMessages1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED,
),
)

val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1]))
Expand All @@ -787,7 +787,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
// We are expecting count match for all streams, including non RFR streams.
assertExpectedStateMessageCountMatches(
stateMessages1,
MODEL_RECORDS.size.toLong() + MODEL_RECORDS_2.size.toLong()
MODEL_RECORDS.size.toLong() + MODEL_RECORDS_2.size.toLong(),
)

// Expect state and record message from MODEL_RECORDS_2.
Expand All @@ -797,17 +797,17 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE,
),
)
assertStreamStatusTraceMessageIndex(
2 * MODEL_RECORDS_2.size + 3,
actualMessages1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED,
),
)

assertExpectedRecords(
Expand Down Expand Up @@ -935,7 +935,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
// Non resumeable full refresh will also get state messages with count.
assertExpectedStateMessageCountMatches(
stateMessages1,
MODEL_RECORDS.size.toLong() + MODEL_RECORDS_2.size.toLong()
MODEL_RECORDS.size.toLong() + MODEL_RECORDS_2.size.toLong(),
)
stateMessages1.map { state -> assertStateDoNotHaveDuplicateStreams(state) }
assertExpectedRecords(
Expand Down Expand Up @@ -1002,17 +1002,17 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED,
),
)
assertStreamStatusTraceMessageIndex(
actualRecords.size - 1,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE,
),
)
}

Expand Down Expand Up @@ -1058,11 +1058,11 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {

Assertions.assertEquals(
expectedCatalog.streams.sortedWith(
Comparator.comparing { obj: AirbyteStream -> obj.name }
Comparator.comparing { obj: AirbyteStream -> obj.name },
),
actualCatalog.streams.sortedWith(
Comparator.comparing { obj: AirbyteStream -> obj.name }
)
Comparator.comparing { obj: AirbyteStream -> obj.name },
),
)
}

Expand Down Expand Up @@ -1225,7 +1225,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
recordsWrittenInRandomTable.add(record2)
}

val state2 = stateAfterSecondBatch[stateAfterSecondBatch.size - 1].data
val state2 = Jsons.jsonNode(listOf(stateAfterSecondBatch[stateAfterSecondBatch.size - 1]))
val thirdBatchIterator = source().read(config()!!, updatedCatalog, state2)
val dataFromThirdBatch = AutoCloseableIterators.toListAndClose(thirdBatchIterator)

Expand Down
2 changes: 1 addition & 1 deletion airbyte-integrations/connectors/source-mysql/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ plugins {
}

airbyteJavaConnector {
cdkVersionRequired = '0.40.7'
cdkVersionRequired = '0.41.0'
features = ['db-sources']
useLocalCdk = false
}
Expand Down
2 changes: 1 addition & 1 deletion airbyte-integrations/connectors/source-mysql/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ data:
connectorSubtype: database
connectorType: source
definitionId: 435bb9a5-7887-4809-aa58-28c27df0d7ad
dockerImageTag: 3.4.12
dockerImageTag: 3.5.0
dockerRepository: airbyte/source-mysql
documentationUrl: https://docs.airbyte.com/integrations/sources/mysql
githubIssueLabel: source-mysql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ public List<AutoCloseableIterator<AirbyteMessage>> getIncrementalIterators(final
final List<AutoCloseableIterator<AirbyteMessage>> initialLoadIterator = new ArrayList<>(initialLoadHandler.getIncrementalIterators(
new ConfiguredAirbyteCatalog().withStreams(initialLoadStreams.streamsForInitialLoad()),
tableNameToTable,
emittedAt, true, true));
emittedAt, true, true, Optional.empty()));

// Build Cursor based iterator
final List<AutoCloseableIterator<AirbyteMessage>> cursorBasedIterator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public List<AutoCloseableIterator<AirbyteMessage>> getIncrementalIterators(
final Map<String, TableInfo<CommonField<MysqlType>>> tableNameToTable,
final Instant emittedAt,
final boolean decorateWithStartedStatus,
final boolean decorateWithCompletedStatus) {
final boolean decorateWithCompletedStatus,
final Optional<Duration> cdcInitialLoadTimeout) {
final List<AutoCloseableIterator<AirbyteMessage>> iteratorList = new ArrayList<>();
for (final ConfiguredAirbyteStream airbyteStream : catalog.getStreams()) {
final AirbyteStream stream = airbyteStream.getStream();
Expand All @@ -107,7 +108,7 @@ public List<AutoCloseableIterator<AirbyteMessage>> getIncrementalIterators(
new StreamStatusTraceEmitterIterator(new AirbyteStreamStatusHolder(pair, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED)));
}

iteratorList.add(getIteratorForStream(airbyteStream, table, emittedAt));
iteratorList.add(getIteratorForStream(airbyteStream, table, emittedAt, cdcInitialLoadTimeout));
if (decorateWithCompletedStatus) {
iteratorList.add(new StreamStatusTraceEmitterIterator(
new AirbyteStreamStatusHolder(pair, AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE)));
Expand All @@ -121,7 +122,8 @@ public List<AutoCloseableIterator<AirbyteMessage>> getIncrementalIterators(
public AutoCloseableIterator<AirbyteMessage> getIteratorForStream(
@NotNull ConfiguredAirbyteStream airbyteStream,
@NotNull TableInfo<CommonField<MysqlType>> table,
@NotNull Instant emittedAt) {
@NotNull Instant emittedAt,
@NotNull final Optional<Duration> cdcInitialLoadTimeout) {

final AirbyteStream stream = airbyteStream.getStream();
final String streamName = stream.getName();
Expand All @@ -134,7 +136,8 @@ public AutoCloseableIterator<AirbyteMessage> getIteratorForStream(
.collect(Collectors.toList());
final AutoCloseableIterator<AirbyteRecordData> queryStream =
new MySqlInitialLoadRecordIterator(database, sourceOperations, quoteString, initialLoadStateManager, selectedDatabaseFields, pair,
Long.min(calculateChunkSize(tableSizeInfoMap.get(pair), pair), MAX_CHUNK_SIZE), isCompositePrimaryKey(airbyteStream));
Long.min(calculateChunkSize(tableSizeInfoMap.get(pair), pair), MAX_CHUNK_SIZE), isCompositePrimaryKey(airbyteStream), emittedAt,
cdcInitialLoadTimeout);
final AutoCloseableIterator<AirbyteMessage> recordIterator =
getRecordIterator(queryStream, streamName, namespace, emittedAt.toEpochMilli());
final AutoCloseableIterator<AirbyteMessage> recordAndMessageIterator = augmentWithState(recordIterator, airbyteStream, pair);
Expand Down
Loading

0 comments on commit 310e6bd

Please sign in to comment.