diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DSLContextFactory.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DSLContextFactory.kt index 1d95b235a8f5..2ecf74e0fd94 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DSLContextFactory.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DSLContextFactory.kt @@ -67,7 +67,7 @@ object DSLContextFactory { driverClassName: String, jdbcConnectionString: String?, dialect: SQLDialect?, - connectionProperties: Map?, + connectionProperties: Map?, connectionTimeout: Duration? ): DSLContext { return DSL.using( diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DataSourceFactory.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DataSourceFactory.kt index 507a4f366bdb..0b3625d18dd2 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DataSourceFactory.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DataSourceFactory.kt @@ -50,7 +50,7 @@ object DataSourceFactory { password: String?, driverClassName: String, jdbcConnectionString: String?, - connectionProperties: Map?, + connectionProperties: Map?, connectionTimeout: Duration? ): DataSource { return DataSourceBuilder(username, password, driverClassName, jdbcConnectionString) @@ -100,7 +100,7 @@ object DataSourceFactory { port: Int, database: String?, driverClassName: String, - connectionProperties: Map? + connectionProperties: Map? ): DataSource { return DataSourceBuilder(username, password, driverClassName, host, port, database) .withConnectionProperties(connectionProperties) @@ -152,7 +152,7 @@ object DataSourceFactory { private var password: String?, private var driverClassName: String ) { - private var connectionProperties: Map = java.util.Map.of() + private var connectionProperties: Map = java.util.Map.of() private var database: String? = null private var host: String? = null private var jdbcUrl: String? = null @@ -185,7 +185,7 @@ object DataSourceFactory { } fun withConnectionProperties( - connectionProperties: Map? + connectionProperties: Map? ): DataSourceBuilder { if (connectionProperties != null) { this.connectionProperties = connectionProperties diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/jdbc/JdbcSourceOperations.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/jdbc/JdbcSourceOperations.kt index 01e976ee8d71..29e13f73424a 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/jdbc/JdbcSourceOperations.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/jdbc/JdbcSourceOperations.kt @@ -16,7 +16,7 @@ import org.slf4j.LoggerFactory /** Implementation of source operations with standard JDBC types. */ class JdbcSourceOperations : - AbstractJdbcCompatibleSourceOperations(), SourceOperations { + AbstractJdbcCompatibleSourceOperations(), SourceOperations { protected fun safeGetJdbcType(columnTypeInt: Int): JDBCType { return try { JDBCType.valueOf(columnTypeInt) @@ -80,12 +80,12 @@ class JdbcSourceOperations : JDBCType.TINYINT, JDBCType.SMALLINT -> setShortInt(preparedStatement, parameterIndex, value!!) JDBCType.INTEGER -> setInteger(preparedStatement, parameterIndex, value!!) - JDBCType.BIGINT -> setBigInteger(preparedStatement, parameterIndex, value) + JDBCType.BIGINT -> setBigInteger(preparedStatement, parameterIndex, value!!) JDBCType.FLOAT, JDBCType.DOUBLE -> setDouble(preparedStatement, parameterIndex, value!!) JDBCType.REAL -> setReal(preparedStatement, parameterIndex, value!!) JDBCType.NUMERIC, - JDBCType.DECIMAL -> setDecimal(preparedStatement, parameterIndex, value) + JDBCType.DECIMAL -> setDecimal(preparedStatement, parameterIndex, value!!) JDBCType.CHAR, JDBCType.NCHAR, JDBCType.NVARCHAR, @@ -147,7 +147,7 @@ class JdbcSourceOperations : return JdbcUtils.ALLOWED_CURSOR_TYPES.contains(type) } - override fun getAirbyteType(jdbcType: JDBCType?): JsonSchemaType { + override fun getAirbyteType(jdbcType: JDBCType): JsonSchemaType { return when (jdbcType) { JDBCType.BIT, JDBCType.BOOLEAN -> JsonSchemaType.BOOLEAN diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/JdbcConnector.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/JdbcConnector.kt index 297925119c87..e8ff27ccad66 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/JdbcConnector.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/JdbcConnector.kt @@ -11,7 +11,7 @@ import java.util.* abstract class JdbcConnector protected constructor(@JvmField protected val driverClassName: String) : BaseConnector() { - protected fun getConnectionTimeout(connectionProperties: Map): Duration { + protected fun getConnectionTimeout(connectionProperties: Map): Duration { return getConnectionTimeout(connectionProperties, driverClassName) } @@ -37,7 +37,7 @@ protected constructor(@JvmField protected val driverClassName: String) : BaseCon * @return DataSourceBuilder class used to create dynamic fields for DataSource */ fun getConnectionTimeout( - connectionProperties: Map, + connectionProperties: Map, driverClassName: String? ): Duration { val parsedConnectionTimeout = diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/Source.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/Source.kt index 9a39069444ec..a1943109bae2 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/Source.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/Source.kt @@ -42,7 +42,7 @@ interface Source : Integration { @Throws(Exception::class) fun read( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): AutoCloseableIterator @@ -65,7 +65,7 @@ interface Source : Integration { @Throws(Exception::class) fun readStreams( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): Collection>? { return List.of(read(config, catalog, state)) diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/spec_modification/SpecModifyingSource.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/spec_modification/SpecModifyingSource.kt index 4a6306d28f13..07231c743c57 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/spec_modification/SpecModifyingSource.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/spec_modification/SpecModifyingSource.kt @@ -35,7 +35,7 @@ abstract class SpecModifyingSource(private val source: Source) : Source { @Throws(Exception::class) override fun read( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): AutoCloseableIterator { return source.read(config, catalog, state) @@ -44,7 +44,7 @@ abstract class SpecModifyingSource(private val source: Source) : Source { @Throws(Exception::class) override fun readStreams( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): Collection>? { return source.readStreams(config, catalog, state) diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/ssh/SshWrappedSource.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/ssh/SshWrappedSource.kt index cd77066827d9..db045767eb8e 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/ssh/SshWrappedSource.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/ssh/SshWrappedSource.kt @@ -76,7 +76,7 @@ class SshWrappedSource : Source { @Throws(Exception::class) override fun read( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): AutoCloseableIterator { val tunnel: SshTunnel = SshTunnel.Companion.getInstance(config, hostKey, portKey) @@ -97,7 +97,7 @@ class SshWrappedSource : Source { @Throws(Exception::class) override fun readStreams( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): Collection>? { val tunnel: SshTunnel = SshTunnel.Companion.getInstance(config, hostKey, portKey) diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/resources/version.properties b/airbyte-cdk/java/airbyte-cdk/core/src/main/resources/version.properties index 3a7b1b095571..c7be3358f550 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/resources/version.properties +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/resources/version.properties @@ -1 +1 @@ -version=0.28.10 +version=0.28.11 diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/build.gradle b/airbyte-cdk/java/airbyte-cdk/db-sources/build.gradle index 5ac716385f16..3f21973ce07a 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/build.gradle +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/build.gradle @@ -11,6 +11,11 @@ java { } } +compileKotlin.compilerOptions.allWarningsAsErrors = false +compileTestFixturesKotlin.compilerOptions.allWarningsAsErrors = false +compileTestKotlin.compilerOptions.allWarningsAsErrors = false + + // Convert yaml to java: relationaldb.models jsonSchema2Pojo { sourceType = SourceType.YAMLSCHEMA @@ -53,4 +58,5 @@ dependencies { testImplementation testFixtures(project(':airbyte-cdk:java:airbyte-cdk:datastore-postgres')) testImplementation 'uk.org.webcompere:system-stubs-jupiter:2.0.1' + testImplementation 'org.mockito.kotlin:mockito-kotlin:5.2.1' } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.java index cf190d2637f3..bbed6723cae4 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.java +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.java @@ -1,149 +1,3 @@ /* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. */ - -package io.airbyte.cdk.integrations.debezium; - -import static io.airbyte.cdk.integrations.debezium.DebeziumIteratorConstants.*; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.integrations.debezium.internals.*; -import io.airbyte.cdk.integrations.source.relationaldb.state.SourceStateIterator; -import io.airbyte.cdk.integrations.source.relationaldb.state.StateEmitFrequency; -import io.airbyte.commons.util.AutoCloseableIterator; -import io.airbyte.commons.util.AutoCloseableIterators; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.SyncMode; -import io.debezium.engine.ChangeEvent; -import io.debezium.engine.DebeziumEngine; -import java.time.Duration; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Optional; -import java.util.concurrent.LinkedBlockingQueue; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This class acts as the bridge between Airbyte DB connectors and debezium. If a DB connector wants - * to use debezium for CDC, it should use this class - */ -public class AirbyteDebeziumHandler { - - private static final Logger LOGGER = LoggerFactory.getLogger(AirbyteDebeziumHandler.class); - /** - * We use 10000 as capacity cause the default queue size and batch size of debezium is : - * {@link io.debezium.config.CommonConnectorConfig#DEFAULT_MAX_BATCH_SIZE}is 2048 - * {@link io.debezium.config.CommonConnectorConfig#DEFAULT_MAX_QUEUE_SIZE} is 8192 - */ - public static final int QUEUE_CAPACITY = 10_000; - - private final JsonNode config; - private final CdcTargetPosition targetPosition; - private final boolean trackSchemaHistory; - private final Duration firstRecordWaitTime, subsequentRecordWaitTime; - private final int queueSize; - private final boolean addDbNameToOffsetState; - - public AirbyteDebeziumHandler(final JsonNode config, - final CdcTargetPosition targetPosition, - final boolean trackSchemaHistory, - final Duration firstRecordWaitTime, - final Duration subsequentRecordWaitTime, - final int queueSize, - final boolean addDbNameToOffsetState) { - this.config = config; - this.targetPosition = targetPosition; - this.trackSchemaHistory = trackSchemaHistory; - this.firstRecordWaitTime = firstRecordWaitTime; - this.subsequentRecordWaitTime = subsequentRecordWaitTime; - this.queueSize = queueSize; - this.addDbNameToOffsetState = addDbNameToOffsetState; - } - - class CapacityReportingBlockingQueue extends LinkedBlockingQueue { - - private static Duration REPORT_DURATION = Duration.of(10, ChronoUnit.SECONDS); - private Instant lastReport; - - CapacityReportingBlockingQueue(final int capacity) { - super(capacity); - } - - private void reportQueueUtilization() { - if (lastReport == null || Duration.between(lastReport, Instant.now()).compareTo(REPORT_DURATION) > 0) { - LOGGER.info("CDC events queue size: {}. remaining {}", this.size(), this.remainingCapacity()); - synchronized (this) { - lastReport = Instant.now(); - } - } - } - - @Override - public void put(final E e) throws InterruptedException { - reportQueueUtilization(); - super.put(e); - } - - @Override - public E poll() { - reportQueueUtilization(); - return super.poll(); - } - - } - - public AutoCloseableIterator getIncrementalIterators(final DebeziumPropertiesManager debeziumPropertiesManager, - final DebeziumEventConverter eventConverter, - final CdcSavedInfoFetcher cdcSavedInfoFetcher, - final CdcStateHandler cdcStateHandler) { - LOGGER.info("Using CDC: {}", true); - LOGGER.info("Using DBZ version: {}", DebeziumEngine.class.getPackage().getImplementationVersion()); - final AirbyteFileOffsetBackingStore offsetManager = AirbyteFileOffsetBackingStore.initializeState( - cdcSavedInfoFetcher.getSavedOffset(), - addDbNameToOffsetState ? Optional.ofNullable(config.get(JdbcUtils.DATABASE_KEY).asText()) : Optional.empty()); - final var schemaHistoryManager = trackSchemaHistory - ? Optional.of(AirbyteSchemaHistoryStorage.initializeDBHistory( - cdcSavedInfoFetcher.getSavedSchemaHistory(), cdcStateHandler.compressSchemaHistoryForState())) - : Optional.empty(); - final var publisher = new DebeziumRecordPublisher(debeziumPropertiesManager); - final var queue = new CapacityReportingBlockingQueue>(queueSize); - publisher.start(queue, offsetManager, schemaHistoryManager); - // handle state machine around pub/sub logic. - final AutoCloseableIterator eventIterator = new DebeziumRecordIterator<>( - queue, - targetPosition, - publisher::hasClosed, - new DebeziumShutdownProcedure<>(queue, publisher::close, publisher::hasClosed), - firstRecordWaitTime, - subsequentRecordWaitTime); - - final Duration syncCheckpointDuration = config.has(SYNC_CHECKPOINT_DURATION_PROPERTY) - ? Duration.ofSeconds(config.get(SYNC_CHECKPOINT_DURATION_PROPERTY).asLong()) - : SYNC_CHECKPOINT_DURATION; - final Long syncCheckpointRecords = config.has(SYNC_CHECKPOINT_RECORDS_PROPERTY) - ? config.get(SYNC_CHECKPOINT_RECORDS_PROPERTY).asLong() - : SYNC_CHECKPOINT_RECORDS; - - DebeziumMessageProducer messageProducer = new DebeziumMessageProducer(cdcStateHandler, - targetPosition, - eventConverter, - offsetManager, - schemaHistoryManager); - - // Usually sourceStateIterator requires airbyteStream as input. For DBZ iterator, stream is not used - // at all thus we will pass in null. - SourceStateIterator iterator = - new SourceStateIterator<>(eventIterator, null, messageProducer, new StateEmitFrequency(syncCheckpointRecords, syncCheckpointDuration)); - return AutoCloseableIterators.fromIterator(iterator); - } - - public static boolean isAnyStreamIncrementalSyncMode(final ConfiguredAirbyteCatalog catalog) { - return catalog.getStreams().stream().map(ConfiguredAirbyteStream::getSyncMode) - .anyMatch(syncMode -> syncMode == SyncMode.INCREMENTAL); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.java index 0c6eb87deed3..bbed6723cae4 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.java +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.java @@ -1,56 +1,3 @@ /* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. */ - -package io.airbyte.cdk.integrations.debezium; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.ObjectNode; - -/** - * This interface is used to add metadata to the records fetched from the database. For instance, in - * Postgres we add the lsn to the records. In MySql we add the file name and position to the - * records. - */ -public interface CdcMetadataInjector { - - /** - * A debezium record contains multiple pieces. Ref : - * https://debezium.io/documentation/reference/1.9/connectors/mysql.html#mysql-create-events - * - * @param event is the actual record which contains data and would be written to the destination - * @param source contains the metadata about the record and we need to extract that metadata and add - * it to the event before writing it to destination - */ - void addMetaData(ObjectNode event, JsonNode source); - - // TODO : Remove this - it is deprecated. - default void addMetaDataToRowsFetchedOutsideDebezium(final ObjectNode record, final String transactionTimestamp, final T metadataToAdd) { - throw new RuntimeException("Not Supported"); - } - - default void addMetaDataToRowsFetchedOutsideDebezium(final ObjectNode record) { - throw new RuntimeException("Not Supported"); - } - - /** - * As part of Airbyte record we need to add the namespace (schema name) - * - * @param source part of debezium record and contains the metadata about the record. We need to - * extract namespace out of this metadata and return Ref : - * https://debezium.io/documentation/reference/1.9/connectors/mysql.html#mysql-create-events - * @return the stream namespace extracted from the change event source. - */ - String namespace(JsonNode source); - - /** - * As part of Airbyte record we need to add the name (e.g. table name) - * - * @param source part of debezium record and contains the metadata about the record. We need to - * extract namespace out of this metadata and return Ref : - * https://debezium.io/documentation/reference/1.9/connectors/mysql.html#mysql-create-events - * @return The stream name extracted from the change event source. - */ - String name(JsonNode source); - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcStateHandler.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcStateHandler.java deleted file mode 100644 index 7d0d64be1027..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcStateHandler.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium; - -import io.airbyte.cdk.integrations.debezium.internals.AirbyteSchemaHistoryStorage.SchemaHistory; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import java.util.Map; - -/** - * This interface is used to allow connectors to save the offset and schema history in the manner - * which suits them. Also, it adds some utils to verify CDC event status. - */ -public interface CdcStateHandler { - - AirbyteMessage saveState(final Map offset, final SchemaHistory dbHistory); - - AirbyteMessage saveStateAfterCompletionOfSnapshotOfNewStreams(); - - default boolean compressSchemaHistoryForState() { - return false; - } - - /** - * This function is used as feature flag for sending state messages as checkpoints in CDC syncs. - * - * @return Returns `true` if checkpoint state messages are enabled for CDC syncs. Otherwise, it - * returns `false` - */ - default boolean isCdcCheckpointEnabled() { - return false; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.java deleted file mode 100644 index 56ae64066283..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium; - -import io.airbyte.cdk.integrations.debezium.internals.ChangeEventWithMetadata; -import java.util.Map; - -/** - * This interface is used to define the target position at the beginning of the sync so that once we - * reach the desired target, we can shutdown the sync. This is needed because it might happen that - * while we are syncing the data, new changes are being made in the source database and as a result - * we might end up syncing forever. In order to tackle that, we need to define a point to end at the - * beginning of the sync - */ -public interface CdcTargetPosition { - - /** - * Reads a position value (ex: LSN) from a change event and compares it to target position - * - * @param changeEventWithMetadata change event from Debezium with extra calculated metadata - * @return true if event position is equal or greater than target position, or if last snapshot - * event - */ - boolean reachedTargetPosition(final ChangeEventWithMetadata changeEventWithMetadata); - - /** - * Reads a position value (lsn) from a change event and compares it to target lsn - * - * @param positionFromHeartbeat is the position extracted out of a heartbeat event (if the connector - * supports heartbeat) - * @return true if heartbeat position is equal or greater than target position - */ - default boolean reachedTargetPosition(final T positionFromHeartbeat) { - throw new UnsupportedOperationException(); - } - - /** - * Indicates whether the implementation supports heartbeat position. - * - * @return true if heartbeats are supported - */ - default boolean isHeartbeatSupported() { - return false; - } - - /** - * Returns a position value from a heartbeat event offset. - * - * @param sourceOffset source offset params from heartbeat change event - * @return the heartbeat position in a heartbeat change event or null - */ - T extractPositionFromHeartbeatOffset(final Map sourceOffset); - - /** - * This function checks if the event we are processing in the loop is already behind the offset so - * the process can safety save the state. - * - * @param offset DB CDC offset - * @param event Event from the CDC load - * @return Returns `true` when the event is ahead of the offset. Otherwise, it returns `false` - */ - default boolean isEventAheadOffset(final Map offset, final ChangeEventWithMetadata event) { - return false; - } - - /** - * This function compares two offsets to make sure both are not pointing to the same position. The - * main purpose is to avoid sending same offset multiple times. - * - * @param offsetA Offset to compare - * @param offsetB Offset to compare - * @return Returns `true` if both offsets are at the same position. Otherwise, it returns `false` - */ - default boolean isSameOffset(final Map offsetA, final Map offsetB) { - return false; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.java deleted file mode 100644 index 9148f93cdac4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium; - -import java.time.Duration; - -public class DebeziumIteratorConstants { - - public static final String SYNC_CHECKPOINT_DURATION_PROPERTY = "sync_checkpoint_seconds"; - public static final String SYNC_CHECKPOINT_RECORDS_PROPERTY = "sync_checkpoint_records"; - - // TODO: Move these variables to a separate class IteratorConstants, as they will be used in state - // iterators for non debezium cases too. - public static final Duration SYNC_CHECKPOINT_DURATION = Duration.ofMinutes(15); - public static final Integer SYNC_CHECKPOINT_RECORDS = 10_000; - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.java deleted file mode 100644 index d1576c3f2868..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.base.Preconditions; -import io.airbyte.commons.json.Jsons; -import java.io.EOFException; -import java.io.IOException; -import java.io.ObjectOutputStream; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.NoSuchFileException; -import java.nio.file.Path; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Optional; -import java.util.Properties; -import java.util.function.BiFunction; -import java.util.stream.Collectors; -import org.apache.commons.io.FileUtils; -import org.apache.kafka.connect.errors.ConnectException; -import org.apache.kafka.connect.util.SafeObjectInputStream; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This class handles reading and writing a debezium offset file. In many cases it is duplicating - * logic in debezium because that logic is not exposed in the public API. We mostly treat the - * contents of this state file like a black box. We know it is a Map<ByteBuffer, Bytebuffer>. - * We deserialize it to a Map<String, String> so that the state file can be human readable. If - * we ever discover that any of the contents of these offset files is not string serializable we - * will likely have to drop the human readability support and just base64 encode it. - */ -public class AirbyteFileOffsetBackingStore { - - private static final Logger LOGGER = LoggerFactory.getLogger(AirbyteFileOffsetBackingStore.class); - private static final BiFunction SQL_SERVER_STATE_MUTATION = (key, databaseName) -> key.substring(0, key.length() - 2) - + ",\"database\":\"" + databaseName + "\"" + key.substring(key.length() - 2); - private final Path offsetFilePath; - private final Optional dbName; - - public AirbyteFileOffsetBackingStore(final Path offsetFilePath, final Optional dbName) { - this.offsetFilePath = offsetFilePath; - this.dbName = dbName; - } - - public Map read() { - final Map raw = load(); - - return raw.entrySet().stream().collect(Collectors.toMap( - e -> byteBufferToString(e.getKey()), - e -> byteBufferToString(e.getValue()))); - } - - @SuppressWarnings("unchecked") - public void persist(final JsonNode cdcState) { - final Map mapAsString = - cdcState != null ? Jsons.object(cdcState, Map.class) : Collections.emptyMap(); - - final Map updatedMap = updateStateForDebezium2_1(mapAsString); - - final Map mappedAsStrings = updatedMap.entrySet().stream().collect(Collectors.toMap( - e -> stringToByteBuffer(e.getKey()), - e -> stringToByteBuffer(e.getValue()))); - - FileUtils.deleteQuietly(offsetFilePath.toFile()); - save(mappedAsStrings); - } - - private Map updateStateForDebezium2_1(final Map mapAsString) { - final Map updatedMap = new LinkedHashMap<>(); - if (mapAsString.size() > 0) { - final String key = mapAsString.keySet().stream().toList().get(0); - final int i = key.indexOf('['); - final int i1 = key.lastIndexOf(']'); - - if (i == 0 && i1 == key.length() - 1) { - // The state is Debezium 2.1 compatible. No need to change anything. - return mapAsString; - } - - LOGGER.info("Mutating sate to make it Debezium 2.1 compatible"); - final String newKey = dbName.isPresent() ? SQL_SERVER_STATE_MUTATION.apply(key.substring(i, i1 + 1), dbName.get()) : key.substring(i, i1 + 1); - final String value = mapAsString.get(key); - updatedMap.put(newKey, value); - } - return updatedMap; - } - - private static String byteBufferToString(final ByteBuffer byteBuffer) { - Preconditions.checkNotNull(byteBuffer); - return new String(byteBuffer.array(), StandardCharsets.UTF_8); - } - - private static ByteBuffer stringToByteBuffer(final String s) { - Preconditions.checkNotNull(s); - return ByteBuffer.wrap(s.getBytes(StandardCharsets.UTF_8)); - } - - /** - * See FileOffsetBackingStore#load - logic is mostly borrowed from here. duplicated because this - * method is not public. Reduced the try catch block to only the read operation from original code - * to reduce errors when reading the file. - */ - @SuppressWarnings("unchecked") - private Map load() { - Object obj; - try (final SafeObjectInputStream is = new SafeObjectInputStream(Files.newInputStream(offsetFilePath))) { - // todo (cgardens) - we currently suppress a security warning for this line. use of readObject from - // untrusted sources is considered unsafe. Since the source is controlled by us in this case it - // should be safe. That said, changing this implementation to not use readObject would remove some - // headache. - obj = is.readObject(); - } catch (final NoSuchFileException | EOFException e) { - // NoSuchFileException: Ignore, may be new. - // EOFException: Ignore, this means the file was missing or corrupt - return Collections.emptyMap(); - } catch (final IOException | ClassNotFoundException e) { - throw new ConnectException(e); - } - - if (!(obj instanceof HashMap)) - throw new ConnectException("Expected HashMap but found " + obj.getClass()); - final Map raw = (Map) obj; - final Map data = new HashMap<>(); - for (final Map.Entry mapEntry : raw.entrySet()) { - final ByteBuffer key = (mapEntry.getKey() != null) ? ByteBuffer.wrap(mapEntry.getKey()) : null; - final ByteBuffer value = (mapEntry.getValue() != null) ? ByteBuffer.wrap(mapEntry.getValue()) : null; - data.put(key, value); - } - - return data; - } - - /** - * See FileOffsetBackingStore#save - logic is mostly borrowed from here. duplicated because this - * method is not public. - */ - private void save(final Map data) { - try (final ObjectOutputStream os = new ObjectOutputStream(Files.newOutputStream(offsetFilePath))) { - final Map raw = new HashMap<>(); - for (final Map.Entry mapEntry : data.entrySet()) { - final byte[] key = (mapEntry.getKey() != null) ? mapEntry.getKey().array() : null; - final byte[] value = (mapEntry.getValue() != null) ? mapEntry.getValue().array() : null; - raw.put(key, value); - } - os.writeObject(raw); - } catch (final IOException e) { - throw new ConnectException(e); - } - } - - public static AirbyteFileOffsetBackingStore initializeState(final JsonNode cdcState, final Optional dbName) { - final Path cdcWorkingDir; - try { - cdcWorkingDir = Files.createTempDirectory(Path.of("/tmp"), "cdc-state-offset"); - } catch (final IOException e) { - throw new RuntimeException(e); - } - final Path cdcOffsetFilePath = cdcWorkingDir.resolve("offset.dat"); - - final AirbyteFileOffsetBackingStore offsetManager = new AirbyteFileOffsetBackingStore(cdcOffsetFilePath, dbName); - offsetManager.persist(cdcState); - return offsetManager; - } - - public static AirbyteFileOffsetBackingStore initializeDummyStateForSnapshotPurpose() { - final Path cdcWorkingDir; - try { - cdcWorkingDir = Files.createTempDirectory(Path.of("/tmp"), "cdc-dummy-state-offset"); - } catch (final IOException e) { - throw new RuntimeException(e); - } - final Path cdcOffsetFilePath = cdcWorkingDir.resolve("offset.dat"); - - return new AirbyteFileOffsetBackingStore(cdcOffsetFilePath, Optional.empty()); - } - - public void setDebeziumProperties(Properties props) { - // debezium engine configuration - // https://debezium.io/documentation/reference/2.2/development/engine.html#engine-properties - props.setProperty("offset.storage", "org.apache.kafka.connect.storage.FileOffsetBackingStore"); - props.setProperty("offset.storage.file.filename", offsetFilePath.toString()); - props.setProperty("offset.flush.interval.ms", "1000"); // todo: make this longer - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.java deleted file mode 100644 index a3525cc0a3c8..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.java +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.annotations.VisibleForTesting; -import io.airbyte.commons.json.Jsons; -import io.debezium.document.Document; -import io.debezium.document.DocumentReader; -import io.debezium.document.DocumentWriter; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.nio.file.FileAlreadyExistsException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; -import java.util.Optional; -import java.util.Properties; -import java.util.zip.GZIPInputStream; -import java.util.zip.GZIPOutputStream; -import org.apache.commons.io.FileUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * The purpose of this class is : to , 1. Read the contents of the file {@link #path} which contains - * the schema history at the end of the sync so that it can be saved in state for future syncs. - * Check {@link #read()} 2. Write the saved content back to the file {@link #path} at the beginning - * of the sync so that debezium can function smoothly. Check persist(Optional<JsonNode>). - */ -public class AirbyteSchemaHistoryStorage { - - private static final Logger LOGGER = LoggerFactory.getLogger(AirbyteSchemaHistoryStorage.class); - private static final long SIZE_LIMIT_TO_COMPRESS_MB = 1; - public static final int ONE_MB = 1024 * 1024; - private static final Charset UTF8 = StandardCharsets.UTF_8; - - private final DocumentReader reader = DocumentReader.defaultReader(); - private final DocumentWriter writer = DocumentWriter.defaultWriter(); - private final Path path; - private final boolean compressSchemaHistoryForState; - - public AirbyteSchemaHistoryStorage(final Path path, final boolean compressSchemaHistoryForState) { - this.path = path; - this.compressSchemaHistoryForState = compressSchemaHistoryForState; - } - - public record SchemaHistory (T schema, boolean isCompressed) {} - - public SchemaHistory read() { - final double fileSizeMB = (double) path.toFile().length() / (ONE_MB); - if ((fileSizeMB > SIZE_LIMIT_TO_COMPRESS_MB) && compressSchemaHistoryForState) { - LOGGER.info("File Size {} MB is greater than the size limit of {} MB, compressing the content of the file.", fileSizeMB, - SIZE_LIMIT_TO_COMPRESS_MB); - final String schemaHistory = readCompressed(); - final double compressedSizeMB = calculateSizeOfStringInMB(schemaHistory); - if (fileSizeMB > compressedSizeMB) { - LOGGER.info("Content Size post compression is {} MB ", compressedSizeMB); - } else { - throw new RuntimeException("Compressing increased the size of the content. Size before compression " + fileSizeMB + ", after compression " - + compressedSizeMB); - } - return new SchemaHistory<>(schemaHistory, true); - } - if (compressSchemaHistoryForState) { - LOGGER.info("File Size {} MB is less than the size limit of {} MB, reading the content of the file without compression.", fileSizeMB, - SIZE_LIMIT_TO_COMPRESS_MB); - } else { - LOGGER.info("File Size {} MB.", fileSizeMB); - } - final String schemaHistory = readUncompressed(); - return new SchemaHistory<>(schemaHistory, false); - } - - @VisibleForTesting - public String readUncompressed() { - final StringBuilder fileAsString = new StringBuilder(); - try { - for (final String line : Files.readAllLines(path, UTF8)) { - if (line != null && !line.isEmpty()) { - final Document record = reader.read(line); - final String recordAsString = writer.write(record); - fileAsString.append(recordAsString); - fileAsString.append(System.lineSeparator()); - } - } - return fileAsString.toString(); - } catch (final IOException e) { - throw new RuntimeException(e); - } - } - - private String readCompressed() { - final String lineSeparator = System.lineSeparator(); - final ByteArrayOutputStream compressedStream = new ByteArrayOutputStream(); - try (final GZIPOutputStream gzipOutputStream = new GZIPOutputStream(compressedStream); - final BufferedReader bufferedReader = Files.newBufferedReader(path, UTF8)) { - for (;;) { - final String line = bufferedReader.readLine(); - if (line == null) { - break; - } - - if (!line.isEmpty()) { - final Document record = reader.read(line); - final String recordAsString = writer.write(record); - gzipOutputStream.write(recordAsString.getBytes(StandardCharsets.UTF_8)); - gzipOutputStream.write(lineSeparator.getBytes(StandardCharsets.UTF_8)); - } - } - } catch (IOException e) { - throw new RuntimeException(e); - } - return Jsons.serialize(compressedStream.toByteArray()); - } - - private void makeSureFileExists() { - try { - // Make sure the file exists ... - if (!Files.exists(path)) { - // Create parent directories if we have them ... - if (path.getParent() != null) { - Files.createDirectories(path.getParent()); - } - try { - Files.createFile(path); - } catch (final FileAlreadyExistsException e) { - // do nothing - } - } - } catch (final IOException e) { - throw new IllegalStateException( - "Unable to check or create history file at " + path + ": " + e.getMessage(), e); - } - } - - private void persist(final SchemaHistory> schemaHistory) { - if (schemaHistory.schema().isEmpty()) { - return; - } - final String fileAsString = Jsons.object(schemaHistory.schema().get(), String.class); - - if (fileAsString == null || fileAsString.isEmpty()) { - return; - } - - FileUtils.deleteQuietly(path.toFile()); - makeSureFileExists(); - if (schemaHistory.isCompressed()) { - writeCompressedStringToFile(fileAsString); - } else { - writeToFile(fileAsString); - } - } - - /** - * @param fileAsString Represents the contents of the file saved in state from previous syncs - */ - private void writeToFile(final String fileAsString) { - try { - final String[] split = fileAsString.split(System.lineSeparator()); - for (final String element : split) { - final Document read = reader.read(element); - final String line = writer.write(read); - - try (final BufferedWriter historyWriter = Files - .newBufferedWriter(path, StandardOpenOption.APPEND)) { - try { - historyWriter.append(line); - historyWriter.newLine(); - } catch (final IOException e) { - throw new RuntimeException(e); - } - } - } - } catch (final IOException e) { - throw new RuntimeException(e); - } - } - - private void writeCompressedStringToFile(final String compressedString) { - try (final ByteArrayInputStream inputStream = new ByteArrayInputStream(Jsons.deserialize(compressedString, byte[].class)); - final GZIPInputStream gzipInputStream = new GZIPInputStream(inputStream); - final FileOutputStream fileOutputStream = new FileOutputStream(path.toFile())) { - final byte[] buffer = new byte[1024]; - int bytesRead; - while ((bytesRead = gzipInputStream.read(buffer)) != -1) { - fileOutputStream.write(buffer, 0, bytesRead); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @VisibleForTesting - public static double calculateSizeOfStringInMB(final String string) { - return (double) string.getBytes(StandardCharsets.UTF_8).length / (ONE_MB); - } - - public static AirbyteSchemaHistoryStorage initializeDBHistory(final SchemaHistory> schemaHistory, - final boolean compressSchemaHistoryForState) { - final Path dbHistoryWorkingDir; - try { - dbHistoryWorkingDir = Files.createTempDirectory(Path.of("/tmp"), "cdc-db-history"); - } catch (final IOException e) { - throw new RuntimeException(e); - } - final Path dbHistoryFilePath = dbHistoryWorkingDir.resolve("dbhistory.dat"); - - final AirbyteSchemaHistoryStorage schemaHistoryManager = - new AirbyteSchemaHistoryStorage(dbHistoryFilePath, compressSchemaHistoryForState); - schemaHistoryManager.persist(schemaHistory); - return schemaHistoryManager; - } - - public void setDebeziumProperties(Properties props) { - // https://debezium.io/documentation/reference/2.2/operations/debezium-server.html#debezium-source-database-history-class - // https://debezium.io/documentation/reference/development/engine.html#_in_the_code - // As mentioned in the documents above, debezium connector for MySQL needs to track the schema - // changes. If we don't do this, we can't fetch records for the table. - props.setProperty("schema.history.internal", "io.debezium.storage.file.history.FileSchemaHistory"); - props.setProperty("schema.history.internal.file.filename", path.toString()); - props.setProperty("schema.history.internal.store.only.captured.databases.ddl", "true"); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.java deleted file mode 100644 index c6469b864877..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.commons.json.Jsons; -import io.debezium.engine.ChangeEvent; - -public class ChangeEventWithMetadata { - - private final ChangeEvent event; - private final JsonNode eventKeyAsJson; - private final JsonNode eventValueAsJson; - private final SnapshotMetadata snapshotMetadata; - - public ChangeEventWithMetadata(final ChangeEvent event) { - this.event = event; - this.eventKeyAsJson = Jsons.deserialize(event.key()); - this.eventValueAsJson = Jsons.deserialize(event.value()); - this.snapshotMetadata = SnapshotMetadata.fromString(eventValueAsJson.get("source").get("snapshot").asText()); - } - - public ChangeEvent event() { - return event; - } - - public JsonNode eventKeyAsJson() { - return eventKeyAsJson; - } - - public JsonNode eventValueAsJson() { - return eventValueAsJson; - } - - public boolean isSnapshotEvent() { - return SnapshotMetadata.isSnapshotEventMetadata(snapshotMetadata); - } - - public SnapshotMetadata snapshotMetadata() { - return snapshotMetadata; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.java deleted file mode 100644 index 4bb065476a41..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import io.airbyte.cdk.db.DataTypeUtils; -import io.debezium.spi.converter.RelationalColumn; -import java.sql.Date; -import java.sql.Timestamp; -import java.time.Duration; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.format.DateTimeParseException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public final class DebeziumConverterUtils { - - private static final Logger LOGGER = LoggerFactory.getLogger(DebeziumConverterUtils.class); - - private DebeziumConverterUtils() { - throw new UnsupportedOperationException(); - } - - /** - * TODO : Replace usage of this method with {@link io.airbyte.cdk.db.jdbc.DateTimeConverter} - */ - public static String convertDate(final Object input) { - /** - * While building this custom converter we were not sure what type debezium could return cause there - * is no mention of it in the documentation. Secondly if you take a look at - * {@link io.debezium.connector.mysql.converters.TinyIntOneToBooleanConverter#converterFor(io.debezium.spi.converter.RelationalColumn, io.debezium.spi.converter.CustomConverter.ConverterRegistration)} - * method, even it is handling multiple data types but its not clear under what circumstances which - * data type would be returned. I just went ahead and handled the data types that made sense. - * Secondly, we use LocalDateTime to handle this cause it represents DATETIME datatype in JAVA - */ - if (input instanceof LocalDateTime) { - return DataTypeUtils.toISO8601String((LocalDateTime) input); - } else if (input instanceof LocalDate) { - return DataTypeUtils.toISO8601String((LocalDate) input); - } else if (input instanceof Duration) { - return DataTypeUtils.toISO8601String((Duration) input); - } else if (input instanceof Timestamp) { - return DataTypeUtils.toISO8601StringWithMicroseconds((((Timestamp) input).toInstant())); - } else if (input instanceof Number) { - return DataTypeUtils.toISO8601String( - new Timestamp(((Number) input).longValue()).toLocalDateTime()); - } else if (input instanceof Date) { - return DataTypeUtils.toISO8601String((Date) input); - } else if (input instanceof String) { - try { - return LocalDateTime.parse((String) input).toString(); - } catch (final DateTimeParseException e) { - LOGGER.warn("Cannot convert value '{}' to LocalDateTime type", input); - return input.toString(); - } - } - LOGGER.warn("Uncovered date class type '{}'. Use default converter", input.getClass().getName()); - return input.toString(); - } - - public static Object convertDefaultValue(final RelationalColumn field) { - if (field.isOptional()) { - return null; - } else if (field.hasDefaultValue()) { - return field.defaultValue(); - } - return null; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.java deleted file mode 100644 index 74c6e026a0b9..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import io.airbyte.cdk.integrations.debezium.CdcMetadataInjector; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import java.time.Instant; - -public interface DebeziumEventConverter { - - String CDC_LSN = "_ab_cdc_lsn"; - String CDC_UPDATED_AT = "_ab_cdc_updated_at"; - String CDC_DELETED_AT = "_ab_cdc_deleted_at"; - String AFTER_EVENT = "after"; - String BEFORE_EVENT = "before"; - String OPERATION_FIELD = "op"; - String SOURCE_EVENT = "source"; - - static AirbyteMessage buildAirbyteMessage( - final JsonNode source, - final CdcMetadataInjector cdcMetadataInjector, - final Instant emittedAt, - final JsonNode data) { - final String streamNamespace = cdcMetadataInjector.namespace(source); - final String streamName = cdcMetadataInjector.name(source); - - final AirbyteRecordMessage airbyteRecordMessage = new AirbyteRecordMessage() - .withStream(streamName) - .withNamespace(streamNamespace) - .withEmittedAt(emittedAt.toEpochMilli()) - .withData(data); - - return new AirbyteMessage() - .withType(AirbyteMessage.Type.RECORD) - .withRecord(airbyteRecordMessage); - } - - static JsonNode addCdcMetadata( - final ObjectNode baseNode, - final JsonNode source, - final CdcMetadataInjector cdcMetadataInjector, - final boolean isDelete) { - - final long transactionMillis = source.get("ts_ms").asLong(); - final String transactionTimestamp = Instant.ofEpochMilli(transactionMillis).toString(); - - baseNode.put(CDC_UPDATED_AT, transactionTimestamp); - cdcMetadataInjector.addMetaData(baseNode, source); - - if (isDelete) { - baseNode.put(CDC_DELETED_AT, transactionTimestamp); - } else { - baseNode.put(CDC_DELETED_AT, (String) null); - } - - return baseNode; - } - - AirbyteMessage toAirbyteMessage(final ChangeEventWithMetadata event); - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducer.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducer.java deleted file mode 100644 index 4b8c7c65d176..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducer.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright (c) 2024 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import io.airbyte.cdk.integrations.debezium.CdcStateHandler; -import io.airbyte.cdk.integrations.debezium.CdcTargetPosition; -import io.airbyte.cdk.integrations.source.relationaldb.state.SourceStateMessageProducer; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import org.apache.kafka.connect.errors.ConnectException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class DebeziumMessageProducer implements SourceStateMessageProducer { - - private static final Logger LOGGER = LoggerFactory.getLogger(DebeziumMessageProducer.class); - - private final CdcStateHandler cdcStateHandler; - - /** - * `checkpointOffsetToSend` is used as temporal storage for the offset that we want to send as - * message. As Debezium is reading records faster that we process them, if we try to send - * `offsetManger.read()` offset, it is possible that the state is behind the record we are currently - * propagating. To avoid that, we store the offset as soon as we reach the checkpoint threshold - * (time or records) and we wait to send it until we are sure that the record we are processing is - * behind the offset to be sent. - */ - private final HashMap checkpointOffsetToSend = new HashMap<>(); - - /** - * `previousCheckpointOffset` is used to make sure we don't send duplicated states with the same - * offset. Is it possible that the offset Debezium report doesn't move for a period of time, and if - * we just rely on the `offsetManger.read()`, there is a chance to sent duplicate states, generating - * an unneeded usage of networking and processing. - */ - private final HashMap initialOffset, previousCheckpointOffset; - private final AirbyteFileOffsetBackingStore offsetManager; - private final CdcTargetPosition targetPosition; - private final Optional schemaHistoryManager; - - private boolean shouldEmitStateMessage = false; - - private final DebeziumEventConverter eventConverter; - - public DebeziumMessageProducer( - final CdcStateHandler cdcStateHandler, - final CdcTargetPosition targetPosition, - final DebeziumEventConverter eventConverter, - final AirbyteFileOffsetBackingStore offsetManager, - final Optional schemaHistoryManager) { - this.cdcStateHandler = cdcStateHandler; - this.targetPosition = targetPosition; - this.eventConverter = eventConverter; - this.offsetManager = offsetManager; - if (offsetManager == null) { - throw new RuntimeException("Offset manager cannot be null"); - } - this.schemaHistoryManager = schemaHistoryManager; - this.previousCheckpointOffset = (HashMap) offsetManager.read(); - this.initialOffset = new HashMap<>(this.previousCheckpointOffset); - } - - @Override - public AirbyteStateMessage generateStateMessageAtCheckpoint(ConfiguredAirbyteStream stream) { - LOGGER.info("Sending CDC checkpoint state message."); - final AirbyteStateMessage stateMessage = createStateMessage(checkpointOffsetToSend); - previousCheckpointOffset.clear(); - previousCheckpointOffset.putAll(checkpointOffsetToSend); - checkpointOffsetToSend.clear(); - shouldEmitStateMessage = false; - return stateMessage; - } - - /** - * @param stream - * @param message - * @return - */ - @Override - public AirbyteMessage processRecordMessage(ConfiguredAirbyteStream stream, ChangeEventWithMetadata message) { - - if (checkpointOffsetToSend.isEmpty()) { - try { - final HashMap temporalOffset = (HashMap) offsetManager.read(); - if (!targetPosition.isSameOffset(previousCheckpointOffset, temporalOffset)) { - checkpointOffsetToSend.putAll(temporalOffset); - } - } catch (final ConnectException e) { - LOGGER.warn("Offset file is being written by Debezium. Skipping CDC checkpoint in this loop."); - } - } - - if (checkpointOffsetToSend.size() == 1 && !message.isSnapshotEvent()) { - if (targetPosition.isEventAheadOffset(checkpointOffsetToSend, message)) { - shouldEmitStateMessage = true; - } else { - LOGGER.info("Encountered records with the same event offset."); - } - } - - return eventConverter.toAirbyteMessage(message); - } - - @Override - public AirbyteStateMessage createFinalStateMessage(ConfiguredAirbyteStream stream) { - - final var syncFinishedOffset = (HashMap) offsetManager.read(); - if (targetPosition.isSameOffset(initialOffset, syncFinishedOffset)) { - // Edge case where no progress has been made: wrap up the - // sync by returning the initial offset instead of the - // current offset. We do this because we found that - // for some databases, heartbeats will cause Debezium to - // overwrite the offset file with a state which doesn't - // include all necessary data such as snapshot completion. - // This is the case for MS SQL Server, at least. - return createStateMessage(initialOffset); - } - return createStateMessage(syncFinishedOffset); - } - - @Override - public boolean shouldEmitStateMessage(ConfiguredAirbyteStream stream) { - return shouldEmitStateMessage; - } - - /** - * Creates {@link AirbyteStateMessage} while updating CDC data, used to checkpoint the state of the - * process. - * - * @return {@link AirbyteStateMessage} which includes offset and schema history if used. - */ - private AirbyteStateMessage createStateMessage(final Map offset) { - final AirbyteStateMessage message = - cdcStateHandler.saveState(offset, schemaHistoryManager.map(AirbyteSchemaHistoryStorage::read).orElse(null)).getState(); - return message; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.java deleted file mode 100644 index b41790c054a3..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.debezium.spi.common.ReplacementFunction; -import java.util.Optional; -import java.util.Properties; - -public abstract class DebeziumPropertiesManager { - - private static final String BYTE_VALUE_256_MB = Integer.toString(256 * 1024 * 1024); - - public static final String NAME_KEY = "name"; - public static final String TOPIC_PREFIX_KEY = "topic.prefix"; - - private final JsonNode config; - private final Properties properties; - private final ConfiguredAirbyteCatalog catalog; - - public DebeziumPropertiesManager(final Properties properties, - final JsonNode config, - final ConfiguredAirbyteCatalog catalog) { - this.properties = properties; - this.config = config; - this.catalog = catalog; - } - - public Properties getDebeziumProperties(final AirbyteFileOffsetBackingStore offsetManager) { - return getDebeziumProperties(offsetManager, Optional.empty()); - } - - public Properties getDebeziumProperties( - final AirbyteFileOffsetBackingStore offsetManager, - final Optional schemaHistoryManager) { - final Properties props = new Properties(); - props.putAll(properties); - - // debezium engine configuration - offsetManager.setDebeziumProperties(props); - // default values from debezium CommonConnectorConfig - props.setProperty("max.batch.size", "2048"); - props.setProperty("max.queue.size", "8192"); - - props.setProperty("errors.max.retries", "5"); - // This property must be strictly less than errors.retry.delay.max.ms - // (https://github.com/debezium/debezium/blob/bcc7d49519a4f07d123c616cfa45cd6268def0b9/debezium-core/src/main/java/io/debezium/util/DelayStrategy.java#L135) - props.setProperty("errors.retry.delay.initial.ms", "299"); - props.setProperty("errors.retry.delay.max.ms", "300"); - - schemaHistoryManager.ifPresent(m -> m.setDebeziumProperties(props)); - - // https://debezium.io/documentation/reference/2.2/configuration/avro.html - props.setProperty("key.converter.schemas.enable", "false"); - props.setProperty("value.converter.schemas.enable", "false"); - - // debezium names - props.setProperty(NAME_KEY, getName(config)); - - // connection configuration - props.putAll(getConnectionConfiguration(config)); - - // By default "decimal.handing.mode=precise" which's caused returning this value as a binary. - // The "double" type may cause a loss of precision, so set Debezium's config to store it as a String - // explicitly in its Kafka messages for more details see: - // https://debezium.io/documentation/reference/2.2/connectors/postgresql.html#postgresql-decimal-types - // https://debezium.io/documentation/faq/#how_to_retrieve_decimal_field_from_binary_representation - props.setProperty("decimal.handling.mode", "string"); - - // https://debezium.io/documentation/reference/2.2/connectors/postgresql.html#postgresql-property-max-queue-size-in-bytes - props.setProperty("max.queue.size.in.bytes", BYTE_VALUE_256_MB); - - // WARNING : Never change the value of this otherwise all the connectors would start syncing from - // scratch. - props.setProperty(TOPIC_PREFIX_KEY, sanitizeTopicPrefix(getName(config))); - // https://issues.redhat.com/browse/DBZ-7635 - // https://cwiki.apache.org/confluence/display/KAFKA/KIP-581%3A+Value+of+optional+null+field+which+has+default+value - // A null value in a column with default value won't be generated correctly in CDC unless we set the - // following - props.setProperty("value.converter.replace.null.with.default", "false"); - // includes - props.putAll(getIncludeConfiguration(catalog, config)); - - return props; - } - - public static String sanitizeTopicPrefix(final String topicName) { - StringBuilder sanitizedNameBuilder = new StringBuilder(topicName.length()); - boolean changed = false; - - for (int i = 0; i < topicName.length(); ++i) { - char c = topicName.charAt(i); - if (isValidCharacter(c)) { - sanitizedNameBuilder.append(c); - } else { - sanitizedNameBuilder.append(ReplacementFunction.UNDERSCORE_REPLACEMENT.replace(c)); - changed = true; - } - } - - if (changed) { - return sanitizedNameBuilder.toString(); - } else { - return topicName; - } - } - - // We need to keep the validation rule the same as debezium engine, which is defined here: - // https://github.com/debezium/debezium/blob/c51ef3099a688efb41204702d3aa6d4722bb4825/debezium-core/src/main/java/io/debezium/schema/AbstractTopicNamingStrategy.java#L178 - private static boolean isValidCharacter(char c) { - return c == '.' || c == '_' || c == '-' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9'; - } - - protected abstract Properties getConnectionConfiguration(final JsonNode config); - - protected abstract String getName(final JsonNode config); - - protected abstract Properties getIncludeConfiguration(final ConfiguredAirbyteCatalog catalog, final JsonNode config); - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.java deleted file mode 100644 index aa0dd90ea5ab..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.java +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.AbstractIterator; -import io.airbyte.cdk.integrations.debezium.CdcTargetPosition; -import io.airbyte.commons.lang.MoreBooleans; -import io.airbyte.commons.util.AutoCloseableIterator; -import io.debezium.engine.ChangeEvent; -import java.lang.reflect.Field; -import java.time.Duration; -import java.time.LocalDateTime; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; -import org.apache.kafka.connect.source.SourceRecord; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * The record iterator is the consumer (in the producer / consumer relationship with debezium) - * responsible for 1. making sure every record produced by the record publisher is processed 2. - * signalling to the record publisher when it is time for it to stop producing records. It emits - * this signal either when the publisher had not produced a new record for a long time or when it - * has processed at least all of the records that were present in the database when the source was - * started. Because the publisher might publish more records between the consumer sending this - * signal and the publisher actually shutting down, the consumer must stay alive as long as the - * publisher is not closed. Even after the publisher is closed, the consumer will finish processing - * any produced records before closing. - */ -public class DebeziumRecordIterator extends AbstractIterator - implements AutoCloseableIterator { - - private static final Logger LOGGER = LoggerFactory.getLogger(DebeziumRecordIterator.class); - - private final Map, Field> heartbeatEventSourceField; - private final LinkedBlockingQueue> queue; - private final CdcTargetPosition targetPosition; - private final Supplier publisherStatusSupplier; - private final Duration firstRecordWaitTime, subsequentRecordWaitTime; - private final DebeziumShutdownProcedure> debeziumShutdownProcedure; - - private boolean receivedFirstRecord; - private boolean hasSnapshotFinished; - private LocalDateTime tsLastHeartbeat; - private T lastHeartbeatPosition; - private int maxInstanceOfNoRecordsFound; - private boolean signalledDebeziumEngineShutdown; - - public DebeziumRecordIterator(final LinkedBlockingQueue> queue, - final CdcTargetPosition targetPosition, - final Supplier publisherStatusSupplier, - final DebeziumShutdownProcedure> debeziumShutdownProcedure, - final Duration firstRecordWaitTime, - final Duration subsequentRecordWaitTime) { - this.queue = queue; - this.targetPosition = targetPosition; - this.publisherStatusSupplier = publisherStatusSupplier; - this.debeziumShutdownProcedure = debeziumShutdownProcedure; - this.firstRecordWaitTime = firstRecordWaitTime; - this.subsequentRecordWaitTime = firstRecordWaitTime.dividedBy(2); - this.heartbeatEventSourceField = new HashMap<>(1); - - this.receivedFirstRecord = false; - this.hasSnapshotFinished = true; - this.tsLastHeartbeat = null; - this.lastHeartbeatPosition = null; - this.maxInstanceOfNoRecordsFound = 0; - this.signalledDebeziumEngineShutdown = false; - } - - // The following logic incorporates heartbeat (CDC postgres only for now): - // 1. Wait on queue either the configured time first or 1 min after a record received - // 2. If nothing came out of queue finish sync - // 3. If received heartbeat: check if hearbeat_lsn reached target or hasn't changed in a while - // finish sync - // 4. If change event lsn reached target finish sync - // 5. Otherwise check message queuen again - @Override - protected ChangeEventWithMetadata computeNext() { - // keep trying until the publisher is closed or until the queue is empty. the latter case is - // possible when the publisher has shutdown but the consumer has not yet processed all messages it - // emitted. - while (!MoreBooleans.isTruthy(publisherStatusSupplier.get()) || !queue.isEmpty()) { - final ChangeEvent next; - - final Duration waitTime = receivedFirstRecord ? this.subsequentRecordWaitTime : this.firstRecordWaitTime; - try { - next = queue.poll(waitTime.getSeconds(), TimeUnit.SECONDS); - } catch (final InterruptedException e) { - throw new RuntimeException(e); - } - - // if within the timeout, the consumer could not get a record, it is time to tell the producer to - // shutdown. - if (next == null) { - if (!receivedFirstRecord || hasSnapshotFinished || maxInstanceOfNoRecordsFound >= 10) { - requestClose(String.format("No records were returned by Debezium in the timeout seconds %s, closing the engine and iterator", - waitTime.getSeconds())); - } - LOGGER.info("no record found. polling again."); - maxInstanceOfNoRecordsFound++; - continue; - } - - if (isHeartbeatEvent(next)) { - if (!hasSnapshotFinished) { - continue; - } - - final T heartbeatPos = getHeartbeatPosition(next); - // wrap up sync if heartbeat position crossed the target OR heartbeat position hasn't changed for - // too long - if (targetPosition.reachedTargetPosition(heartbeatPos)) { - requestClose("Closing: Heartbeat indicates sync is done by reaching the target position"); - } else if (heartbeatPos.equals(this.lastHeartbeatPosition) && heartbeatPosNotChanging()) { - requestClose("Closing: Heartbeat indicates sync is not progressing"); - } - - if (!heartbeatPos.equals(lastHeartbeatPosition)) { - this.tsLastHeartbeat = LocalDateTime.now(); - this.lastHeartbeatPosition = heartbeatPos; - } - continue; - } - - final ChangeEventWithMetadata changeEventWithMetadata = new ChangeEventWithMetadata(next); - hasSnapshotFinished = !changeEventWithMetadata.isSnapshotEvent(); - - // if the last record matches the target file position, it is time to tell the producer to shutdown. - if (targetPosition.reachedTargetPosition(changeEventWithMetadata)) { - requestClose("Closing: Change event reached target position"); - } - this.tsLastHeartbeat = null; - this.lastHeartbeatPosition = null; - this.receivedFirstRecord = true; - this.maxInstanceOfNoRecordsFound = 0; - return changeEventWithMetadata; - } - - if (!signalledDebeziumEngineShutdown) { - LOGGER.warn("Debezium engine has not been signalled to shutdown, this is unexpected"); - } - - // Read the records that Debezium might have fetched right at the time we called shutdown - while (!debeziumShutdownProcedure.getRecordsRemainingAfterShutdown().isEmpty()) { - final ChangeEvent event; - try { - event = debeziumShutdownProcedure.getRecordsRemainingAfterShutdown().poll(100, TimeUnit.MILLISECONDS); - } catch (final InterruptedException e) { - throw new RuntimeException(e); - } - if (event == null || isHeartbeatEvent(event)) { - continue; - } - final ChangeEventWithMetadata changeEventWithMetadata = new ChangeEventWithMetadata(event); - hasSnapshotFinished = !changeEventWithMetadata.isSnapshotEvent(); - return changeEventWithMetadata; - } - throwExceptionIfSnapshotNotFinished(); - return endOfData(); - } - - /** - * Debezium was built as an ever running process which keeps on listening for new changes on DB and - * immediately processing them. Airbyte needs debezium to work as a start stop mechanism. In order - * to determine when to stop debezium engine we rely on few factors 1. TargetPosition logic. At the - * beginning of the sync we define a target position in the logs of the DB. This can be an LSN or - * anything specific to the DB which can help us identify that we have reached a specific position - * in the log based replication When we start processing records from debezium, we extract the the - * log position from the metadata of the record and compare it with our target that we defined at - * the beginning of the sync. If we have reached the target position, we shutdown the debezium - * engine 2. The TargetPosition logic might not always work and in order to tackle that we have - * another logic where if we do not receive records from debezium for a given duration, we ask - * debezium engine to shutdown 3. We also take the Snapshot into consideration, when a connector is - * running for the first time, we let it complete the snapshot and only after the completion of - * snapshot we should shutdown the engine. If we are closing the engine before completion of - * snapshot, we throw an exception - */ - @Override - public void close() throws Exception { - requestClose("Closing: Iterator closing"); - } - - private boolean isHeartbeatEvent(final ChangeEvent event) { - return targetPosition.isHeartbeatSupported() && Objects.nonNull(event) && !event.value().contains("source"); - } - - private boolean heartbeatPosNotChanging() { - if (this.tsLastHeartbeat == null) { - return false; - } - final Duration timeElapsedSinceLastHeartbeatTs = Duration.between(this.tsLastHeartbeat, LocalDateTime.now()); - LOGGER.info("Time since last hb_pos change {}s", timeElapsedSinceLastHeartbeatTs.toSeconds()); - // wait time for no change in heartbeat position is half of initial waitTime - return timeElapsedSinceLastHeartbeatTs.compareTo(this.firstRecordWaitTime.dividedBy(2)) > 0; - } - - private void requestClose(final String closeLogMessage) { - if (signalledDebeziumEngineShutdown) { - return; - } - LOGGER.info(closeLogMessage); - debeziumShutdownProcedure.initiateShutdownProcedure(); - signalledDebeziumEngineShutdown = true; - } - - private void throwExceptionIfSnapshotNotFinished() { - if (!hasSnapshotFinished) { - throw new RuntimeException("Closing down debezium engine but snapshot has not finished"); - } - } - - /** - * {@link DebeziumRecordIterator#heartbeatEventSourceField} acts as a cache so that we avoid using - * reflection to setAccessible for each event - */ - @VisibleForTesting - protected T getHeartbeatPosition(final ChangeEvent heartbeatEvent) { - - try { - final Class eventClass = heartbeatEvent.getClass(); - final Field f; - if (heartbeatEventSourceField.containsKey(eventClass)) { - f = heartbeatEventSourceField.get(eventClass); - } else { - f = eventClass.getDeclaredField("sourceRecord"); - f.setAccessible(true); - heartbeatEventSourceField.put(eventClass, f); - - if (heartbeatEventSourceField.size() > 1) { - LOGGER.warn("Field Cache size growing beyond expected size of 1, size is " + heartbeatEventSourceField.size()); - } - } - - final SourceRecord sr = (SourceRecord) f.get(heartbeatEvent); - return targetPosition.extractPositionFromHeartbeatOffset(sr.sourceOffset()); - } catch (final NoSuchFieldException | IllegalAccessException e) { - LOGGER.info("failed to get heartbeat source offset"); - throw new RuntimeException(e); - } - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.java deleted file mode 100644 index 93a05b70f586..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import io.debezium.engine.ChangeEvent; -import io.debezium.engine.DebeziumEngine; -import io.debezium.engine.format.Json; -import io.debezium.engine.spi.OffsetCommitPolicy; -import java.util.Optional; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * The purpose of this class is to initialize and spawn the debezium engine with the right - * properties to fetch records - */ -public class DebeziumRecordPublisher implements AutoCloseable { - - private static final Logger LOGGER = LoggerFactory.getLogger(DebeziumRecordPublisher.class); - private final ExecutorService executor; - private DebeziumEngine> engine; - private final AtomicBoolean hasClosed; - private final AtomicBoolean isClosing; - private final AtomicReference thrownError; - private final CountDownLatch engineLatch; - private final DebeziumPropertiesManager debeziumPropertiesManager; - - public DebeziumRecordPublisher(DebeziumPropertiesManager debeziumPropertiesManager) { - this.debeziumPropertiesManager = debeziumPropertiesManager; - this.hasClosed = new AtomicBoolean(false); - this.isClosing = new AtomicBoolean(false); - this.thrownError = new AtomicReference<>(); - this.executor = Executors.newSingleThreadExecutor(); - this.engineLatch = new CountDownLatch(1); - } - - public void start(final BlockingQueue> queue, - final AirbyteFileOffsetBackingStore offsetManager, - final Optional schemaHistoryManager) { - engine = DebeziumEngine.create(Json.class) - .using(debeziumPropertiesManager.getDebeziumProperties(offsetManager, schemaHistoryManager)) - .using(new OffsetCommitPolicy.AlwaysCommitOffsetPolicy()) - .notifying(e -> { - // debezium outputs a tombstone event that has a value of null. this is an artifact of how it - // interacts with kafka. we want to ignore it. - // more on the tombstone: - // https://debezium.io/documentation/reference/2.2/transformations/event-flattening.html - if (e.value() != null) { - try { - queue.put(e); - } catch (final InterruptedException ex) { - Thread.currentThread().interrupt(); - throw new RuntimeException(ex); - } - } - }) - .using((success, message, error) -> { - LOGGER.info("Debezium engine shutdown. Engine terminated successfully : {}", success); - LOGGER.info(message); - if (!success) { - if (error != null) { - thrownError.set(error); - } else { - // There are cases where Debezium doesn't succeed but only fills the message field. - // In that case, we still want to fail loud and clear - thrownError.set(new RuntimeException(message)); - } - } - engineLatch.countDown(); - }) - .build(); - - // Run the engine asynchronously ... - executor.execute(engine); - } - - public boolean hasClosed() { - return hasClosed.get(); - } - - public void close() throws Exception { - if (isClosing.compareAndSet(false, true)) { - // consumers should assume records can be produced until engine has closed. - if (engine != null) { - engine.close(); - } - - // wait for closure before shutting down executor service - engineLatch.await(5, TimeUnit.MINUTES); - - // shut down and await for thread to actually go down - executor.shutdown(); - executor.awaitTermination(5, TimeUnit.MINUTES); - - // after the engine is completely off, we can mark this as closed - hasClosed.set(true); - - if (thrownError.get() != null) { - throw new RuntimeException(thrownError.get()); - } - } - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.java deleted file mode 100644 index d0661e0a7cdc..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import io.airbyte.commons.concurrency.VoidCallable; -import io.airbyte.commons.lang.MoreBooleans; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This class has the logic for shutting down Debezium Engine in graceful manner. We made it Generic - * to allow us to write tests easily. - */ -public class DebeziumShutdownProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(DebeziumShutdownProcedure.class); - private final LinkedBlockingQueue sourceQueue; - private final LinkedBlockingQueue targetQueue; - private final ExecutorService executorService; - private final Supplier publisherStatusSupplier; - private final VoidCallable debeziumThreadRequestClose; - private Throwable exception; - private boolean hasTransferThreadShutdown; - - public DebeziumShutdownProcedure(final LinkedBlockingQueue sourceQueue, - final VoidCallable debeziumThreadRequestClose, - final Supplier publisherStatusSupplier) { - this.sourceQueue = sourceQueue; - this.targetQueue = new LinkedBlockingQueue<>(); - this.debeziumThreadRequestClose = debeziumThreadRequestClose; - this.publisherStatusSupplier = publisherStatusSupplier; - this.hasTransferThreadShutdown = false; - this.executorService = Executors.newSingleThreadExecutor(r -> { - final Thread thread = new Thread(r, "queue-data-transfer-thread"); - thread.setUncaughtExceptionHandler((t, e) -> { - exception = e; - }); - return thread; - }); - } - - private Runnable transfer() { - return () -> { - while (!sourceQueue.isEmpty() || !hasEngineShutDown()) { - try { - final T event = sourceQueue.poll(100, TimeUnit.MILLISECONDS); - if (event != null) { - targetQueue.put(event); - } - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } - } - }; - } - - private boolean hasEngineShutDown() { - return MoreBooleans.isTruthy(publisherStatusSupplier.get()); - } - - private void initiateTransfer() { - executorService.execute(transfer()); - } - - public LinkedBlockingQueue getRecordsRemainingAfterShutdown() { - if (!hasTransferThreadShutdown) { - LOGGER.warn("Queue transfer thread has not shut down, some records might be missing."); - } - return targetQueue; - } - - /** - * This method triggers the shutdown of Debezium Engine. When we trigger Debezium shutdown, the main - * thread pauses, as a result we stop reading data from the {@link sourceQueue} and since the queue - * is of fixed size, if it's already at capacity, Debezium won't be able to put remaining records in - * the queue. So before we trigger Debezium shutdown, we initiate a transfer of the records from the - * {@link sourceQueue} to a new queue i.e. {@link targetQueue}. This allows Debezium to continue to - * put records in the {@link sourceQueue} and once done, gracefully shutdown. After the shutdown is - * complete we just have to read the remaining records from the {@link targetQueue} - */ - public void initiateShutdownProcedure() { - if (hasEngineShutDown()) { - LOGGER.info("Debezium Engine has already shut down."); - return; - } - Exception exceptionDuringEngineClose = null; - try { - initiateTransfer(); - debeziumThreadRequestClose.call(); - } catch (final Exception e) { - exceptionDuringEngineClose = e; - throw new RuntimeException(e); - } finally { - try { - shutdownTransferThread(); - } catch (final Exception e) { - if (exceptionDuringEngineClose != null) { - e.addSuppressed(exceptionDuringEngineClose); - throw e; - } - } - } - } - - private void shutdownTransferThread() { - executorService.shutdown(); - boolean terminated = false; - while (!terminated) { - try { - terminated = executorService.awaitTermination(5, TimeUnit.MINUTES); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } - } - hasTransferThreadShutdown = true; - if (exception != null) { - throw new RuntimeException(exception); - } - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.java deleted file mode 100644 index 243faccb5939..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import io.debezium.config.Configuration; -import io.debezium.embedded.KafkaConnectUtil; -import java.util.Map; -import java.util.Properties; -import org.apache.kafka.connect.json.JsonConverter; -import org.apache.kafka.connect.json.JsonConverterConfig; -import org.apache.kafka.connect.runtime.WorkerConfig; -import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; -import org.apache.kafka.connect.storage.FileOffsetBackingStore; -import org.apache.kafka.connect.storage.OffsetStorageReaderImpl; - -/** - * Represents a utility class that assists with the parsing of Debezium offset state. - */ -public interface DebeziumStateUtil { - - /** - * The name of the Debezium property that contains the unique name for the Debezium connector. - */ - String CONNECTOR_NAME_PROPERTY = "name"; - - /** - * Configuration for offset state key/value converters. - */ - Map INTERNAL_CONVERTER_CONFIG = Map.of(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG, Boolean.FALSE.toString()); - - /** - * Creates and starts a {@link FileOffsetBackingStore} that is used to store the tracked Debezium - * offset state. - * - * @param properties The Debezium configuration properties for the selected Debezium connector. - * @return A configured and started {@link FileOffsetBackingStore} instance. - */ - default FileOffsetBackingStore getFileOffsetBackingStore(final Properties properties) { - final FileOffsetBackingStore fileOffsetBackingStore = KafkaConnectUtil.fileOffsetBackingStore(); - final Map propertiesMap = Configuration.from(properties).asMap(); - propertiesMap.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, JsonConverter.class.getName()); - propertiesMap.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, JsonConverter.class.getName()); - fileOffsetBackingStore.configure(new StandaloneConfig(propertiesMap)); - fileOffsetBackingStore.start(); - return fileOffsetBackingStore; - } - - /** - * Creates and returns a {@link JsonConverter} that can be used to parse keys in the Debezium offset - * state storage. - * - * @return A {@link JsonConverter} for key conversion. - */ - default JsonConverter getKeyConverter() { - final JsonConverter keyConverter = new JsonConverter(); - keyConverter.configure(INTERNAL_CONVERTER_CONFIG, true); - return keyConverter; - } - - /** - * Creates and returns an {@link OffsetStorageReaderImpl} instance that can be used to load offset - * state from the offset file storage. - * - * @param fileOffsetBackingStore The {@link FileOffsetBackingStore} that contains the offset state - * saved to disk. - * @param properties The Debezium configuration properties for the selected Debezium connector. - * @return An {@link OffsetStorageReaderImpl} instance that can be used to load the offset state - * from the offset file storage. - */ - default OffsetStorageReaderImpl getOffsetStorageReader(final FileOffsetBackingStore fileOffsetBackingStore, final Properties properties) { - return new OffsetStorageReaderImpl(fileOffsetBackingStore, properties.getProperty(CONNECTOR_NAME_PROPERTY), getKeyConverter(), - getValueConverter()); - } - - /** - * Creates and returns a {@link JsonConverter} that can be used to parse values in the Debezium - * offset state storage. - * - * @return A {@link JsonConverter} for value conversion. - */ - default JsonConverter getValueConverter() { - final JsonConverter valueConverter = new JsonConverter(); - valueConverter.configure(INTERNAL_CONVERTER_CONFIG, false); - return valueConverter; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.java deleted file mode 100644 index be44f8882044..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.fasterxml.jackson.databind.JsonNode; -import java.time.Duration; -import java.util.Optional; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class RecordWaitTimeUtil { - - private static final Logger LOGGER = LoggerFactory.getLogger(RecordWaitTimeUtil.class); - - public static final Duration MIN_FIRST_RECORD_WAIT_TIME = Duration.ofMinutes(2); - public static final Duration MAX_FIRST_RECORD_WAIT_TIME = Duration.ofMinutes(40); - public static final Duration DEFAULT_FIRST_RECORD_WAIT_TIME = Duration.ofMinutes(5); - public static final Duration DEFAULT_SUBSEQUENT_RECORD_WAIT_TIME = Duration.ofMinutes(1); - - public static void checkFirstRecordWaitTime(final JsonNode config) { - // we need to skip the check because in tests, we set initial_waiting_seconds - // to 5 seconds for performance reasons, which is shorter than the minimum - // value allowed in production - if (config.has("is_test") && config.get("is_test").asBoolean()) { - return; - } - - final Optional firstRecordWaitSeconds = getFirstRecordWaitSeconds(config); - if (firstRecordWaitSeconds.isPresent()) { - final int seconds = firstRecordWaitSeconds.get(); - if (seconds < MIN_FIRST_RECORD_WAIT_TIME.getSeconds() || seconds > MAX_FIRST_RECORD_WAIT_TIME.getSeconds()) { - throw new IllegalArgumentException( - String.format("initial_waiting_seconds must be between %d and %d seconds", - MIN_FIRST_RECORD_WAIT_TIME.getSeconds(), MAX_FIRST_RECORD_WAIT_TIME.getSeconds())); - } - } - } - - public static Duration getFirstRecordWaitTime(final JsonNode config) { - final boolean isTest = config.has("is_test") && config.get("is_test").asBoolean(); - Duration firstRecordWaitTime = DEFAULT_FIRST_RECORD_WAIT_TIME; - - final Optional firstRecordWaitSeconds = getFirstRecordWaitSeconds(config); - if (firstRecordWaitSeconds.isPresent()) { - firstRecordWaitTime = Duration.ofSeconds(firstRecordWaitSeconds.get()); - if (!isTest && firstRecordWaitTime.compareTo(MIN_FIRST_RECORD_WAIT_TIME) < 0) { - LOGGER.warn("First record waiting time is overridden to {} minutes, which is the min time allowed for safety.", - MIN_FIRST_RECORD_WAIT_TIME.toMinutes()); - firstRecordWaitTime = MIN_FIRST_RECORD_WAIT_TIME; - } else if (!isTest && firstRecordWaitTime.compareTo(MAX_FIRST_RECORD_WAIT_TIME) > 0) { - LOGGER.warn("First record waiting time is overridden to {} minutes, which is the max time allowed for safety.", - MAX_FIRST_RECORD_WAIT_TIME.toMinutes()); - firstRecordWaitTime = MAX_FIRST_RECORD_WAIT_TIME; - } - } - - LOGGER.info("First record waiting time: {} seconds", firstRecordWaitTime.getSeconds()); - return firstRecordWaitTime; - } - - public static Duration getSubsequentRecordWaitTime(final JsonNode config) { - Duration subsequentRecordWaitTime = DEFAULT_SUBSEQUENT_RECORD_WAIT_TIME; - final boolean isTest = config.has("is_test") && config.get("is_test").asBoolean(); - final Optional firstRecordWaitSeconds = getFirstRecordWaitSeconds(config); - if (isTest && firstRecordWaitSeconds.isPresent()) { - // In tests, reuse the initial_waiting_seconds property to speed things up. - subsequentRecordWaitTime = Duration.ofSeconds(firstRecordWaitSeconds.get()); - } - LOGGER.info("Subsequent record waiting time: {} seconds", subsequentRecordWaitTime.getSeconds()); - return subsequentRecordWaitTime; - } - - public static Optional getFirstRecordWaitSeconds(final JsonNode config) { - final JsonNode replicationMethod = config.get("replication_method"); - if (replicationMethod != null && replicationMethod.has("initial_waiting_seconds")) { - final int seconds = config.get("replication_method").get("initial_waiting_seconds").asInt(); - return Optional.of(seconds); - } - return Optional.empty(); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.java deleted file mode 100644 index 4003007ba807..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import io.airbyte.cdk.integrations.debezium.CdcMetadataInjector; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import java.time.Instant; - -public class RelationalDbDebeziumEventConverter implements DebeziumEventConverter { - - private final CdcMetadataInjector cdcMetadataInjector; - private final Instant emittedAt; - - public RelationalDbDebeziumEventConverter(CdcMetadataInjector cdcMetadataInjector, Instant emittedAt) { - this.cdcMetadataInjector = cdcMetadataInjector; - this.emittedAt = emittedAt; - } - - @Override - public AirbyteMessage toAirbyteMessage(ChangeEventWithMetadata event) { - final JsonNode debeziumEvent = event.eventValueAsJson(); - final JsonNode before = debeziumEvent.get(DebeziumEventConverter.BEFORE_EVENT); - final JsonNode after = debeziumEvent.get(DebeziumEventConverter.AFTER_EVENT); - final JsonNode source = debeziumEvent.get(DebeziumEventConverter.SOURCE_EVENT); - - final ObjectNode baseNode = (ObjectNode) (after.isNull() ? before : after); - final JsonNode data = DebeziumEventConverter.addCdcMetadata(baseNode, source, cdcMetadataInjector, after.isNull()); - return DebeziumEventConverter.buildAirbyteMessage(source, cdcMetadataInjector, emittedAt, data); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.java deleted file mode 100644 index 53af1cf72656..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.Iterator; -import java.util.Properties; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; -import org.codehaus.plexus.util.StringUtils; - -public class RelationalDbDebeziumPropertiesManager extends DebeziumPropertiesManager { - - public RelationalDbDebeziumPropertiesManager(final Properties properties, - final JsonNode config, - final ConfiguredAirbyteCatalog catalog) { - super(properties, config, catalog); - } - - @Override - protected Properties getConnectionConfiguration(JsonNode config) { - final Properties properties = new Properties(); - - // db connection configuration - properties.setProperty("database.hostname", config.get(JdbcUtils.HOST_KEY).asText()); - properties.setProperty("database.port", config.get(JdbcUtils.PORT_KEY).asText()); - properties.setProperty("database.user", config.get(JdbcUtils.USERNAME_KEY).asText()); - properties.setProperty("database.dbname", config.get(JdbcUtils.DATABASE_KEY).asText()); - - if (config.has(JdbcUtils.PASSWORD_KEY)) { - properties.setProperty("database.password", config.get(JdbcUtils.PASSWORD_KEY).asText()); - } - - return properties; - } - - @Override - protected String getName(JsonNode config) { - return config.get(JdbcUtils.DATABASE_KEY).asText(); - } - - @Override - protected Properties getIncludeConfiguration(ConfiguredAirbyteCatalog catalog, JsonNode config) { - final Properties properties = new Properties(); - - // table selection - properties.setProperty("table.include.list", getTableIncludelist(catalog)); - // column selection - properties.setProperty("column.include.list", getColumnIncludeList(catalog)); - - return properties; - } - - public static String getTableIncludelist(final ConfiguredAirbyteCatalog catalog) { - // Turn "stream": { - // "namespace": "schema1" - // "name": "table1 - // }, - // "stream": { - // "namespace": "schema2" - // "name": "table2 - // } -------> info "schema1.table1, schema2.table2" - - return catalog.getStreams().stream() - .filter(s -> s.getSyncMode() == SyncMode.INCREMENTAL) - .map(ConfiguredAirbyteStream::getStream) - .map(stream -> stream.getNamespace() + "." + stream.getName()) - // debezium needs commas escaped to split properly - .map(x -> StringUtils.escape(Pattern.quote(x), ",".toCharArray(), "\\,")) - .collect(Collectors.joining(",")); - } - - public static String getColumnIncludeList(final ConfiguredAirbyteCatalog catalog) { - // Turn "stream": { - // "namespace": "schema1" - // "name": "table1" - // "jsonSchema": { - // "properties": { - // "column1": { - // }, - // "column2": { - // } - // } - // } - // } -------> info "schema1.table1.(column1 | column2)" - - return catalog.getStreams().stream() - .filter(s -> s.getSyncMode() == SyncMode.INCREMENTAL) - .map(ConfiguredAirbyteStream::getStream) - .map(s -> { - final String fields = parseFields(s.getJsonSchema().get("properties").fieldNames()); - // schema.table.(col1|col2) - return Pattern.quote(s.getNamespace() + "." + s.getName()) + (StringUtils.isNotBlank(fields) ? "\\." + fields : ""); - }) - .map(x -> StringUtils.escape(x, ",".toCharArray(), "\\,")) - .collect(Collectors.joining(",")); - } - - private static String parseFields(final Iterator fieldNames) { - if (fieldNames == null || !fieldNames.hasNext()) { - return ""; - } - final Iterable iter = () -> fieldNames; - return StreamSupport.stream(iter.spliterator(), false) - .map(f -> Pattern.quote(f)) - .collect(Collectors.joining("|", "(", ")")); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.java deleted file mode 100644 index 35dcbe119f9d..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import com.google.common.collect.ImmutableSet; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -public enum SnapshotMetadata { - - FIRST, - FIRST_IN_DATA_COLLECTION, - LAST_IN_DATA_COLLECTION, - TRUE, - LAST, - FALSE, - NULL; - - private static final Set ENTRIES_OF_SNAPSHOT_EVENTS = - ImmutableSet.of(TRUE, FIRST, FIRST_IN_DATA_COLLECTION, LAST_IN_DATA_COLLECTION); - private static final Map STRING_TO_ENUM; - static { - STRING_TO_ENUM = new HashMap<>(12); - STRING_TO_ENUM.put("true", TRUE); - STRING_TO_ENUM.put("TRUE", TRUE); - STRING_TO_ENUM.put("false", FALSE); - STRING_TO_ENUM.put("FALSE", FALSE); - STRING_TO_ENUM.put("last", LAST); - STRING_TO_ENUM.put("LAST", LAST); - STRING_TO_ENUM.put("first", FIRST); - STRING_TO_ENUM.put("FIRST", FIRST); - STRING_TO_ENUM.put("last_in_data_collection", LAST_IN_DATA_COLLECTION); - STRING_TO_ENUM.put("LAST_IN_DATA_COLLECTION", LAST_IN_DATA_COLLECTION); - STRING_TO_ENUM.put("first_in_data_collection", FIRST_IN_DATA_COLLECTION); - STRING_TO_ENUM.put("FIRST_IN_DATA_COLLECTION", FIRST_IN_DATA_COLLECTION); - STRING_TO_ENUM.put("NULL", NULL); - STRING_TO_ENUM.put("null", NULL); - } - - public static SnapshotMetadata fromString(final String value) { - if (STRING_TO_ENUM.containsKey(value)) { - return STRING_TO_ENUM.get(value); - } - throw new RuntimeException("ENUM value not found for " + value); - } - - public static boolean isSnapshotEventMetadata(final SnapshotMetadata snapshotMetadata) { - return ENTRIES_OF_SNAPSHOT_EVENTS.contains(snapshotMetadata); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.java deleted file mode 100644 index bd5b83880826..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.java +++ /dev/null @@ -1,520 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc; - -import static io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_SIZE; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_TYPE; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_TYPE_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_DECIMAL_DIGITS; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_IS_NULLABLE; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_SCHEMA_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_TABLE_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_COLUMN_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_DATABASE_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_DATA_TYPE; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_SCHEMA_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_SIZE; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_TABLE_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_TYPE_NAME; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_DECIMAL_DIGITS; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_IS_NULLABLE; -import static io.airbyte.cdk.db.jdbc.JdbcConstants.KEY_SEQ; -import static io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils.enquoteIdentifier; -import static io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils.enquoteIdentifierList; -import static io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Sets; -import datadog.trace.api.Trace; -import io.airbyte.cdk.db.JdbcCompatibleSourceOperations; -import io.airbyte.cdk.db.SqlDatabase; -import io.airbyte.cdk.db.factory.DataSourceFactory; -import io.airbyte.cdk.db.jdbc.AirbyteRecordData; -import io.airbyte.cdk.db.jdbc.JdbcDatabase; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.db.jdbc.StreamingJdbcDatabase; -import io.airbyte.cdk.db.jdbc.streaming.JdbcStreamingQueryConfig; -import io.airbyte.cdk.integrations.base.Source; -import io.airbyte.cdk.integrations.source.jdbc.dto.JdbcPrivilegeDto; -import io.airbyte.cdk.integrations.source.relationaldb.AbstractDbSource; -import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo; -import io.airbyte.cdk.integrations.source.relationaldb.TableInfo; -import io.airbyte.cdk.integrations.source.relationaldb.state.StateManager; -import io.airbyte.commons.functional.CheckedConsumer; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.stream.AirbyteStreamUtils; -import io.airbyte.commons.util.AutoCloseableIterator; -import io.airbyte.commons.util.AutoCloseableIterators; -import io.airbyte.protocol.models.CommonField; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.SyncMode; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.function.Predicate; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import javax.sql.DataSource; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This class contains helper functions and boilerplate for implementing a source connector for a - * relational DB source which can be accessed via JDBC driver. If you are implementing a connector - * for a relational DB which has a JDBC driver, make an effort to use this class. - */ -public abstract class AbstractJdbcSource extends AbstractDbSource implements Source { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractJdbcSource.class); - - protected final Supplier streamingQueryConfigProvider; - protected final JdbcCompatibleSourceOperations sourceOperations; - - protected String quoteString; - protected Collection dataSources = new ArrayList<>(); - - public AbstractJdbcSource(final String driverClass, - final Supplier streamingQueryConfigProvider, - final JdbcCompatibleSourceOperations sourceOperations) { - super(driverClass); - this.streamingQueryConfigProvider = streamingQueryConfigProvider; - this.sourceOperations = sourceOperations; - } - - @Override - protected AutoCloseableIterator queryTableFullRefresh(final JdbcDatabase database, - final List columnNames, - final String schemaName, - final String tableName, - final SyncMode syncMode, - final Optional cursorField) { - LOGGER.info("Queueing query for table: {}", tableName); - final io.airbyte.protocol.models.AirbyteStreamNameNamespacePair airbyteStream = - AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName); - return AutoCloseableIterators.lazyIterator(() -> { - try { - final Stream stream = database.unsafeQuery( - connection -> { - LOGGER.info("Preparing query for table: {}", tableName); - final String fullTableName = getFullyQualifiedTableNameWithQuoting(schemaName, tableName, getQuoteString()); - - final String wrappedColumnNames = getWrappedColumnNames(database, connection, columnNames, schemaName, tableName); - final StringBuilder sql = new StringBuilder(String.format("SELECT %s FROM %s", - wrappedColumnNames, - fullTableName)); - // if the connector emits intermediate states, the incremental query must be sorted by the cursor - // field - if (syncMode.equals(SyncMode.INCREMENTAL) && getStateEmissionFrequency() > 0) { - final String quotedCursorField = enquoteIdentifier(cursorField.get(), getQuoteString()); - sql.append(String.format(" ORDER BY %s ASC", quotedCursorField)); - } - - final PreparedStatement preparedStatement = connection.prepareStatement(sql.toString()); - LOGGER.info("Executing query for table {}: {}", tableName, preparedStatement); - return preparedStatement; - }, - sourceOperations::convertDatabaseRowToAirbyteRecordData); - return AutoCloseableIterators.fromStream(stream, airbyteStream); - } catch (final SQLException e) { - throw new RuntimeException(e); - } - }, airbyteStream); - } - - /** - * Configures a list of operations that can be used to check the connection to the source. - * - * @return list of consumers that run queries for the check command. - */ - @Trace(operationName = CHECK_TRACE_OPERATION_NAME) - protected List> getCheckOperations(final JsonNode config) throws Exception { - return ImmutableList.of(database -> { - LOGGER.info("Attempting to get metadata from the database to see if we can connect."); - database.bufferedResultSetQuery(connection -> connection.getMetaData().getCatalogs(), sourceOperations::rowToJson); - }); - } - - /** - * Aggregate list of @param entries of StreamName and PrimaryKey and - * - * @return a map by StreamName to associated list of primary keys - */ - @VisibleForTesting - public static Map> aggregatePrimateKeys(final List entries) { - final Map> result = new HashMap<>(); - entries.stream().sorted(Comparator.comparingInt(PrimaryKeyAttributesFromDb::keySequence)).forEach(entry -> { - if (!result.containsKey(entry.streamName())) { - result.put(entry.streamName(), new ArrayList<>()); - } - result.get(entry.streamName()).add(entry.primaryKey()); - }); - return result; - } - - private String getCatalog(final SqlDatabase database) { - JsonNode sourceConfig = database.sourceConfig; - if (sourceConfig != null) { - return (sourceConfig.has(JdbcUtils.DATABASE_KEY) ? sourceConfig.get(JdbcUtils.DATABASE_KEY).asText() : null); - } - throw new NullPointerException(); - } - - @Override - protected List>> discoverInternal(final JdbcDatabase database, final String schema) throws Exception { - final Set internalSchemas = new HashSet<>(getExcludedInternalNameSpaces()); - LOGGER.info("Internal schemas to exclude: {}", internalSchemas); - final Set tablesWithSelectGrantPrivilege = getPrivilegesTableForCurrentUser(database, schema); - return database.bufferedResultSetQuery( - // retrieve column metadata from the database - connection -> connection.getMetaData().getColumns(getCatalog(database), schema, null, null), - // store essential column metadata to a Json object from the result set about each column - this::getColumnMetadata) - .stream() - .filter(excludeNotAccessibleTables(internalSchemas, tablesWithSelectGrantPrivilege)) - // group by schema and table name to handle the case where a table with the same name exists in - // multiple schemas. - .collect(Collectors.groupingBy(t -> ImmutablePair.of(t.get(INTERNAL_SCHEMA_NAME).asText(), t.get(INTERNAL_TABLE_NAME).asText()))) - .values() - .stream() - .map(fields -> TableInfo.>builder() - .nameSpace(fields.get(0).get(INTERNAL_SCHEMA_NAME).asText()) - .name(fields.get(0).get(INTERNAL_TABLE_NAME).asText()) - .fields(fields.stream() - // read the column metadata Json object, and determine its type - .map(f -> { - final Datatype datatype = sourceOperations.getDatabaseFieldType(f); - final JsonSchemaType jsonType = getAirbyteType(datatype); - LOGGER.debug("Table {} column {} (type {}[{}], nullable {}) -> {}", - fields.get(0).get(INTERNAL_TABLE_NAME).asText(), - f.get(INTERNAL_COLUMN_NAME).asText(), - f.get(INTERNAL_COLUMN_TYPE_NAME).asText(), - f.get(INTERNAL_COLUMN_SIZE).asInt(), - f.get(INTERNAL_IS_NULLABLE).asBoolean(), - jsonType); - return new CommonField(f.get(INTERNAL_COLUMN_NAME).asText(), datatype) {}; - }) - .collect(Collectors.toList())) - .cursorFields(extractCursorFields(fields)) - .build()) - .collect(Collectors.toList()); - } - - private List extractCursorFields(final List fields) { - return fields.stream() - .filter(field -> isCursorType(sourceOperations.getDatabaseFieldType(field))) - .map(field -> field.get(INTERNAL_COLUMN_NAME).asText()) - .collect(Collectors.toList()); - } - - protected Predicate excludeNotAccessibleTables(final Set internalSchemas, - final Set tablesWithSelectGrantPrivilege) { - return jsonNode -> { - if (tablesWithSelectGrantPrivilege.isEmpty()) { - return isNotInternalSchema(jsonNode, internalSchemas); - } - return tablesWithSelectGrantPrivilege.stream() - .anyMatch(e -> e.getSchemaName().equals(jsonNode.get(INTERNAL_SCHEMA_NAME).asText())) - && tablesWithSelectGrantPrivilege.stream() - .anyMatch(e -> e.getTableName().equals(jsonNode.get(INTERNAL_TABLE_NAME).asText())) - && !internalSchemas.contains(jsonNode.get(INTERNAL_SCHEMA_NAME).asText()); - }; - } - - // needs to override isNotInternalSchema for connectors that override - // getPrivilegesTableForCurrentUser() - protected boolean isNotInternalSchema(final JsonNode jsonNode, final Set internalSchemas) { - return !internalSchemas.contains(jsonNode.get(INTERNAL_SCHEMA_NAME).asText()); - } - - /** - * @param resultSet Description of a column available in the table catalog. - * @return Essential information about a column to determine which table it belongs to and its type. - */ - private JsonNode getColumnMetadata(final ResultSet resultSet) throws SQLException { - final var fieldMap = ImmutableMap.builder() - // we always want a namespace, if we cannot get a schema, use db name. - .put(INTERNAL_SCHEMA_NAME, - resultSet.getObject(JDBC_COLUMN_SCHEMA_NAME) != null ? resultSet.getString(JDBC_COLUMN_SCHEMA_NAME) - : resultSet.getObject(JDBC_COLUMN_DATABASE_NAME)) - .put(INTERNAL_TABLE_NAME, resultSet.getString(JDBC_COLUMN_TABLE_NAME)) - .put(INTERNAL_COLUMN_NAME, resultSet.getString(JDBC_COLUMN_COLUMN_NAME)) - .put(INTERNAL_COLUMN_TYPE, resultSet.getString(JDBC_COLUMN_DATA_TYPE)) - .put(INTERNAL_COLUMN_TYPE_NAME, resultSet.getString(JDBC_COLUMN_TYPE_NAME)) - .put(INTERNAL_COLUMN_SIZE, resultSet.getInt(JDBC_COLUMN_SIZE)) - .put(INTERNAL_IS_NULLABLE, resultSet.getString(JDBC_IS_NULLABLE)); - if (resultSet.getString(JDBC_DECIMAL_DIGITS) != null) { - fieldMap.put(INTERNAL_DECIMAL_DIGITS, resultSet.getString(JDBC_DECIMAL_DIGITS)); - } - return Jsons.jsonNode(fieldMap.build()); - } - - @Override - public List>> discoverInternal(final JdbcDatabase database) - throws Exception { - return discoverInternal(database, null); - } - - @Override - public JsonSchemaType getAirbyteType(final Datatype columnType) { - return sourceOperations.getAirbyteType(columnType); - } - - @VisibleForTesting - public record PrimaryKeyAttributesFromDb(String streamName, - String primaryKey, - int keySequence) { - - } - - @Override - protected Map> discoverPrimaryKeys(final JdbcDatabase database, - final List>> tableInfos) { - LOGGER.info("Discover primary keys for tables: " + tableInfos.stream().map(TableInfo::getName).collect( - Collectors.toSet())); - try { - // Get all primary keys without specifying a table name - final Map> tablePrimaryKeys = aggregatePrimateKeys(database.bufferedResultSetQuery( - connection -> connection.getMetaData().getPrimaryKeys(getCatalog(database), null, null), - r -> { - final String schemaName = - r.getObject(JDBC_COLUMN_SCHEMA_NAME) != null ? r.getString(JDBC_COLUMN_SCHEMA_NAME) : r.getString(JDBC_COLUMN_DATABASE_NAME); - final String streamName = JdbcUtils.getFullyQualifiedTableName(schemaName, r.getString(JDBC_COLUMN_TABLE_NAME)); - final String primaryKey = r.getString(JDBC_COLUMN_COLUMN_NAME); - final int keySeq = r.getInt(KEY_SEQ); - return new PrimaryKeyAttributesFromDb(streamName, primaryKey, keySeq); - })); - if (!tablePrimaryKeys.isEmpty()) { - return tablePrimaryKeys; - } - } catch (final SQLException e) { - LOGGER.debug(String.format("Could not retrieve primary keys without a table name (%s), retrying", e)); - } - // Get primary keys one table at a time - return tableInfos.stream() - .collect(Collectors.toMap( - tableInfo -> JdbcUtils.getFullyQualifiedTableName(tableInfo.getNameSpace(), tableInfo.getName()), - tableInfo -> { - final String streamName = JdbcUtils.getFullyQualifiedTableName(tableInfo.getNameSpace(), tableInfo.getName()); - try { - final Map> primaryKeys = aggregatePrimateKeys(database.bufferedResultSetQuery( - connection -> connection.getMetaData().getPrimaryKeys(getCatalog(database), tableInfo.getNameSpace(), tableInfo.getName()), - r -> new PrimaryKeyAttributesFromDb(streamName, r.getString(JDBC_COLUMN_COLUMN_NAME), r.getInt(KEY_SEQ)))); - return primaryKeys.getOrDefault(streamName, Collections.emptyList()); - } catch (final SQLException e) { - LOGGER.error(String.format("Could not retrieve primary keys for %s: %s", streamName, e)); - return Collections.emptyList(); - } - })); - } - - @Override - protected String getQuoteString() { - return quoteString; - } - - @Override - public boolean isCursorType(final Datatype type) { - return sourceOperations.isCursorType(type); - } - - @Override - public AutoCloseableIterator queryTableIncremental(final JdbcDatabase database, - final List columnNames, - final String schemaName, - final String tableName, - final CursorInfo cursorInfo, - final Datatype cursorFieldType) { - LOGGER.info("Queueing query for table: {}", tableName); - final io.airbyte.protocol.models.AirbyteStreamNameNamespacePair airbyteStream = - AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName); - return AutoCloseableIterators.lazyIterator(() -> { - try { - final Stream stream = database.unsafeQuery( - connection -> { - LOGGER.info("Preparing query for table: {}", tableName); - final String fullTableName = getFullyQualifiedTableNameWithQuoting(schemaName, tableName, getQuoteString()); - final String quotedCursorField = enquoteIdentifier(cursorInfo.getCursorField(), getQuoteString()); - - final String operator; - if (cursorInfo.getCursorRecordCount() <= 0L) { - operator = ">"; - } else { - final long actualRecordCount = getActualCursorRecordCount( - connection, fullTableName, quotedCursorField, cursorFieldType, cursorInfo.getCursor()); - LOGGER.info("Table {} cursor count: expected {}, actual {}", tableName, cursorInfo.getCursorRecordCount(), actualRecordCount); - if (actualRecordCount == cursorInfo.getCursorRecordCount()) { - operator = ">"; - } else { - operator = ">="; - } - } - - final String wrappedColumnNames = getWrappedColumnNames(database, connection, columnNames, schemaName, tableName); - final StringBuilder sql = new StringBuilder(String.format("SELECT %s FROM %s WHERE %s %s ?", - wrappedColumnNames, - fullTableName, - quotedCursorField, - operator)); - // if the connector emits intermediate states, the incremental query must be sorted by the cursor - // field - if (getStateEmissionFrequency() > 0) { - sql.append(String.format(" ORDER BY %s ASC", quotedCursorField)); - } - - final PreparedStatement preparedStatement = connection.prepareStatement(sql.toString()); - LOGGER.info("Executing query for table {}: {}", tableName, preparedStatement); - sourceOperations.setCursorField(preparedStatement, 1, cursorFieldType, cursorInfo.getCursor()); - return preparedStatement; - }, - sourceOperations::convertDatabaseRowToAirbyteRecordData); - return AutoCloseableIterators.fromStream(stream, airbyteStream); - } catch (final SQLException e) { - throw new RuntimeException(e); - } - }, airbyteStream); - } - - /** - * Some databases need special column names in the query. - */ - protected String getWrappedColumnNames(final JdbcDatabase database, - final Connection connection, - final List columnNames, - final String schemaName, - final String tableName) - throws SQLException { - return enquoteIdentifierList(columnNames, getQuoteString()); - } - - protected String getCountColumnName() { - return "record_count"; - } - - protected long getActualCursorRecordCount(final Connection connection, - final String fullTableName, - final String quotedCursorField, - final Datatype cursorFieldType, - final String cursor) - throws SQLException { - final String columnName = getCountColumnName(); - final PreparedStatement cursorRecordStatement; - if (cursor == null) { - final String cursorRecordQuery = String.format("SELECT COUNT(*) AS %s FROM %s WHERE %s IS NULL", - columnName, - fullTableName, - quotedCursorField); - cursorRecordStatement = connection.prepareStatement(cursorRecordQuery); - } else { - final String cursorRecordQuery = String.format("SELECT COUNT(*) AS %s FROM %s WHERE %s = ?", - columnName, - fullTableName, - quotedCursorField); - cursorRecordStatement = connection.prepareStatement(cursorRecordQuery);; - sourceOperations.setCursorField(cursorRecordStatement, 1, cursorFieldType, cursor); - } - final ResultSet resultSet = cursorRecordStatement.executeQuery(); - if (resultSet.next()) { - return resultSet.getLong(columnName); - } else { - return 0L; - } - } - - @Override - public JdbcDatabase createDatabase(final JsonNode sourceConfig) throws SQLException { - return createDatabase(sourceConfig, JdbcDataSourceUtils.DEFAULT_JDBC_PARAMETERS_DELIMITER); - } - - public JdbcDatabase createDatabase(final JsonNode sourceConfig, String delimiter) throws SQLException { - final JsonNode jdbcConfig = toDatabaseConfig(sourceConfig); - Map connectionProperties = JdbcDataSourceUtils.getConnectionProperties(sourceConfig, delimiter); - // Create the data source - final DataSource dataSource = DataSourceFactory.create( - jdbcConfig.has(JdbcUtils.USERNAME_KEY) ? jdbcConfig.get(JdbcUtils.USERNAME_KEY).asText() : null, - jdbcConfig.has(JdbcUtils.PASSWORD_KEY) ? jdbcConfig.get(JdbcUtils.PASSWORD_KEY).asText() : null, - driverClassName, - jdbcConfig.get(JdbcUtils.JDBC_URL_KEY).asText(), - connectionProperties, - getConnectionTimeout(connectionProperties)); - // Record the data source so that it can be closed. - dataSources.add(dataSource); - - final JdbcDatabase database = new StreamingJdbcDatabase( - dataSource, - sourceOperations, - streamingQueryConfigProvider); - - quoteString = (quoteString == null ? database.getMetaData().getIdentifierQuoteString() : quoteString); - database.sourceConfig = sourceConfig; - database.databaseConfig = jdbcConfig; - return database; - } - - /** - * {@inheritDoc} - * - * @param database database instance - * @param catalog schema of the incoming messages. - * @throws SQLException - */ - @Override - protected void logPreSyncDebugData(final JdbcDatabase database, final ConfiguredAirbyteCatalog catalog) - throws SQLException { - LOGGER.info("Data source product recognized as {}:{}", - database.getMetaData().getDatabaseProductName(), - database.getMetaData().getDatabaseProductVersion()); - } - - @Override - public void close() { - dataSources.forEach(d -> { - try { - DataSourceFactory.close(d); - } catch (final Exception e) { - LOGGER.warn("Unable to close data source.", e); - } - }); - dataSources.clear(); - } - - protected List identifyStreamsToSnapshot(final ConfiguredAirbyteCatalog catalog, final StateManager stateManager) { - final Set alreadySyncedStreams = stateManager.getCdcStateManager().getInitialStreamsSynced(); - if (alreadySyncedStreams.isEmpty() && (stateManager.getCdcStateManager().getCdcState() == null - || stateManager.getCdcStateManager().getCdcState().getState() == null)) { - return Collections.emptyList(); - } - - final Set allStreams = AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog); - - final Set newlyAddedStreams = new HashSet<>(Sets.difference(allStreams, alreadySyncedStreams)); - - return catalog.getStreams().stream() - .filter(c -> c.getSyncMode() == SyncMode.INCREMENTAL) - .filter(stream -> newlyAddedStreams.contains(AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.getStream()))) - .map(Jsons::clone) - .collect(Collectors.toList()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.java deleted file mode 100644 index f11193178ec4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.commons.map.MoreMaps; -import java.util.Map; -import java.util.Objects; - -public class JdbcDataSourceUtils { - - public static final String DEFAULT_JDBC_PARAMETERS_DELIMITER = "&"; - - /** - * Validates for duplication parameters - * - * @param customParameters custom connection properties map as specified by each Jdbc source - * @param defaultParameters connection properties map as specified by each Jdbc source - * @throws IllegalArgumentException - */ - public static void assertCustomParametersDontOverwriteDefaultParameters(final Map customParameters, - final Map defaultParameters) { - for (final String key : defaultParameters.keySet()) { - if (customParameters.containsKey(key) && !Objects.equals(customParameters.get(key), defaultParameters.get(key))) { - throw new IllegalArgumentException("Cannot overwrite default JDBC parameter " + key); - } - } - } - - /** - * Retrieves connection_properties from config and also validates if custom jdbc_url parameters - * overlap with the default properties - * - * @param config A configuration used to check Jdbc connection - * @return A mapping of connection properties - */ - public static Map getConnectionProperties(final JsonNode config) { - return getConnectionProperties(config, DEFAULT_JDBC_PARAMETERS_DELIMITER); - } - - public static Map getConnectionProperties(final JsonNode config, String parameterDelimiter) { - final Map customProperties = JdbcUtils.parseJdbcParameters(config, JdbcUtils.JDBC_URL_PARAMS_KEY, parameterDelimiter); - final Map defaultProperties = JdbcDataSourceUtils.getDefaultConnectionProperties(config); - assertCustomParametersDontOverwriteDefaultParameters(customProperties, defaultProperties); - return MoreMaps.merge(customProperties, defaultProperties); - } - - /** - * Retrieves default connection_properties from config - * - * TODO: make this method abstract and add parity features to destination connectors - * - * @param config A configuration used to check Jdbc connection - * @return A mapping of the default connection properties - */ - public static Map getDefaultConnectionProperties(final JsonNode config) { - // NOTE that Postgres returns an empty map for some reason? - return JdbcUtils.parseJdbcParameters(config, "connection_properties", DEFAULT_JDBC_PARAMETERS_DELIMITER); - }; - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.java deleted file mode 100644 index 83106c17d6ce..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.java +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.db.util.SSLCertificateUtils; -import java.io.IOException; -import java.net.MalformedURLException; -import java.net.URI; -import java.nio.file.Files; -import java.nio.file.Path; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; -import java.security.spec.InvalidKeySpecException; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import org.apache.commons.lang3.RandomStringUtils; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class JdbcSSLConnectionUtils { - - public static final String SSL_MODE = "sslMode"; - - public static final String TRUST_KEY_STORE_URL = "trustCertificateKeyStoreUrl"; - public static final String TRUST_KEY_STORE_PASS = "trustCertificateKeyStorePassword"; - public static final String CLIENT_KEY_STORE_URL = "clientCertificateKeyStoreUrl"; - public static final String CLIENT_KEY_STORE_PASS = "clientCertificateKeyStorePassword"; - public static final String CLIENT_KEY_STORE_TYPE = "clientCertificateKeyStoreType"; - public static final String TRUST_KEY_STORE_TYPE = "trustCertificateKeyStoreType"; - public static final String KEY_STORE_TYPE_PKCS12 = "PKCS12"; - public static final String PARAM_MODE = "mode"; - Pair caCertKeyStorePair; - Pair clientCertKeyStorePair; - - public enum SslMode { - - DISABLED("disable"), - ALLOWED("allow"), - PREFERRED("preferred", "prefer"), - REQUIRED("required", "require"), - VERIFY_CA("verify_ca", "verify-ca"), - VERIFY_IDENTITY("verify_identity", "verify-full"); - - public final List spec; - - SslMode(final String... spec) { - this.spec = Arrays.asList(spec); - } - - public static Optional bySpec(final String spec) { - return Arrays.stream(SslMode.values()) - .filter(sslMode -> sslMode.spec.contains(spec)) - .findFirst(); - } - - } - - private static final Logger LOGGER = LoggerFactory.getLogger(JdbcSSLConnectionUtils.class.getClass()); - public static final String PARAM_CA_CERTIFICATE = "ca_certificate"; - public static final String PARAM_CLIENT_CERTIFICATE = "client_certificate"; - public static final String PARAM_CLIENT_KEY = "client_key"; - public static final String PARAM_CLIENT_KEY_PASSWORD = "client_key_password"; - - /** - * Parses SSL related configuration and generates keystores to be used by connector - * - * @param config configuration - * @return map containing relevant parsed values including location of keystore or an empty map - */ - public static Map parseSSLConfig(final JsonNode config) { - LOGGER.debug("source config: {}", config); - - Pair caCertKeyStorePair = null; - Pair clientCertKeyStorePair = null; - final Map additionalParameters = new HashMap<>(); - // assume ssl if not explicitly mentioned. - if (!config.has(JdbcUtils.SSL_KEY) || config.get(JdbcUtils.SSL_KEY).asBoolean()) { - if (config.has(JdbcUtils.SSL_MODE_KEY)) { - final String specMode = config.get(JdbcUtils.SSL_MODE_KEY).get(PARAM_MODE).asText(); - additionalParameters.put(SSL_MODE, - SslMode.bySpec(specMode).orElseThrow(() -> new IllegalArgumentException("unexpected ssl mode")).name()); - if (Objects.isNull(caCertKeyStorePair)) { - caCertKeyStorePair = JdbcSSLConnectionUtils.prepareCACertificateKeyStore(config); - } - - if (Objects.nonNull(caCertKeyStorePair)) { - LOGGER.debug("uri for ca cert keystore: {}", caCertKeyStorePair.getLeft().toString()); - try { - additionalParameters.putAll(Map.of( - TRUST_KEY_STORE_URL, caCertKeyStorePair.getLeft().toURL().toString(), - TRUST_KEY_STORE_PASS, caCertKeyStorePair.getRight(), - TRUST_KEY_STORE_TYPE, KEY_STORE_TYPE_PKCS12)); - } catch (final MalformedURLException e) { - throw new RuntimeException("Unable to get a URL for trust key store"); - } - - } - - if (Objects.isNull(clientCertKeyStorePair)) { - clientCertKeyStorePair = JdbcSSLConnectionUtils.prepareClientCertificateKeyStore(config); - } - - if (Objects.nonNull(clientCertKeyStorePair)) { - LOGGER.debug("uri for client cert keystore: {} / {}", clientCertKeyStorePair.getLeft().toString(), clientCertKeyStorePair.getRight()); - try { - additionalParameters.putAll(Map.of( - CLIENT_KEY_STORE_URL, clientCertKeyStorePair.getLeft().toURL().toString(), - CLIENT_KEY_STORE_PASS, clientCertKeyStorePair.getRight(), - CLIENT_KEY_STORE_TYPE, KEY_STORE_TYPE_PKCS12)); - } catch (final MalformedURLException e) { - throw new RuntimeException("Unable to get a URL for client key store"); - } - } - } else { - additionalParameters.put(SSL_MODE, SslMode.DISABLED.name()); - } - } - LOGGER.debug("additional params: {}", additionalParameters); - return additionalParameters; - } - - public static Pair prepareCACertificateKeyStore(final JsonNode config) { - // if config available - // if has CA cert - make keystore - // if has client cert - // if has client password - make keystore using password - // if no client password - make keystore using random password - Pair caCertKeyStorePair = null; - if (Objects.nonNull(config)) { - if (!config.has(JdbcUtils.SSL_KEY) || config.get(JdbcUtils.SSL_KEY).asBoolean()) { - final var encryption = config.get(JdbcUtils.SSL_MODE_KEY); - if (encryption.has(PARAM_CA_CERTIFICATE) && !encryption.get(PARAM_CA_CERTIFICATE).asText().isEmpty()) { - final String clientKeyPassword = getOrGeneratePassword(encryption); - try { - final URI caCertKeyStoreUri = SSLCertificateUtils.keyStoreFromCertificate( - encryption.get(PARAM_CA_CERTIFICATE).asText(), - clientKeyPassword, - null, - null); - caCertKeyStorePair = new ImmutablePair<>(caCertKeyStoreUri, clientKeyPassword); - } catch (final CertificateException | IOException | KeyStoreException | NoSuchAlgorithmException e) { - throw new RuntimeException("Failed to create keystore for CA certificate", e); - } - } - } - } - return caCertKeyStorePair; - } - - private static String getOrGeneratePassword(final JsonNode sslModeConfig) { - final String clientKeyPassword; - if (sslModeConfig.has(PARAM_CLIENT_KEY_PASSWORD) && !sslModeConfig.get(PARAM_CLIENT_KEY_PASSWORD).asText().isEmpty()) { - clientKeyPassword = sslModeConfig.get(PARAM_CLIENT_KEY_PASSWORD).asText(); - } else { - clientKeyPassword = RandomStringUtils.randomAlphanumeric(10); - } - return clientKeyPassword; - } - - public static Pair prepareClientCertificateKeyStore(final JsonNode config) { - Pair clientCertKeyStorePair = null; - if (Objects.nonNull(config)) { - if (!config.has(JdbcUtils.SSL_KEY) || config.get(JdbcUtils.SSL_KEY).asBoolean()) { - final var encryption = config.get(JdbcUtils.SSL_MODE_KEY); - if (encryption.has(PARAM_CLIENT_CERTIFICATE) && !encryption.get(PARAM_CLIENT_CERTIFICATE).asText().isEmpty() - && encryption.has(PARAM_CLIENT_KEY) && !encryption.get(PARAM_CLIENT_KEY).asText().isEmpty()) { - final String clientKeyPassword = getOrGeneratePassword(encryption); - try { - final URI clientCertKeyStoreUri = SSLCertificateUtils.keyStoreFromClientCertificate(encryption.get(PARAM_CLIENT_CERTIFICATE).asText(), - encryption.get(PARAM_CLIENT_KEY).asText(), - clientKeyPassword, null); - clientCertKeyStorePair = new ImmutablePair<>(clientCertKeyStoreUri, clientKeyPassword); - } catch (final CertificateException | IOException - | KeyStoreException | NoSuchAlgorithmException - | InvalidKeySpecException | InterruptedException e) { - throw new RuntimeException("Failed to create keystore for Client certificate", e); - } - } - } - } - return clientCertKeyStorePair; - } - - public static Path fileFromCertPem(final String certPem) { - try { - final Path path = Files.createTempFile(null, ".crt"); - Files.writeString(path, certPem); - path.toFile().deleteOnExit(); - return path; - } catch (final IOException e) { - throw new RuntimeException("Cannot save root certificate to file", e); - } - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.java deleted file mode 100644 index 4669f92260f4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.db.factory.DatabaseDriver; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig; -import io.airbyte.cdk.integrations.base.IntegrationRunner; -import io.airbyte.cdk.integrations.base.Source; -import java.sql.JDBCType; -import java.util.Set; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class JdbcSource extends AbstractJdbcSource implements Source { - - private static final Logger LOGGER = LoggerFactory.getLogger(JdbcSource.class); - - public JdbcSource() { - super(DatabaseDriver.POSTGRESQL.driverClassName, AdaptiveStreamingQueryConfig::new, JdbcUtils.defaultSourceOperations); - } - - // no-op for JdbcSource since the config it receives is designed to be use for JDBC. - @Override - public JsonNode toDatabaseConfig(final JsonNode config) { - return config; - } - - @Override - public Set getExcludedInternalNameSpaces() { - return Set.of("information_schema", "pg_catalog", "pg_internal", "catalog_history"); - } - - public static void main(final String[] args) throws Exception { - final Source source = new JdbcSource(); - LOGGER.info("starting source: {}", JdbcSource.class); - new IntegrationRunner(source).run(args); - LOGGER.info("completed source: {}", JdbcSource.class); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.java deleted file mode 100644 index b598f041dde4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc.dto; - -import com.google.common.base.Objects; - -/** - * The class to store values from privileges table - */ -public class JdbcPrivilegeDto { - - private String grantee; - private String tableName; - private String schemaName; - private String privilege; - - public JdbcPrivilegeDto(String grantee, String tableName, String schemaName, String privilege) { - this.grantee = grantee; - this.tableName = tableName; - this.schemaName = schemaName; - this.privilege = privilege; - } - - public String getGrantee() { - return grantee; - } - - public String getTableName() { - return tableName; - } - - public String getSchemaName() { - return schemaName; - } - - public String getPrivilege() { - return privilege; - } - - public static JdbcPrivilegeDtoBuilder builder() { - return new JdbcPrivilegeDtoBuilder(); - } - - public static class JdbcPrivilegeDtoBuilder { - - private String grantee; - private String tableName; - private String schemaName; - private String privilege; - - public JdbcPrivilegeDtoBuilder grantee(String grantee) { - this.grantee = grantee; - return this; - } - - public JdbcPrivilegeDtoBuilder tableName(String tableName) { - this.tableName = tableName; - return this; - } - - public JdbcPrivilegeDtoBuilder schemaName(String schemaName) { - this.schemaName = schemaName; - return this; - } - - public JdbcPrivilegeDtoBuilder privilege(String privilege) { - this.privilege = privilege; - return this; - } - - public JdbcPrivilegeDto build() { - return new JdbcPrivilegeDto(grantee, tableName, schemaName, privilege); - } - - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - JdbcPrivilegeDto that = (JdbcPrivilegeDto) o; - return Objects.equal(grantee, that.grantee) && Objects.equal(tableName, that.tableName) - && Objects.equal(schemaName, that.schemaName) && Objects.equal(privilege, that.privilege); - } - - @Override - public int hashCode() { - return Objects.hashCode(grantee, tableName, schemaName, privilege); - } - - @Override - public String toString() { - return "JdbcPrivilegeDto{" + - "grantee='" + grantee + '\'' + - ", columnName='" + tableName + '\'' + - ", schemaName='" + schemaName + '\'' + - ", privilege='" + privilege + '\'' + - '}'; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.java deleted file mode 100644 index 945e36c8ba57..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.java +++ /dev/null @@ -1,706 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import datadog.trace.api.Trace; -import io.airbyte.cdk.db.AbstractDatabase; -import io.airbyte.cdk.db.IncrementalUtils; -import io.airbyte.cdk.db.jdbc.AirbyteRecordData; -import io.airbyte.cdk.db.jdbc.JdbcDatabase; -import io.airbyte.cdk.integrations.JdbcConnector; -import io.airbyte.cdk.integrations.base.AirbyteTraceMessageUtility; -import io.airbyte.cdk.integrations.base.Source; -import io.airbyte.cdk.integrations.base.errors.messages.ErrorMessage; -import io.airbyte.cdk.integrations.source.relationaldb.InvalidCursorInfoUtil.InvalidCursorInfo; -import io.airbyte.cdk.integrations.source.relationaldb.state.CursorStateMessageProducer; -import io.airbyte.cdk.integrations.source.relationaldb.state.SourceStateIterator; -import io.airbyte.cdk.integrations.source.relationaldb.state.StateEmitFrequency; -import io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils; -import io.airbyte.cdk.integrations.source.relationaldb.state.StateManager; -import io.airbyte.cdk.integrations.source.relationaldb.state.StateManagerFactory; -import io.airbyte.cdk.integrations.util.ApmTraceUtils; -import io.airbyte.cdk.integrations.util.ConnectorExceptionUtil; -import io.airbyte.commons.exceptions.ConfigErrorException; -import io.airbyte.commons.exceptions.ConnectionErrorException; -import io.airbyte.commons.features.EnvVariableFeatureFlags; -import io.airbyte.commons.features.FeatureFlags; -import io.airbyte.commons.functional.CheckedConsumer; -import io.airbyte.commons.lang.Exceptions; -import io.airbyte.commons.stream.AirbyteStreamUtils; -import io.airbyte.commons.util.AutoCloseableIterator; -import io.airbyte.commons.util.AutoCloseableIterators; -import io.airbyte.protocol.models.CommonField; -import io.airbyte.protocol.models.JsonSchemaPrimitiveUtil.JsonSchemaPrimitive; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteCatalog; -import io.airbyte.protocol.models.v0.AirbyteConnectionStatus; -import io.airbyte.protocol.models.v0.AirbyteConnectionStatus.Status; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import io.airbyte.protocol.models.v0.AirbyteRecordMessageMeta; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.SyncMode; -import java.sql.SQLException; -import java.time.Duration; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This class contains helper functions and boilerplate for implementing a source connector for a DB - * source of both non-relational and relational type - */ -public abstract class AbstractDbSource extends - JdbcConnector implements Source, AutoCloseable { - - public static final String CHECK_TRACE_OPERATION_NAME = "check-operation"; - public static final String DISCOVER_TRACE_OPERATION_NAME = "discover-operation"; - public static final String READ_TRACE_OPERATION_NAME = "read-operation"; - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractDbSource.class); - - // TODO: Remove when the flag is not use anymore - protected FeatureFlags featureFlags = new EnvVariableFeatureFlags(); - - protected AbstractDbSource(String driverClassName) { - super(driverClassName); - } - - @VisibleForTesting - public void setFeatureFlags(FeatureFlags featureFlags) { - this.featureFlags = featureFlags; - } - - @Override - @Trace(operationName = CHECK_TRACE_OPERATION_NAME) - public AirbyteConnectionStatus check(final JsonNode config) throws Exception { - try { - final Database database = createDatabase(config); - for (final CheckedConsumer checkOperation : getCheckOperations(config)) { - checkOperation.accept(database); - } - - return new AirbyteConnectionStatus().withStatus(Status.SUCCEEDED); - } catch (final ConnectionErrorException ex) { - ApmTraceUtils.addExceptionToTrace(ex); - final String message = ErrorMessage.getErrorMessage(ex.getStateCode(), ex.getErrorCode(), - ex.getExceptionMessage(), ex); - AirbyteTraceMessageUtility.emitConfigErrorTrace(ex, message); - return new AirbyteConnectionStatus() - .withStatus(Status.FAILED) - .withMessage(message); - } catch (final Exception e) { - ApmTraceUtils.addExceptionToTrace(e); - LOGGER.info("Exception while checking connection: ", e); - return new AirbyteConnectionStatus() - .withStatus(Status.FAILED) - .withMessage(String.format(ConnectorExceptionUtil.COMMON_EXCEPTION_MESSAGE_TEMPLATE, e.getMessage())); - } finally { - close(); - } - } - - @Override - @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) - public AirbyteCatalog discover(final JsonNode config) throws Exception { - try { - final Database database = createDatabase(config); - final List>> tableInfos = discoverWithoutSystemTables(database); - final Map> fullyQualifiedTableNameToPrimaryKeys = discoverPrimaryKeys( - database, tableInfos); - return DbSourceDiscoverUtil.convertTableInfosToAirbyteCatalog(tableInfos, fullyQualifiedTableNameToPrimaryKeys, this::getAirbyteType); - } finally { - close(); - } - } - - /** - * Creates a list of AirbyteMessageIterators with all the streams selected in a configured catalog - * - * @param config - integration-specific configuration object as json. e.g. { "username": "airbyte", - * "password": "super secure" } - * @param catalog - schema of the incoming messages. - * @param state - state of the incoming messages. - * @return AirbyteMessageIterator with all the streams that are to be synced - * @throws Exception - */ - @Override - public AutoCloseableIterator read(final JsonNode config, - final ConfiguredAirbyteCatalog catalog, - final JsonNode state) - throws Exception { - final AirbyteStateType supportedStateType = getSupportedStateType(config); - final StateManager stateManager = - StateManagerFactory.createStateManager(supportedStateType, - StateGeneratorUtils.deserializeInitialState(state, supportedStateType), catalog); - final Instant emittedAt = Instant.now(); - - final Database database = createDatabase(config); - - logPreSyncDebugData(database, catalog); - - final Map>> fullyQualifiedTableNameToInfo = - discoverWithoutSystemTables(database) - .stream() - .collect(Collectors.toMap(t -> String.format("%s.%s", t.getNameSpace(), t.getName()), - Function - .identity())); - - validateCursorFieldForIncrementalTables(fullyQualifiedTableNameToInfo, catalog, database); - - DbSourceDiscoverUtil.logSourceSchemaChange(fullyQualifiedTableNameToInfo, catalog, this::getAirbyteType); - - final List> incrementalIterators = - getIncrementalIterators(database, catalog, fullyQualifiedTableNameToInfo, stateManager, - emittedAt); - final List> fullRefreshIterators = - getFullRefreshIterators(database, catalog, fullyQualifiedTableNameToInfo, stateManager, - emittedAt); - final List> iteratorList = Stream - .of(incrementalIterators, fullRefreshIterators) - .flatMap(Collection::stream) - .collect(Collectors.toList()); - - return AutoCloseableIterators - .appendOnClose(AutoCloseableIterators.concatWithEagerClose(iteratorList, AirbyteTraceMessageUtility::emitStreamStatusTrace), () -> { - LOGGER.info("Closing database connection pool."); - Exceptions.toRuntime(this::close); - LOGGER.info("Closed database connection pool."); - }); - } - - protected void validateCursorFieldForIncrementalTables( - final Map>> tableNameToTable, - final ConfiguredAirbyteCatalog catalog, - final Database database) - throws SQLException { - final List tablesWithInvalidCursor = new ArrayList<>(); - for (final ConfiguredAirbyteStream airbyteStream : catalog.getStreams()) { - final AirbyteStream stream = airbyteStream.getStream(); - final String fullyQualifiedTableName = DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.getNamespace(), - stream.getName()); - final boolean hasSourceDefinedCursor = - !Objects.isNull(airbyteStream.getStream().getSourceDefinedCursor()) - && airbyteStream.getStream().getSourceDefinedCursor(); - if (!tableNameToTable.containsKey(fullyQualifiedTableName) - || airbyteStream.getSyncMode() != SyncMode.INCREMENTAL || hasSourceDefinedCursor) { - continue; - } - - final TableInfo> table = tableNameToTable - .get(fullyQualifiedTableName); - final Optional cursorField = IncrementalUtils.getCursorFieldOptional(airbyteStream); - if (cursorField.isEmpty()) { - continue; - } - final DataType cursorType = table.getFields().stream() - .filter(info -> info.getName().equals(cursorField.get())) - .map(CommonField::getType) - .findFirst() - .orElseThrow(); - - if (!isCursorType(cursorType)) { - tablesWithInvalidCursor.add( - new InvalidCursorInfo(fullyQualifiedTableName, cursorField.get(), - cursorType.toString(), "Unsupported cursor type")); - continue; - } - - if (!verifyCursorColumnValues(database, stream.getNamespace(), stream.getName(), cursorField.get())) { - tablesWithInvalidCursor.add( - new InvalidCursorInfo(fullyQualifiedTableName, cursorField.get(), - cursorType.toString(), "Cursor column contains NULL value")); - } - } - - if (!tablesWithInvalidCursor.isEmpty()) { - throw new ConfigErrorException( - InvalidCursorInfoUtil.getInvalidCursorConfigMessage(tablesWithInvalidCursor)); - } - } - - /** - * Verify that cursor column allows syncing to go through. - * - * @param database database - * @return true if syncing can go through. false otherwise - * @throws SQLException exception - */ - protected boolean verifyCursorColumnValues(final Database database, final String schema, final String tableName, final String columnName) - throws SQLException { - /* no-op */ - return true; - } - - /** - * Estimates the total volume (rows and bytes) to sync and emits a - * {@link AirbyteEstimateTraceMessage} associated with the full refresh stream. - * - * @param database database - */ - protected void estimateFullRefreshSyncSize(final Database database, - final ConfiguredAirbyteStream configuredAirbyteStream) { - /* no-op */ - } - - protected List>> discoverWithoutSystemTables(final Database database) - throws Exception { - final Set systemNameSpaces = getExcludedInternalNameSpaces(); - final Set systemViews = getExcludedViews(); - final List>> discoveredTables = discoverInternal(database); - return (systemNameSpaces == null || systemNameSpaces.isEmpty() ? discoveredTables - : discoveredTables.stream() - .filter(table -> !systemNameSpaces.contains(table.getNameSpace()) && !systemViews.contains(table.getName())).collect( - Collectors.toList())); - } - - protected List> getFullRefreshIterators( - final Database database, - final ConfiguredAirbyteCatalog catalog, - final Map>> tableNameToTable, - final StateManager stateManager, - final Instant emittedAt) { - return getSelectedIterators( - database, - catalog, - tableNameToTable, - stateManager, - emittedAt, - SyncMode.FULL_REFRESH); - } - - protected List> getIncrementalIterators( - final Database database, - final ConfiguredAirbyteCatalog catalog, - final Map>> tableNameToTable, - final StateManager stateManager, - final Instant emittedAt) { - return getSelectedIterators( - database, - catalog, - tableNameToTable, - stateManager, - emittedAt, - SyncMode.INCREMENTAL); - } - - /** - * Creates a list of read iterators for each stream within an ConfiguredAirbyteCatalog - * - * @param database Source Database - * @param catalog List of streams (e.g. database tables or API endpoints) with settings on sync mode - * @param tableNameToTable Mapping of table name to table - * @param stateManager Manager used to track the state of data synced by the connector - * @param emittedAt Time when data was emitted from the Source database - * @param syncMode the sync mode for which we want to grab the required iterators - * @return List of AirbyteMessageIterators containing all iterators for a catalog - */ - private List> getSelectedIterators( - final Database database, - final ConfiguredAirbyteCatalog catalog, - final Map>> tableNameToTable, - final StateManager stateManager, - final Instant emittedAt, - final SyncMode syncMode) { - final List> iteratorList = new ArrayList<>(); - for (final ConfiguredAirbyteStream airbyteStream : catalog.getStreams()) { - if (airbyteStream.getSyncMode().equals(syncMode)) { - final AirbyteStream stream = airbyteStream.getStream(); - final String fullyQualifiedTableName = DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.getNamespace(), - stream.getName()); - if (!tableNameToTable.containsKey(fullyQualifiedTableName)) { - LOGGER - .info("Skipping stream {} because it is not in the source", fullyQualifiedTableName); - continue; - } - - final TableInfo> table = tableNameToTable - .get(fullyQualifiedTableName); - final AutoCloseableIterator tableReadIterator = createReadIterator( - database, - airbyteStream, - table, - stateManager, - emittedAt); - iteratorList.add(tableReadIterator); - } - } - - return iteratorList; - } - - /** - * ReadIterator is used to retrieve records from a source connector - * - * @param database Source Database - * @param airbyteStream represents an ingestion source (e.g. API endpoint or database table) - * @param table information in tabular format - * @param stateManager Manager used to track the state of data synced by the connector - * @param emittedAt Time when data was emitted from the Source database - * @return - */ - private AutoCloseableIterator createReadIterator(final Database database, - final ConfiguredAirbyteStream airbyteStream, - final TableInfo> table, - final StateManager stateManager, - final Instant emittedAt) { - final String streamName = airbyteStream.getStream().getName(); - final String namespace = airbyteStream.getStream().getNamespace(); - final AirbyteStreamNameNamespacePair pair = new AirbyteStreamNameNamespacePair(streamName, - namespace); - final Set selectedFieldsInCatalog = CatalogHelpers.getTopLevelFieldNames(airbyteStream); - final List selectedDatabaseFields = table.getFields() - .stream() - .map(CommonField::getName) - .filter(selectedFieldsInCatalog::contains) - .collect(Collectors.toList()); - - final AutoCloseableIterator iterator; - // checks for which sync mode we're using based on the configured airbytestream - // this is where the bifurcation between full refresh and incremental - if (airbyteStream.getSyncMode() == SyncMode.INCREMENTAL) { - final String cursorField = IncrementalUtils.getCursorField(airbyteStream); - final Optional cursorInfo = stateManager.getCursorInfo(pair); - - final AutoCloseableIterator airbyteMessageIterator; - if (cursorInfo.map(CursorInfo::getCursor).isPresent()) { - airbyteMessageIterator = getIncrementalStream( - database, - airbyteStream, - selectedDatabaseFields, - table, - cursorInfo.get(), - emittedAt); - } else { - // if no cursor is present then this is the first read for is the same as doing a full refresh read. - estimateFullRefreshSyncSize(database, airbyteStream); - airbyteMessageIterator = getFullRefreshStream(database, streamName, namespace, - selectedDatabaseFields, table, emittedAt, SyncMode.INCREMENTAL, Optional.of(cursorField)); - } - - final JsonSchemaPrimitive cursorType = IncrementalUtils.getCursorType(airbyteStream, - cursorField); - - CursorStateMessageProducer messageProducer = new CursorStateMessageProducer( - stateManager, - cursorInfo.map(CursorInfo::getCursor)); - - iterator = AutoCloseableIterators.transform( - autoCloseableIterator -> new SourceStateIterator(autoCloseableIterator, airbyteStream, messageProducer, - new StateEmitFrequency(getStateEmissionFrequency(), - Duration.ZERO)), - airbyteMessageIterator, - AirbyteStreamUtils.convertFromNameAndNamespace(pair.getName(), pair.getNamespace())); - } else if (airbyteStream.getSyncMode() == SyncMode.FULL_REFRESH) { - estimateFullRefreshSyncSize(database, airbyteStream); - iterator = getFullRefreshStream(database, streamName, namespace, selectedDatabaseFields, - table, emittedAt, SyncMode.FULL_REFRESH, Optional.empty()); - } else if (airbyteStream.getSyncMode() == null) { - throw new IllegalArgumentException( - String.format("%s requires a source sync mode", this.getClass())); - } else { - throw new IllegalArgumentException( - String.format("%s does not support sync mode: %s.", this.getClass(), - airbyteStream.getSyncMode())); - } - - final AtomicLong recordCount = new AtomicLong(); - return AutoCloseableIterators.transform(iterator, - AirbyteStreamUtils.convertFromNameAndNamespace(pair.getName(), pair.getNamespace()), - r -> { - final long count = recordCount.incrementAndGet(); - if (count % 10000 == 0) { - LOGGER.info("Reading stream {}. Records read: {}", streamName, count); - } - return r; - }); - } - - /** - * @param database Source Database - * @param airbyteStream represents an ingestion source (e.g. API endpoint or database table) - * @param selectedDatabaseFields subset of database fields selected for replication - * @param table information in tabular format - * @param cursorInfo state of where to start the sync from - * @param emittedAt Time when data was emitted from the Source database - * @return AirbyteMessage Iterator that - */ - private AutoCloseableIterator getIncrementalStream(final Database database, - final ConfiguredAirbyteStream airbyteStream, - final List selectedDatabaseFields, - final TableInfo> table, - final CursorInfo cursorInfo, - final Instant emittedAt) { - final String streamName = airbyteStream.getStream().getName(); - final String namespace = airbyteStream.getStream().getNamespace(); - final String cursorField = IncrementalUtils.getCursorField(airbyteStream); - final DataType cursorType = table.getFields().stream() - .filter(info -> info.getName().equals(cursorField)) - .map(CommonField::getType) - .findFirst() - .orElseThrow(); - - Preconditions.checkState( - table.getFields().stream().anyMatch(f -> f.getName().equals(cursorField)), - String.format("Could not find cursor field %s in table %s", cursorField, table.getName())); - - final AutoCloseableIterator queryIterator = queryTableIncremental( - database, - selectedDatabaseFields, - table.getNameSpace(), - table.getName(), - cursorInfo, - cursorType); - - return getMessageIterator(queryIterator, streamName, namespace, emittedAt.toEpochMilli()); - } - - /** - * Creates a AirbyteMessageIterator that contains all records for a database source connection - * - * @param database Source Database - * @param streamName name of an individual stream in which a stream represents a source (e.g. API - * endpoint or database table) - * @param namespace Namespace of the database (e.g. public) - * @param selectedDatabaseFields List of all interested database column names - * @param table information in tabular format - * @param emittedAt Time when data was emitted from the Source database - * @param syncMode The sync mode that this full refresh stream should be associated with. - * @return AirbyteMessageIterator with all records for a database source - */ - private AutoCloseableIterator getFullRefreshStream(final Database database, - final String streamName, - final String namespace, - final List selectedDatabaseFields, - final TableInfo> table, - final Instant emittedAt, - final SyncMode syncMode, - final Optional cursorField) { - final AutoCloseableIterator queryStream = - queryTableFullRefresh(database, selectedDatabaseFields, table.getNameSpace(), - table.getName(), syncMode, cursorField); - return getMessageIterator(queryStream, streamName, namespace, emittedAt.toEpochMilli()); - } - - private static AutoCloseableIterator getMessageIterator( - final AutoCloseableIterator recordIterator, - final String streamName, - final String namespace, - final long emittedAt) { - return AutoCloseableIterators.transform(recordIterator, - new io.airbyte.protocol.models.AirbyteStreamNameNamespacePair(streamName, namespace), - airbyteRecordData -> new AirbyteMessage() - .withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage() - .withStream(streamName) - .withNamespace(namespace) - .withEmittedAt(emittedAt) - .withData(airbyteRecordData.rawRowData()) - .withMeta(isMetaChangesEmptyOrNull(airbyteRecordData.meta()) ? null : airbyteRecordData.meta()))); - } - - private static boolean isMetaChangesEmptyOrNull(AirbyteRecordMessageMeta meta) { - return meta == null || meta.getChanges() == null || meta.getChanges().isEmpty(); - } - - /** - * @param database - The database where from privileges for tables will be consumed - * @param schema - The schema where from privileges for tables will be consumed - * @return Set with privileges for tables for current DB-session user The method is responsible for - * SELECT-ing the table with privileges. In some cases such SELECT doesn't require (e.g. in - * Oracle DB - the schema is the user, you cannot REVOKE a privilege on a table from its - * owner). - */ - protected Set getPrivilegesTableForCurrentUser(final JdbcDatabase database, - final String schema) - throws SQLException { - return Collections.emptySet(); - } - - /** - * Map a database implementation-specific configuration to json object that adheres to the database - * config spec. See resources/spec.json. - * - * @param config database implementation-specific configuration. - * @return database spec config - */ - @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) - public abstract JsonNode toDatabaseConfig(JsonNode config); - - /** - * Creates a database instance using the database spec config. - * - * @param config database spec config - * @return database instance - * @throws Exception might throw an error during connection to database - */ - @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) - protected abstract Database createDatabase(JsonNode config) throws Exception; - - /** - * Gets and logs relevant and useful database metadata such as DB product/version, index names and - * definition. Called before syncing data. Any logged information should be scoped to the configured - * catalog and database. - * - * @param database given database instance. - * @param catalog configured catalog. - */ - protected void logPreSyncDebugData(final Database database, final ConfiguredAirbyteCatalog catalog) throws Exception {} - - /** - * Configures a list of operations that can be used to check the connection to the source. - * - * @return list of consumers that run queries for the check command. - */ - protected abstract List> getCheckOperations(JsonNode config) - throws Exception; - - /** - * Map source types to Airbyte types - * - * @param columnType source data type - * @return airbyte data type - */ - protected abstract JsonSchemaType getAirbyteType(DataType columnType); - - /** - * Get list of system namespaces(schemas) in order to exclude them from the `discover` result list. - * - * @return set of system namespaces(schemas) to be excluded - */ - protected abstract Set getExcludedInternalNameSpaces(); - - /** - * Get list of system views in order to exclude them from the `discover` result list. - * - * @return set of views to be excluded - */ - protected Set getExcludedViews() { - return Collections.emptySet(); - }; - - /** - * Discover all available tables in the source database. - * - * @param database source database - * @return list of the source tables - * @throws Exception access to the database might lead to an exceptions. - */ - @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) - protected abstract List>> discoverInternal( - final Database database) - throws Exception; - - /** - * Discovers all available tables within a schema in the source database. - * - * @param database - source database - * @param schema - source schema - * @return list of source tables - * @throws Exception - access to the database might lead to exceptions. - */ - protected abstract List>> discoverInternal( - final Database database, - String schema) - throws Exception; - - /** - * Discover Primary keys for each table and @return a map of namespace.table name to their - * associated list of primary key fields. - * - * @param database source database - * @param tableInfos list of tables - * @return map of namespace.table and primary key fields. - */ - protected abstract Map> discoverPrimaryKeys(Database database, - List>> tableInfos); - - /** - * Returns quote symbol of the database - * - * @return quote symbol - */ - protected abstract String getQuoteString(); - - /** - * Read all data from a table. - * - * @param database source database - * @param columnNames interested column names - * @param schemaName table namespace - * @param tableName target table - * @param syncMode The sync mode that this full refresh stream should be associated with. - * @return iterator with read data - */ - protected abstract AutoCloseableIterator queryTableFullRefresh(final Database database, - final List columnNames, - final String schemaName, - final String tableName, - final SyncMode syncMode, - final Optional cursorField); - - /** - * Read incremental data from a table. Incremental read should return only records where cursor - * column value is bigger than cursor. Note that if the connector needs to emit intermediate state - * (i.e. {@link AbstractDbSource#getStateEmissionFrequency} > 0), the incremental query must be - * sorted by the cursor field. - * - * @return iterator with read data - */ - protected abstract AutoCloseableIterator queryTableIncremental(Database database, - List columnNames, - String schemaName, - String tableName, - CursorInfo cursorInfo, - DataType cursorFieldType); - - /** - * When larger than 0, the incremental iterator will emit intermediate state for every N records. - * Please note that if intermediate state emission is enabled, the incremental query must be ordered - * by the cursor field. - * - * TODO: Return an optional value instead of 0 to make it easier to understand. - */ - protected int getStateEmissionFrequency() { - return 0; - } - - /** - * @return list of fields that could be used as cursors - */ - protected abstract boolean isCursorType(DataType type); - - /** - * Returns the {@link AirbyteStateType} supported by this connector. - * - * @param config The connector configuration. - * @return A {@link AirbyteStateType} representing the state supported by this connector. - */ - protected AirbyteStateType getSupportedStateType(final JsonNode config) { - return AirbyteStateType.STREAM; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.java deleted file mode 100644 index c4532dcd0270..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import java.util.Collections; -import java.util.Set; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class CdcStateManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(CdcStateManager.class); - - private final CdcState initialState; - private final Set initialStreamsSynced; - private final AirbyteStateMessage rawStateMessage; - private CdcState currentState; - - public CdcStateManager(final CdcState serialized, - final Set initialStreamsSynced, - final AirbyteStateMessage stateMessage) { - this.initialState = serialized; - this.currentState = serialized; - this.initialStreamsSynced = initialStreamsSynced; - - this.rawStateMessage = stateMessage; - LOGGER.info("Initialized CDC state"); - } - - public void setCdcState(final CdcState state) { - this.currentState = state; - } - - public CdcState getCdcState() { - return currentState != null ? Jsons.clone(currentState) : null; - } - - public AirbyteStateMessage getRawStateMessage() { - return rawStateMessage; - } - - public Set getInitialStreamsSynced() { - return initialStreamsSynced != null ? Collections.unmodifiableSet(initialStreamsSynced) : null; - } - - @Override - public String toString() { - return "CdcStateManager{" + - "initialState=" + initialState + - ", currentState=" + currentState + - '}'; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.java deleted file mode 100644 index cf92ed8668d4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import java.util.Objects; - -public class CursorInfo { - - private final String originalCursorField; - private final String originalCursor; - private final long originalCursorRecordCount; - - private final String cursorField; - private String cursor; - private long cursorRecordCount; - - public CursorInfo(final String originalCursorField, - final String originalCursor, - final String cursorField, - final String cursor) { - this(originalCursorField, originalCursor, 0L, cursorField, cursor, 0L); - } - - public CursorInfo(final String originalCursorField, - final String originalCursor, - final long originalCursorRecordCount, - final String cursorField, - final String cursor, - final long cursorRecordCount) { - this.originalCursorField = originalCursorField; - this.originalCursor = originalCursor; - this.originalCursorRecordCount = originalCursorRecordCount; - this.cursorField = cursorField; - this.cursor = cursor; - this.cursorRecordCount = cursorRecordCount; - } - - public String getOriginalCursorField() { - return originalCursorField; - } - - public String getOriginalCursor() { - return originalCursor; - } - - public long getOriginalCursorRecordCount() { - return originalCursorRecordCount; - } - - public String getCursorField() { - return cursorField; - } - - public String getCursor() { - return cursor; - } - - public long getCursorRecordCount() { - return cursorRecordCount; - } - - @SuppressWarnings("UnusedReturnValue") - public CursorInfo setCursor(final String cursor) { - this.cursor = cursor; - return this; - } - - public CursorInfo setCursorRecordCount(final long cursorRecordCount) { - this.cursorRecordCount = cursorRecordCount; - return this; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - final CursorInfo that = (CursorInfo) o; - return Objects.equals(originalCursorField, that.originalCursorField) - && Objects.equals(originalCursor, that.originalCursor) - && Objects.equals(originalCursorRecordCount, that.originalCursorRecordCount) - && Objects.equals(cursorField, that.cursorField) - && Objects.equals(cursor, that.cursor) - && Objects.equals(cursorRecordCount, that.cursorRecordCount); - } - - @Override - public int hashCode() { - return Objects.hash(originalCursorField, originalCursor, originalCursorRecordCount, cursorField, cursor, cursorRecordCount); - } - - @Override - public String toString() { - return "CursorInfo{" + - "originalCursorField='" + originalCursorField + '\'' + - ", originalCursor='" + originalCursor + '\'' + - ", originalCursorRecordCount='" + originalCursorRecordCount + '\'' + - ", cursorField='" + cursorField + '\'' + - ", cursor='" + cursor + '\'' + - ", cursorRecordCount='" + cursorRecordCount + '\'' + - '}'; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.java deleted file mode 100644 index 9377190b7595..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.java +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import static io.airbyte.protocol.models.v0.CatalogHelpers.fieldsToJsonSchema; -import static java.util.stream.Collectors.toList; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.Lists; -import io.airbyte.protocol.models.CommonField; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteCatalog; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.function.Function; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Contains utilities and helper classes for discovering schemas in database sources. - */ -public class DbSourceDiscoverUtil { - - private static final Logger LOGGER = LoggerFactory.getLogger(DbSourceDiscoverUtil.class); - private static final List AIRBYTE_METADATA = Arrays.asList("_ab_cdc_lsn", - "_ab_cdc_updated_at", - "_ab_cdc_deleted_at"); - - /* - * This method logs schema drift between source table and the catalog. This can happen if (i) - * underlying table schema changed between syncs (ii) The source connector's mapping of datatypes to - * Airbyte types changed between runs - */ - public static void logSourceSchemaChange(final Map>> fullyQualifiedTableNameToInfo, - final ConfiguredAirbyteCatalog catalog, - final Function airbyteTypeConverter) { - for (final ConfiguredAirbyteStream airbyteStream : catalog.getStreams()) { - final AirbyteStream stream = airbyteStream.getStream(); - final String fullyQualifiedTableName = DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.getNamespace(), - stream.getName()); - if (!fullyQualifiedTableNameToInfo.containsKey(fullyQualifiedTableName)) { - continue; - } - final TableInfo> table = fullyQualifiedTableNameToInfo.get(fullyQualifiedTableName); - final List fields = table.getFields() - .stream() - .map(commonField -> toField(commonField, airbyteTypeConverter)) - .distinct() - .collect(toList()); - final JsonNode currentJsonSchema = fieldsToJsonSchema(fields); - final JsonNode catalogSchema = stream.getJsonSchema(); - final JsonNode currentSchemaProperties = currentJsonSchema.get("properties"); - final JsonNode catalogProperties = catalogSchema.get("properties"); - final List mismatchedFields = new ArrayList<>(); - catalogProperties.fieldNames().forEachRemaining(fieldName -> { - // Ignoring metadata fields since those are automatically added onto the catalog schema by Airbyte - // and don't exist in the source schema. They should not be considered a change - if (AIRBYTE_METADATA.contains(fieldName)) { - return; - } - - if (!currentSchemaProperties.has(fieldName) || - !currentSchemaProperties.get(fieldName).equals(catalogProperties.get(fieldName))) { - mismatchedFields.add(fieldName); - } - }); - - if (!mismatchedFields.isEmpty()) { - LOGGER.warn( - "Source schema changed for table {}! Potential mismatches: {}. Actual schema: {}. Catalog schema: {}", - fullyQualifiedTableName, - String.join(", ", mismatchedFields.toString()), - currentJsonSchema, - catalogSchema); - } - } - } - - public static AirbyteCatalog convertTableInfosToAirbyteCatalog(final List>> tableInfos, - final Map> fullyQualifiedTableNameToPrimaryKeys, - final Function airbyteTypeConverter) { - final List> tableInfoFieldList = tableInfos.stream() - .map(t -> { - // some databases return multiple copies of the same record for a column (e.g. redshift) because - // they have at least once delivery guarantees. we want to dedupe these, but first we check that the - // records are actually the same and provide a good error message if they are not. - assertColumnsWithSameNameAreSame(t.getNameSpace(), t.getName(), t.getFields()); - final List fields = t.getFields() - .stream() - .map(commonField -> toField(commonField, airbyteTypeConverter)) - .distinct() - .collect(toList()); - final String fullyQualifiedTableName = getFullyQualifiedTableName(t.getNameSpace(), - t.getName()); - final List primaryKeys = fullyQualifiedTableNameToPrimaryKeys.getOrDefault( - fullyQualifiedTableName, Collections - .emptyList()); - return TableInfo.builder().nameSpace(t.getNameSpace()).name(t.getName()) - .fields(fields).primaryKeys(primaryKeys) - .cursorFields(t.getCursorFields()) - .build(); - }) - .collect(toList()); - - final List streams = tableInfoFieldList.stream() - .map(tableInfo -> { - final var primaryKeys = tableInfo.getPrimaryKeys().stream() - .filter(Objects::nonNull) - .map(Collections::singletonList) - .collect(toList()); - - return CatalogHelpers - .createAirbyteStream(tableInfo.getName(), tableInfo.getNameSpace(), - tableInfo.getFields()) - .withSupportedSyncModes( - tableInfo.getCursorFields() != null && tableInfo.getCursorFields().isEmpty() - ? Lists.newArrayList(SyncMode.FULL_REFRESH) - : Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(primaryKeys); - }) - .collect(toList()); - return new AirbyteCatalog().withStreams(streams); - } - - public static String getFullyQualifiedTableName(final String nameSpace, final String tableName) { - return nameSpace != null ? nameSpace + "." + tableName : tableName; - } - - private static Field toField(final CommonField commonField, final Function airbyteTypeConverter) { - if (airbyteTypeConverter.apply(commonField.getType()) == JsonSchemaType.OBJECT && commonField.getProperties() != null - && !commonField.getProperties().isEmpty()) { - final var properties = commonField.getProperties().stream().map(commField -> toField(commField, airbyteTypeConverter)).toList(); - return Field.of(commonField.getName(), airbyteTypeConverter.apply(commonField.getType()), properties); - } else { - return Field.of(commonField.getName(), airbyteTypeConverter.apply(commonField.getType())); - } - } - - private static void assertColumnsWithSameNameAreSame(final String nameSpace, - final String tableName, - final List> columns) { - columns.stream() - .collect(Collectors.groupingBy(CommonField::getName)) - .values() - .forEach(columnsWithSameName -> { - final CommonField comparisonColumn = columnsWithSameName.get(0); - columnsWithSameName.forEach(column -> { - if (!column.equals(comparisonColumn)) { - throw new RuntimeException( - String.format( - "Found multiple columns with same name: %s in table: %s.%s but the columns are not the same. columns: %s", - comparisonColumn.getName(), nameSpace, tableName, columns)); - } - }); - }); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.java deleted file mode 100644 index 650b2a60a0ac..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import java.util.List; -import java.util.stream.Collectors; - -public class InvalidCursorInfoUtil { - - public static String getInvalidCursorConfigMessage(final List tablesWithInvalidCursor) { - return "The following tables have invalid columns selected as cursor, please select a column with a well-defined ordering with no null values as a cursor. " - + tablesWithInvalidCursor.stream().map(InvalidCursorInfo::toString) - .collect(Collectors.joining(",")); - } - - public record InvalidCursorInfo(String tableName, String cursorColumnName, String cursorSqlType, String cause) { - - @Override - public String toString() { - return "{" + - "tableName='" + tableName + '\'' + - ", cursorColumnName='" + cursorColumnName + '\'' + - ", cursorSqlType=" + cursorSqlType + - ", cause=" + cause + - '}'; - } - - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.java deleted file mode 100644 index fd66d1a43b35..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.db.SqlDatabase; -import io.airbyte.commons.stream.AirbyteStreamUtils; -import io.airbyte.commons.util.AutoCloseableIterator; -import io.airbyte.commons.util.AutoCloseableIterators; -import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import java.util.List; -import java.util.StringJoiner; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Utility class for methods to query a relational db. - */ -public class RelationalDbQueryUtils { - - private static final Logger LOGGER = LoggerFactory.getLogger(RelationalDbQueryUtils.class); - - public record TableSizeInfo(Long tableSize, Long avgRowLength) {} - - public static String getIdentifierWithQuoting(final String identifier, final String quoteString) { - // double-quoted values within a database name or column name should be wrapped with extra - // quoteString - if (identifier.startsWith(quoteString) && identifier.endsWith(quoteString)) { - return quoteString + quoteString + identifier + quoteString + quoteString; - } else { - return quoteString + identifier + quoteString; - } - } - - public static String enquoteIdentifierList(final List identifiers, final String quoteString) { - final StringJoiner joiner = new StringJoiner(","); - for (final String identifier : identifiers) { - joiner.add(getIdentifierWithQuoting(identifier, quoteString)); - } - return joiner.toString(); - } - - /** - * @return fully qualified table name with the schema (if a schema exists) in quotes. - */ - public static String getFullyQualifiedTableNameWithQuoting(final String nameSpace, final String tableName, final String quoteString) { - return (nameSpace == null || nameSpace.isEmpty() ? getIdentifierWithQuoting(tableName, quoteString) - : getIdentifierWithQuoting(nameSpace, quoteString) + "." + getIdentifierWithQuoting(tableName, quoteString)); - } - - /** - * @return fully qualified table name with the schema (if a schema exists) without quotes. - */ - public static String getFullyQualifiedTableName(final String schemaName, final String tableName) { - return schemaName != null ? schemaName + "." + tableName : tableName; - } - - /** - * @return the input identifier with quotes. - */ - public static String enquoteIdentifier(final String identifier, final String quoteString) { - return quoteString + identifier + quoteString; - } - - public static AutoCloseableIterator queryTable(final Database database, - final String sqlQuery, - final String tableName, - final String schemaName) { - final AirbyteStreamNameNamespacePair airbyteStreamNameNamespacePair = AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName); - return AutoCloseableIterators.lazyIterator(() -> { - try { - LOGGER.info("Queueing query: {}", sqlQuery); - final Stream stream = database.unsafeQuery(sqlQuery); - return AutoCloseableIterators.fromStream(stream, airbyteStreamNameNamespacePair); - } catch (final Exception e) { - throw new RuntimeException(e); - } - }, airbyteStreamNameNamespacePair); - } - - public static void logStreamSyncStatus(final List streams, final String syncType) { - if (streams.isEmpty()) { - LOGGER.info("No Streams will be synced via {}.", syncType); - } else { - LOGGER.info("Streams to be synced via {} : {}", syncType, streams.size()); - LOGGER.info("Streams: {}", prettyPrintConfiguredAirbyteStreamList(streams)); - } - } - - public static String prettyPrintConfiguredAirbyteStreamList(final List streamList) { - return streamList.stream().map(s -> "%s.%s".formatted(s.getStream().getNamespace(), s.getStream().getName())).collect(Collectors.joining(", ")); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.java deleted file mode 100644 index 9e1b8464e06a..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import com.google.common.collect.Sets; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - -public class RelationalDbReadUtil { - - public static List identifyStreamsToSnapshot(final ConfiguredAirbyteCatalog catalog, - final Set alreadySyncedStreams) { - final Set allStreams = AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog); - final Set newlyAddedStreams = new HashSet<>(Sets.difference(allStreams, alreadySyncedStreams)); - return catalog.getStreams().stream() - .filter(c -> c.getSyncMode() == SyncMode.INCREMENTAL) - .filter(stream -> newlyAddedStreams.contains(AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.getStream()))) - .map(Jsons::clone) - .collect(Collectors.toList()); - } - - public static List identifyStreamsForCursorBased(final ConfiguredAirbyteCatalog catalog, - final List streamsForInitialLoad) { - - final Set initialLoadStreamsNamespacePairs = - streamsForInitialLoad.stream().map(stream -> AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.getStream())) - .collect( - Collectors.toSet()); - return catalog.getStreams().stream() - .filter(c -> c.getSyncMode() == SyncMode.INCREMENTAL) - .filter(stream -> !initialLoadStreamsNamespacePairs.contains(AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.getStream()))) - .map(Jsons::clone) - .collect(Collectors.toList()); - } - - public static AirbyteStreamNameNamespacePair convertNameNamespacePairFromV0(final io.airbyte.protocol.models.AirbyteStreamNameNamespacePair v1NameNamespacePair) { - return new AirbyteStreamNameNamespacePair(v1NameNamespacePair.getName(), v1NameNamespacePair.getNamespace()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.java deleted file mode 100644 index 91402de54816..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.java +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import com.google.common.collect.AbstractIterator; -import io.airbyte.cdk.db.IncrementalUtils; -import io.airbyte.cdk.integrations.source.relationaldb.state.StateManager; -import io.airbyte.protocol.models.JsonSchemaPrimitiveUtil.JsonSchemaPrimitive; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateStats; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import java.util.Iterator; -import java.util.Objects; -import java.util.Optional; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -@Deprecated -public class StateDecoratingIterator extends AbstractIterator implements Iterator { - - private static final Logger LOGGER = LoggerFactory.getLogger(StateDecoratingIterator.class); - - private final Iterator messageIterator; - private final StateManager stateManager; - private final AirbyteStreamNameNamespacePair pair; - private final String cursorField; - private final JsonSchemaPrimitive cursorType; - - private final String initialCursor; - private String currentMaxCursor; - private long currentMaxCursorRecordCount = 0L; - private boolean hasEmittedFinalState; - - /** - * These parameters are for intermediate state message emission. We can emit an intermediate state - * when the following two conditions are met. - *

- * 1. The records are sorted by the cursor field. This is true when {@code stateEmissionFrequency} > - * 0. This logic is guaranteed in {@code AbstractJdbcSource#queryTableIncremental}, in which an - * "ORDER BY" clause is appended to the SQL query if {@code stateEmissionFrequency} > 0. - *

- * 2. There is a cursor value that is ready for emission. A cursor value is "ready" if there is no - * more record with the same value. We cannot emit a cursor at will, because there may be multiple - * records with the same cursor value. If we emit a cursor ignoring this condition, should the sync - * fail right after the emission, the next sync may skip some records with the same cursor value due - * to "WHERE cursor_field > cursor" in {@code AbstractJdbcSource#queryTableIncremental}. - *

- * The {@code intermediateStateMessage} is set to the latest state message that is ready for - * emission. For every {@code stateEmissionFrequency} messages, {@code emitIntermediateState} is set - * to true and the latest "ready" state will be emitted in the next {@code computeNext} call. - */ - private final int stateEmissionFrequency; - private int totalRecordCount = 0; - // In between each state message, recordCountInStateMessage will be reset to 0. - private int recordCountInStateMessage = 0; - private boolean emitIntermediateState = false; - private AirbyteMessage intermediateStateMessage = null; - private boolean hasCaughtException = false; - - /** - * @param stateManager Manager that maintains connector state - * @param pair Stream Name and Namespace (e.g. public.users) - * @param cursorField Path to the comparator field used to track the records read so far - * @param initialCursor name of the initial cursor column - * @param cursorType ENUM type of primitive values that can be used as a cursor for checkpointing - * @param stateEmissionFrequency If larger than 0, the records are sorted by the cursor field, and - * intermediate states will be emitted for every {@code stateEmissionFrequency} records. The - * order of the records is guaranteed in {@code AbstractJdbcSource#queryTableIncremental}, in - * which an "ORDER BY" clause is appended to the SQL query if {@code stateEmissionFrequency} - * > 0. - */ - public StateDecoratingIterator(final Iterator messageIterator, - final StateManager stateManager, - final AirbyteStreamNameNamespacePair pair, - final String cursorField, - final String initialCursor, - final JsonSchemaPrimitive cursorType, - final int stateEmissionFrequency) { - this.messageIterator = messageIterator; - this.stateManager = stateManager; - this.pair = pair; - this.cursorField = cursorField; - this.cursorType = cursorType; - this.initialCursor = initialCursor; - this.currentMaxCursor = initialCursor; - this.stateEmissionFrequency = stateEmissionFrequency; - } - - private String getCursorCandidate(final AirbyteMessage message) { - final String cursorCandidate = message.getRecord().getData().get(cursorField).asText(); - return (cursorCandidate != null ? replaceNull(cursorCandidate) : null); - } - - private String replaceNull(final String cursorCandidate) { - if (cursorCandidate.contains("\u0000")) { - return cursorCandidate.replaceAll("\u0000", ""); - } - return cursorCandidate; - } - - /** - * Computes the next record retrieved from Source stream. Emits StateMessage containing data of the - * record that has been read so far - * - *

- * If this method throws an exception, it will propagate outward to the {@code hasNext} or - * {@code next} invocation that invoked this method. Any further attempts to use the iterator will - * result in an {@link IllegalStateException}. - *

- * - * @return {@link AirbyteStateMessage} containing information of the records read so far - */ - @Override - protected AirbyteMessage computeNext() { - if (hasCaughtException) { - // Mark iterator as done since the next call to messageIterator will result in an - // IllegalArgumentException and resets exception caught state. - // This occurs when the previous iteration emitted state so this iteration cycle will indicate - // iteration is complete - hasCaughtException = false; - return endOfData(); - } - - if (messageIterator.hasNext()) { - Optional optionalIntermediateMessage = getIntermediateMessage(); - if (optionalIntermediateMessage.isPresent()) { - return optionalIntermediateMessage.get(); - } - - totalRecordCount++; - recordCountInStateMessage++; - // Use try-catch to catch Exception that could occur when connection to the database fails - try { - final AirbyteMessage message = messageIterator.next(); - if (message.getRecord().getData().hasNonNull(cursorField)) { - final String cursorCandidate = getCursorCandidate(message); - final int cursorComparison = IncrementalUtils.compareCursors(currentMaxCursor, cursorCandidate, cursorType); - if (cursorComparison < 0) { - // Update the current max cursor only when current max cursor < cursor candidate from the message - if (stateEmissionFrequency > 0 && !Objects.equals(currentMaxCursor, initialCursor) && messageIterator.hasNext()) { - // Only create an intermediate state when it is not the first or last record message. - // The last state message will be processed seperately. - intermediateStateMessage = createStateMessage(false, recordCountInStateMessage); - } - currentMaxCursor = cursorCandidate; - currentMaxCursorRecordCount = 1L; - } else if (cursorComparison == 0) { - currentMaxCursorRecordCount++; - } else if (cursorComparison > 0 && stateEmissionFrequency > 0) { - LOGGER.warn("Intermediate state emission feature requires records to be processed in order according to the cursor value. Otherwise, " - + "data loss can occur."); - } - } - - if (stateEmissionFrequency > 0 && totalRecordCount % stateEmissionFrequency == 0) { - emitIntermediateState = true; - } - - return message; - } catch (final Exception e) { - emitIntermediateState = true; - hasCaughtException = true; - LOGGER.error("Message iterator failed to read next record.", e); - optionalIntermediateMessage = getIntermediateMessage(); - return optionalIntermediateMessage.orElse(endOfData()); - } - } else if (!hasEmittedFinalState) { - return createStateMessage(true, recordCountInStateMessage); - } else { - return endOfData(); - } - } - - /** - * Returns AirbyteStateMessage when in a ready state, a ready state means that it has satifies the - * conditions of: - *

- * cursorField has changed (e.g. 08-22-2022 -> 08-23-2022) and there have been at least - * stateEmissionFrequency number of records since the last emission - *

- * - * @return AirbyteStateMessage if one exists, otherwise Optional indicating state was not ready to - * be emitted - */ - protected final Optional getIntermediateMessage() { - if (emitIntermediateState && intermediateStateMessage != null) { - final AirbyteMessage message = intermediateStateMessage; - if (message.getState() != null) { - message.getState().setSourceStats(new AirbyteStateStats().withRecordCount((double) recordCountInStateMessage)); - } - - intermediateStateMessage = null; - recordCountInStateMessage = 0; - emitIntermediateState = false; - return Optional.of(message); - } - return Optional.empty(); - } - - /** - * Creates AirbyteStateMessage while updating the cursor used to checkpoint the state of records - * read up so far - * - * @param isFinalState marker for if the final state of the iterator has been reached - * @param recordCount count of read messages - * @return AirbyteMessage which includes information on state of records read so far - */ - public AirbyteMessage createStateMessage(final boolean isFinalState, final int recordCount) { - final AirbyteStateMessage stateMessage = stateManager.updateAndEmit(pair, currentMaxCursor, currentMaxCursorRecordCount); - final Optional cursorInfo = stateManager.getCursorInfo(pair); - - // logging once every 100 messages to reduce log verbosity - if (recordCount % 100 == 0) { - LOGGER.info("State report for stream {} - original: {} = {} (count {}) -> latest: {} = {} (count {})", - pair, - cursorInfo.map(CursorInfo::getOriginalCursorField).orElse(null), - cursorInfo.map(CursorInfo::getOriginalCursor).orElse(null), - cursorInfo.map(CursorInfo::getOriginalCursorRecordCount).orElse(null), - cursorInfo.map(CursorInfo::getCursorField).orElse(null), - cursorInfo.map(CursorInfo::getCursor).orElse(null), - cursorInfo.map(CursorInfo::getCursorRecordCount).orElse(null)); - } - - if (stateMessage != null) { - stateMessage.withSourceStats(new AirbyteStateStats().withRecordCount((double) recordCount)); - } - if (isFinalState) { - hasEmittedFinalState = true; - if (stateManager.getCursor(pair).isEmpty()) { - LOGGER.warn("Cursor for stream {} was null. This stream will replicate all records on the next run", pair); - } - } - - return new AirbyteMessage().withType(Type.STATE).withState(stateMessage); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.java deleted file mode 100644 index 1d990bdfd46b..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import java.util.List; -import lombok.Builder; -import lombok.Getter; - -/** - * This class encapsulates all externally relevant Table information. - */ -@Getter -@Builder -public class TableInfo { - - private final String nameSpace; - private final String name; - private final List fields; - private final List primaryKeys; - private final List cursorFields; - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.java deleted file mode 100644 index ea4214cf30b1..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.function.Supplier; - -/** - * Abstract implementation of the {@link StateManager} interface that provides common functionality - * for state manager implementations. - * - * @param The type associated with the state object managed by this manager. - * @param The type associated with the state object stored in the state managed by this manager. - */ -public abstract class AbstractStateManager implements StateManager { - - /** - * The {@link CursorManager} responsible for keeping track of the current cursor value for each - * stream managed by this state manager. - */ - private final CursorManager cursorManager; - - /** - * Constructs a new state manager for the given configured connector. - * - * @param catalog The connector's configured catalog. - * @param streamSupplier A {@link Supplier} that provides the cursor manager with the collection of - * streams tracked by the connector's state. - * @param cursorFunction A {@link Function} that extracts the current cursor from a stream stored in - * the connector's state. - * @param cursorFieldFunction A {@link Function} that extracts the cursor field name from a stream - * stored in the connector's state. - * @param cursorRecordCountFunction A {@link Function} that extracts the cursor record count for a - * stream stored in the connector's state. - * @param namespacePairFunction A {@link Function} that generates a - * {@link AirbyteStreamNameNamespacePair} that identifies each stream in the connector's - * state. - */ - public AbstractStateManager(final ConfiguredAirbyteCatalog catalog, - final Supplier> streamSupplier, - final Function cursorFunction, - final Function> cursorFieldFunction, - final Function cursorRecordCountFunction, - final Function namespacePairFunction) { - this(catalog, streamSupplier, cursorFunction, cursorFieldFunction, cursorRecordCountFunction, namespacePairFunction, false); - } - - public AbstractStateManager(final ConfiguredAirbyteCatalog catalog, - final Supplier> streamSupplier, - final Function cursorFunction, - final Function> cursorFieldFunction, - final Function cursorRecordCountFunction, - final Function namespacePairFunction, - final boolean onlyIncludeIncrementalStreams) { - cursorManager = new CursorManager(catalog, streamSupplier, cursorFunction, cursorFieldFunction, cursorRecordCountFunction, namespacePairFunction, - onlyIncludeIncrementalStreams); - } - - @Override - public Map getPairToCursorInfoMap() { - return cursorManager.getPairToCursorInfo(); - } - - @Override - public abstract AirbyteStateMessage toState(final Optional pair); - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.java deleted file mode 100644 index 2449c7666d55..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.java +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import com.google.common.annotations.VisibleForTesting; -import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Manages the map of streams to current cursor values for state management. - * - * @param The type that represents the stream object which holds the current cursor information - * in the state. - */ -public class CursorManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(CursorManager.class); - - /** - * Map of streams (name/namespace tuple) to the current cursor information stored in the state. - */ - private final Map pairToCursorInfo; - - /** - * Constructs a new {@link CursorManager} based on the configured connector and current state - * information. - * - * @param catalog The connector's configured catalog. - * @param streamSupplier A {@link Supplier} that provides the cursor manager with the collection of - * streams tracked by the connector's state. - * @param cursorFunction A {@link Function} that extracts the current cursor from a stream stored in - * the connector's state. - * @param cursorFieldFunction A {@link Function} that extracts the cursor field name from a stream - * stored in the connector's state. - * @param cursorRecordCountFunction A {@link Function} that extracts the cursor record count for a - * stream stored in the connector's state. - * @param namespacePairFunction A {@link Function} that generates a - * {@link AirbyteStreamNameNamespacePair} that identifies each stream in the connector's - * state. - */ - public CursorManager(final ConfiguredAirbyteCatalog catalog, - final Supplier> streamSupplier, - final Function cursorFunction, - final Function> cursorFieldFunction, - final Function cursorRecordCountFunction, - final Function namespacePairFunction, - final boolean onlyIncludeIncrementalStreams) { - pairToCursorInfo = createCursorInfoMap( - catalog, streamSupplier, cursorFunction, cursorFieldFunction, cursorRecordCountFunction, namespacePairFunction, - onlyIncludeIncrementalStreams); - } - - /** - * Creates the cursor information map that associates stream name/namespace tuples with the current - * cursor information for that stream as stored in the connector's state. - * - * @param catalog The connector's configured catalog. - * @param streamSupplier A {@link Supplier} that provides the cursor manager with the collection of - * streams tracked by the connector's state. - * @param cursorFunction A {@link Function} that extracts the current cursor from a stream stored in - * the connector's state. - * @param cursorFieldFunction A {@link Function} that extracts the cursor field name from a stream - * stored in the connector's state. - * @param cursorRecordCountFunction A {@link Function} that extracts the cursor record count for a - * stream stored in the connector's state. - * @param namespacePairFunction A {@link Function} that generates a - * {@link AirbyteStreamNameNamespacePair} that identifies each stream in the connector's - * state. - * @return A map of streams to current cursor information for the stream. - */ - @VisibleForTesting - protected Map createCursorInfoMap( - final ConfiguredAirbyteCatalog catalog, - final Supplier> streamSupplier, - final Function cursorFunction, - final Function> cursorFieldFunction, - final Function cursorRecordCountFunction, - final Function namespacePairFunction, - final boolean onlyIncludeIncrementalStreams) { - final Set allStreamNames = catalog.getStreams() - .stream() - .filter(c -> { - if (onlyIncludeIncrementalStreams) { - return c.getSyncMode() == SyncMode.INCREMENTAL; - } - return true; - }) - .map(ConfiguredAirbyteStream::getStream) - .map(AirbyteStreamNameNamespacePair::fromAirbyteStream) - .collect(Collectors.toSet()); - allStreamNames.addAll(streamSupplier.get().stream().map(namespacePairFunction).filter(Objects::nonNull).collect(Collectors.toSet())); - - final Map localMap = new ConcurrentHashMap<>(); - final Map pairToState = streamSupplier.get() - .stream() - .collect(Collectors.toMap(namespacePairFunction, Function.identity())); - final Map pairToConfiguredAirbyteStream = catalog.getStreams().stream() - .collect(Collectors.toMap(AirbyteStreamNameNamespacePair::fromConfiguredAirbyteSteam, Function.identity())); - - for (final AirbyteStreamNameNamespacePair pair : allStreamNames) { - final Optional stateOptional = Optional.ofNullable(pairToState.get(pair)); - final Optional streamOptional = Optional.ofNullable(pairToConfiguredAirbyteStream.get(pair)); - localMap.put(pair, - createCursorInfoForStream(pair, stateOptional, streamOptional, cursorFunction, cursorFieldFunction, cursorRecordCountFunction)); - } - - return localMap; - } - - /** - * Generates a {@link CursorInfo} object based on the data currently stored in the connector's state - * for the given stream. - * - * @param pair A {@link AirbyteStreamNameNamespacePair} that identifies a specific stream managed by - * the connector. - * @param stateOptional {@link Optional} containing the current state associated with the stream. - * @param streamOptional {@link Optional} containing the {@link ConfiguredAirbyteStream} associated - * with the stream. - * @param cursorFunction A {@link Function} that provides the current cursor from the state - * associated with the stream. - * @param cursorFieldFunction A {@link Function} that provides the cursor field name for the cursor - * stored in the state associated with the stream. - * @param cursorRecordCountFunction A {@link Function} that extracts the cursor record count for a - * stream stored in the connector's state. - * @return A {@link CursorInfo} object based on the data currently stored in the connector's state - * for the given stream. - */ - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - @VisibleForTesting - protected CursorInfo createCursorInfoForStream(final AirbyteStreamNameNamespacePair pair, - final Optional stateOptional, - final Optional streamOptional, - final Function cursorFunction, - final Function> cursorFieldFunction, - final Function cursorRecordCountFunction) { - final String originalCursorField = stateOptional - .map(cursorFieldFunction) - .flatMap(f -> f.size() > 0 ? Optional.of(f.get(0)) : Optional.empty()) - .orElse(null); - final String originalCursor = stateOptional.map(cursorFunction).orElse(null); - final long originalCursorRecordCount = stateOptional.map(cursorRecordCountFunction).orElse(0L); - - final String cursor; - final String cursorField; - final long cursorRecordCount; - - // if cursor field is set in catalog. - if (streamOptional.map(ConfiguredAirbyteStream::getCursorField).isPresent()) { - cursorField = streamOptional - .map(ConfiguredAirbyteStream::getCursorField) - .flatMap(f -> f.size() > 0 ? Optional.of(f.get(0)) : Optional.empty()) - .orElse(null); - // if cursor field is set in state. - if (stateOptional.map(cursorFieldFunction).isPresent()) { - // if cursor field in catalog and state are the same. - if (stateOptional.map(cursorFieldFunction).equals(streamOptional.map(ConfiguredAirbyteStream::getCursorField))) { - cursor = stateOptional.map(cursorFunction).orElse(null); - cursorRecordCount = stateOptional.map(cursorRecordCountFunction).orElse(0L); - // If a matching cursor is found in the state, and it's value is null - this indicates a CDC stream - // and we shouldn't log anything. - if (cursor != null) { - LOGGER.info("Found matching cursor in state. Stream: {}. Cursor Field: {} Value: {} Count: {}", - pair, cursorField, cursor, cursorRecordCount); - } - // if cursor field in catalog and state are different. - } else { - cursor = null; - cursorRecordCount = 0L; - LOGGER.info( - "Found cursor field. Does not match previous cursor field. Stream: {}. Original Cursor Field: {} (count {}). New Cursor Field: {}. Resetting cursor value.", - pair, originalCursorField, originalCursorRecordCount, cursorField); - } - // if cursor field is not set in state but is set in catalog. - } else { - LOGGER.info("No cursor field set in catalog but not present in state. Stream: {}, New Cursor Field: {}. Resetting cursor value", pair, - cursorField); - cursor = null; - cursorRecordCount = 0L; - } - // if cursor field is not set in catalog. - } else { - LOGGER.info( - "Cursor field set in state but not present in catalog. Stream: {}. Original Cursor Field: {}. Original value: {}. Resetting cursor.", - pair, originalCursorField, originalCursor); - cursorField = null; - cursor = null; - cursorRecordCount = 0L; - } - - return new CursorInfo(originalCursorField, originalCursor, originalCursorRecordCount, cursorField, cursor, cursorRecordCount); - } - - /** - * Retrieves a copy of the stream name/namespace tuple to current cursor information map. - * - * @return A copy of the stream name/namespace tuple to current cursor information map. - */ - public Map getPairToCursorInfo() { - return Map.copyOf(pairToCursorInfo); - } - - /** - * Retrieves an {@link Optional} possibly containing the current {@link CursorInfo} associated with - * the provided stream name/namespace tuple. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} which identifies a stream. - * @return An {@link Optional} possibly containing the current {@link CursorInfo} associated with - * the provided stream name/namespace tuple. - */ - public Optional getCursorInfo(final AirbyteStreamNameNamespacePair pair) { - return Optional.ofNullable(pairToCursorInfo.get(pair)); - } - - /** - * Retrieves an {@link Optional} possibly containing the cursor field name associated with the - * cursor tracked in the state associated with the provided stream name/namespace tuple. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} which identifies a stream. - * @return An {@link Optional} possibly containing the cursor field name associated with the cursor - * tracked in the state associated with the provided stream name/namespace tuple. - */ - public Optional getCursorField(final AirbyteStreamNameNamespacePair pair) { - return getCursorInfo(pair).map(CursorInfo::getCursorField); - } - - /** - * Retrieves an {@link Optional} possibly containing the cursor value tracked in the state - * associated with the provided stream name/namespace tuple. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} which identifies a stream. - * @return An {@link Optional} possibly containing the cursor value tracked in the state associated - * with the provided stream name/namespace tuple. - */ - public Optional getCursor(final AirbyteStreamNameNamespacePair pair) { - return getCursorInfo(pair).map(CursorInfo::getCursor); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.java deleted file mode 100644 index a97c1d33a812..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import io.airbyte.cdk.db.IncrementalUtils; -import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo; -import io.airbyte.protocol.models.JsonSchemaPrimitiveUtil.JsonSchemaPrimitive; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import java.util.Objects; -import java.util.Optional; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class CursorStateMessageProducer implements SourceStateMessageProducer { - - private static final Logger LOGGER = LoggerFactory.getLogger(CursorStateMessageProducer.class); - private static final int LOG_FREQUENCY = 100; - - private final StateManager stateManager; - private final Optional initialCursor; - private Optional currentMaxCursor; - - // We keep this field to mark `cursor_record_count` and also to control logging frequency. - private int currentCursorRecordCount = 0; - private AirbyteStateMessage intermediateStateMessage = null; - - private boolean cursorOutOfOrderDetected = false; - - public CursorStateMessageProducer(final StateManager stateManager, - final Optional initialCursor) { - this.stateManager = stateManager; - this.initialCursor = initialCursor; - this.currentMaxCursor = initialCursor; - } - - @Override - public AirbyteStateMessage generateStateMessageAtCheckpoint(final ConfiguredAirbyteStream stream) { - // At this stage intermediate state message should never be null; otherwise it would have been - // blocked by shouldEmitStateMessage check. - final AirbyteStateMessage message = intermediateStateMessage; - intermediateStateMessage = null; - if (cursorOutOfOrderDetected) { - LOGGER.warn("Intermediate state emission feature requires records to be processed in order according to the cursor value. Otherwise, " - + "data loss can occur."); - } - return message; - } - - /** - * Note: We do not try to catch exception here. If error/exception happens, we should fail the sync, - * and since we have saved state message before, we should be able to resume it in next sync if we - * have fixed the underlying issue, of if the issue is transient. - */ - @Override - public AirbyteMessage processRecordMessage(final ConfiguredAirbyteStream stream, AirbyteMessage message) { - final String cursorField = IncrementalUtils.getCursorField(stream); - if (message.getRecord().getData().hasNonNull(cursorField)) { - final String cursorCandidate = getCursorCandidate(cursorField, message); - final JsonSchemaPrimitive cursorType = IncrementalUtils.getCursorType(stream, - cursorField); - final int cursorComparison = IncrementalUtils.compareCursors(currentMaxCursor.orElse(null), cursorCandidate, cursorType); - if (cursorComparison < 0) { - // Reset cursor but include current record message. This value will be used to create state message. - // Update the current max cursor only when current max cursor < cursor candidate from the message - if (!Objects.equals(currentMaxCursor, initialCursor)) { - // Only create an intermediate state when it is not the first record. - intermediateStateMessage = createStateMessage(stream); - } - currentMaxCursor = Optional.of(cursorCandidate); - currentCursorRecordCount = 1; - } else if (cursorComparison > 0) { - cursorOutOfOrderDetected = true; - } else { - currentCursorRecordCount++; - } - } - System.out.println("processed a record message. count: " + currentCursorRecordCount); - return message; - - } - - @Override - public AirbyteStateMessage createFinalStateMessage(final ConfiguredAirbyteStream stream) { - return createStateMessage(stream); - } - - /** - * Only sends out state message when there is a state message to be sent out. - */ - @Override - public boolean shouldEmitStateMessage(final ConfiguredAirbyteStream stream) { - return intermediateStateMessage != null; - } - - /** - * Creates AirbyteStateMessage while updating the cursor used to checkpoint the state of records - * read up so far - * - * @return AirbyteMessage which includes information on state of records read so far - */ - private AirbyteStateMessage createStateMessage(final ConfiguredAirbyteStream stream) { - final AirbyteStreamNameNamespacePair pair = new AirbyteStreamNameNamespacePair(stream.getStream().getName(), stream.getStream().getNamespace()); - System.out.println("state message creation: " + pair + " " + currentMaxCursor.orElse(null) + " " + currentCursorRecordCount); - final AirbyteStateMessage stateMessage = stateManager.updateAndEmit(pair, currentMaxCursor.orElse(null), currentCursorRecordCount); - final Optional cursorInfo = stateManager.getCursorInfo(pair); - - // logging once every 100 messages to reduce log verbosity - if (currentCursorRecordCount % LOG_FREQUENCY == 0) { - LOGGER.info("State report for stream {}: {}", pair, cursorInfo); - } - - return stateMessage; - } - - private String getCursorCandidate(final String cursorField, AirbyteMessage message) { - final String cursorCandidate = message.getRecord().getData().get(cursorField).asText(); - return (cursorCandidate != null ? replaceNull(cursorCandidate) : null); - } - - private String replaceNull(final String cursorCandidate) { - if (cursorCandidate.contains("\u0000")) { - return cursorCandidate.replaceAll("\u0000", ""); - } - return cursorCandidate; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/FailedRecordIteratorException.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/FailedRecordIteratorException.java deleted file mode 100644 index e8fedc5d9e9c..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/FailedRecordIteratorException.java +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright (c) 2024 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -public class FailedRecordIteratorException extends RuntimeException { - - public FailedRecordIteratorException(Throwable cause) { - super(cause); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.java deleted file mode 100644 index 384bd4d0cb8e..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils.CURSOR_FIELD_FUNCTION; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils.CURSOR_FUNCTION; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils.CURSOR_RECORD_COUNT_FUNCTION; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils.NAME_NAMESPACE_PAIR_FUNCTION; - -import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager; -import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteGlobalState; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.StreamDescriptor; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.function.Supplier; -import java.util.stream.Collectors; - -/** - * Global implementation of the {@link StateManager} interface. - * - * This implementation generates a single, global state object for the state tracked by this - * manager. - */ -public class GlobalStateManager extends AbstractStateManager { - - /** - * Legacy {@link CdcStateManager} used to manage state for connectors that support Change Data - * Capture (CDC). - */ - private final CdcStateManager cdcStateManager; - - /** - * Constructs a new {@link GlobalStateManager} that is seeded with the provided - * {@link AirbyteStateMessage}. - * - * @param airbyteStateMessage The initial state represented as an {@link AirbyteStateMessage}. - * @param catalog The {@link ConfiguredAirbyteCatalog} for the connector associated with this state - * manager. - */ - public GlobalStateManager(final AirbyteStateMessage airbyteStateMessage, final ConfiguredAirbyteCatalog catalog) { - super(catalog, - getStreamsSupplier(airbyteStateMessage), - CURSOR_FUNCTION, - CURSOR_FIELD_FUNCTION, - CURSOR_RECORD_COUNT_FUNCTION, - NAME_NAMESPACE_PAIR_FUNCTION, - true); - - this.cdcStateManager = new CdcStateManager(extractCdcState(airbyteStateMessage), extractStreams(airbyteStateMessage), airbyteStateMessage); - } - - @Override - public CdcStateManager getCdcStateManager() { - return cdcStateManager; - } - - @Override - public List getRawStateMessages() { - throw new UnsupportedOperationException("Raw state retrieval not supported by global state manager."); - } - - @Override - public AirbyteStateMessage toState(final Optional pair) { - // Populate global state - final AirbyteGlobalState globalState = new AirbyteGlobalState(); - globalState.setSharedState(Jsons.jsonNode(getCdcStateManager().getCdcState())); - globalState.setStreamStates(StateGeneratorUtils.generateStreamStateList(getPairToCursorInfoMap())); - - // Generate the legacy state for backwards compatibility - final DbState dbState = StateGeneratorUtils.generateDbState(getPairToCursorInfoMap()) - .withCdc(true) - .withCdcState(getCdcStateManager().getCdcState()); - - return new AirbyteStateMessage() - .withType(AirbyteStateType.GLOBAL) - // Temporarily include legacy state for backwards compatibility with the platform - .withData(Jsons.jsonNode(dbState)) - .withGlobal(globalState); - } - - /** - * Extracts the Change Data Capture (CDC) state stored in the initial state provided to this state - * manager. - * - * @param airbyteStateMessage The {@link AirbyteStateMessage} that contains the initial state - * provided to the state manager. - * @return The {@link CdcState} stored in the state, if any. Note that this will not be {@code null} - * but may be empty. - */ - private CdcState extractCdcState(final AirbyteStateMessage airbyteStateMessage) { - if (airbyteStateMessage.getType() == AirbyteStateType.GLOBAL) { - return Jsons.object(airbyteStateMessage.getGlobal().getSharedState(), CdcState.class); - } else { - final DbState legacyState = Jsons.object(airbyteStateMessage.getData(), DbState.class); - return legacyState != null ? legacyState.getCdcState() : null; - } - } - - private Set extractStreams(final AirbyteStateMessage airbyteStateMessage) { - if (airbyteStateMessage.getType() == AirbyteStateType.GLOBAL) { - return airbyteStateMessage.getGlobal().getStreamStates().stream() - .map(streamState -> { - final AirbyteStreamState cloned = Jsons.clone(streamState); - return new AirbyteStreamNameNamespacePair(cloned.getStreamDescriptor().getName(), cloned.getStreamDescriptor().getNamespace()); - }).collect(Collectors.toSet()); - } else { - final DbState legacyState = Jsons.object(airbyteStateMessage.getData(), DbState.class); - return legacyState != null ? extractNamespacePairsFromDbStreamState(legacyState.getStreams()) : Collections.emptySet(); - } - } - - private Set extractNamespacePairsFromDbStreamState(final List streams) { - return streams.stream().map(stream -> { - final DbStreamState cloned = Jsons.clone(stream); - return new AirbyteStreamNameNamespacePair(cloned.getStreamName(), cloned.getStreamNamespace()); - }).collect(Collectors.toSet()); - } - - /** - * Generates the {@link Supplier} that will be used to extract the streams from the incoming - * {@link AirbyteStateMessage}. - * - * @param airbyteStateMessage The {@link AirbyteStateMessage} supplied to this state manager with - * the initial state. - * @return A {@link Supplier} that will be used to fetch the streams present in the initial state. - */ - private static Supplier> getStreamsSupplier(final AirbyteStateMessage airbyteStateMessage) { - /* - * If the incoming message has the state type set to GLOBAL, it is using the new format. Therefore, - * we can look for streams in the "global" field of the message. Otherwise, the message is still - * storing state in the legacy "data" field. - */ - return () -> { - if (airbyteStateMessage.getType() == AirbyteStateType.GLOBAL) { - return airbyteStateMessage.getGlobal().getStreamStates(); - } else if (airbyteStateMessage.getData() != null) { - return Jsons.object(airbyteStateMessage.getData(), DbState.class).getStreams().stream() - .map(s -> new AirbyteStreamState().withStreamState(Jsons.jsonNode(s)) - .withStreamDescriptor(new StreamDescriptor().withNamespace(s.getStreamNamespace()).withName(s.getStreamName()))) - .collect( - Collectors.toList()); - } else { - return List.of(); - } - }; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.java deleted file mode 100644 index c12137e607a7..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Legacy implementation (pre-per-stream state support) of the {@link StateManager} interface. - * - * This implementation assumes that the state matches the {@link DbState} object and effectively - * tracks state as global across the streams managed by a connector. - * - * @deprecated This manager may be removed in the future if/once all connectors support per-stream - * state management. - */ -@Deprecated(forRemoval = true) -public class LegacyStateManager extends AbstractStateManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(LegacyStateManager.class); - - /** - * {@link Function} that extracts the cursor from the stream state. - */ - private static final Function CURSOR_FUNCTION = DbStreamState::getCursor; - - /** - * {@link Function} that extracts the cursor field(s) from the stream state. - */ - private static final Function> CURSOR_FIELD_FUNCTION = DbStreamState::getCursorField; - - private static final Function CURSOR_RECORD_COUNT_FUNCTION = - stream -> Objects.requireNonNullElse(stream.getCursorRecordCount(), 0L); - - /** - * {@link Function} that creates an {@link AirbyteStreamNameNamespacePair} from the stream state. - */ - private static final Function NAME_NAMESPACE_PAIR_FUNCTION = - s -> new AirbyteStreamNameNamespacePair(s.getStreamName(), s.getStreamNamespace()); - - /** - * Tracks whether the connector associated with this state manager supports CDC. - */ - private Boolean isCdc; - - /** - * {@link CdcStateManager} used to manage state for connectors that support CDC. - */ - private final CdcStateManager cdcStateManager; - - /** - * Constructs a new {@link LegacyStateManager} that is seeded with the provided {@link DbState} - * instance. - * - * @param dbState The initial state represented as an {@link DbState} instance. - * @param catalog The {@link ConfiguredAirbyteCatalog} for the connector associated with this state - * manager. - */ - public LegacyStateManager(final DbState dbState, final ConfiguredAirbyteCatalog catalog) { - super(catalog, - dbState::getStreams, - CURSOR_FUNCTION, - CURSOR_FIELD_FUNCTION, - CURSOR_RECORD_COUNT_FUNCTION, - NAME_NAMESPACE_PAIR_FUNCTION); - - this.cdcStateManager = new CdcStateManager(dbState.getCdcState(), AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog), null); - this.isCdc = dbState.getCdc(); - if (dbState.getCdc() == null) { - this.isCdc = false; - } - } - - @Override - public CdcStateManager getCdcStateManager() { - return cdcStateManager; - } - - @Override - public List getRawStateMessages() { - throw new UnsupportedOperationException("Raw state retrieval not supported by global state manager."); - } - - @Override - public AirbyteStateMessage toState(final Optional pair) { - final DbState dbState = StateGeneratorUtils.generateDbState(getPairToCursorInfoMap()) - .withCdc(isCdc) - .withCdcState(getCdcStateManager().getCdcState()); - - LOGGER.debug("Generated legacy state for {} streams", dbState.getStreams().size()); - return new AirbyteStateMessage().withType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState)); - } - - @Override - public AirbyteStateMessage updateAndEmit(final AirbyteStreamNameNamespacePair pair, final String cursor) { - return updateAndEmit(pair, cursor, 0L); - } - - @Override - public AirbyteStateMessage updateAndEmit(final AirbyteStreamNameNamespacePair pair, final String cursor, final long cursorRecordCount) { - // cdc file gets updated by debezium so the "update" part is a no op. - if (!isCdc) { - return super.updateAndEmit(pair, cursor, cursorRecordCount); - } - - return toState(Optional.ofNullable(pair)); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.java deleted file mode 100644 index 238f9471a8b0..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import com.google.common.collect.AbstractIterator; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateStats; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import java.time.Duration; -import java.time.Instant; -import java.time.OffsetDateTime; -import java.util.Iterator; -import javax.annotation.CheckForNull; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class SourceStateIterator extends AbstractIterator implements Iterator { - - private static final Logger LOGGER = LoggerFactory.getLogger(SourceStateIterator.class); - private final Iterator messageIterator; - private final ConfiguredAirbyteStream stream; - private final StateEmitFrequency stateEmitFrequency; - private boolean hasEmittedFinalState = false; - private long recordCount = 0L; - private Instant lastCheckpoint = Instant.now(); - - private final SourceStateMessageProducer sourceStateMessageProducer; - - public SourceStateIterator(final Iterator messageIterator, - final ConfiguredAirbyteStream stream, - final SourceStateMessageProducer sourceStateMessageProducer, - final StateEmitFrequency stateEmitFrequency) { - this.messageIterator = messageIterator; - this.stream = stream; - this.sourceStateMessageProducer = sourceStateMessageProducer; - this.stateEmitFrequency = stateEmitFrequency; - } - - @CheckForNull - @Override - protected AirbyteMessage computeNext() { - - boolean iteratorHasNextValue = false; - try { - iteratorHasNextValue = messageIterator.hasNext(); - } catch (final Exception ex) { - // If the underlying iterator throws an exception, we want to fail the sync, expecting sync/attempt - // will be restarted and - // sync will resume from the last state message. - throw new FailedRecordIteratorException(ex); - } - if (iteratorHasNextValue) { - if (shouldEmitStateMessage() && sourceStateMessageProducer.shouldEmitStateMessage(stream)) { - final AirbyteStateMessage stateMessage = sourceStateMessageProducer.generateStateMessageAtCheckpoint(stream); - stateMessage.withSourceStats(new AirbyteStateStats().withRecordCount((double) recordCount)); - - recordCount = 0L; - lastCheckpoint = Instant.now(); - return new AirbyteMessage() - .withType(Type.STATE) - .withState(stateMessage); - } - // Use try-catch to catch Exception that could occur when connection to the database fails - try { - final T message = messageIterator.next(); - final AirbyteMessage processedMessage = sourceStateMessageProducer.processRecordMessage(stream, message); - recordCount++; - return processedMessage; - } catch (final Exception e) { - throw new FailedRecordIteratorException(e); - } - } else if (!hasEmittedFinalState) { - hasEmittedFinalState = true; - final AirbyteStateMessage finalStateMessageForStream = sourceStateMessageProducer.createFinalStateMessage(stream); - finalStateMessageForStream.withSourceStats(new AirbyteStateStats().withRecordCount((double) recordCount)); - recordCount = 0L; - return new AirbyteMessage() - .withType(Type.STATE) - .withState(finalStateMessageForStream); - } else { - return endOfData(); - } - } - - // This method is used to check if we should emit a state message. If the record count is set to 0, - // we should not emit a state message. - // If the frequency is set to be zero, we should not use it. - private boolean shouldEmitStateMessage() { - if (stateEmitFrequency.syncCheckpointRecords() == 0) { - return false; - } - if (recordCount >= stateEmitFrequency.syncCheckpointRecords()) { - return true; - } - if (!stateEmitFrequency.syncCheckpointDuration().isZero()) { - return Duration.between(lastCheckpoint, OffsetDateTime.now()).compareTo(stateEmitFrequency.syncCheckpointDuration()) > 0; - } - return false; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.java deleted file mode 100644 index c4d95b2b1fbb..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; - -/** - * To be used with SourceStateIterator. SourceStateIterator will iterate over the records and - * generate state messages when needed. This interface defines how would those state messages be - * generated, and how the incoming record messages will be processed. - * - * @param - */ -public interface SourceStateMessageProducer { - - /** - * Returns a state message that should be emitted at checkpoint. - */ - AirbyteStateMessage generateStateMessageAtCheckpoint(final ConfiguredAirbyteStream stream); - - /** - * For the incoming record message, this method defines how the connector will consume it. - */ - AirbyteMessage processRecordMessage(final ConfiguredAirbyteStream stream, final T message); - - /** - * At the end of the iteration, this method will be called and it will generate the final state - * message. - * - * @return - */ - AirbyteStateMessage createFinalStateMessage(final ConfiguredAirbyteStream stream); - - /** - * Determines if the iterator has reached checkpoint or not per connector's definition. By default - * iterator will check if the number of records processed is greater than the checkpoint interval or - * last state message has already passed syncCheckpointDuration. - */ - boolean shouldEmitStateMessage(final ConfiguredAirbyteStream stream); - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.java deleted file mode 100644 index ee1eef34c421..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.java +++ /dev/null @@ -1,9 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import java.time.Duration; - -public record StateEmitFrequency(long syncCheckpointRecords, Duration syncCheckpointDuration) {} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.java deleted file mode 100644 index 4c272190946b..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.java +++ /dev/null @@ -1,266 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.Lists; -import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo; -import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.configoss.StateWrapper; -import io.airbyte.configoss.helpers.StateMessageHelper; -import io.airbyte.protocol.models.v0.AirbyteGlobalState; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.StreamDescriptor; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Collection of utilities that facilitate the generation of state objects. - */ -public class StateGeneratorUtils { - - private static final Logger LOGGER = LoggerFactory.getLogger(StateGeneratorUtils.class); - - /** - * {@link Function} that extracts the cursor from the stream state. - */ - public static final Function CURSOR_FUNCTION = stream -> { - final Optional dbStreamState = StateGeneratorUtils.extractState(stream); - return dbStreamState.map(DbStreamState::getCursor).orElse(null); - }; - - /** - * {@link Function} that extracts the cursor field(s) from the stream state. - */ - public static final Function> CURSOR_FIELD_FUNCTION = stream -> { - final Optional dbStreamState = StateGeneratorUtils.extractState(stream); - if (dbStreamState.isPresent()) { - return dbStreamState.get().getCursorField(); - } else { - return List.of(); - } - }; - - public static final Function CURSOR_RECORD_COUNT_FUNCTION = stream -> { - final Optional dbStreamState = StateGeneratorUtils.extractState(stream); - return dbStreamState.map(DbStreamState::getCursorRecordCount).orElse(0L); - }; - - /** - * {@link Function} that creates an {@link AirbyteStreamNameNamespacePair} from the stream state. - */ - public static final Function NAME_NAMESPACE_PAIR_FUNCTION = - s -> isValidStreamDescriptor(s.getStreamDescriptor()) - ? new AirbyteStreamNameNamespacePair(s.getStreamDescriptor().getName(), s.getStreamDescriptor().getNamespace()) - : null; - - private StateGeneratorUtils() {} - - /** - * Generates the stream state for the given stream and cursor information. - * - * @param airbyteStreamNameNamespacePair The stream. - * @param cursorInfo The current cursor. - * @return The {@link AirbyteStreamState} representing the current state of the stream. - */ - public static AirbyteStreamState generateStreamState(final AirbyteStreamNameNamespacePair airbyteStreamNameNamespacePair, - final CursorInfo cursorInfo) { - return new AirbyteStreamState() - .withStreamDescriptor( - new StreamDescriptor().withName(airbyteStreamNameNamespacePair.getName()).withNamespace(airbyteStreamNameNamespacePair.getNamespace())) - .withStreamState(Jsons.jsonNode(generateDbStreamState(airbyteStreamNameNamespacePair, cursorInfo))); - } - - /** - * Generates a list of valid stream states from the provided stream and cursor information. A stream - * state is considered to be valid if the stream has a valid descriptor (see - * {@link #isValidStreamDescriptor(StreamDescriptor)} for more details). - * - * @param pairToCursorInfoMap The map of stream name/namespace tuple to the current cursor - * information for that stream - * @return The list of stream states derived from the state information extracted from the provided - * map. - */ - public static List generateStreamStateList(final Map pairToCursorInfoMap) { - return pairToCursorInfoMap.entrySet().stream() - .sorted(Entry.comparingByKey()) - .map(e -> generateStreamState(e.getKey(), e.getValue())) - .filter(s -> isValidStreamDescriptor(s.getStreamDescriptor())) - .collect(Collectors.toList()); - } - - /** - * Generates the legacy global state for backwards compatibility. - * - * @param pairToCursorInfoMap The map of stream name/namespace tuple to the current cursor - * information for that stream - * @return The legacy {@link DbState}. - */ - public static DbState generateDbState(final Map pairToCursorInfoMap) { - return new DbState() - .withCdc(false) - .withStreams(pairToCursorInfoMap.entrySet().stream() - .sorted(Entry.comparingByKey()) // sort by stream name then namespace for sanity. - .map(e -> generateDbStreamState(e.getKey(), e.getValue())) - .collect(Collectors.toList())); - } - - /** - * Generates the {@link DbStreamState} for the given stream and cursor. - * - * @param airbyteStreamNameNamespacePair The stream. - * @param cursorInfo The current cursor. - * @return The {@link DbStreamState}. - */ - public static DbStreamState generateDbStreamState(final AirbyteStreamNameNamespacePair airbyteStreamNameNamespacePair, - final CursorInfo cursorInfo) { - final DbStreamState state = new DbStreamState() - .withStreamName(airbyteStreamNameNamespacePair.getName()) - .withStreamNamespace(airbyteStreamNameNamespacePair.getNamespace()) - .withCursorField(cursorInfo.getCursorField() == null ? Collections.emptyList() : Lists.newArrayList(cursorInfo.getCursorField())) - .withCursor(cursorInfo.getCursor()); - if (cursorInfo.getCursorRecordCount() > 0L) { - state.setCursorRecordCount(cursorInfo.getCursorRecordCount()); - } - return state; - } - - /** - * Extracts the actual state from the {@link AirbyteStreamState} object. - * - * @param state The {@link AirbyteStreamState} that contains the actual stream state as JSON. - * @return An {@link Optional} possibly containing the deserialized representation of the stream - * state or an empty {@link Optional} if the state is not present or could not be - * deserialized. - */ - public static Optional extractState(final AirbyteStreamState state) { - try { - return Optional.ofNullable(Jsons.object(state.getStreamState(), DbStreamState.class)); - } catch (final IllegalArgumentException e) { - LOGGER.error("Unable to extract state.", e); - return Optional.empty(); - } - } - - /** - * Tests whether the provided {@link StreamDescriptor} is valid. A valid descriptor is defined as - * one that has a non-{@code null} name. - * - * See - * https://github.com/airbytehq/airbyte/blob/e63458fabb067978beb5eaa74d2bc130919b419f/docs/understanding-airbyte/airbyte-protocol.md - * for more details - * - * @param streamDescriptor A {@link StreamDescriptor} to be validated. - * @return {@code true} if the provided {@link StreamDescriptor} is valid or {@code false} if it is - * invalid. - */ - public static boolean isValidStreamDescriptor(final StreamDescriptor streamDescriptor) { - if (streamDescriptor != null) { - return streamDescriptor.getName() != null; - } else { - return false; - } - } - - /** - * Converts a {@link AirbyteStateType#LEGACY} state message into a {@link AirbyteStateType#GLOBAL} - * message. - * - * @param airbyteStateMessage A {@link AirbyteStateType#LEGACY} state message. - * @return A {@link AirbyteStateType#GLOBAL} state message. - */ - public static AirbyteStateMessage convertLegacyStateToGlobalState(final AirbyteStateMessage airbyteStateMessage) { - final DbState dbState = Jsons.object(airbyteStateMessage.getData(), DbState.class); - final AirbyteGlobalState globalState = new AirbyteGlobalState() - .withSharedState(Jsons.jsonNode(dbState.getCdcState())) - .withStreamStates(dbState.getStreams().stream() - .map(s -> new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(s.getStreamName()).withNamespace(s.getStreamNamespace())) - .withStreamState(Jsons.jsonNode(s))) - .collect( - Collectors.toList())); - return new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(globalState); - } - - /** - * Converts a {@link AirbyteStateType#LEGACY} state message into a list of - * {@link AirbyteStateType#STREAM} messages. - * - * @param airbyteStateMessage A {@link AirbyteStateType#LEGACY} state message. - * @return A list {@link AirbyteStateType#STREAM} state messages. - */ - public static List convertLegacyStateToStreamState(final AirbyteStateMessage airbyteStateMessage) { - return Jsons.object(airbyteStateMessage.getData(), DbState.class).getStreams().stream() - .map(s -> new AirbyteStateMessage().withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withNamespace(s.getStreamNamespace()).withName(s.getStreamName())) - .withStreamState(Jsons.jsonNode(s)))) - .collect(Collectors.toList()); - } - - public static AirbyteStateMessage convertStateMessage(final io.airbyte.protocol.models.AirbyteStateMessage state) { - return Jsons.object(Jsons.jsonNode(state), AirbyteStateMessage.class); - } - - /** - * Deserializes the state represented as JSON into an object representation. - * - * @param initialStateJson The state as JSON. - * @Param supportedStateType the {@link AirbyteStateType} supported by this connector. - * @return The deserialized object representation of the state. - */ - public static List deserializeInitialState(final JsonNode initialStateJson, - final AirbyteStateType supportedStateType) { - final Optional typedState = StateMessageHelper.getTypedState(initialStateJson); - return typedState - .map(state -> switch (state.getStateType()) { - case GLOBAL -> List.of(StateGeneratorUtils.convertStateMessage(state.getGlobal())); - case STREAM -> state.getStateMessages() - .stream() - .map(StateGeneratorUtils::convertStateMessage).toList(); - default -> List.of(new AirbyteStateMessage().withType(AirbyteStateType.LEGACY) - .withData(state.getLegacyState())); - }) - .orElse(generateEmptyInitialState(supportedStateType)); - } - - /** - * Generates an empty, initial state for use by the connector. - * - * @Param supportedStateType the {@link AirbyteStateType} supported by this connector. - * @return The empty, initial state. - */ - private static List generateEmptyInitialState(final AirbyteStateType supportedStateType) { - // For backwards compatibility with existing connectors - if (supportedStateType == AirbyteStateType.LEGACY) { - return List.of(new AirbyteStateMessage() - .withType(AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(new DbState()))); - } else if (supportedStateType == AirbyteStateType.GLOBAL) { - final AirbyteGlobalState globalState = new AirbyteGlobalState() - .withSharedState(Jsons.jsonNode(new CdcState())) - .withStreamStates(List.of()); - return List.of(new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(globalState)); - } else { - return List.of(new AirbyteStateMessage() - .withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState())); - } - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.java deleted file mode 100644 index 3bfb211ea2aa..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import com.google.common.base.Preconditions; -import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager; -import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Defines a manager that manages connector state. Connector state is used to keep track of the data - * synced by the connector. - * - * @param The type of the state maintained by the manager. - * @param The type of the stream(s) stored within the state maintained by the manager. - */ -public interface StateManager { - - Logger LOGGER = LoggerFactory.getLogger(StateManager.class); - - /** - * Retrieves the {@link CdcStateManager} associated with the state manager. - * - * @return The {@link CdcStateManager} - * @throws UnsupportedOperationException if the state manager does not support tracking change data - * capture (CDC) state. - */ - CdcStateManager getCdcStateManager(); - - /** - * Retries the raw state messages associated with the state manager. This is required for - * database-specific sync modes (e.g. Xmin) that would want to handle and parse their own state - * - * @return the list of airbyte state messages - * @throws UnsupportedOperationException if the state manager does not support retrieving raw state. - */ - List getRawStateMessages(); - - /** - * Retrieves the map of stream name/namespace tuple to the current cursor information for that - * stream. - * - * @return The map of stream name/namespace tuple to the current cursor information for that stream - * as maintained by this state manager. - */ - Map getPairToCursorInfoMap(); - - /** - * Generates an {@link AirbyteStateMessage} that represents the current state contained in the state - * manager. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} that represents a stream managed by the - * state manager. - * @return The {@link AirbyteStateMessage} that represents the current state contained in the state - * manager. - */ - AirbyteStateMessage toState(final Optional pair); - - /** - * Retrieves an {@link Optional} possibly containing the cursor value tracked in the state - * associated with the provided stream name/namespace tuple. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} which identifies a stream. - * @return An {@link Optional} possibly containing the cursor value tracked in the state associated - * with the provided stream name/namespace tuple. - */ - default Optional getCursor(final AirbyteStreamNameNamespacePair pair) { - return getCursorInfo(pair).map(CursorInfo::getCursor); - } - - /** - * Retrieves an {@link Optional} possibly containing the cursor field name associated with the - * cursor tracked in the state associated with the provided stream name/namespace tuple. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} which identifies a stream. - * @return An {@link Optional} possibly containing the cursor field name associated with the cursor - * tracked in the state associated with the provided stream name/namespace tuple. - */ - default Optional getCursorField(final AirbyteStreamNameNamespacePair pair) { - return getCursorInfo(pair).map(CursorInfo::getCursorField); - } - - /** - * Retrieves an {@link Optional} possibly containing the original cursor value tracked in the state - * associated with the provided stream name/namespace tuple. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} which identifies a stream. - * @return An {@link Optional} possibly containing the original cursor value tracked in the state - * associated with the provided stream name/namespace tuple. - */ - default Optional getOriginalCursor(final AirbyteStreamNameNamespacePair pair) { - return getCursorInfo(pair).map(CursorInfo::getOriginalCursor); - } - - /** - * Retrieves an {@link Optional} possibly containing the original cursor field name associated with - * the cursor tracked in the state associated with the provided stream name/namespace tuple. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} which identifies a stream. - * @return An {@link Optional} possibly containing the original cursor field name associated with - * the cursor tracked in the state associated with the provided stream name/namespace tuple. - */ - default Optional getOriginalCursorField(final AirbyteStreamNameNamespacePair pair) { - return getCursorInfo(pair).map(CursorInfo::getOriginalCursorField); - } - - /** - * Retrieves the current cursor information stored in the state manager for the steam name/namespace - * tuple. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} that represents a stream managed by the - * state manager. - * @return {@link Optional} that potentially contains the current cursor information for the given - * stream name/namespace tuple. - */ - default Optional getCursorInfo(final AirbyteStreamNameNamespacePair pair) { - return Optional.ofNullable(getPairToCursorInfoMap().get(pair)); - } - - /** - * Emits the current state maintained by the manager as an {@link AirbyteStateMessage}. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} that represents a stream managed by the - * state manager. - * @return An {@link AirbyteStateMessage} that represents the current state maintained by the state - * manager. - */ - default AirbyteStateMessage emit(final Optional pair) { - return toState(pair); - } - - /** - * Updates the cursor associated with the provided stream name/namespace pair and emits the current - * state maintained by the state manager. - * - * @param pair The {@link AirbyteStreamNameNamespacePair} that represents a stream managed by the - * state manager. - * @param cursor The new value for the cursor associated with the - * {@link AirbyteStreamNameNamespacePair} that represents a stream managed by the state - * manager. - * @return An {@link AirbyteStateMessage} that represents the current state maintained by the state - * manager. - */ - default AirbyteStateMessage updateAndEmit(final AirbyteStreamNameNamespacePair pair, final String cursor) { - return updateAndEmit(pair, cursor, 0L); - } - - default AirbyteStateMessage updateAndEmit(final AirbyteStreamNameNamespacePair pair, final String cursor, final long cursorRecordCount) { - final Optional cursorInfo = getCursorInfo(pair); - Preconditions.checkState(cursorInfo.isPresent(), "Could not find cursor information for stream: " + pair); - cursorInfo.get().setCursor(cursor); - if (cursorRecordCount > 0L) { - cursorInfo.get().setCursorRecordCount(cursorRecordCount); - } - LOGGER.debug("Updating cursor value for {} to {} (count {})...", pair, cursor, cursorRecordCount); - return emit(Optional.ofNullable(pair)); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.java deleted file mode 100644 index 6c6d8b166443..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import java.util.ArrayList; -import java.util.List; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Factory class that creates {@link StateManager} instances based on the provided state. - */ -public class StateManagerFactory { - - private static final Logger LOGGER = LoggerFactory.getLogger(StateManagerFactory.class); - - /** - * Private constructor to prevent direct instantiation. - */ - private StateManagerFactory() {} - - /** - * Creates a {@link StateManager} based on the provided state object and catalog. This method will - * handle the conversion of the provided state to match the requested state manager based on the - * provided {@link AirbyteStateType}. - * - * @param supportedStateType The type of state supported by the connector. - * @param initialState The deserialized initial state that will be provided to the selected - * {@link StateManager}. - * @param catalog The {@link ConfiguredAirbyteCatalog} for the connector that will utilize the state - * manager. - * @return A newly created {@link StateManager} implementation based on the provided state. - */ - public static StateManager createStateManager(final AirbyteStateType supportedStateType, - final List initialState, - final ConfiguredAirbyteCatalog catalog) { - if (initialState != null && !initialState.isEmpty()) { - final AirbyteStateMessage airbyteStateMessage = initialState.get(0); - switch (supportedStateType) { - case LEGACY: - LOGGER.info("Legacy state manager selected to manage state object with type {}.", airbyteStateMessage.getType()); - @SuppressWarnings("deprecation") - StateManager retVal = new LegacyStateManager(Jsons.object(airbyteStateMessage.getData(), DbState.class), catalog); - return retVal; - case GLOBAL: - LOGGER.info("Global state manager selected to manage state object with type {}.", airbyteStateMessage.getType()); - return new GlobalStateManager(generateGlobalState(airbyteStateMessage), catalog); - case STREAM: - default: - LOGGER.info("Stream state manager selected to manage state object with type {}.", airbyteStateMessage.getType()); - return new StreamStateManager(generateStreamState(initialState), catalog); - } - } else { - throw new IllegalArgumentException("Failed to create state manager due to empty state list."); - } - } - - /** - * Handles the conversion between a different state type and the global state. This method handles - * the following transitions: - *
    - *
  • Stream -> Global (not supported, results in {@link IllegalArgumentException}
  • - *
  • Legacy -> Global (supported)
  • - *
  • Global -> Global (supported/no conversion required)
  • - *
- * - * @param airbyteStateMessage The current state that is to be converted to global state. - * @return The converted state message. - * @throws IllegalArgumentException if unable to convert between the given state type and global. - */ - private static AirbyteStateMessage generateGlobalState(final AirbyteStateMessage airbyteStateMessage) { - AirbyteStateMessage globalStateMessage = airbyteStateMessage; - - switch (airbyteStateMessage.getType()) { - case STREAM: - throw new IllegalArgumentException("Unable to convert connector state from stream to global. Please reset the connection to continue."); - case LEGACY: - globalStateMessage = StateGeneratorUtils.convertLegacyStateToGlobalState(airbyteStateMessage); - LOGGER.info("Legacy state converted to global state.", airbyteStateMessage.getType()); - break; - case GLOBAL: - default: - break; - } - - return globalStateMessage; - } - - /** - * Handles the conversion between a different state type and the stream state. This method handles - * the following transitions: - *
    - *
  • Global -> Stream (not supported, results in {@link IllegalArgumentException}
  • - *
  • Legacy -> Stream (supported)
  • - *
  • Stream -> Stream (supported/no conversion required)
  • - *
- * - * @param states The list of current states. - * @return The converted state messages. - * @throws IllegalArgumentException if unable to convert between the given state type and stream. - */ - private static List generateStreamState(final List states) { - final AirbyteStateMessage airbyteStateMessage = states.get(0); - final List streamStates = new ArrayList<>(); - switch (airbyteStateMessage.getType()) { - case GLOBAL: - throw new IllegalArgumentException("Unable to convert connector state from global to stream. Please reset the connection to continue."); - case LEGACY: - streamStates.addAll(StateGeneratorUtils.convertLegacyStateToStreamState(airbyteStateMessage)); - break; - case STREAM: - default: - streamStates.addAll(states); - break; - } - - return streamStates; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.java deleted file mode 100644 index efb874b8b034..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils.CURSOR_FIELD_FUNCTION; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils.CURSOR_FUNCTION; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils.CURSOR_RECORD_COUNT_FUNCTION; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils.NAME_NAMESPACE_PAIR_FUNCTION; - -import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager; -import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Per-stream implementation of the {@link StateManager} interface. - *

- * This implementation generates a state object for each stream detected in catalog/map of known - * streams to cursor information stored in this manager. - */ -public class StreamStateManager extends AbstractStateManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(StreamStateManager.class); - private final List rawAirbyteStateMessages; - - /** - * Constructs a new {@link StreamStateManager} that is seeded with the provided - * {@link AirbyteStateMessage}. - * - * @param airbyteStateMessages The initial state represented as a list of - * {@link AirbyteStateMessage}s. - * @param catalog The {@link ConfiguredAirbyteCatalog} for the connector associated with this state - * manager. - */ - public StreamStateManager(final List airbyteStateMessages, final ConfiguredAirbyteCatalog catalog) { - super(catalog, - () -> airbyteStateMessages.stream().map(AirbyteStateMessage::getStream).collect(Collectors.toList()), - CURSOR_FUNCTION, - CURSOR_FIELD_FUNCTION, - CURSOR_RECORD_COUNT_FUNCTION, - NAME_NAMESPACE_PAIR_FUNCTION); - this.rawAirbyteStateMessages = airbyteStateMessages; - } - - @Override - public CdcStateManager getCdcStateManager() { - throw new UnsupportedOperationException("CDC state management not supported by stream state manager."); - } - - @Override - public List getRawStateMessages() { - return rawAirbyteStateMessages; - } - - @Override - public AirbyteStateMessage toState(final Optional pair) { - if (pair.isPresent()) { - final Map pairToCursorInfoMap = getPairToCursorInfoMap(); - final Optional cursorInfo = Optional.ofNullable(pairToCursorInfoMap.get(pair.get())); - - if (cursorInfo.isPresent()) { - LOGGER.debug("Generating state message for {}...", pair); - return new AirbyteStateMessage() - .withType(AirbyteStateType.STREAM) - // Temporarily include legacy state for backwards compatibility with the platform - .withData(Jsons.jsonNode(StateGeneratorUtils.generateDbState(pairToCursorInfoMap))) - .withStream(StateGeneratorUtils.generateStreamState(pair.get(), cursorInfo.get())); - } else { - LOGGER.warn("Cursor information could not be located in state for stream {}. Returning a new, empty state message...", pair); - return new AirbyteStateMessage().withType(AirbyteStateType.STREAM).withStream(new AirbyteStreamState()); - } - } else { - LOGGER.warn("Stream not provided. Returning a new, empty state message..."); - return new AirbyteStateMessage().withType(AirbyteStateType.STREAM).withStream(new AirbyteStreamState()); - } - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.kt new file mode 100644 index 000000000000..c23ad4c0b73e --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.kt @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.integrations.debezium.internals.* +import io.airbyte.cdk.integrations.source.relationaldb.state.SourceStateIterator +import io.airbyte.cdk.integrations.source.relationaldb.state.StateEmitFrequency +import io.airbyte.commons.util.AutoCloseableIterator +import io.airbyte.commons.util.AutoCloseableIterators +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import io.airbyte.protocol.models.v0.SyncMode +import io.debezium.engine.ChangeEvent +import io.debezium.engine.DebeziumEngine +import java.time.Duration +import java.time.Instant +import java.time.temporal.ChronoUnit +import java.util.* +import java.util.concurrent.LinkedBlockingQueue +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * This class acts as the bridge between Airbyte DB connectors and debezium. If a DB connector wants + * to use debezium for CDC, it should use this class + */ +class AirbyteDebeziumHandler( + private val config: JsonNode, + private val targetPosition: CdcTargetPosition, + private val trackSchemaHistory: Boolean, + private val firstRecordWaitTime: Duration, + private val subsequentRecordWaitTime: Duration, + private val queueSize: Int, + private val addDbNameToOffsetState: Boolean +) { + internal inner class CapacityReportingBlockingQueue(capacity: Int) : + LinkedBlockingQueue(capacity) { + private var lastReport: Instant? = null + + private fun reportQueueUtilization() { + if ( + lastReport == null || + Duration.between(lastReport, Instant.now()) + .compareTo(Companion.REPORT_DURATION) > 0 + ) { + LOGGER.info( + "CDC events queue size: {}. remaining {}", + this.size, + this.remainingCapacity() + ) + synchronized(this) { lastReport = Instant.now() } + } + } + + @Throws(InterruptedException::class) + override fun put(e: E) { + reportQueueUtilization() + super.put(e) + } + + override fun poll(): E { + reportQueueUtilization() + return super.poll() + } + } + + fun getIncrementalIterators( + debeziumPropertiesManager: DebeziumPropertiesManager, + eventConverter: DebeziumEventConverter, + cdcSavedInfoFetcher: CdcSavedInfoFetcher, + cdcStateHandler: CdcStateHandler + ): AutoCloseableIterator { + LOGGER.info("Using CDC: {}", true) + LOGGER.info( + "Using DBZ version: {}", + DebeziumEngine::class.java.getPackage().implementationVersion + ) + val offsetManager: AirbyteFileOffsetBackingStore = + AirbyteFileOffsetBackingStore.Companion.initializeState( + cdcSavedInfoFetcher.savedOffset, + if (addDbNameToOffsetState) + Optional.ofNullable(config[JdbcUtils.DATABASE_KEY].asText()) + else Optional.empty() + ) + val schemaHistoryManager: Optional = + if (trackSchemaHistory) + Optional.of( + AirbyteSchemaHistoryStorage.Companion.initializeDBHistory( + cdcSavedInfoFetcher.savedSchemaHistory, + cdcStateHandler.compressSchemaHistoryForState() + ) + ) + else Optional.empty() + val publisher = DebeziumRecordPublisher(debeziumPropertiesManager) + val queue: CapacityReportingBlockingQueue> = + CapacityReportingBlockingQueue(queueSize) + publisher.start(queue, offsetManager, schemaHistoryManager) + // handle state machine around pub/sub logic. + val eventIterator: AutoCloseableIterator = + DebeziumRecordIterator( + queue, + targetPosition, + { publisher.hasClosed() }, + DebeziumShutdownProcedure(queue, { publisher.close() }, { publisher.hasClosed() }), + firstRecordWaitTime, + subsequentRecordWaitTime + ) + + val syncCheckpointDuration = + if (config.has(DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION_PROPERTY)) + Duration.ofSeconds( + config[DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION_PROPERTY].asLong() + ) + else DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION + val syncCheckpointRecords = + if (config.has(DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS_PROPERTY)) + config[DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS_PROPERTY].asLong() + else DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS.toLong() + + val messageProducer: DebeziumMessageProducer = + DebeziumMessageProducer( + cdcStateHandler, + targetPosition, + eventConverter, + offsetManager, + schemaHistoryManager + ) + + // Usually sourceStateIterator requires airbyteStream as input. For DBZ iterator, stream is + // not used + // at all thus we will pass in null. + val iterator: SourceStateIterator = + SourceStateIterator( + eventIterator, + null, + messageProducer!!, + StateEmitFrequency(syncCheckpointRecords, syncCheckpointDuration) + ) + return AutoCloseableIterators.fromIterator(iterator) + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(AirbyteDebeziumHandler::class.java) + private val REPORT_DURATION: Duration = Duration.of(10, ChronoUnit.SECONDS) + + /** + * We use 10000 as capacity cause the default queue size and batch size of debezium is : + * [io.debezium.config.CommonConnectorConfig.DEFAULT_MAX_BATCH_SIZE]is 2048 + * [io.debezium.config.CommonConnectorConfig.DEFAULT_MAX_QUEUE_SIZE] is 8192 + */ + const val QUEUE_CAPACITY: Int = 10000 + + fun isAnyStreamIncrementalSyncMode(catalog: ConfiguredAirbyteCatalog): Boolean { + return catalog.streams + .stream() + .map { obj: ConfiguredAirbyteStream -> obj.syncMode } + .anyMatch { syncMode: SyncMode -> syncMode == SyncMode.INCREMENTAL } + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.kt new file mode 100644 index 000000000000..7054290eb0df --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.kt @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.ObjectNode + +/** + * This interface is used to add metadata to the records fetched from the database. For instance, in + * Postgres we add the lsn to the records. In MySql we add the file name and position to the + * records. + */ +interface CdcMetadataInjector { + /** + * A debezium record contains multiple pieces. Ref : + * https://debezium.io/documentation/reference/1.9/connectors/mysql.html#mysql-create-events + * + * @param event is the actual record which contains data and would be written to the destination + * @param source contains the metadata about the record and we need to extract that metadata and + * add it to the event before writing it to destination + */ + fun addMetaData(event: ObjectNode?, source: JsonNode?) + + fun addMetaDataToRowsFetchedOutsideDebezium( + record: ObjectNode?, + transactionTimestamp: String?, + metadataToAdd: T + ) { + throw RuntimeException("Not Supported") + } + + fun addMetaDataToRowsFetchedOutsideDebezium(record: ObjectNode?) { + throw java.lang.RuntimeException("Not Supported") + } + + /** + * As part of Airbyte record we need to add the namespace (schema name) + * + * @param source part of debezium record and contains the metadata about the record. We need to + * extract namespace out of this metadata and return Ref : + * https://debezium.io/documentation/reference/1.9/connectors/mysql.html#mysql-create-events + * @return the stream namespace extracted from the change event source. + */ + fun namespace(source: JsonNode?): String? + + /** + * As part of Airbyte record we need to add the name (e.g. table name) + * + * @param source part of debezium record and contains the metadata about the record. We need to + * extract namespace out of this metadata and return Ref : + * https://debezium.io/documentation/reference/1.9/connectors/mysql.html#mysql-create-events + * @return The stream name extracted from the change event source. + */ + fun name(source: JsonNode?): String? +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.kt similarity index 51% rename from airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.java rename to airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.kt index 27030d4a2597..abcc9e591539 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/java/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.java +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.kt @@ -1,21 +1,18 @@ /* * Copyright (c) 2023 Airbyte, Inc., all rights reserved. */ +package io.airbyte.cdk.integrations.debezium -package io.airbyte.cdk.integrations.debezium; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.integrations.debezium.internals.AirbyteSchemaHistoryStorage.SchemaHistory; -import java.util.Optional; +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.integrations.debezium.internals.AirbyteSchemaHistoryStorage +import java.util.* /** * This interface is used to fetch the saved info required for debezium to run incrementally. Each * connector saves offset and schema history in different manner */ -public interface CdcSavedInfoFetcher { - - JsonNode getSavedOffset(); - - SchemaHistory> getSavedSchemaHistory(); +interface CdcSavedInfoFetcher { + val savedOffset: JsonNode? + val savedSchemaHistory: AirbyteSchemaHistoryStorage.SchemaHistory>? } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcStateHandler.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcStateHandler.kt new file mode 100644 index 000000000000..976c97952a1d --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcStateHandler.kt @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import io.airbyte.cdk.integrations.debezium.internals.AirbyteSchemaHistoryStorage +import io.airbyte.protocol.models.v0.AirbyteMessage + +/** + * This interface is used to allow connectors to save the offset and schema history in the manner + * which suits them. Also, it adds some utils to verify CDC event status. + */ +interface CdcStateHandler { + fun saveState( + offset: Map, + dbHistory: AirbyteSchemaHistoryStorage.SchemaHistory? + ): AirbyteMessage? + + fun saveStateAfterCompletionOfSnapshotOfNewStreams(): AirbyteMessage? + + fun compressSchemaHistoryForState(): Boolean { + return false + } + + val isCdcCheckpointEnabled: Boolean + /** + * This function is used as feature flag for sending state messages as checkpoints in CDC + * syncs. + * + * @return Returns `true` if checkpoint state messages are enabled for CDC syncs. Otherwise, + * it returns `false` + */ + get() = false +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.kt new file mode 100644 index 000000000000..0930a2ef3790 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.kt @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import io.airbyte.cdk.integrations.debezium.internals.ChangeEventWithMetadata + +/** + * This interface is used to define the target position at the beginning of the sync so that once we + * reach the desired target, we can shutdown the sync. This is needed because it might happen that + * while we are syncing the data, new changes are being made in the source database and as a result + * we might end up syncing forever. In order to tackle that, we need to define a point to end at the + * beginning of the sync + */ +interface CdcTargetPosition { + /** + * Reads a position value (ex: LSN) from a change event and compares it to target position + * + * @param changeEventWithMetadata change event from Debezium with extra calculated metadata + * @return true if event position is equal or greater than target position, or if last snapshot + * event + */ + fun reachedTargetPosition(changeEventWithMetadata: ChangeEventWithMetadata?): Boolean + + /** + * Reads a position value (lsn) from a change event and compares it to target lsn + * + * @param positionFromHeartbeat is the position extracted out of a heartbeat event (if the + * connector supports heartbeat) + * @return true if heartbeat position is equal or greater than target position + */ + fun reachedTargetPosition(positionFromHeartbeat: T): Boolean { + throw UnsupportedOperationException() + } + + val isHeartbeatSupported: Boolean + /** + * Indicates whether the implementation supports heartbeat position. + * + * @return true if heartbeats are supported + */ + get() = false + + /** + * Returns a position value from a heartbeat event offset. + * + * @param sourceOffset source offset params from heartbeat change event + * @return the heartbeat position in a heartbeat change event or null + */ + fun extractPositionFromHeartbeatOffset(sourceOffset: Map?): T + + /** + * This function checks if the event we are processing in the loop is already behind the offset + * so the process can safety save the state. + * + * @param offset DB CDC offset + * @param event Event from the CDC load + * @return Returns `true` when the event is ahead of the offset. Otherwise, it returns `false` + */ + fun isEventAheadOffset(offset: Map?, event: ChangeEventWithMetadata?): Boolean { + return false + } + + /** + * This function compares two offsets to make sure both are not pointing to the same position. + * The main purpose is to avoid sending same offset multiple times. + * + * @param offsetA Offset to compare + * @param offsetB Offset to compare + * @return Returns `true` if both offsets are at the same position. Otherwise, it returns + * `false` + */ + fun isSameOffset(offsetA: Map<*, *>, offsetB: Map): Boolean { + return false + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.kt new file mode 100644 index 000000000000..f143e084e32a --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.kt @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import java.time.Duration + +object DebeziumIteratorConstants { + const val SYNC_CHECKPOINT_DURATION_PROPERTY: String = "sync_checkpoint_seconds" + const val SYNC_CHECKPOINT_RECORDS_PROPERTY: String = "sync_checkpoint_records" + + // TODO: Move these variables to a separate class IteratorConstants, as they will be used in + // state + // iterators for non debezium cases too. + val SYNC_CHECKPOINT_DURATION: Duration = Duration.ofMinutes(15) + const val SYNC_CHECKPOINT_RECORDS: Int = 10000 +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.kt new file mode 100644 index 000000000000..ce9fa8cd0035 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.kt @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.base.Preconditions +import io.airbyte.commons.json.Jsons +import java.io.EOFException +import java.io.IOException +import java.io.ObjectOutputStream +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.nio.file.NoSuchFileException +import java.nio.file.Path +import java.util.* +import java.util.function.BiFunction +import java.util.function.Function +import java.util.stream.Collectors +import org.apache.commons.io.FileUtils +import org.apache.kafka.connect.errors.ConnectException +import org.apache.kafka.connect.util.SafeObjectInputStream +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * This class handles reading and writing a debezium offset file. In many cases it is duplicating + * logic in debezium because that logic is not exposed in the public API. We mostly treat the + * contents of this state file like a black box. We know it is a Map<ByteBuffer, Bytebuffer>. + * We deserialize it to a Map<String, String> so that the state file can be human readable. If + * we ever discover that any of the contents of these offset files is not string serializable we + * will likely have to drop the human readability support and just base64 encode it. + */ +class AirbyteFileOffsetBackingStore( + private val offsetFilePath: Path, + private val dbName: Optional +) { + fun read(): Map { + val raw = load() + + return raw.entries + .stream() + .collect( + Collectors.toMap( + Function { e: Map.Entry -> + byteBufferToString(e.key) + }, + Function { e: Map.Entry -> + byteBufferToString(e.value) + } + ) + ) + } + + fun persist(cdcState: JsonNode?) { + val mapAsString: Map = + if (cdcState != null) + Jsons.`object`(cdcState, MutableMap::class.java) as Map + else emptyMap() + + val updatedMap = updateStateForDebezium2_1(mapAsString) + + val mappedAsStrings = + updatedMap.entries + .stream() + .collect( + Collectors.toMap( + Function { e: Map.Entry -> stringToByteBuffer(e.key) }, + Function { e: Map.Entry -> stringToByteBuffer(e.value) } + ) + ) + + FileUtils.deleteQuietly(offsetFilePath.toFile()) + save(mappedAsStrings) + } + + private fun updateStateForDebezium2_1(mapAsString: Map): Map { + val updatedMap: MutableMap = LinkedHashMap() + if (mapAsString.size > 0) { + val key = mapAsString.keys.stream().toList()[0] + val i = key.indexOf('[') + val i1 = key.lastIndexOf(']') + + if (i == 0 && i1 == key.length - 1) { + // The state is Debezium 2.1 compatible. No need to change anything. + return mapAsString + } + + LOGGER.info("Mutating sate to make it Debezium 2.1 compatible") + val newKey = + if (dbName.isPresent) + SQL_SERVER_STATE_MUTATION.apply(key.substring(i, i1 + 1), dbName.get()) + else key.substring(i, i1 + 1) + val value = mapAsString[key] + updatedMap[newKey] = value + } + return updatedMap + } + + /** + * See FileOffsetBackingStore#load - logic is mostly borrowed from here. duplicated because this + * method is not public. Reduced the try catch block to only the read operation from original + * code to reduce errors when reading the file. + */ + private fun load(): Map { + var obj: Any + try { + SafeObjectInputStream(Files.newInputStream(offsetFilePath)).use { `is` -> + // todo (cgardens) - we currently suppress a security warning for this line. use of + // readObject from + // untrusted sources is considered unsafe. Since the source is controlled by us in + // this case it + // should be safe. That said, changing this implementation to not use readObject + // would remove some + // headache. + obj = `is`.readObject() + } + } catch (e: NoSuchFileException) { + // NoSuchFileException: Ignore, may be new. + // EOFException: Ignore, this means the file was missing or corrupt + return emptyMap() + } catch (e: EOFException) { + return emptyMap() + } catch (e: IOException) { + throw ConnectException(e) + } catch (e: ClassNotFoundException) { + throw ConnectException(e) + } + + if (obj !is HashMap<*, *>) + throw ConnectException("Expected HashMap but found " + obj.javaClass) + val raw = obj as Map + val data: MutableMap = HashMap() + for ((key1, value1) in raw) { + val key = if ((key1 != null)) ByteBuffer.wrap(key1) else null + val value = if ((value1 != null)) ByteBuffer.wrap(value1) else null + data[key] = value + } + + return data + } + + /** + * See FileOffsetBackingStore#save - logic is mostly borrowed from here. duplicated because this + * method is not public. + */ + private fun save(data: Map) { + try { + ObjectOutputStream(Files.newOutputStream(offsetFilePath)).use { os -> + val raw: MutableMap = HashMap() + for ((key1, value1) in data) { + val key = if ((key1 != null)) key1.array() else null + val value = if ((value1 != null)) value1.array() else null + raw[key] = value + } + os.writeObject(raw) + } + } catch (e: IOException) { + throw ConnectException(e) + } + } + + fun setDebeziumProperties(props: Properties) { + // debezium engine configuration + // https://debezium.io/documentation/reference/2.2/development/engine.html#engine-properties + props.setProperty( + "offset.storage", + "org.apache.kafka.connect.storage.FileOffsetBackingStore" + ) + props.setProperty("offset.storage.file.filename", offsetFilePath.toString()) + props.setProperty("offset.flush.interval.ms", "1000") // todo: make this longer + } + + companion object { + private val LOGGER: Logger = + LoggerFactory.getLogger(AirbyteFileOffsetBackingStore::class.java) + private val SQL_SERVER_STATE_MUTATION = BiFunction { key: String, databaseName: String -> + (key.substring(0, key.length - 2) + + ",\"database\":\"" + + databaseName + + "\"" + + key.substring(key.length - 2)) + } + + private fun byteBufferToString(byteBuffer: ByteBuffer?): String { + Preconditions.checkNotNull(byteBuffer) + return String(byteBuffer!!.array(), StandardCharsets.UTF_8) + } + + private fun stringToByteBuffer(s: String?): ByteBuffer { + Preconditions.checkNotNull(s) + return ByteBuffer.wrap(s!!.toByteArray(StandardCharsets.UTF_8)) + } + + fun initializeState( + cdcState: JsonNode?, + dbName: Optional + ): AirbyteFileOffsetBackingStore { + val cdcWorkingDir: Path + try { + cdcWorkingDir = Files.createTempDirectory(Path.of("/tmp"), "cdc-state-offset") + } catch (e: IOException) { + throw RuntimeException(e) + } + val cdcOffsetFilePath = cdcWorkingDir.resolve("offset.dat") + + val offsetManager = AirbyteFileOffsetBackingStore(cdcOffsetFilePath, dbName) + offsetManager.persist(cdcState) + return offsetManager + } + + fun initializeDummyStateForSnapshotPurpose(): AirbyteFileOffsetBackingStore { + val cdcWorkingDir: Path + try { + cdcWorkingDir = Files.createTempDirectory(Path.of("/tmp"), "cdc-dummy-state-offset") + } catch (e: IOException) { + throw RuntimeException(e) + } + val cdcOffsetFilePath = cdcWorkingDir.resolve("offset.dat") + + return AirbyteFileOffsetBackingStore(cdcOffsetFilePath, Optional.empty()) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.kt new file mode 100644 index 000000000000..d6829b1d3d6f --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.kt @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.annotations.VisibleForTesting +import io.airbyte.commons.json.Jsons +import io.debezium.document.DocumentReader +import io.debezium.document.DocumentWriter +import java.io.* +import java.nio.charset.Charset +import java.nio.charset.StandardCharsets +import java.nio.file.FileAlreadyExistsException +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.StandardOpenOption +import java.util.* +import java.util.zip.GZIPInputStream +import java.util.zip.GZIPOutputStream +import org.apache.commons.io.FileUtils +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * The purpose of this class is : to , 1. Read the contents of the file [.path] which contains the + * schema history at the end of the sync so that it can be saved in state for future syncs. Check + * [.read] 2. Write the saved content back to the file [.path] at the beginning of the sync so that + * debezium can function smoothly. Check persist(Optional<JsonNode>). + */ +class AirbyteSchemaHistoryStorage( + private val path: Path, + private val compressSchemaHistoryForState: Boolean +) { + private val reader: DocumentReader = DocumentReader.defaultReader() + private val writer: DocumentWriter = DocumentWriter.defaultWriter() + + class SchemaHistory(schema: T, isCompressed: Boolean) { + val schema: T + val isCompressed: Boolean + + init { + this.schema = schema + this.isCompressed = isCompressed + } + } + + fun read(): SchemaHistory { + val fileSizeMB = path.toFile().length().toDouble() / (ONE_MB) + if ((fileSizeMB > SIZE_LIMIT_TO_COMPRESS_MB) && compressSchemaHistoryForState) { + LOGGER.info( + "File Size {} MB is greater than the size limit of {} MB, compressing the content of the file.", + fileSizeMB, + SIZE_LIMIT_TO_COMPRESS_MB + ) + val schemaHistory = readCompressed() + val compressedSizeMB = calculateSizeOfStringInMB(schemaHistory) + if (fileSizeMB > compressedSizeMB) { + LOGGER.info("Content Size post compression is {} MB ", compressedSizeMB) + } else { + throw RuntimeException( + "Compressing increased the size of the content. Size before compression " + + fileSizeMB + + ", after compression " + + compressedSizeMB + ) + } + return SchemaHistory(schemaHistory, true) + } + if (compressSchemaHistoryForState) { + LOGGER.info( + "File Size {} MB is less than the size limit of {} MB, reading the content of the file without compression.", + fileSizeMB, + SIZE_LIMIT_TO_COMPRESS_MB + ) + } else { + LOGGER.info("File Size {} MB.", fileSizeMB) + } + val schemaHistory = readUncompressed() + return SchemaHistory(schemaHistory, false) + } + + @VisibleForTesting + fun readUncompressed(): String { + val fileAsString = StringBuilder() + try { + for (line in Files.readAllLines(path, UTF8)) { + if (line != null && !line.isEmpty()) { + val record = reader.read(line) + val recordAsString = writer.write(record) + fileAsString.append(recordAsString) + fileAsString.append(System.lineSeparator()) + } + } + return fileAsString.toString() + } catch (e: IOException) { + throw RuntimeException(e) + } + } + + private fun readCompressed(): String { + val lineSeparator = System.lineSeparator() + val compressedStream = ByteArrayOutputStream() + try { + GZIPOutputStream(compressedStream).use { gzipOutputStream -> + Files.newBufferedReader(path, UTF8).use { bufferedReader -> + while (true) { + val line = bufferedReader.readLine() ?: break + + if (!line.isEmpty()) { + val record = reader.read(line) + val recordAsString = writer.write(record) + gzipOutputStream.write( + recordAsString.toByteArray(StandardCharsets.UTF_8) + ) + gzipOutputStream.write( + lineSeparator.toByteArray(StandardCharsets.UTF_8) + ) + } + } + } + } + } catch (e: IOException) { + throw RuntimeException(e) + } + return Jsons.serialize(compressedStream.toByteArray()) + } + + private fun makeSureFileExists() { + try { + // Make sure the file exists ... + if (!Files.exists(path)) { + // Create parent directories if we have them ... + if (path.parent != null) { + Files.createDirectories(path.parent) + } + try { + Files.createFile(path) + } catch (e: FileAlreadyExistsException) { + // do nothing + } + } + } catch (e: IOException) { + throw IllegalStateException( + "Unable to check or create history file at " + path + ": " + e.message, + e + ) + } + } + + private fun persist(schemaHistory: SchemaHistory>?) { + if (schemaHistory!!.schema!!.isEmpty) { + return + } + val fileAsString = Jsons.`object`(schemaHistory.schema!!.get(), String::class.java) + + if (fileAsString == null || fileAsString.isEmpty()) { + return + } + + FileUtils.deleteQuietly(path.toFile()) + makeSureFileExists() + if (schemaHistory.isCompressed) { + writeCompressedStringToFile(fileAsString) + } else { + writeToFile(fileAsString) + } + } + + /** + * @param fileAsString Represents the contents of the file saved in state from previous syncs + */ + private fun writeToFile(fileAsString: String) { + try { + val split = + fileAsString + .split(System.lineSeparator().toRegex()) + .dropLastWhile { it.isEmpty() } + .toTypedArray() + for (element in split) { + val read = reader.read(element) + val line = writer.write(read) + + Files.newBufferedWriter(path, StandardOpenOption.APPEND).use { historyWriter -> + try { + historyWriter.append(line) + historyWriter.newLine() + } catch (e: IOException) { + throw RuntimeException(e) + } + } + } + } catch (e: IOException) { + throw RuntimeException(e) + } + } + + private fun writeCompressedStringToFile(compressedString: String) { + try { + ByteArrayInputStream(Jsons.deserialize(compressedString, ByteArray::class.java)).use { + inputStream -> + GZIPInputStream(inputStream).use { gzipInputStream -> + FileOutputStream(path.toFile()).use { fileOutputStream -> + val buffer = ByteArray(1024) + var bytesRead: Int + while ((gzipInputStream.read(buffer).also { bytesRead = it }) != -1) { + fileOutputStream.write(buffer, 0, bytesRead) + } + } + } + } + } catch (e: IOException) { + throw RuntimeException(e) + } + } + + fun setDebeziumProperties(props: Properties) { + // https://debezium.io/documentation/reference/2.2/operations/debezium-server.html#debezium-source-database-history-class + // https://debezium.io/documentation/reference/development/engine.html#_in_the_code + // As mentioned in the documents above, debezium connector for MySQL needs to track the + // schema + // changes. If we don't do this, we can't fetch records for the table. + props.setProperty( + "schema.history.internal", + "io.debezium.storage.file.history.FileSchemaHistory" + ) + props.setProperty("schema.history.internal.file.filename", path.toString()) + props.setProperty("schema.history.internal.store.only.captured.databases.ddl", "true") + } + + companion object { + private val LOGGER: Logger = + LoggerFactory.getLogger(AirbyteSchemaHistoryStorage::class.java) + private const val SIZE_LIMIT_TO_COMPRESS_MB: Long = 1 + const val ONE_MB: Int = 1024 * 1024 + private val UTF8: Charset = StandardCharsets.UTF_8 + + @VisibleForTesting + fun calculateSizeOfStringInMB(string: String): Double { + return string.toByteArray(StandardCharsets.UTF_8).size.toDouble() / (ONE_MB) + } + + fun initializeDBHistory( + schemaHistory: SchemaHistory>?, + compressSchemaHistoryForState: Boolean + ): AirbyteSchemaHistoryStorage { + val dbHistoryWorkingDir: Path + try { + dbHistoryWorkingDir = Files.createTempDirectory(Path.of("/tmp"), "cdc-db-history") + } catch (e: IOException) { + throw RuntimeException(e) + } + val dbHistoryFilePath = dbHistoryWorkingDir.resolve("dbhistory.dat") + + val schemaHistoryManager = + AirbyteSchemaHistoryStorage(dbHistoryFilePath, compressSchemaHistoryForState) + schemaHistoryManager.persist(schemaHistory) + return schemaHistoryManager + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.kt new file mode 100644 index 000000000000..8e0a8985e2ff --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.kt @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.commons.json.Jsons +import io.debezium.engine.ChangeEvent + +class ChangeEventWithMetadata(private val event: ChangeEvent) { + private val eventKeyAsJson: JsonNode = Jsons.deserialize(event.key()) + private val eventValueAsJson: JsonNode = Jsons.deserialize(event.value()) + private val snapshotMetadata: SnapshotMetadata? = + SnapshotMetadata.Companion.fromString(eventValueAsJson["source"]["snapshot"].asText()) + + fun event(): ChangeEvent { + return event + } + + fun eventKeyAsJson(): JsonNode { + return eventKeyAsJson + } + + fun eventValueAsJson(): JsonNode { + return eventValueAsJson + } + + val isSnapshotEvent: Boolean + get() = SnapshotMetadata.Companion.isSnapshotEventMetadata(snapshotMetadata) + + fun snapshotMetadata(): SnapshotMetadata? { + return snapshotMetadata + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.kt new file mode 100644 index 000000000000..6c87fb88a35a --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.kt @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.airbyte.cdk.db.DataTypeUtils.toISO8601String +import io.airbyte.cdk.db.DataTypeUtils.toISO8601StringWithMicroseconds +import io.debezium.spi.converter.RelationalColumn +import java.sql.Date +import java.sql.Timestamp +import java.time.Duration +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.format.DateTimeParseException +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +class DebeziumConverterUtils private constructor() { + init { + throw UnsupportedOperationException() + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(DebeziumConverterUtils::class.java) + + /** TODO : Replace usage of this method with [io.airbyte.cdk.db.jdbc.DateTimeConverter] */ + fun convertDate(input: Any): String { + /** + * While building this custom converter we were not sure what type debezium could return + * cause there is no mention of it in the documentation. Secondly if you take a look at + * [io.debezium.connector.mysql.converters.TinyIntOneToBooleanConverter.converterFor] + * method, even it is handling multiple data types but its not clear under what + * circumstances which data type would be returned. I just went ahead and handled the + * data types that made sense. Secondly, we use LocalDateTime to handle this cause it + * represents DATETIME datatype in JAVA + */ + if (input is LocalDateTime) { + return toISO8601String(input) + } else if (input is LocalDate) { + return toISO8601String(input) + } else if (input is Duration) { + return toISO8601String(input) + } else if (input is Timestamp) { + return toISO8601StringWithMicroseconds((input.toInstant())) + } else if (input is Number) { + return toISO8601String(Timestamp(input.toLong()).toLocalDateTime()) + } else if (input is Date) { + return toISO8601String(input) + } else if (input is String) { + try { + return LocalDateTime.parse(input).toString() + } catch (e: DateTimeParseException) { + LOGGER.warn("Cannot convert value '{}' to LocalDateTime type", input) + return input.toString() + } + } + LOGGER.warn( + "Uncovered date class type '{}'. Use default converter", + input.javaClass.name + ) + return input.toString() + } + + fun convertDefaultValue(field: RelationalColumn): Any? { + if (field.isOptional) { + return null + } else if (field.hasDefaultValue()) { + return field.defaultValue() + } + return null + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.kt new file mode 100644 index 000000000000..806371b69f74 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.kt @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.ObjectNode +import io.airbyte.cdk.integrations.debezium.CdcMetadataInjector +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteRecordMessage +import java.time.Instant + +interface DebeziumEventConverter { + fun toAirbyteMessage(event: ChangeEventWithMetadata): AirbyteMessage + + companion object { + fun buildAirbyteMessage( + source: JsonNode?, + cdcMetadataInjector: CdcMetadataInjector<*>, + emittedAt: Instant, + data: JsonNode? + ): AirbyteMessage { + val streamNamespace = cdcMetadataInjector.namespace(source) + val streamName = cdcMetadataInjector.name(source) + + val airbyteRecordMessage = + AirbyteRecordMessage() + .withStream(streamName) + .withNamespace(streamNamespace) + .withEmittedAt(emittedAt.toEpochMilli()) + .withData(data) + + return AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord(airbyteRecordMessage) + } + + fun addCdcMetadata( + baseNode: ObjectNode, + source: JsonNode, + cdcMetadataInjector: CdcMetadataInjector<*>, + isDelete: Boolean + ): JsonNode { + val transactionMillis = source["ts_ms"].asLong() + val transactionTimestamp = Instant.ofEpochMilli(transactionMillis).toString() + + baseNode.put(CDC_UPDATED_AT, transactionTimestamp) + cdcMetadataInjector.addMetaData(baseNode, source) + + if (isDelete) { + baseNode.put(CDC_DELETED_AT, transactionTimestamp) + } else { + baseNode.put(CDC_DELETED_AT, null as String?) + } + + return baseNode + } + + const val CDC_LSN: String = "_ab_cdc_lsn" + const val CDC_UPDATED_AT: String = "_ab_cdc_updated_at" + const val CDC_DELETED_AT: String = "_ab_cdc_deleted_at" + const val AFTER_EVENT: String = "after" + const val BEFORE_EVENT: String = "before" + const val OPERATION_FIELD: String = "op" + const val SOURCE_EVENT: String = "source" + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducer.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducer.kt new file mode 100644 index 000000000000..8c4bc32ae2fc --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducer.kt @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.airbyte.cdk.integrations.debezium.CdcStateHandler +import io.airbyte.cdk.integrations.debezium.CdcTargetPosition +import io.airbyte.cdk.integrations.source.relationaldb.state.SourceStateMessageProducer +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import java.util.* +import org.apache.kafka.connect.errors.ConnectException +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +class DebeziumMessageProducer( + private val cdcStateHandler: CdcStateHandler, + targetPosition: CdcTargetPosition, + eventConverter: DebeziumEventConverter, + offsetManager: AirbyteFileOffsetBackingStore?, + schemaHistoryManager: Optional +) : SourceStateMessageProducer { + /** + * `checkpointOffsetToSend` is used as temporal storage for the offset that we want to send as + * message. As Debezium is reading records faster that we process them, if we try to send + * `offsetManger.read()` offset, it is possible that the state is behind the record we are + * currently propagating. To avoid that, we store the offset as soon as we reach the checkpoint + * threshold (time or records) and we wait to send it until we are sure that the record we are + * processing is behind the offset to be sent. + */ + private val checkpointOffsetToSend = HashMap() + + /** + * `previousCheckpointOffset` is used to make sure we don't send duplicated states with the same + * offset. Is it possible that the offset Debezium report doesn't move for a period of time, and + * if we just rely on the `offsetManger.read()`, there is a chance to sent duplicate states, + * generating an unneeded usage of networking and processing. + */ + private val initialOffset: HashMap + private val previousCheckpointOffset: HashMap + private val offsetManager: AirbyteFileOffsetBackingStore? + private val targetPosition: CdcTargetPosition + private val schemaHistoryManager: Optional + + private var shouldEmitStateMessage = false + + private val eventConverter: DebeziumEventConverter + + init { + this.targetPosition = targetPosition + this.eventConverter = eventConverter + this.offsetManager = offsetManager + if (offsetManager == null) { + throw RuntimeException("Offset manager cannot be null") + } + this.schemaHistoryManager = schemaHistoryManager + this.previousCheckpointOffset = offsetManager.read() as HashMap + this.initialOffset = HashMap(this.previousCheckpointOffset) + } + + override fun generateStateMessageAtCheckpoint( + stream: ConfiguredAirbyteStream? + ): AirbyteStateMessage { + LOGGER.info("Sending CDC checkpoint state message.") + val stateMessage = createStateMessage(checkpointOffsetToSend) + previousCheckpointOffset.clear() + previousCheckpointOffset.putAll(checkpointOffsetToSend) + checkpointOffsetToSend.clear() + shouldEmitStateMessage = false + return stateMessage + } + + /** + * @param stream + * @param message + * @return + */ + override fun processRecordMessage( + stream: ConfiguredAirbyteStream?, + message: ChangeEventWithMetadata + ): AirbyteMessage { + if (checkpointOffsetToSend.isEmpty()) { + try { + val temporalOffset = offsetManager!!.read() + if (!targetPosition.isSameOffset(previousCheckpointOffset, temporalOffset)) { + checkpointOffsetToSend.putAll(temporalOffset) + } + } catch (e: ConnectException) { + LOGGER.warn( + "Offset file is being written by Debezium. Skipping CDC checkpoint in this loop." + ) + } + } + + if (checkpointOffsetToSend.size == 1 && !message!!.isSnapshotEvent) { + if (targetPosition.isEventAheadOffset(checkpointOffsetToSend, message)) { + shouldEmitStateMessage = true + } else { + LOGGER.info("Encountered records with the same event offset.") + } + } + + return eventConverter.toAirbyteMessage(message!!) + } + + override fun createFinalStateMessage(stream: ConfiguredAirbyteStream?): AirbyteStateMessage { + val syncFinishedOffset = offsetManager!!.read() + if (targetPosition.isSameOffset(initialOffset, syncFinishedOffset)) { + // Edge case where no progress has been made: wrap up the + // sync by returning the initial offset instead of the + // current offset. We do this because we found that + // for some databases, heartbeats will cause Debezium to + // overwrite the offset file with a state which doesn't + // include all necessary data such as snapshot completion. + // This is the case for MS SQL Server, at least. + return createStateMessage(initialOffset) + } + return createStateMessage(syncFinishedOffset) + } + + override fun shouldEmitStateMessage(stream: ConfiguredAirbyteStream?): Boolean { + return shouldEmitStateMessage + } + + /** + * Creates [AirbyteStateMessage] while updating CDC data, used to checkpoint the state of the + * process. + * + * @return [AirbyteStateMessage] which includes offset and schema history if used. + */ + private fun createStateMessage(offset: Map): AirbyteStateMessage { + val message = + cdcStateHandler + .saveState( + offset, + schemaHistoryManager + .map { obj: AirbyteSchemaHistoryStorage -> obj.read() } + .orElse(null) + )!! + .state + return message + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(DebeziumMessageProducer::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.kt new file mode 100644 index 000000000000..d4787d615bc0 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.kt @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.debezium.spi.common.ReplacementFunction +import java.util.* + +abstract class DebeziumPropertiesManager( + private val properties: Properties, + private val config: JsonNode, + private val catalog: ConfiguredAirbyteCatalog +) { + fun getDebeziumProperties(offsetManager: AirbyteFileOffsetBackingStore): Properties { + return getDebeziumProperties(offsetManager, Optional.empty()) + } + + fun getDebeziumProperties( + offsetManager: AirbyteFileOffsetBackingStore, + schemaHistoryManager: Optional + ): Properties { + val props = Properties() + props.putAll(properties) + + // debezium engine configuration + offsetManager.setDebeziumProperties(props) + // default values from debezium CommonConnectorConfig + props.setProperty("max.batch.size", "2048") + props.setProperty("max.queue.size", "8192") + + props.setProperty("errors.max.retries", "5") + // This property must be strictly less than errors.retry.delay.max.ms + // (https://github.com/debezium/debezium/blob/bcc7d49519a4f07d123c616cfa45cd6268def0b9/debezium-core/src/main/java/io/debezium/util/DelayStrategy.java#L135) + props.setProperty("errors.retry.delay.initial.ms", "299") + props.setProperty("errors.retry.delay.max.ms", "300") + + schemaHistoryManager.ifPresent { m: AirbyteSchemaHistoryStorage -> + m.setDebeziumProperties(props) + } + + // https://debezium.io/documentation/reference/2.2/configuration/avro.html + props.setProperty("key.converter.schemas.enable", "false") + props.setProperty("value.converter.schemas.enable", "false") + + // debezium names + props.setProperty(NAME_KEY, getName(config)) + + // connection configuration + props.putAll(getConnectionConfiguration(config)) + + // By default "decimal.handing.mode=precise" which's caused returning this value as a + // binary. + // The "double" type may cause a loss of precision, so set Debezium's config to store it as + // a String + // explicitly in its Kafka messages for more details see: + // https://debezium.io/documentation/reference/2.2/connectors/postgresql.html#postgresql-decimal-types + // https://debezium.io/documentation/faq/#how_to_retrieve_decimal_field_from_binary_representation + props.setProperty("decimal.handling.mode", "string") + + // https://debezium.io/documentation/reference/2.2/connectors/postgresql.html#postgresql-property-max-queue-size-in-bytes + props.setProperty("max.queue.size.in.bytes", BYTE_VALUE_256_MB) + + // WARNING : Never change the value of this otherwise all the connectors would start syncing + // from + // scratch. + props.setProperty(TOPIC_PREFIX_KEY, sanitizeTopicPrefix(getName(config))) + // https://issues.redhat.com/browse/DBZ-7635 + // https://cwiki.apache.org/confluence/display/KAFKA/KIP-581%3A+Value+of+optional+null+field+which+has+default+value + // A null value in a column with default value won't be generated correctly in CDC unless we + // set the + // following + props.setProperty("value.converter.replace.null.with.default", "false") + // includes + props.putAll(getIncludeConfiguration(catalog, config)) + + return props + } + + protected abstract fun getConnectionConfiguration(config: JsonNode): Properties + + protected abstract fun getName(config: JsonNode): String + + protected abstract fun getIncludeConfiguration( + catalog: ConfiguredAirbyteCatalog, + config: JsonNode? + ): Properties + + companion object { + private const val BYTE_VALUE_256_MB = (256 * 1024 * 1024).toString() + + const val NAME_KEY: String = "name" + const val TOPIC_PREFIX_KEY: String = "topic.prefix" + + fun sanitizeTopicPrefix(topicName: String): String { + val sanitizedNameBuilder = StringBuilder(topicName.length) + var changed = false + + for (i in 0 until topicName.length) { + val c = topicName[i] + if (isValidCharacter(c)) { + sanitizedNameBuilder.append(c) + } else { + sanitizedNameBuilder.append( + ReplacementFunction.UNDERSCORE_REPLACEMENT.replace(c) + ) + changed = true + } + } + + return if (changed) { + sanitizedNameBuilder.toString() + } else { + topicName + } + } + + // We need to keep the validation rule the same as debezium engine, which is defined here: + // https://github.com/debezium/debezium/blob/c51ef3099a688efb41204702d3aa6d4722bb4825/debezium-core/src/main/java/io/debezium/schema/AbstractTopicNamingStrategy.java#L178 + private fun isValidCharacter(c: Char): Boolean { + return c == '.' || + c == '_' || + c == '-' || + c >= 'A' && c <= 'Z' || + c >= 'a' && c <= 'z' || + c >= '0' && c <= '9' + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.kt new file mode 100644 index 000000000000..a0b0253e4d68 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.kt @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.google.common.annotations.VisibleForTesting +import com.google.common.collect.AbstractIterator +import io.airbyte.cdk.integrations.debezium.CdcTargetPosition +import io.airbyte.commons.lang.MoreBooleans +import io.airbyte.commons.util.AutoCloseableIterator +import io.debezium.engine.ChangeEvent +import java.lang.reflect.Field +import java.time.Duration +import java.time.LocalDateTime +import java.util.* +import java.util.concurrent.* +import java.util.function.Supplier +import org.apache.kafka.connect.source.SourceRecord +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * The record iterator is the consumer (in the producer / consumer relationship with debezium) + * responsible for 1. making sure every record produced by the record publisher is processed 2. + * signalling to the record publisher when it is time for it to stop producing records. It emits + * this signal either when the publisher had not produced a new record for a long time or when it + * has processed at least all of the records that were present in the database when the source was + * started. Because the publisher might publish more records between the consumer sending this + * signal and the publisher actually shutting down, the consumer must stay alive as long as the + * publisher is not closed. Even after the publisher is closed, the consumer will finish processing + * any produced records before closing. + */ +class DebeziumRecordIterator( + private val queue: LinkedBlockingQueue>, + private val targetPosition: CdcTargetPosition, + private val publisherStatusSupplier: Supplier, + private val debeziumShutdownProcedure: DebeziumShutdownProcedure>, + private val firstRecordWaitTime: Duration, + subsequentRecordWaitTime: Duration? +) : AbstractIterator(), AutoCloseableIterator { + private val heartbeatEventSourceField: MutableMap?>, Field?> = + HashMap(1) + private val subsequentRecordWaitTime: Duration = firstRecordWaitTime.dividedBy(2) + + private var receivedFirstRecord = false + private var hasSnapshotFinished = true + private var tsLastHeartbeat: LocalDateTime? = null + private var lastHeartbeatPosition: T? = null + private var maxInstanceOfNoRecordsFound = 0 + private var signalledDebeziumEngineShutdown = false + + // The following logic incorporates heartbeat (CDC postgres only for now): + // 1. Wait on queue either the configured time first or 1 min after a record received + // 2. If nothing came out of queue finish sync + // 3. If received heartbeat: check if hearbeat_lsn reached target or hasn't changed in a while + // finish sync + // 4. If change event lsn reached target finish sync + // 5. Otherwise check message queuen again + override fun computeNext(): ChangeEventWithMetadata? { + // keep trying until the publisher is closed or until the queue is empty. the latter case is + // possible when the publisher has shutdown but the consumer has not yet processed all + // messages it + // emitted. + while (!MoreBooleans.isTruthy(publisherStatusSupplier.get()) || !queue.isEmpty()) { + val next: ChangeEvent? + + val waitTime = + if (receivedFirstRecord) this.subsequentRecordWaitTime else this.firstRecordWaitTime + try { + next = queue.poll(waitTime.seconds, TimeUnit.SECONDS) + } catch (e: InterruptedException) { + throw RuntimeException(e) + } + + // if within the timeout, the consumer could not get a record, it is time to tell the + // producer to + // shutdown. + if (next == null) { + if ( + !receivedFirstRecord || hasSnapshotFinished || maxInstanceOfNoRecordsFound >= 10 + ) { + requestClose( + String.format( + "No records were returned by Debezium in the timeout seconds %s, closing the engine and iterator", + waitTime.seconds + ) + ) + } + LOGGER.info("no record found. polling again.") + maxInstanceOfNoRecordsFound++ + continue + } + + if (isHeartbeatEvent(next)) { + if (!hasSnapshotFinished) { + continue + } + + val heartbeatPos = getHeartbeatPosition(next) + // wrap up sync if heartbeat position crossed the target OR heartbeat position + // hasn't changed for + // too long + if (targetPosition.reachedTargetPosition(heartbeatPos)) { + requestClose( + "Closing: Heartbeat indicates sync is done by reaching the target position" + ) + } else if ( + heartbeatPos == this.lastHeartbeatPosition && heartbeatPosNotChanging() + ) { + requestClose("Closing: Heartbeat indicates sync is not progressing") + } + + if (heartbeatPos != lastHeartbeatPosition) { + this.tsLastHeartbeat = LocalDateTime.now() + this.lastHeartbeatPosition = heartbeatPos + } + continue + } + + val changeEventWithMetadata = ChangeEventWithMetadata(next) + hasSnapshotFinished = !changeEventWithMetadata.isSnapshotEvent + + // if the last record matches the target file position, it is time to tell the producer + // to shutdown. + if (targetPosition.reachedTargetPosition(changeEventWithMetadata)) { + requestClose("Closing: Change event reached target position") + } + this.tsLastHeartbeat = null + this.receivedFirstRecord = true + this.maxInstanceOfNoRecordsFound = 0 + return changeEventWithMetadata + } + + if (!signalledDebeziumEngineShutdown) { + LOGGER.warn("Debezium engine has not been signalled to shutdown, this is unexpected") + } + + // Read the records that Debezium might have fetched right at the time we called shutdown + while (!debeziumShutdownProcedure.recordsRemainingAfterShutdown.isEmpty()) { + val event: ChangeEvent? + try { + event = + debeziumShutdownProcedure.recordsRemainingAfterShutdown.poll( + 100, + TimeUnit.MILLISECONDS + ) + } catch (e: InterruptedException) { + throw RuntimeException(e) + } + if (event == null || isHeartbeatEvent(event)) { + continue + } + val changeEventWithMetadata = ChangeEventWithMetadata(event) + hasSnapshotFinished = !changeEventWithMetadata.isSnapshotEvent + return changeEventWithMetadata + } + throwExceptionIfSnapshotNotFinished() + return endOfData() + } + + /** + * Debezium was built as an ever running process which keeps on listening for new changes on DB + * and immediately processing them. Airbyte needs debezium to work as a start stop mechanism. In + * order to determine when to stop debezium engine we rely on few factors 1. TargetPosition + * logic. At the beginning of the sync we define a target position in the logs of the DB. This + * can be an LSN or anything specific to the DB which can help us identify that we have reached + * a specific position in the log based replication When we start processing records from + * debezium, we extract the the log position from the metadata of the record and compare it with + * our target that we defined at the beginning of the sync. If we have reached the target + * position, we shutdown the debezium engine 2. The TargetPosition logic might not always work + * and in order to tackle that we have another logic where if we do not receive records from + * debezium for a given duration, we ask debezium engine to shutdown 3. We also take the + * Snapshot into consideration, when a connector is running for the first time, we let it + * complete the snapshot and only after the completion of snapshot we should shutdown the + * engine. If we are closing the engine before completion of snapshot, we throw an exception + */ + @Throws(Exception::class) + override fun close() { + requestClose("Closing: Iterator closing") + } + + private fun isHeartbeatEvent(event: ChangeEvent): Boolean { + return targetPosition.isHeartbeatSupported && + Objects.nonNull(event) && + !event.value()!!.contains("source") + } + + private fun heartbeatPosNotChanging(): Boolean { + if (this.tsLastHeartbeat == null) { + return false + } + val timeElapsedSinceLastHeartbeatTs = + Duration.between(this.tsLastHeartbeat, LocalDateTime.now()) + LOGGER.info( + "Time since last hb_pos change {}s", + timeElapsedSinceLastHeartbeatTs.toSeconds() + ) + // wait time for no change in heartbeat position is half of initial waitTime + return timeElapsedSinceLastHeartbeatTs.compareTo(firstRecordWaitTime.dividedBy(2)) > 0 + } + + private fun requestClose(closeLogMessage: String) { + if (signalledDebeziumEngineShutdown) { + return + } + LOGGER.info(closeLogMessage) + debeziumShutdownProcedure.initiateShutdownProcedure() + signalledDebeziumEngineShutdown = true + } + + private fun throwExceptionIfSnapshotNotFinished() { + if (!hasSnapshotFinished) { + throw RuntimeException("Closing down debezium engine but snapshot has not finished") + } + } + + /** + * [DebeziumRecordIterator.heartbeatEventSourceField] acts as a cache so that we avoid using + * reflection to setAccessible for each event + */ + @VisibleForTesting + internal fun getHeartbeatPosition(heartbeatEvent: ChangeEvent): T { + try { + val eventClass: Class?> = heartbeatEvent.javaClass + val f: Field? + if (heartbeatEventSourceField.containsKey(eventClass)) { + f = heartbeatEventSourceField[eventClass] + } else { + f = eventClass.getDeclaredField("sourceRecord") + f.isAccessible = true + heartbeatEventSourceField[eventClass] = f + + if (heartbeatEventSourceField.size > 1) { + LOGGER.warn( + "Field Cache size growing beyond expected size of 1, size is " + + heartbeatEventSourceField.size + ) + } + } + + val sr = f!![heartbeatEvent] as SourceRecord + return targetPosition.extractPositionFromHeartbeatOffset(sr.sourceOffset()) + } catch (e: NoSuchFieldException) { + LOGGER.info("failed to get heartbeat source offset") + throw RuntimeException(e) + } catch (e: IllegalAccessException) { + LOGGER.info("failed to get heartbeat source offset") + throw RuntimeException(e) + } + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(DebeziumRecordIterator::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.kt new file mode 100644 index 000000000000..4e0bfc1e14e8 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.kt @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.debezium.engine.ChangeEvent +import io.debezium.engine.DebeziumEngine +import io.debezium.engine.format.Json +import io.debezium.engine.spi.OffsetCommitPolicy +import java.util.* +import java.util.concurrent.* +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * The purpose of this class is to initialize and spawn the debezium engine with the right + * properties to fetch records + */ +class DebeziumRecordPublisher(private val debeziumPropertiesManager: DebeziumPropertiesManager) : + AutoCloseable { + private val executor: ExecutorService = Executors.newSingleThreadExecutor() + private var engine: DebeziumEngine>? = null + private val hasClosed = AtomicBoolean(false) + private val isClosing = AtomicBoolean(false) + private val thrownError = AtomicReference() + private val engineLatch = CountDownLatch(1) + + fun start( + queue: BlockingQueue>, + offsetManager: AirbyteFileOffsetBackingStore, + schemaHistoryManager: Optional + ) { + engine = + DebeziumEngine.create(Json::class.java) + .using( + debeziumPropertiesManager.getDebeziumProperties( + offsetManager, + schemaHistoryManager + ) + ) + .using(OffsetCommitPolicy.AlwaysCommitOffsetPolicy()) + .notifying { e: ChangeEvent -> + // debezium outputs a tombstone event that has a value of null. this is an + // artifact of how it + // interacts with kafka. we want to ignore it. + // more on the tombstone: + // https://debezium.io/documentation/reference/2.2/transformations/event-flattening.html + if (e.value() != null) { + try { + queue.put(e) + } catch (ex: InterruptedException) { + Thread.currentThread().interrupt() + throw RuntimeException(ex) + } + } + } + .using { success: Boolean, message: String?, error: Throwable? -> + LOGGER.info( + "Debezium engine shutdown. Engine terminated successfully : {}", + success + ) + LOGGER.info(message) + if (!success) { + if (error != null) { + thrownError.set(error) + } else { + // There are cases where Debezium doesn't succeed but only fills the + // message field. + // In that case, we still want to fail loud and clear + thrownError.set(RuntimeException(message)) + } + } + engineLatch.countDown() + } + .build() + + // Run the engine asynchronously ... + executor.execute(engine) + } + + fun hasClosed(): Boolean { + return hasClosed.get() + } + + @Throws(Exception::class) + override fun close() { + if (isClosing.compareAndSet(false, true)) { + // consumers should assume records can be produced until engine has closed. + if (engine != null) { + engine!!.close() + } + + // wait for closure before shutting down executor service + engineLatch.await(5, TimeUnit.MINUTES) + + // shut down and await for thread to actually go down + executor.shutdown() + executor.awaitTermination(5, TimeUnit.MINUTES) + + // after the engine is completely off, we can mark this as closed + hasClosed.set(true) + + if (thrownError.get() != null) { + throw RuntimeException(thrownError.get()) + } + } + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(DebeziumRecordPublisher::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.kt new file mode 100644 index 000000000000..939303c1cc73 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.kt @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.airbyte.commons.concurrency.VoidCallable +import io.airbyte.commons.lang.MoreBooleans +import java.util.concurrent.* +import java.util.function.Supplier +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * This class has the logic for shutting down Debezium Engine in graceful manner. We made it Generic + * to allow us to write tests easily. + */ +class DebeziumShutdownProcedure( + private val sourceQueue: LinkedBlockingQueue, + private val debeziumThreadRequestClose: VoidCallable, + private val publisherStatusSupplier: Supplier +) { + private val targetQueue = LinkedBlockingQueue() + private val executorService: ExecutorService + private var exception: Throwable? = null + private var hasTransferThreadShutdown: Boolean + + init { + this.hasTransferThreadShutdown = false + this.executorService = + Executors.newSingleThreadExecutor { r: Runnable? -> + val thread = Thread(r, "queue-data-transfer-thread") + thread.uncaughtExceptionHandler = + Thread.UncaughtExceptionHandler { t: Thread?, e: Throwable? -> exception = e } + thread + } + } + + private fun transfer(): Runnable { + return Runnable { + while (!sourceQueue.isEmpty() || !hasEngineShutDown()) { + try { + val event = sourceQueue.poll(100, TimeUnit.MILLISECONDS) + if (event != null) { + targetQueue.put(event) + } + } catch (e: InterruptedException) { + Thread.currentThread().interrupt() + throw RuntimeException(e) + } + } + } + } + + private fun hasEngineShutDown(): Boolean { + return MoreBooleans.isTruthy(publisherStatusSupplier.get()) + } + + private fun initiateTransfer() { + executorService.execute(transfer()) + } + + val recordsRemainingAfterShutdown: LinkedBlockingQueue + get() { + if (!hasTransferThreadShutdown) { + LOGGER.warn( + "Queue transfer thread has not shut down, some records might be missing." + ) + } + return targetQueue + } + + /** + * This method triggers the shutdown of Debezium Engine. When we trigger Debezium shutdown, the + * main thread pauses, as a result we stop reading data from the [sourceQueue] and since the + * queue is of fixed size, if it's already at capacity, Debezium won't be able to put remaining + * records in the queue. So before we trigger Debezium shutdown, we initiate a transfer of the + * records from the [sourceQueue] to a new queue i.e. [targetQueue]. This allows Debezium to + * continue to put records in the [sourceQueue] and once done, gracefully shutdown. After the + * shutdown is complete we just have to read the remaining records from the [targetQueue] + */ + fun initiateShutdownProcedure() { + if (hasEngineShutDown()) { + LOGGER.info("Debezium Engine has already shut down.") + return + } + var exceptionDuringEngineClose: Exception? = null + try { + initiateTransfer() + debeziumThreadRequestClose.call() + } catch (e: Exception) { + exceptionDuringEngineClose = e + throw RuntimeException(e) + } finally { + try { + shutdownTransferThread() + } catch (e: Exception) { + if (exceptionDuringEngineClose != null) { + e.addSuppressed(exceptionDuringEngineClose) + throw e + } + } + } + } + + private fun shutdownTransferThread() { + executorService.shutdown() + var terminated = false + while (!terminated) { + try { + terminated = executorService.awaitTermination(5, TimeUnit.MINUTES) + } catch (e: InterruptedException) { + Thread.currentThread().interrupt() + throw RuntimeException(e) + } + } + hasTransferThreadShutdown = true + if (exception != null) { + throw RuntimeException(exception) + } + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(DebeziumShutdownProcedure::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.kt new file mode 100644 index 000000000000..fbc6534eb091 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.kt @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.debezium.config.Configuration +import io.debezium.embedded.KafkaConnectUtil +import java.lang.Boolean +import java.util.* +import kotlin.String +import org.apache.kafka.connect.json.JsonConverter +import org.apache.kafka.connect.json.JsonConverterConfig +import org.apache.kafka.connect.runtime.WorkerConfig +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig +import org.apache.kafka.connect.storage.FileOffsetBackingStore +import org.apache.kafka.connect.storage.OffsetStorageReaderImpl + +/** Represents a utility class that assists with the parsing of Debezium offset state. */ +interface DebeziumStateUtil { + /** + * Creates and starts a [FileOffsetBackingStore] that is used to store the tracked Debezium + * offset state. + * + * @param properties The Debezium configuration properties for the selected Debezium connector. + * @return A configured and started [FileOffsetBackingStore] instance. + */ + fun getFileOffsetBackingStore(properties: Properties?): FileOffsetBackingStore? { + val fileOffsetBackingStore = KafkaConnectUtil.fileOffsetBackingStore() + val propertiesMap = Configuration.from(properties).asMap() + propertiesMap[WorkerConfig.KEY_CONVERTER_CLASS_CONFIG] = JsonConverter::class.java.name + propertiesMap[WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG] = JsonConverter::class.java.name + fileOffsetBackingStore.configure(StandaloneConfig(propertiesMap)) + fileOffsetBackingStore.start() + return fileOffsetBackingStore + } + + val keyConverter: JsonConverter? + /** + * Creates and returns a [JsonConverter] that can be used to parse keys in the Debezium + * offset state storage. + * + * @return A [JsonConverter] for key conversion. + */ + get() { + val keyConverter = JsonConverter() + keyConverter.configure(INTERNAL_CONVERTER_CONFIG, true) + return keyConverter + } + + /** + * Creates and returns an [OffsetStorageReaderImpl] instance that can be used to load offset + * state from the offset file storage. + * + * @param fileOffsetBackingStore The [FileOffsetBackingStore] that contains the offset state + * saved to disk. + * @param properties The Debezium configuration properties for the selected Debezium connector. + * @return An [OffsetStorageReaderImpl] instance that can be used to load the offset state from + * the offset file storage. + */ + fun getOffsetStorageReader( + fileOffsetBackingStore: FileOffsetBackingStore?, + properties: Properties + ): OffsetStorageReaderImpl? { + return OffsetStorageReaderImpl( + fileOffsetBackingStore, + properties.getProperty(CONNECTOR_NAME_PROPERTY), + keyConverter, + valueConverter + ) + } + + val valueConverter: JsonConverter? + /** + * Creates and returns a [JsonConverter] that can be used to parse values in the Debezium + * offset state storage. + * + * @return A [JsonConverter] for value conversion. + */ + get() { + val valueConverter = JsonConverter() + valueConverter.configure(INTERNAL_CONVERTER_CONFIG, false) + return valueConverter + } + + companion object { + /** + * The name of the Debezium property that contains the unique name for the Debezium + * connector. + */ + const val CONNECTOR_NAME_PROPERTY: String = "name" + + /** Configuration for offset state key/value converters. */ + val INTERNAL_CONVERTER_CONFIG: Map = + java.util.Map.of(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG, Boolean.FALSE.toString()) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.kt new file mode 100644 index 000000000000..17bf9e0e1512 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.kt @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.fasterxml.jackson.databind.JsonNode +import java.time.Duration +import java.util.* +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +object RecordWaitTimeUtil { + private val LOGGER: Logger = LoggerFactory.getLogger(RecordWaitTimeUtil::class.java) + + val MIN_FIRST_RECORD_WAIT_TIME: Duration = Duration.ofMinutes(2) + val MAX_FIRST_RECORD_WAIT_TIME: Duration = Duration.ofMinutes(40) + val DEFAULT_FIRST_RECORD_WAIT_TIME: Duration = Duration.ofMinutes(5) + val DEFAULT_SUBSEQUENT_RECORD_WAIT_TIME: Duration = Duration.ofMinutes(1) + + fun checkFirstRecordWaitTime(config: JsonNode) { + // we need to skip the check because in tests, we set initial_waiting_seconds + // to 5 seconds for performance reasons, which is shorter than the minimum + // value allowed in production + if (config.has("is_test") && config["is_test"].asBoolean()) { + return + } + + val firstRecordWaitSeconds = getFirstRecordWaitSeconds(config) + if (firstRecordWaitSeconds.isPresent) { + val seconds = firstRecordWaitSeconds.get() + require( + !(seconds < MIN_FIRST_RECORD_WAIT_TIME.seconds || + seconds > MAX_FIRST_RECORD_WAIT_TIME.seconds) + ) { + String.format( + "initial_waiting_seconds must be between %d and %d seconds", + MIN_FIRST_RECORD_WAIT_TIME.seconds, + MAX_FIRST_RECORD_WAIT_TIME.seconds + ) + } + } + } + + fun getFirstRecordWaitTime(config: JsonNode): Duration { + val isTest = config.has("is_test") && config["is_test"].asBoolean() + var firstRecordWaitTime = DEFAULT_FIRST_RECORD_WAIT_TIME + + val firstRecordWaitSeconds = getFirstRecordWaitSeconds(config) + if (firstRecordWaitSeconds.isPresent) { + firstRecordWaitTime = Duration.ofSeconds(firstRecordWaitSeconds.get().toLong()) + if (!isTest && firstRecordWaitTime.compareTo(MIN_FIRST_RECORD_WAIT_TIME) < 0) { + LOGGER.warn( + "First record waiting time is overridden to {} minutes, which is the min time allowed for safety.", + MIN_FIRST_RECORD_WAIT_TIME.toMinutes() + ) + firstRecordWaitTime = MIN_FIRST_RECORD_WAIT_TIME + } else if (!isTest && firstRecordWaitTime.compareTo(MAX_FIRST_RECORD_WAIT_TIME) > 0) { + LOGGER.warn( + "First record waiting time is overridden to {} minutes, which is the max time allowed for safety.", + MAX_FIRST_RECORD_WAIT_TIME.toMinutes() + ) + firstRecordWaitTime = MAX_FIRST_RECORD_WAIT_TIME + } + } + + LOGGER.info("First record waiting time: {} seconds", firstRecordWaitTime.seconds) + return firstRecordWaitTime + } + + fun getSubsequentRecordWaitTime(config: JsonNode): Duration { + var subsequentRecordWaitTime = DEFAULT_SUBSEQUENT_RECORD_WAIT_TIME + val isTest = config.has("is_test") && config["is_test"].asBoolean() + val firstRecordWaitSeconds = getFirstRecordWaitSeconds(config) + if (isTest && firstRecordWaitSeconds.isPresent) { + // In tests, reuse the initial_waiting_seconds property to speed things up. + subsequentRecordWaitTime = Duration.ofSeconds(firstRecordWaitSeconds.get().toLong()) + } + LOGGER.info("Subsequent record waiting time: {} seconds", subsequentRecordWaitTime.seconds) + return subsequentRecordWaitTime + } + + fun getFirstRecordWaitSeconds(config: JsonNode): Optional { + val replicationMethod = config["replication_method"] + if (replicationMethod != null && replicationMethod.has("initial_waiting_seconds")) { + val seconds = config["replication_method"]["initial_waiting_seconds"].asInt() + return Optional.of(seconds) + } + return Optional.empty() + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.kt new file mode 100644 index 000000000000..b7e09e7c9b9e --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.kt @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.ObjectNode +import io.airbyte.cdk.integrations.debezium.CdcMetadataInjector +import io.airbyte.protocol.models.v0.AirbyteMessage +import java.time.Instant + +class RelationalDbDebeziumEventConverter( + private val cdcMetadataInjector: CdcMetadataInjector<*>, + private val emittedAt: Instant +) : DebeziumEventConverter { + override fun toAirbyteMessage(event: ChangeEventWithMetadata): AirbyteMessage { + val debeziumEvent = event.eventValueAsJson() + val before: JsonNode = debeziumEvent!!.get(DebeziumEventConverter.Companion.BEFORE_EVENT) + val after: JsonNode = debeziumEvent.get(DebeziumEventConverter.Companion.AFTER_EVENT) + val source: JsonNode = debeziumEvent.get(DebeziumEventConverter.Companion.SOURCE_EVENT) + + val baseNode = (if (after.isNull) before else after) as ObjectNode + val data: JsonNode = + DebeziumEventConverter.Companion.addCdcMetadata( + baseNode, + source, + cdcMetadataInjector, + after.isNull + ) + return DebeziumEventConverter.Companion.buildAirbyteMessage( + source, + cdcMetadataInjector, + emittedAt, + data + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.kt new file mode 100644 index 000000000000..c78ead79f77d --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.kt @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.protocol.models.v0.AirbyteStream +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import io.airbyte.protocol.models.v0.SyncMode +import java.util.* +import java.util.regex.Pattern +import java.util.stream.Collectors +import java.util.stream.StreamSupport +import org.codehaus.plexus.util.StringUtils + +class RelationalDbDebeziumPropertiesManager( + properties: Properties, + config: JsonNode, + catalog: ConfiguredAirbyteCatalog +) : DebeziumPropertiesManager(properties, config, catalog) { + override fun getConnectionConfiguration(config: JsonNode): Properties { + val properties = Properties() + + // db connection configuration + properties.setProperty("database.hostname", config[JdbcUtils.HOST_KEY].asText()) + properties.setProperty("database.port", config[JdbcUtils.PORT_KEY].asText()) + properties.setProperty("database.user", config[JdbcUtils.USERNAME_KEY].asText()) + properties.setProperty("database.dbname", config[JdbcUtils.DATABASE_KEY].asText()) + + if (config.has(JdbcUtils.PASSWORD_KEY)) { + properties.setProperty("database.password", config[JdbcUtils.PASSWORD_KEY].asText()) + } + + return properties + } + + override fun getName(config: JsonNode): String { + return config[JdbcUtils.DATABASE_KEY].asText() + } + + override fun getIncludeConfiguration( + catalog: ConfiguredAirbyteCatalog, + config: JsonNode? + ): Properties { + val properties = Properties() + + // table selection + properties.setProperty("table.include.list", getTableIncludelist(catalog)) + // column selection + properties.setProperty("column.include.list", getColumnIncludeList(catalog)) + + return properties + } + + companion object { + fun getTableIncludelist(catalog: ConfiguredAirbyteCatalog): String { + // Turn "stream": { + // "namespace": "schema1" + // "name": "table1 + // }, + // "stream": { + // "namespace": "schema2" + // "name": "table2 + // } -------> info "schema1.table1, schema2.table2" + + return catalog.streams + .stream() + .filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL } + .map { obj: ConfiguredAirbyteStream -> obj.stream } + .map { stream: AirbyteStream -> + stream.namespace + "." + stream.name + } // debezium needs commas escaped to split properly + .map { x: String -> StringUtils.escape(Pattern.quote(x), ",".toCharArray(), "\\,") } + .collect(Collectors.joining(",")) + } + + fun getColumnIncludeList(catalog: ConfiguredAirbyteCatalog): String { + // Turn "stream": { + // "namespace": "schema1" + // "name": "table1" + // "jsonSchema": { + // "properties": { + // "column1": { + // }, + // "column2": { + // } + // } + // } + // } -------> info "schema1.table1.(column1 | column2)" + + return catalog.streams + .stream() + .filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL } + .map { obj: ConfiguredAirbyteStream -> obj.stream } + .map { s: AirbyteStream -> + val fields = parseFields(s.jsonSchema["properties"].fieldNames()) + Pattern.quote(s.namespace + "." + s.name) + + (if (StringUtils.isNotBlank(fields)) "\\.$fields" else "") + } + .map { x: String? -> StringUtils.escape(x, ",".toCharArray(), "\\,") } + .collect(Collectors.joining(",")) + } + + private fun parseFields(fieldNames: Iterator?): String { + if (fieldNames == null || !fieldNames.hasNext()) { + return "" + } + val iter = Iterable { fieldNames } + return StreamSupport.stream(iter.spliterator(), false) + .map { f: String -> Pattern.quote(f) } + .collect(Collectors.joining("|", "(", ")")) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.kt new file mode 100644 index 000000000000..f34141431ca1 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.kt @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import com.google.common.collect.ImmutableSet + +enum class SnapshotMetadata { + FIRST, + FIRST_IN_DATA_COLLECTION, + LAST_IN_DATA_COLLECTION, + TRUE, + LAST, + FALSE, + NULL; + + companion object { + private val ENTRIES_OF_SNAPSHOT_EVENTS: Set = + ImmutableSet.of(TRUE, FIRST, FIRST_IN_DATA_COLLECTION, LAST_IN_DATA_COLLECTION) + private val STRING_TO_ENUM: MutableMap = HashMap(12) + + init { + STRING_TO_ENUM["true"] = TRUE + STRING_TO_ENUM["TRUE"] = TRUE + STRING_TO_ENUM["false"] = FALSE + STRING_TO_ENUM["FALSE"] = FALSE + STRING_TO_ENUM["last"] = LAST + STRING_TO_ENUM["LAST"] = LAST + STRING_TO_ENUM["first"] = FIRST + STRING_TO_ENUM["FIRST"] = FIRST + STRING_TO_ENUM["last_in_data_collection"] = LAST_IN_DATA_COLLECTION + STRING_TO_ENUM["LAST_IN_DATA_COLLECTION"] = LAST_IN_DATA_COLLECTION + STRING_TO_ENUM["first_in_data_collection"] = FIRST_IN_DATA_COLLECTION + STRING_TO_ENUM["FIRST_IN_DATA_COLLECTION"] = FIRST_IN_DATA_COLLECTION + STRING_TO_ENUM["NULL"] = NULL + STRING_TO_ENUM["null"] = NULL + } + + fun fromString(value: String): SnapshotMetadata? { + if (STRING_TO_ENUM.containsKey(value)) { + return STRING_TO_ENUM[value] + } + throw RuntimeException("ENUM value not found for $value") + } + + fun isSnapshotEventMetadata(snapshotMetadata: SnapshotMetadata?): Boolean { + return ENTRIES_OF_SNAPSHOT_EVENTS.contains(snapshotMetadata) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.kt new file mode 100644 index 000000000000..b73d956d499b --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.kt @@ -0,0 +1,741 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.annotations.VisibleForTesting +import com.google.common.collect.ImmutableList +import com.google.common.collect.ImmutableMap +import com.google.common.collect.Sets +import datadog.trace.api.Trace +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import io.airbyte.cdk.db.JdbcCompatibleSourceOperations +import io.airbyte.cdk.db.SqlDatabase +import io.airbyte.cdk.db.factory.DataSourceFactory.close +import io.airbyte.cdk.db.factory.DataSourceFactory.create +import io.airbyte.cdk.db.jdbc.AirbyteRecordData +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_SIZE +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_TYPE +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_TYPE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_DECIMAL_DIGITS +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_IS_NULLABLE +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_SCHEMA_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_TABLE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_COLUMN_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_DATABASE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_DATA_TYPE +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_SCHEMA_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_SIZE +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_TABLE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_TYPE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_DECIMAL_DIGITS +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_IS_NULLABLE +import io.airbyte.cdk.db.jdbc.JdbcConstants.KEY_SEQ +import io.airbyte.cdk.db.jdbc.JdbcDatabase +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.db.jdbc.JdbcUtils.getFullyQualifiedTableName +import io.airbyte.cdk.db.jdbc.StreamingJdbcDatabase +import io.airbyte.cdk.db.jdbc.streaming.JdbcStreamingQueryConfig +import io.airbyte.cdk.integrations.base.Source +import io.airbyte.cdk.integrations.source.jdbc.dto.JdbcPrivilegeDto +import io.airbyte.cdk.integrations.source.relationaldb.AbstractDbSource +import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo +import io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils +import io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils.enquoteIdentifier +import io.airbyte.cdk.integrations.source.relationaldb.TableInfo +import io.airbyte.cdk.integrations.source.relationaldb.state.StateManager +import io.airbyte.commons.functional.CheckedConsumer +import io.airbyte.commons.functional.CheckedFunction +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.stream.AirbyteStreamUtils +import io.airbyte.commons.util.AutoCloseableIterator +import io.airbyte.commons.util.AutoCloseableIterators +import io.airbyte.protocol.models.CommonField +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import io.airbyte.protocol.models.v0.SyncMode +import java.sql.Connection +import java.sql.PreparedStatement +import java.sql.ResultSet +import java.sql.SQLException +import java.util.* +import java.util.function.Consumer +import java.util.function.Function +import java.util.function.Predicate +import java.util.function.Supplier +import java.util.stream.Collectors +import javax.sql.DataSource +import org.apache.commons.lang3.tuple.ImmutablePair +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * This class contains helper functions and boilerplate for implementing a source connector for a + * relational DB source which can be accessed via JDBC driver. If you are implementing a connector + * for a relational DB which has a JDBC driver, make an effort to use this class. + */ +// This is onoly here because spotbugs complains about aggregatePrimateKeys and I wasn't able to +// figure out what it's complaining about +@SuppressFBWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") +abstract class AbstractJdbcSource( + driverClass: String, + protected val streamingQueryConfigProvider: Supplier, + sourceOperations: JdbcCompatibleSourceOperations +) : AbstractDbSource(driverClass), Source { + protected val sourceOperations: JdbcCompatibleSourceOperations + + override var quoteString: String? = null + protected var dataSources: MutableCollection = ArrayList() + + init { + this.sourceOperations = sourceOperations + } + + override fun queryTableFullRefresh( + database: JdbcDatabase, + columnNames: List, + schemaName: String?, + tableName: String, + syncMode: SyncMode, + cursorField: Optional + ): AutoCloseableIterator { + AbstractDbSource.LOGGER.info("Queueing query for table: {}", tableName) + val airbyteStream = AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName) + return AutoCloseableIterators.lazyIterator( + Supplier> { + try { + val stream = + database.unsafeQuery( + { connection: Connection -> + AbstractDbSource.LOGGER.info( + "Preparing query for table: {}", + tableName + ) + val fullTableName: String = + RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting( + schemaName, + tableName, + quoteString!! + ) + + val wrappedColumnNames = + getWrappedColumnNames( + database, + connection, + columnNames, + schemaName, + tableName + ) + val sql = + java.lang.StringBuilder( + String.format( + "SELECT %s FROM %s", + wrappedColumnNames, + fullTableName + ) + ) + // if the connector emits intermediate states, the incremental query + // must be sorted by the cursor + // field + if ( + syncMode == SyncMode.INCREMENTAL && stateEmissionFrequency > 0 + ) { + val quotedCursorField: String = + enquoteIdentifier(cursorField.get(), quoteString) + sql.append(String.format(" ORDER BY %s ASC", quotedCursorField)) + } + + val preparedStatement = connection.prepareStatement(sql.toString()) + AbstractDbSource.LOGGER.info( + "Executing query for table {}: {}", + tableName, + preparedStatement + ) + preparedStatement + }, + sourceOperations::convertDatabaseRowToAirbyteRecordData + ) + return@Supplier AutoCloseableIterators.fromStream( + stream, + airbyteStream + ) + } catch (e: SQLException) { + throw java.lang.RuntimeException(e) + } + }, + airbyteStream + ) + } + + /** + * Configures a list of operations that can be used to check the connection to the source. + * + * @return list of consumers that run queries for the check command. + */ + @Trace(operationName = AbstractDbSource.Companion.CHECK_TRACE_OPERATION_NAME) + @Throws(Exception::class) + override fun getCheckOperations( + config: JsonNode? + ): List> { + return ImmutableList.of( + CheckedConsumer { database: JdbcDatabase -> + LOGGER.info( + "Attempting to get metadata from the database to see if we can connect." + ) + database.bufferedResultSetQuery( + CheckedFunction { connection: Connection -> connection.metaData.catalogs }, + CheckedFunction { queryResult: ResultSet? -> + sourceOperations.rowToJson(queryResult!!) + } + ) + } + ) + } + + private fun getCatalog(database: SqlDatabase): String? { + return (if (database.sourceConfig!!.has(JdbcUtils.DATABASE_KEY)) + database.sourceConfig!![JdbcUtils.DATABASE_KEY].asText() + else null) + } + + @Throws(Exception::class) + override fun discoverInternal( + database: JdbcDatabase, + schema: String? + ): List>> { + val internalSchemas: Set = HashSet(excludedInternalNameSpaces) + LOGGER.info("Internal schemas to exclude: {}", internalSchemas) + val tablesWithSelectGrantPrivilege = + getPrivilegesTableForCurrentUser(database, schema) + return database + .bufferedResultSetQuery( // retrieve column metadata from the database + { connection: Connection -> + connection.metaData.getColumns(getCatalog(database), schema, null, null) + }, // store essential column metadata to a Json object from the result set about + // each column + { resultSet: ResultSet -> this.getColumnMetadata(resultSet) } + ) + .stream() + .filter( + excludeNotAccessibleTables(internalSchemas, tablesWithSelectGrantPrivilege) + ) // group by schema and table name to handle the case where a table with the same name + // exists in + // multiple schemas. + .collect( + Collectors.groupingBy>( + Function> { t: JsonNode -> + ImmutablePair.of( + t.get(INTERNAL_SCHEMA_NAME).asText(), + t.get(INTERNAL_TABLE_NAME).asText() + ) + } + ) + ) + .values + .stream() + .map>> { fields: List -> + TableInfo>( + nameSpace = fields[0].get(INTERNAL_SCHEMA_NAME).asText(), + name = fields[0].get(INTERNAL_TABLE_NAME).asText(), + fields = + fields + .stream() // read the column metadata Json object, and determine its + // type + .map { f: JsonNode -> + val datatype = sourceOperations.getDatabaseFieldType(f) + val jsonType = getAirbyteType(datatype) + LOGGER.debug( + "Table {} column {} (type {}[{}], nullable {}) -> {}", + fields[0].get(INTERNAL_TABLE_NAME).asText(), + f.get(INTERNAL_COLUMN_NAME).asText(), + f.get(INTERNAL_COLUMN_TYPE_NAME).asText(), + f.get(INTERNAL_COLUMN_SIZE).asInt(), + f.get(INTERNAL_IS_NULLABLE).asBoolean(), + jsonType + ) + object : + CommonField( + f.get(INTERNAL_COLUMN_NAME).asText(), + datatype + ) {} + } + .collect(Collectors.toList>()), + cursorFields = extractCursorFields(fields) + ) + } + .collect(Collectors.toList>>()) + } + + private fun extractCursorFields(fields: List): List { + return fields + .stream() + .filter { field: JsonNode -> + isCursorType(sourceOperations.getDatabaseFieldType(field)) + } + .map( + Function { field: JsonNode -> + field.get(INTERNAL_COLUMN_NAME).asText() + } + ) + .collect(Collectors.toList()) + } + + protected fun excludeNotAccessibleTables( + internalSchemas: Set, + tablesWithSelectGrantPrivilege: Set? + ): Predicate { + return Predicate { jsonNode: JsonNode -> + if (tablesWithSelectGrantPrivilege!!.isEmpty()) { + return@Predicate isNotInternalSchema(jsonNode, internalSchemas) + } + (tablesWithSelectGrantPrivilege.stream().anyMatch { e: JdbcPrivilegeDto -> + e.schemaName == jsonNode.get(INTERNAL_SCHEMA_NAME).asText() + } && + tablesWithSelectGrantPrivilege.stream().anyMatch { e: JdbcPrivilegeDto -> + e.tableName == jsonNode.get(INTERNAL_TABLE_NAME).asText() + } && + !internalSchemas.contains(jsonNode.get(INTERNAL_SCHEMA_NAME).asText())) + } + } + + // needs to override isNotInternalSchema for connectors that override + // getPrivilegesTableForCurrentUser() + protected fun isNotInternalSchema(jsonNode: JsonNode, internalSchemas: Set): Boolean { + return !internalSchemas.contains(jsonNode.get(INTERNAL_SCHEMA_NAME).asText()) + } + + /** + * @param resultSet Description of a column available in the table catalog. + * @return Essential information about a column to determine which table it belongs to and its + * type. + */ + @Throws(SQLException::class) + private fun getColumnMetadata(resultSet: ResultSet): JsonNode { + val fieldMap = + ImmutableMap.builder< + String, Any + >() // we always want a namespace, if we cannot get a schema, use db name. + .put( + INTERNAL_SCHEMA_NAME, + if (resultSet.getObject(JDBC_COLUMN_SCHEMA_NAME) != null) + resultSet.getString(JDBC_COLUMN_SCHEMA_NAME) + else resultSet.getObject(JDBC_COLUMN_DATABASE_NAME) + ) + .put(INTERNAL_TABLE_NAME, resultSet.getString(JDBC_COLUMN_TABLE_NAME)) + .put(INTERNAL_COLUMN_NAME, resultSet.getString(JDBC_COLUMN_COLUMN_NAME)) + .put(INTERNAL_COLUMN_TYPE, resultSet.getString(JDBC_COLUMN_DATA_TYPE)) + .put(INTERNAL_COLUMN_TYPE_NAME, resultSet.getString(JDBC_COLUMN_TYPE_NAME)) + .put(INTERNAL_COLUMN_SIZE, resultSet.getInt(JDBC_COLUMN_SIZE)) + .put(INTERNAL_IS_NULLABLE, resultSet.getString(JDBC_IS_NULLABLE)) + if (resultSet.getString(JDBC_DECIMAL_DIGITS) != null) { + fieldMap.put(INTERNAL_DECIMAL_DIGITS, resultSet.getString(JDBC_DECIMAL_DIGITS)) + } + return Jsons.jsonNode(fieldMap.build()) + } + + @Throws(Exception::class) + public override fun discoverInternal( + database: JdbcDatabase + ): List>> { + return discoverInternal(database, null) + } + + public override fun getAirbyteType(columnType: Datatype): JsonSchemaType { + return sourceOperations.getAirbyteType(columnType) + } + + @VisibleForTesting + @JvmRecord + data class PrimaryKeyAttributesFromDb( + val streamName: String, + val primaryKey: String, + val keySequence: Int + ) + + override fun discoverPrimaryKeys( + database: JdbcDatabase, + tableInfos: List>> + ): Map> { + LOGGER.info( + "Discover primary keys for tables: " + + tableInfos + .stream() + .map { obj: TableInfo> -> obj.name } + .collect(Collectors.toSet()) + ) + try { + // Get all primary keys without specifying a table name + val tablePrimaryKeys = + aggregatePrimateKeys( + database.bufferedResultSetQuery( + { connection: Connection -> + connection.metaData.getPrimaryKeys(getCatalog(database), null, null) + }, + { r: ResultSet -> + val schemaName: String = + if (r.getObject(JDBC_COLUMN_SCHEMA_NAME) != null) + r.getString(JDBC_COLUMN_SCHEMA_NAME) + else r.getString(JDBC_COLUMN_DATABASE_NAME) + val streamName = + getFullyQualifiedTableName( + schemaName, + r.getString(JDBC_COLUMN_TABLE_NAME) + ) + val primaryKey: String = r.getString(JDBC_COLUMN_COLUMN_NAME) + val keySeq: Int = r.getInt(KEY_SEQ) + PrimaryKeyAttributesFromDb(streamName, primaryKey, keySeq) + } + ) + ) + if (!tablePrimaryKeys.isEmpty()) { + return tablePrimaryKeys + } + } catch (e: SQLException) { + LOGGER.debug( + String.format( + "Could not retrieve primary keys without a table name (%s), retrying", + e + ) + ) + } + // Get primary keys one table at a time + return tableInfos + .stream() + .collect( + Collectors.toMap>, String, MutableList>( + Function>, String> { + tableInfo: TableInfo> -> + getFullyQualifiedTableName(tableInfo.nameSpace, tableInfo.name) + }, + Function>, MutableList> toMap@{ + tableInfo: TableInfo> -> + val streamName = + getFullyQualifiedTableName(tableInfo.nameSpace, tableInfo.name) + try { + val primaryKeys = + aggregatePrimateKeys( + database.bufferedResultSetQuery( + { connection: Connection -> + connection.metaData.getPrimaryKeys( + getCatalog(database), + tableInfo.nameSpace, + tableInfo.name + ) + }, + { r: ResultSet -> + PrimaryKeyAttributesFromDb( + streamName, + r.getString(JDBC_COLUMN_COLUMN_NAME), + r.getInt(KEY_SEQ) + ) + } + ) + ) + return@toMap primaryKeys.getOrDefault( + streamName, + mutableListOf() + ) + } catch (e: SQLException) { + LOGGER.error( + String.format( + "Could not retrieve primary keys for %s: %s", + streamName, + e + ) + ) + return@toMap mutableListOf() + } + } + ) + ) + } + + public override fun isCursorType(type: Datatype): Boolean { + return sourceOperations.isCursorType(type) + } + + override fun queryTableIncremental( + database: JdbcDatabase, + columnNames: List, + schemaName: String?, + tableName: String, + cursorInfo: CursorInfo, + cursorFieldType: Datatype + ): AutoCloseableIterator { + AbstractDbSource.LOGGER.info("Queueing query for table: {}", tableName) + val airbyteStream = AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName) + return AutoCloseableIterators.lazyIterator( + { + try { + val stream = + database.unsafeQuery( + { connection: Connection -> + AbstractDbSource.LOGGER.info( + "Preparing query for table: {}", + tableName + ) + val fullTableName: String = + RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting( + schemaName, + tableName, + quoteString!! + ) + val quotedCursorField: String = + enquoteIdentifier(cursorInfo.cursorField, quoteString) + val operator: String + if (cursorInfo.cursorRecordCount <= 0L) { + operator = ">" + } else { + val actualRecordCount = + getActualCursorRecordCount( + connection, + fullTableName, + quotedCursorField, + cursorFieldType, + cursorInfo.cursor + ) + AbstractDbSource.LOGGER.info( + "Table {} cursor count: expected {}, actual {}", + tableName, + cursorInfo.cursorRecordCount, + actualRecordCount + ) + operator = + if (actualRecordCount == cursorInfo.cursorRecordCount) { + ">" + } else { + ">=" + } + } + val wrappedColumnNames = + getWrappedColumnNames( + database, + connection, + columnNames, + schemaName, + tableName + ) + val sql = + StringBuilder( + String.format( + "SELECT %s FROM %s WHERE %s %s ?", + wrappedColumnNames, + fullTableName, + quotedCursorField, + operator + ) + ) + // if the connector emits intermediate states, the incremental query + // must be sorted by the cursor + // field + if (stateEmissionFrequency > 0) { + sql.append(String.format(" ORDER BY %s ASC", quotedCursorField)) + } + val preparedStatement = connection.prepareStatement(sql.toString()) + AbstractDbSource.LOGGER.info( + "Executing query for table {}: {}", + tableName, + preparedStatement + ) + sourceOperations.setCursorField( + preparedStatement, + 1, + cursorFieldType, + cursorInfo.cursor!! + ) + preparedStatement + }, + sourceOperations::convertDatabaseRowToAirbyteRecordData + ) + return@lazyIterator AutoCloseableIterators.fromStream( + stream, + airbyteStream + ) + } catch (e: SQLException) { + throw RuntimeException(e) + } + }, + airbyteStream + ) + } + + protected fun getCountColumnName(): String { + return "record_count" + } + + /** Some databases need special column names in the query. */ + @Throws(SQLException::class) + protected fun getWrappedColumnNames( + database: JdbcDatabase?, + connection: Connection?, + columnNames: List, + schemaName: String?, + tableName: String? + ): String? { + return RelationalDbQueryUtils.enquoteIdentifierList(columnNames, quoteString!!) + } + + @Throws(SQLException::class) + protected fun getActualCursorRecordCount( + connection: Connection, + fullTableName: String?, + quotedCursorField: String?, + cursorFieldType: Datatype, + cursor: String? + ): Long { + val columnName = getCountColumnName() + val cursorRecordStatement: PreparedStatement + if (cursor == null) { + val cursorRecordQuery = + String.format( + "SELECT COUNT(*) AS %s FROM %s WHERE %s IS NULL", + columnName, + fullTableName, + quotedCursorField + ) + cursorRecordStatement = connection.prepareStatement(cursorRecordQuery) + } else { + val cursorRecordQuery = + String.format( + "SELECT COUNT(*) AS %s FROM %s WHERE %s = ?", + columnName, + fullTableName, + quotedCursorField + ) + cursorRecordStatement = connection.prepareStatement(cursorRecordQuery) + + sourceOperations.setCursorField(cursorRecordStatement, 1, cursorFieldType, cursor) + } + val resultSet = cursorRecordStatement.executeQuery() + return if (resultSet.next()) { + resultSet.getLong(columnName) + } else { + 0L + } + } + + @Throws(SQLException::class) + public override fun createDatabase(sourceConfig: JsonNode): JdbcDatabase { + return createDatabase(sourceConfig, JdbcDataSourceUtils.DEFAULT_JDBC_PARAMETERS_DELIMITER) + } + + @Throws(SQLException::class) + fun createDatabase(sourceConfig: JsonNode, delimiter: String): JdbcDatabase { + val jdbcConfig = toDatabaseConfig(sourceConfig) + val connectionProperties = + JdbcDataSourceUtils.getConnectionProperties(sourceConfig, delimiter) + // Create the data source + val dataSource = + create( + if (jdbcConfig!!.has(JdbcUtils.USERNAME_KEY)) + jdbcConfig[JdbcUtils.USERNAME_KEY].asText() + else null, + if (jdbcConfig.has(JdbcUtils.PASSWORD_KEY)) + jdbcConfig[JdbcUtils.PASSWORD_KEY].asText() + else null, + driverClassName, + jdbcConfig[JdbcUtils.JDBC_URL_KEY].asText(), + connectionProperties, + getConnectionTimeout(connectionProperties!!) + ) + // Record the data source so that it can be closed. + dataSources.add(dataSource) + + val database: JdbcDatabase = + StreamingJdbcDatabase(dataSource, sourceOperations, streamingQueryConfigProvider) + + quoteString = + (if (quoteString == null) database.metaData.identifierQuoteString else quoteString) + database.sourceConfig = sourceConfig + database.databaseConfig = jdbcConfig + return database + } + + /** + * {@inheritDoc} + * + * @param database database instance + * @param catalog schema of the incoming messages. + * @throws SQLException + */ + @Throws(SQLException::class) + override fun logPreSyncDebugData(database: JdbcDatabase, catalog: ConfiguredAirbyteCatalog?) { + LOGGER.info( + "Data source product recognized as {}:{}", + database.metaData.databaseProductName, + database.metaData.databaseProductVersion + ) + } + + override fun close() { + dataSources.forEach( + Consumer { d: DataSource? -> + try { + close(d) + } catch (e: Exception) { + LOGGER.warn("Unable to close data source.", e) + } + } + ) + dataSources.clear() + } + + protected fun identifyStreamsToSnapshot( + catalog: ConfiguredAirbyteCatalog, + stateManager: StateManager + ): List { + val alreadySyncedStreams = stateManager.cdcStateManager.initialStreamsSynced + if ( + alreadySyncedStreams!!.isEmpty() && + (stateManager.cdcStateManager.cdcState?.state == null) + ) { + return emptyList() + } + + val allStreams = AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog) + + val newlyAddedStreams: Set = + HashSet(Sets.difference(allStreams, alreadySyncedStreams)) + + return catalog.streams + .stream() + .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } + .filter { stream: ConfiguredAirbyteStream -> + newlyAddedStreams.contains( + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) + ) + } + .map { `object`: ConfiguredAirbyteStream -> Jsons.clone(`object`) } + .collect(Collectors.toList()) + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(AbstractJdbcSource::class.java) + + /** + * Aggregate list of @param entries of StreamName and PrimaryKey and + * + * @return a map by StreamName to associated list of primary keys + */ + @VisibleForTesting + fun aggregatePrimateKeys( + entries: List + ): Map> { + val result: MutableMap> = HashMap() + entries + .stream() + .sorted(Comparator.comparingInt(PrimaryKeyAttributesFromDb::keySequence)) + .forEach { entry: PrimaryKeyAttributesFromDb -> + if (!result.containsKey(entry.streamName)) { + result[entry.streamName] = ArrayList() + } + result[entry.streamName]!!.add(entry.primaryKey) + } + return result + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.kt new file mode 100644 index 000000000000..eda6d797635a --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.kt @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.db.jdbc.JdbcUtils.parseJdbcParameters +import io.airbyte.commons.map.MoreMaps + +object JdbcDataSourceUtils { + const val DEFAULT_JDBC_PARAMETERS_DELIMITER: String = "&" + + /** + * Validates for duplication parameters + * + * @param customParameters custom connection properties map as specified by each Jdbc source + * @param defaultParameters connection properties map as specified by each Jdbc source + * @throws IllegalArgumentException + */ + fun assertCustomParametersDontOverwriteDefaultParameters( + customParameters: Map, + defaultParameters: Map + ) { + for (key in defaultParameters.keys) { + require( + !(customParameters.containsKey(key) && + customParameters[key] != defaultParameters[key]) + ) { "Cannot overwrite default JDBC parameter $key" } + } + } + + /** + * Retrieves connection_properties from config and also validates if custom jdbc_url parameters + * overlap with the default properties + * + * @param config A configuration used to check Jdbc connection + * @return A mapping of connection properties + */ + fun getConnectionProperties(config: JsonNode): Map { + return getConnectionProperties(config, DEFAULT_JDBC_PARAMETERS_DELIMITER) + } + + fun getConnectionProperties(config: JsonNode, parameterDelimiter: String): Map { + val customProperties = + parseJdbcParameters(config, JdbcUtils.JDBC_URL_PARAMS_KEY, parameterDelimiter) + val defaultProperties = getDefaultConnectionProperties(config) + assertCustomParametersDontOverwriteDefaultParameters(customProperties, defaultProperties) + return MoreMaps.merge(customProperties, defaultProperties) + } + + /** + * Retrieves default connection_properties from config + * + * TODO: make this method abstract and add parity features to destination connectors + * + * @param config A configuration used to check Jdbc connection + * @return A mapping of the default connection properties + */ + fun getDefaultConnectionProperties(config: JsonNode): Map { + // NOTE that Postgres returns an empty map for some reason? + return parseJdbcParameters( + config, + "connection_properties", + DEFAULT_JDBC_PARAMETERS_DELIMITER + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.kt new file mode 100644 index 000000000000..875239722991 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.kt @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.db.util.SSLCertificateUtils.keyStoreFromCertificate +import io.airbyte.cdk.db.util.SSLCertificateUtils.keyStoreFromClientCertificate +import java.io.IOException +import java.net.MalformedURLException +import java.net.URI +import java.nio.file.Files +import java.nio.file.Path +import java.security.KeyStoreException +import java.security.NoSuchAlgorithmException +import java.security.cert.CertificateException +import java.security.spec.InvalidKeySpecException +import java.util.* +import org.apache.commons.lang3.RandomStringUtils +import org.apache.commons.lang3.tuple.ImmutablePair +import org.apache.commons.lang3.tuple.Pair +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +class JdbcSSLConnectionUtils { + var caCertKeyStorePair: Pair? = null + var clientCertKeyStorePair: Pair? = null + + enum class SslMode(vararg spec: String) { + DISABLED("disable"), + ALLOWED("allow"), + PREFERRED("preferred", "prefer"), + REQUIRED("required", "require"), + VERIFY_CA("verify_ca", "verify-ca"), + VERIFY_IDENTITY("verify_identity", "verify-full"); + + val spec: List = Arrays.asList(*spec) + + companion object { + fun bySpec(spec: String): Optional { + return Arrays.stream(entries.toTypedArray()) + .filter { sslMode: SslMode -> sslMode.spec.contains(spec) } + .findFirst() + } + } + } + + companion object { + const val SSL_MODE: String = "sslMode" + + const val TRUST_KEY_STORE_URL: String = "trustCertificateKeyStoreUrl" + const val TRUST_KEY_STORE_PASS: String = "trustCertificateKeyStorePassword" + const val CLIENT_KEY_STORE_URL: String = "clientCertificateKeyStoreUrl" + const val CLIENT_KEY_STORE_PASS: String = "clientCertificateKeyStorePassword" + const val CLIENT_KEY_STORE_TYPE: String = "clientCertificateKeyStoreType" + const val TRUST_KEY_STORE_TYPE: String = "trustCertificateKeyStoreType" + const val KEY_STORE_TYPE_PKCS12: String = "PKCS12" + const val PARAM_MODE: String = "mode" + private val LOGGER: Logger = + LoggerFactory.getLogger(JdbcSSLConnectionUtils::class.java.javaClass) + const val PARAM_CA_CERTIFICATE: String = "ca_certificate" + const val PARAM_CLIENT_CERTIFICATE: String = "client_certificate" + const val PARAM_CLIENT_KEY: String = "client_key" + const val PARAM_CLIENT_KEY_PASSWORD: String = "client_key_password" + + /** + * Parses SSL related configuration and generates keystores to be used by connector + * + * @param config configuration + * @return map containing relevant parsed values including location of keystore or an empty + * map + */ + fun parseSSLConfig(config: JsonNode): Map { + LOGGER.debug("source config: {}", config) + + var caCertKeyStorePair: Pair? = null + var clientCertKeyStorePair: Pair? = null + val additionalParameters: MutableMap = HashMap() + // assume ssl if not explicitly mentioned. + if (!config.has(JdbcUtils.SSL_KEY) || config[JdbcUtils.SSL_KEY].asBoolean()) { + if (config.has(JdbcUtils.SSL_MODE_KEY)) { + val specMode = config[JdbcUtils.SSL_MODE_KEY][PARAM_MODE].asText() + additionalParameters[SSL_MODE] = + SslMode.bySpec(specMode) + .orElseThrow { IllegalArgumentException("unexpected ssl mode") } + .name + if (Objects.isNull(caCertKeyStorePair)) { + caCertKeyStorePair = prepareCACertificateKeyStore(config) + } + + if (Objects.nonNull(caCertKeyStorePair)) { + LOGGER.debug( + "uri for ca cert keystore: {}", + caCertKeyStorePair!!.left.toString() + ) + try { + additionalParameters.putAll( + java.util.Map.of( + TRUST_KEY_STORE_URL, + caCertKeyStorePair.left.toURL().toString(), + TRUST_KEY_STORE_PASS, + caCertKeyStorePair.right, + TRUST_KEY_STORE_TYPE, + KEY_STORE_TYPE_PKCS12 + ) + ) + } catch (e: MalformedURLException) { + throw RuntimeException("Unable to get a URL for trust key store") + } + } + + if (Objects.isNull(clientCertKeyStorePair)) { + clientCertKeyStorePair = prepareClientCertificateKeyStore(config) + } + + if (Objects.nonNull(clientCertKeyStorePair)) { + LOGGER.debug( + "uri for client cert keystore: {} / {}", + clientCertKeyStorePair!!.left.toString(), + clientCertKeyStorePair.right + ) + try { + additionalParameters.putAll( + java.util.Map.of( + CLIENT_KEY_STORE_URL, + clientCertKeyStorePair.left.toURL().toString(), + CLIENT_KEY_STORE_PASS, + clientCertKeyStorePair.right, + CLIENT_KEY_STORE_TYPE, + KEY_STORE_TYPE_PKCS12 + ) + ) + } catch (e: MalformedURLException) { + throw RuntimeException("Unable to get a URL for client key store") + } + } + } else { + additionalParameters[SSL_MODE] = SslMode.DISABLED.name + } + } + LOGGER.debug("additional params: {}", additionalParameters) + return additionalParameters + } + + fun prepareCACertificateKeyStore(config: JsonNode): Pair? { + // if config available + // if has CA cert - make keystore + // if has client cert + // if has client password - make keystore using password + // if no client password - make keystore using random password + var caCertKeyStorePair: Pair? = null + if (Objects.nonNull(config)) { + if (!config.has(JdbcUtils.SSL_KEY) || config[JdbcUtils.SSL_KEY].asBoolean()) { + val encryption = config[JdbcUtils.SSL_MODE_KEY] + if ( + encryption.has(PARAM_CA_CERTIFICATE) && + !encryption[PARAM_CA_CERTIFICATE].asText().isEmpty() + ) { + val clientKeyPassword = getOrGeneratePassword(encryption) + try { + val caCertKeyStoreUri = + keyStoreFromCertificate( + encryption[PARAM_CA_CERTIFICATE].asText(), + clientKeyPassword, + null, + null + ) + caCertKeyStorePair = ImmutablePair(caCertKeyStoreUri, clientKeyPassword) + } catch (e: CertificateException) { + throw RuntimeException( + "Failed to create keystore for CA certificate", + e + ) + } catch (e: IOException) { + throw RuntimeException( + "Failed to create keystore for CA certificate", + e + ) + } catch (e: KeyStoreException) { + throw RuntimeException( + "Failed to create keystore for CA certificate", + e + ) + } catch (e: NoSuchAlgorithmException) { + throw RuntimeException( + "Failed to create keystore for CA certificate", + e + ) + } + } + } + } + return caCertKeyStorePair + } + + private fun getOrGeneratePassword(sslModeConfig: JsonNode): String { + val clientKeyPassword = + if ( + sslModeConfig.has(PARAM_CLIENT_KEY_PASSWORD) && + !sslModeConfig[PARAM_CLIENT_KEY_PASSWORD].asText().isEmpty() + ) { + sslModeConfig[PARAM_CLIENT_KEY_PASSWORD].asText() + } else { + RandomStringUtils.randomAlphanumeric(10) + } + return clientKeyPassword + } + + fun prepareClientCertificateKeyStore(config: JsonNode): Pair? { + var clientCertKeyStorePair: Pair? = null + if (Objects.nonNull(config)) { + if (!config.has(JdbcUtils.SSL_KEY) || config[JdbcUtils.SSL_KEY].asBoolean()) { + val encryption = config[JdbcUtils.SSL_MODE_KEY] + if ( + encryption.has(PARAM_CLIENT_CERTIFICATE) && + !encryption[PARAM_CLIENT_CERTIFICATE].asText().isEmpty() && + encryption.has(PARAM_CLIENT_KEY) && + !encryption[PARAM_CLIENT_KEY].asText().isEmpty() + ) { + val clientKeyPassword = getOrGeneratePassword(encryption) + try { + val clientCertKeyStoreUri = + keyStoreFromClientCertificate( + encryption[PARAM_CLIENT_CERTIFICATE].asText(), + encryption[PARAM_CLIENT_KEY].asText(), + clientKeyPassword, + null + ) + clientCertKeyStorePair = + ImmutablePair(clientCertKeyStoreUri, clientKeyPassword) + } catch (e: CertificateException) { + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) + } catch (e: IOException) { + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) + } catch (e: KeyStoreException) { + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) + } catch (e: NoSuchAlgorithmException) { + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) + } catch (e: InvalidKeySpecException) { + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) + } catch (e: InterruptedException) { + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) + } + } + } + } + return clientCertKeyStorePair + } + + fun fileFromCertPem(certPem: String?): Path { + try { + val path = Files.createTempFile(null, ".crt") + Files.writeString(path, certPem) + path.toFile().deleteOnExit() + return path + } catch (e: IOException) { + throw RuntimeException("Cannot save root certificate to file", e) + } + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.kt new file mode 100644 index 000000000000..7e1f9b312534 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.kt @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.factory.DatabaseDriver +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig +import io.airbyte.cdk.integrations.base.IntegrationRunner +import io.airbyte.cdk.integrations.base.Source +import java.sql.JDBCType +import java.util.function.Supplier +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +class JdbcSource : + AbstractJdbcSource( + DatabaseDriver.POSTGRESQL.driverClassName, + Supplier { AdaptiveStreamingQueryConfig() }, + JdbcUtils.defaultSourceOperations + ), + Source { + // no-op for JdbcSource since the config it receives is designed to be use for JDBC. + override fun toDatabaseConfig(config: JsonNode): JsonNode { + return config + } + + override val excludedInternalNameSpaces: Set + get() = setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(JdbcSource::class.java) + + @Throws(Exception::class) + @JvmStatic + fun main(args: Array) { + val source: Source = JdbcSource() + LOGGER.info("starting source: {}", JdbcSource::class.java) + IntegrationRunner(source).run(args) + LOGGER.info("completed source: {}", JdbcSource::class.java) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.kt new file mode 100644 index 000000000000..0e689e819b3a --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.kt @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc.dto + +import com.google.common.base.Objects + +/** The class to store values from privileges table */ +class JdbcPrivilegeDto( + val grantee: String?, + val tableName: String?, + val schemaName: String?, + val privilege: String? +) { + class JdbcPrivilegeDtoBuilder { + private var grantee: String? = null + private var tableName: String? = null + private var schemaName: String? = null + private var privilege: String? = null + + fun grantee(grantee: String?): JdbcPrivilegeDtoBuilder { + this.grantee = grantee + return this + } + + fun tableName(tableName: String?): JdbcPrivilegeDtoBuilder { + this.tableName = tableName + return this + } + + fun schemaName(schemaName: String?): JdbcPrivilegeDtoBuilder { + this.schemaName = schemaName + return this + } + + fun privilege(privilege: String?): JdbcPrivilegeDtoBuilder { + this.privilege = privilege + return this + } + + fun build(): JdbcPrivilegeDto { + return JdbcPrivilegeDto(grantee, tableName, schemaName, privilege) + } + } + + override fun equals(o: Any?): Boolean { + if (this === o) { + return true + } + if (o == null || javaClass != o.javaClass) { + return false + } + val that = o as JdbcPrivilegeDto + return (Objects.equal(grantee, that.grantee) && + Objects.equal(tableName, that.tableName) && + Objects.equal(schemaName, that.schemaName) && + Objects.equal(privilege, that.privilege)) + } + + override fun hashCode(): Int { + return Objects.hashCode(grantee, tableName, schemaName, privilege) + } + + override fun toString(): String { + return "JdbcPrivilegeDto{" + + "grantee='" + + grantee + + '\'' + + ", columnName='" + + tableName + + '\'' + + ", schemaName='" + + schemaName + + '\'' + + ", privilege='" + + privilege + + '\'' + + '}' + } + + companion object { + fun builder(): JdbcPrivilegeDtoBuilder { + return JdbcPrivilegeDtoBuilder() + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.kt new file mode 100644 index 000000000000..0dbfaeb7b9a9 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.kt @@ -0,0 +1,831 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.base.Preconditions +import datadog.trace.api.Trace +import io.airbyte.cdk.db.AbstractDatabase +import io.airbyte.cdk.db.IncrementalUtils.getCursorField +import io.airbyte.cdk.db.IncrementalUtils.getCursorFieldOptional +import io.airbyte.cdk.db.IncrementalUtils.getCursorType +import io.airbyte.cdk.db.jdbc.AirbyteRecordData +import io.airbyte.cdk.db.jdbc.JdbcDatabase +import io.airbyte.cdk.integrations.JdbcConnector +import io.airbyte.cdk.integrations.base.AirbyteTraceMessageUtility +import io.airbyte.cdk.integrations.base.AirbyteTraceMessageUtility.emitConfigErrorTrace +import io.airbyte.cdk.integrations.base.Source +import io.airbyte.cdk.integrations.base.errors.messages.ErrorMessage.getErrorMessage +import io.airbyte.cdk.integrations.source.relationaldb.state.* +import io.airbyte.cdk.integrations.util.ApmTraceUtils.addExceptionToTrace +import io.airbyte.cdk.integrations.util.ConnectorExceptionUtil +import io.airbyte.commons.exceptions.ConfigErrorException +import io.airbyte.commons.exceptions.ConnectionErrorException +import io.airbyte.commons.features.EnvVariableFeatureFlags +import io.airbyte.commons.features.FeatureFlags +import io.airbyte.commons.functional.CheckedConsumer +import io.airbyte.commons.lang.Exceptions +import io.airbyte.commons.stream.AirbyteStreamUtils +import io.airbyte.commons.util.AutoCloseableIterator +import io.airbyte.commons.util.AutoCloseableIterators +import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair +import io.airbyte.protocol.models.CommonField +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.* +import java.sql.SQLException +import java.time.Duration +import java.time.Instant +import java.util.* +import java.util.concurrent.atomic.AtomicLong +import java.util.function.Function +import java.util.stream.Collectors +import java.util.stream.Stream +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * This class contains helper functions and boilerplate for implementing a source connector for a DB + * source of both non-relational and relational type + */ +abstract class AbstractDbSource +protected constructor(driverClassName: String) : + JdbcConnector(driverClassName), Source, AutoCloseable { + // TODO: Remove when the flag is not use anymore + protected var featureFlags: FeatureFlags = EnvVariableFeatureFlags() + + @Trace(operationName = CHECK_TRACE_OPERATION_NAME) + @Throws(Exception::class) + override fun check(config: JsonNode): AirbyteConnectionStatus? { + try { + val database = createDatabase(config) + for (checkOperation in getCheckOperations(config)) { + checkOperation.accept(database) + } + + return AirbyteConnectionStatus().withStatus(AirbyteConnectionStatus.Status.SUCCEEDED) + } catch (ex: ConnectionErrorException) { + addExceptionToTrace(ex) + val message = getErrorMessage(ex.stateCode, ex.errorCode, ex.exceptionMessage, ex) + emitConfigErrorTrace(ex, message) + return AirbyteConnectionStatus() + .withStatus(AirbyteConnectionStatus.Status.FAILED) + .withMessage(message) + } catch (e: Exception) { + addExceptionToTrace(e) + LOGGER.info("Exception while checking connection: ", e) + return AirbyteConnectionStatus() + .withStatus(AirbyteConnectionStatus.Status.FAILED) + .withMessage( + String.format( + ConnectorExceptionUtil.COMMON_EXCEPTION_MESSAGE_TEMPLATE, + e.message + ) + ) + } finally { + close() + } + } + + @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) + @Throws(Exception::class) + override fun discover(config: JsonNode): AirbyteCatalog? { + try { + val database = createDatabase(config) + val tableInfos = discoverWithoutSystemTables(database) + val fullyQualifiedTableNameToPrimaryKeys = discoverPrimaryKeys(database, tableInfos) + return DbSourceDiscoverUtil.convertTableInfosToAirbyteCatalog( + tableInfos, + fullyQualifiedTableNameToPrimaryKeys + ) { columnType: DataType -> this.getAirbyteType(columnType) } + } finally { + close() + } + } + + /** + * Creates a list of AirbyteMessageIterators with all the streams selected in a configured + * catalog + * + * @param config + * - integration-specific configuration object as json. e.g. { "username": "airbyte", + * "password": "super secure" } + * @param catalog + * - schema of the incoming messages. + * @param state + * - state of the incoming messages. + * @return AirbyteMessageIterator with all the streams that are to be synced + * @throws Exception + */ + @Throws(Exception::class) + override fun read( + config: JsonNode, + catalog: ConfiguredAirbyteCatalog, + state: JsonNode? + ): AutoCloseableIterator { + val supportedStateType = getSupportedStateType(config) + val stateManager = + StateManagerFactory.createStateManager( + supportedStateType, + StateGeneratorUtils.deserializeInitialState(state, supportedStateType), + catalog + ) + val emittedAt = Instant.now() + + val database = createDatabase(config) + + logPreSyncDebugData(database, catalog) + + val fullyQualifiedTableNameToInfo = + discoverWithoutSystemTables(database) + .stream() + .collect( + Collectors.toMap( + Function { t: TableInfo> -> + String.format("%s.%s", t.nameSpace, t.name) + }, + Function.identity() + ) + ) + + validateCursorFieldForIncrementalTables(fullyQualifiedTableNameToInfo, catalog, database) + + DbSourceDiscoverUtil.logSourceSchemaChange(fullyQualifiedTableNameToInfo, catalog) { + columnType: DataType -> + this.getAirbyteType(columnType) + } + + val incrementalIterators = + getIncrementalIterators( + database, + catalog, + fullyQualifiedTableNameToInfo, + stateManager, + emittedAt + ) + val fullRefreshIterators = + getFullRefreshIterators( + database, + catalog, + fullyQualifiedTableNameToInfo, + stateManager, + emittedAt + ) + val iteratorList = + Stream.of(incrementalIterators, fullRefreshIterators) + .flatMap(Collection>::stream) + .collect(Collectors.toList()) + + return AutoCloseableIterators.appendOnClose( + AutoCloseableIterators.concatWithEagerClose( + iteratorList, + AirbyteTraceMessageUtility::emitStreamStatusTrace + ) + ) { + LOGGER.info("Closing database connection pool.") + Exceptions.toRuntime { this.close() } + LOGGER.info("Closed database connection pool.") + } + } + + @Throws(SQLException::class) + protected fun validateCursorFieldForIncrementalTables( + tableNameToTable: Map>>, + catalog: ConfiguredAirbyteCatalog, + database: Database + ) { + val tablesWithInvalidCursor: MutableList = + ArrayList() + for (airbyteStream in catalog.streams) { + val stream = airbyteStream.stream + val fullyQualifiedTableName = + DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.namespace, stream.name) + val hasSourceDefinedCursor = + (!Objects.isNull(airbyteStream.stream.sourceDefinedCursor) && + airbyteStream.stream.sourceDefinedCursor) + if ( + !tableNameToTable.containsKey(fullyQualifiedTableName) || + airbyteStream.syncMode != SyncMode.INCREMENTAL || + hasSourceDefinedCursor + ) { + continue + } + + val table = tableNameToTable[fullyQualifiedTableName]!! + val cursorField = getCursorFieldOptional(airbyteStream) + if (cursorField.isEmpty) { + continue + } + val cursorType = + table.fields!! + .stream() + .filter { info: CommonField -> info.name == cursorField.get() } + .map { obj: CommonField -> obj.type } + .findFirst() + .orElseThrow() + + if (!isCursorType(cursorType)) { + tablesWithInvalidCursor.add( + InvalidCursorInfoUtil.InvalidCursorInfo( + fullyQualifiedTableName, + cursorField.get(), + cursorType.toString(), + "Unsupported cursor type" + ) + ) + continue + } + + if ( + !verifyCursorColumnValues( + database, + stream.namespace, + stream.name, + cursorField.get() + ) + ) { + tablesWithInvalidCursor.add( + InvalidCursorInfoUtil.InvalidCursorInfo( + fullyQualifiedTableName, + cursorField.get(), + cursorType.toString(), + "Cursor column contains NULL value" + ) + ) + } + } + + if (!tablesWithInvalidCursor.isEmpty()) { + throw ConfigErrorException( + InvalidCursorInfoUtil.getInvalidCursorConfigMessage(tablesWithInvalidCursor) + ) + } + } + + /** + * Verify that cursor column allows syncing to go through. + * + * @param database database + * @return true if syncing can go through. false otherwise + * @throws SQLException exception + */ + @Throws(SQLException::class) + protected fun verifyCursorColumnValues( + database: Database, + schema: String?, + tableName: String?, + columnName: String? + ): Boolean { + /* no-op */ + return true + } + + /** + * Estimates the total volume (rows and bytes) to sync and emits a [AirbyteEstimateTraceMessage] + * associated with the full refresh stream. + * + * @param database database + */ + protected fun estimateFullRefreshSyncSize( + database: Database, + configuredAirbyteStream: ConfiguredAirbyteStream? + ) { + /* no-op */ + } + + @Throws(Exception::class) + protected fun discoverWithoutSystemTables( + database: Database + ): List>> { + val systemNameSpaces = excludedInternalNameSpaces + val systemViews = excludedViews + val discoveredTables = discoverInternal(database) + return (if (systemNameSpaces == null || systemNameSpaces.isEmpty()) discoveredTables + else + discoveredTables + .stream() + .filter { table: TableInfo> -> + !systemNameSpaces.contains(table.nameSpace) && !systemViews.contains(table.name) + } + .collect(Collectors.toList())) + } + + protected fun getFullRefreshIterators( + database: Database, + catalog: ConfiguredAirbyteCatalog, + tableNameToTable: Map>>, + stateManager: StateManager?, + emittedAt: Instant + ): List> { + return getSelectedIterators( + database, + catalog, + tableNameToTable, + stateManager, + emittedAt, + SyncMode.FULL_REFRESH + ) + } + + protected fun getIncrementalIterators( + database: Database, + catalog: ConfiguredAirbyteCatalog, + tableNameToTable: Map>>, + stateManager: StateManager?, + emittedAt: Instant + ): List> { + return getSelectedIterators( + database, + catalog, + tableNameToTable, + stateManager, + emittedAt, + SyncMode.INCREMENTAL + ) + } + + /** + * Creates a list of read iterators for each stream within an ConfiguredAirbyteCatalog + * + * @param database Source Database + * @param catalog List of streams (e.g. database tables or API endpoints) with settings on sync + * mode + * @param tableNameToTable Mapping of table name to table + * @param stateManager Manager used to track the state of data synced by the connector + * @param emittedAt Time when data was emitted from the Source database + * @param syncMode the sync mode for which we want to grab the required iterators + * @return List of AirbyteMessageIterators containing all iterators for a catalog + */ + private fun getSelectedIterators( + database: Database, + catalog: ConfiguredAirbyteCatalog?, + tableNameToTable: Map>>, + stateManager: StateManager?, + emittedAt: Instant, + syncMode: SyncMode + ): List> { + val iteratorList: MutableList> = ArrayList() + for (airbyteStream in catalog!!.streams) { + if (airbyteStream.syncMode == syncMode) { + val stream = airbyteStream.stream + val fullyQualifiedTableName = + DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.namespace, stream.name) + if (!tableNameToTable.containsKey(fullyQualifiedTableName)) { + LOGGER.info( + "Skipping stream {} because it is not in the source", + fullyQualifiedTableName + ) + continue + } + + val table = tableNameToTable[fullyQualifiedTableName]!! + val tableReadIterator = + createReadIterator(database, airbyteStream, table, stateManager, emittedAt) + iteratorList.add(tableReadIterator) + } + } + + return iteratorList + } + + /** + * ReadIterator is used to retrieve records from a source connector + * + * @param database Source Database + * @param airbyteStream represents an ingestion source (e.g. API endpoint or database table) + * @param table information in tabular format + * @param stateManager Manager used to track the state of data synced by the connector + * @param emittedAt Time when data was emitted from the Source database + * @return + */ + private fun createReadIterator( + database: Database, + airbyteStream: ConfiguredAirbyteStream, + table: TableInfo>, + stateManager: StateManager?, + emittedAt: Instant + ): AutoCloseableIterator { + val streamName = airbyteStream.stream.name + val namespace = airbyteStream.stream.namespace + val pair = + io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair(streamName, namespace) + val selectedFieldsInCatalog = CatalogHelpers.getTopLevelFieldNames(airbyteStream) + val selectedDatabaseFields = + table.fields + .stream() + .map { obj: CommonField -> obj.name } + .filter { o: String -> selectedFieldsInCatalog.contains(o) } + .collect(Collectors.toList()) + + val iterator: AutoCloseableIterator + // checks for which sync mode we're using based on the configured airbytestream + // this is where the bifurcation between full refresh and incremental + if (airbyteStream.syncMode == SyncMode.INCREMENTAL) { + val cursorField = getCursorField(airbyteStream) + val cursorInfo = stateManager!!.getCursorInfo(pair) + + val airbyteMessageIterator: AutoCloseableIterator + if (cursorInfo!!.map { it.cursor }.isPresent) { + airbyteMessageIterator = + getIncrementalStream( + database, + airbyteStream, + selectedDatabaseFields, + table, + cursorInfo.get(), + emittedAt + ) + } else { + // if no cursor is present then this is the first read for is the same as doing a + // full refresh read. + estimateFullRefreshSyncSize(database, airbyteStream) + airbyteMessageIterator = + getFullRefreshStream( + database, + streamName, + namespace, + selectedDatabaseFields, + table, + emittedAt, + SyncMode.INCREMENTAL, + Optional.of(cursorField) + ) + } + + val cursorType = getCursorType(airbyteStream, cursorField) + + val messageProducer = + CursorStateMessageProducer(stateManager, cursorInfo.map { it.cursor }) + + iterator = + AutoCloseableIterators.transform( + { autoCloseableIterator: AutoCloseableIterator -> + SourceStateIterator( + autoCloseableIterator, + airbyteStream, + messageProducer, + StateEmitFrequency(stateEmissionFrequency.toLong(), Duration.ZERO) + ) + }, + airbyteMessageIterator, + AirbyteStreamUtils.convertFromNameAndNamespace(pair.name, pair.namespace) + ) + } else if (airbyteStream.syncMode == SyncMode.FULL_REFRESH) { + estimateFullRefreshSyncSize(database, airbyteStream) + iterator = + getFullRefreshStream( + database, + streamName, + namespace, + selectedDatabaseFields, + table, + emittedAt, + SyncMode.FULL_REFRESH, + Optional.empty() + ) + } else if (airbyteStream.syncMode == null) { + throw IllegalArgumentException( + String.format("%s requires a source sync mode", this.javaClass) + ) + } else { + throw IllegalArgumentException( + String.format( + "%s does not support sync mode: %s.", + this.javaClass, + airbyteStream.syncMode + ) + ) + } + + val recordCount = AtomicLong() + return AutoCloseableIterators.transform( + iterator, + AirbyteStreamUtils.convertFromNameAndNamespace(pair.name, pair.namespace) + ) { r: AirbyteMessage? -> + val count = recordCount.incrementAndGet() + if (count % 10000 == 0L) { + LOGGER.info("Reading stream {}. Records read: {}", streamName, count) + } + r + } + } + + /** + * @param database Source Database + * @param airbyteStream represents an ingestion source (e.g. API endpoint or database table) + * @param selectedDatabaseFields subset of database fields selected for replication + * @param table information in tabular format + * @param cursorInfo state of where to start the sync from + * @param emittedAt Time when data was emitted from the Source database + * @return AirbyteMessage Iterator that + */ + private fun getIncrementalStream( + database: Database, + airbyteStream: ConfiguredAirbyteStream, + selectedDatabaseFields: List, + table: TableInfo>, + cursorInfo: CursorInfo, + emittedAt: Instant + ): AutoCloseableIterator { + val streamName = airbyteStream.stream.name + val namespace = airbyteStream.stream.namespace + val cursorField = getCursorField(airbyteStream) + val cursorType = + table.fields + .stream() + .filter { info: CommonField -> info.name == cursorField } + .map { obj: CommonField -> obj.type } + .findFirst() + .orElseThrow() + + Preconditions.checkState( + table.fields.stream().anyMatch { f: CommonField -> f.name == cursorField }, + String.format("Could not find cursor field %s in table %s", cursorField, table.name) + ) + + val queryIterator = + queryTableIncremental( + database, + selectedDatabaseFields, + table.nameSpace, + table.name, + cursorInfo, + cursorType + ) + + return getMessageIterator(queryIterator, streamName, namespace, emittedAt.toEpochMilli()) + } + + /** + * Creates a AirbyteMessageIterator that contains all records for a database source connection + * + * @param database Source Database + * @param streamName name of an individual stream in which a stream represents a source (e.g. + * API endpoint or database table) + * @param namespace Namespace of the database (e.g. public) + * @param selectedDatabaseFields List of all interested database column names + * @param table information in tabular format + * @param emittedAt Time when data was emitted from the Source database + * @param syncMode The sync mode that this full refresh stream should be associated with. + * @return AirbyteMessageIterator with all records for a database source + */ + private fun getFullRefreshStream( + database: Database, + streamName: String, + namespace: String, + selectedDatabaseFields: List, + table: TableInfo>, + emittedAt: Instant, + syncMode: SyncMode, + cursorField: Optional + ): AutoCloseableIterator { + val queryStream = + queryTableFullRefresh( + database, + selectedDatabaseFields, + table.nameSpace, + table.name, + syncMode, + cursorField + ) + return getMessageIterator(queryStream, streamName, namespace, emittedAt.toEpochMilli()) + } + + /** + * @param database + * - The database where from privileges for tables will be consumed + * @param schema + * - The schema where from privileges for tables will be consumed + * @return Set with privileges for tables for current DB-session user The method is responsible + * for SELECT-ing the table with privileges. In some cases such SELECT doesn't require (e.g. in + * Oracle DB - the schema is the user, you cannot REVOKE a privilege on a table from its owner). + */ + @Throws(SQLException::class) + protected fun getPrivilegesTableForCurrentUser( + database: JdbcDatabase?, + schema: String? + ): Set { + return emptySet() + } + + /** + * Map a database implementation-specific configuration to json object that adheres to the + * database config spec. See resources/spec.json. + * + * @param config database implementation-specific configuration. + * @return database spec config + */ + @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) + abstract fun toDatabaseConfig(config: JsonNode): JsonNode + + /** + * Creates a database instance using the database spec config. + * + * @param config database spec config + * @return database instance + * @throws Exception might throw an error during connection to database + */ + @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) + @Throws(Exception::class) + protected abstract fun createDatabase(config: JsonNode): Database + + /** + * Gets and logs relevant and useful database metadata such as DB product/version, index names + * and definition. Called before syncing data. Any logged information should be scoped to the + * configured catalog and database. + * + * @param database given database instance. + * @param catalog configured catalog. + */ + @Throws(Exception::class) + protected open fun logPreSyncDebugData( + database: Database, + catalog: ConfiguredAirbyteCatalog? + ) {} + + /** + * Configures a list of operations that can be used to check the connection to the source. + * + * @return list of consumers that run queries for the check command. + */ + @Throws(Exception::class) + protected abstract fun getCheckOperations( + config: JsonNode? + ): List> + + /** + * Map source types to Airbyte types + * + * @param columnType source data type + * @return airbyte data type + */ + protected abstract fun getAirbyteType(columnType: DataType): JsonSchemaType + + protected abstract val excludedInternalNameSpaces: Set + /** + * Get list of system namespaces(schemas) in order to exclude them from the `discover` + * result list. + * + * @return set of system namespaces(schemas) to be excluded + */ + get + + protected val excludedViews: Set + /** + * Get list of system views in order to exclude them from the `discover` result list. + * + * @return set of views to be excluded + */ + get() = emptySet() + + /** + * Discover all available tables in the source database. + * + * @param database source database + * @return list of the source tables + * @throws Exception access to the database might lead to an exceptions. + */ + @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) + @Throws(Exception::class) + protected abstract fun discoverInternal( + database: Database + ): List>> + + /** + * Discovers all available tables within a schema in the source database. + * + * @param database + * - source database + * @param schema + * - source schema + * @return list of source tables + * @throws Exception + * - access to the database might lead to exceptions. + */ + @Throws(Exception::class) + protected abstract fun discoverInternal( + database: Database, + schema: String? + ): List>> + + /** + * Discover Primary keys for each table and @return a map of namespace.table name to their + * associated list of primary key fields. + * + * @param database source database + * @param tableInfos list of tables + * @return map of namespace.table and primary key fields. + */ + protected abstract fun discoverPrimaryKeys( + database: Database, + tableInfos: List>> + ): Map> + + protected abstract val quoteString: String? + /** + * Returns quote symbol of the database + * + * @return quote symbol + */ + get + + /** + * Read all data from a table. + * + * @param database source database + * @param columnNames interested column names + * @param schemaName table namespace + * @param tableName target table + * @param syncMode The sync mode that this full refresh stream should be associated with. + * @return iterator with read data + */ + protected abstract fun queryTableFullRefresh( + database: Database, + columnNames: List, + schemaName: String?, + tableName: String, + syncMode: SyncMode, + cursorField: Optional + ): AutoCloseableIterator + + /** + * Read incremental data from a table. Incremental read should return only records where cursor + * column value is bigger than cursor. Note that if the connector needs to emit intermediate + * state (i.e. [AbstractDbSource.getStateEmissionFrequency] > 0), the incremental query must be + * sorted by the cursor field. + * + * @return iterator with read data + */ + protected abstract fun queryTableIncremental( + database: Database, + columnNames: List, + schemaName: String?, + tableName: String, + cursorInfo: CursorInfo, + cursorFieldType: DataType + ): AutoCloseableIterator + + protected val stateEmissionFrequency: Int + /** + * When larger than 0, the incremental iterator will emit intermediate state for every N + * records. Please note that if intermediate state emission is enabled, the incremental + * query must be ordered by the cursor field. + * + * TODO: Return an optional value instead of 0 to make it easier to understand. + */ + get() = 0 + + /** @return list of fields that could be used as cursors */ + protected abstract fun isCursorType(type: DataType): Boolean + + /** + * Returns the [AirbyteStateType] supported by this connector. + * + * @param config The connector configuration. + * @return A [AirbyteStateType] representing the state supported by this connector. + */ + protected open fun getSupportedStateType( + config: JsonNode? + ): AirbyteStateMessage.AirbyteStateType { + return AirbyteStateMessage.AirbyteStateType.STREAM + } + + companion object { + const val CHECK_TRACE_OPERATION_NAME: String = "check-operation" + const val DISCOVER_TRACE_OPERATION_NAME: String = "discover-operation" + const val READ_TRACE_OPERATION_NAME: String = "read-operation" + + @JvmStatic + protected val LOGGER: Logger = LoggerFactory.getLogger(AbstractDbSource::class.java) + + private fun getMessageIterator( + recordIterator: AutoCloseableIterator, + streamName: String, + namespace: String, + emittedAt: Long + ): AutoCloseableIterator { + return AutoCloseableIterators.transform( + recordIterator, + AirbyteStreamNameNamespacePair(streamName, namespace) + ) { airbyteRecordData -> + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName) + .withNamespace(namespace) + .withEmittedAt(emittedAt) + .withData(airbyteRecordData.rawRowData) + .withMeta( + if (isMetaChangesEmptyOrNull(airbyteRecordData.meta)) null + else airbyteRecordData.meta + ) + ) + } + } + + private fun isMetaChangesEmptyOrNull(meta: AirbyteRecordMessageMeta?): Boolean { + return meta == null || meta.changes == null || meta.changes.isEmpty() + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.kt new file mode 100644 index 000000000000..7662c0fb9878 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.kt @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import java.util.* +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +class CdcStateManager( + private val initialState: CdcState?, + initialStreamsSynced: Set?, + stateMessage: AirbyteStateMessage? +) { + val initialStreamsSynced: Set? + val rawStateMessage: AirbyteStateMessage? + private var currentState: CdcState? + + init { + this.currentState = initialState + this.initialStreamsSynced = + if (initialStreamsSynced != null) Collections.unmodifiableSet(initialStreamsSynced) + else null + this.rawStateMessage = stateMessage + LOGGER.info("Initialized CDC state") + } + + var cdcState: CdcState? + get() = if (currentState != null) Jsons.clone(currentState) else null + set(state) { + this.currentState = state + } + + override fun toString(): String { + return "CdcStateManager{" + + "initialState=" + + initialState + + ", currentState=" + + currentState + + '}' + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(CdcStateManager::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.kt new file mode 100644 index 000000000000..b4e4721d1bb1 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.kt @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import java.util.* + +class CursorInfo( + val originalCursorField: String?, + val originalCursor: String?, + val originalCursorRecordCount: Long, + val cursorField: String?, + var cursor: String?, + var cursorRecordCount: Long +) { + constructor( + originalCursorField: String?, + originalCursor: String?, + cursorField: String?, + cursor: String? + ) : this(originalCursorField, originalCursor, 0L, cursorField, cursor, 0L) + + fun setCursor(cursor: String?): CursorInfo { + this.cursor = cursor + return this + } + + fun setCursorRecordCount(cursorRecordCount: Long): CursorInfo { + this.cursorRecordCount = cursorRecordCount + return this + } + + override fun equals(o: Any?): Boolean { + if (this === o) { + return true + } + if (o == null || javaClass != o.javaClass) { + return false + } + val that = o as CursorInfo + return originalCursorField == that.originalCursorField && + originalCursor == that.originalCursor && + originalCursorRecordCount == that.originalCursorRecordCount && + cursorField == that.cursorField && + cursor == that.cursor && + cursorRecordCount == that.cursorRecordCount + } + + override fun hashCode(): Int { + return Objects.hash( + originalCursorField, + originalCursor, + originalCursorRecordCount, + cursorField, + cursor, + cursorRecordCount + ) + } + + override fun toString(): String { + return "CursorInfo{" + + "originalCursorField='" + + originalCursorField + + '\'' + + ", originalCursor='" + + originalCursor + + '\'' + + ", originalCursorRecordCount='" + + originalCursorRecordCount + + '\'' + + ", cursorField='" + + cursorField + + '\'' + + ", cursor='" + + cursor + + '\'' + + ", cursorRecordCount='" + + cursorRecordCount + + '\'' + + '}' + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.kt new file mode 100644 index 000000000000..4bf46677fd3b --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.kt @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import com.google.common.collect.Lists +import io.airbyte.protocol.models.CommonField +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.AirbyteCatalog +import io.airbyte.protocol.models.v0.CatalogHelpers +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.SyncMode +import java.util.* +import java.util.function.Consumer +import java.util.function.Function +import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** Contains utilities and helper classes for discovering schemas in database sources. */ +object DbSourceDiscoverUtil { + private val LOGGER: Logger = LoggerFactory.getLogger(DbSourceDiscoverUtil::class.java) + private val AIRBYTE_METADATA: List = + mutableListOf("_ab_cdc_lsn", "_ab_cdc_updated_at", "_ab_cdc_deleted_at") + + /* + * This method logs schema drift between source table and the catalog. This can happen if (i) + * underlying table schema changed between syncs (ii) The source connector's mapping of datatypes to + * Airbyte types changed between runs + */ + fun logSourceSchemaChange( + fullyQualifiedTableNameToInfo: Map>>, + catalog: ConfiguredAirbyteCatalog, + airbyteTypeConverter: Function + ) { + for (airbyteStream in catalog.streams) { + val stream = airbyteStream.stream + val fullyQualifiedTableName = getFullyQualifiedTableName(stream.namespace, stream.name) + if (!fullyQualifiedTableNameToInfo.containsKey(fullyQualifiedTableName)) { + continue + } + val table = fullyQualifiedTableNameToInfo[fullyQualifiedTableName]!! + val fields = + table.fields + .stream() + .map { commonField: CommonField -> + toField(commonField, airbyteTypeConverter) + } + .distinct() + .collect(Collectors.toList()) + val currentJsonSchema = CatalogHelpers.fieldsToJsonSchema(fields) + val catalogSchema = stream.jsonSchema + val currentSchemaProperties = currentJsonSchema["properties"] + val catalogProperties = catalogSchema["properties"] + val mismatchedFields: MutableList = ArrayList() + catalogProperties.fieldNames().forEachRemaining { fieldName: String -> + // Ignoring metadata fields since those are automatically added onto the catalog + // schema by Airbyte + // and don't exist in the source schema. They should not be considered a change + if (AIRBYTE_METADATA.contains(fieldName)) { + return@forEachRemaining + } + if ( + !currentSchemaProperties.has(fieldName) || + currentSchemaProperties[fieldName] != catalogProperties[fieldName] + ) { + mismatchedFields.add(fieldName) + } + } + + if (!mismatchedFields.isEmpty()) { + LOGGER.warn( + "Source schema changed for table {}! Potential mismatches: {}. Actual schema: {}. Catalog schema: {}", + fullyQualifiedTableName, + java.lang.String.join(", ", mismatchedFields.toString()), + currentJsonSchema, + catalogSchema + ) + } + } + } + + fun convertTableInfosToAirbyteCatalog( + tableInfos: List>>, + fullyQualifiedTableNameToPrimaryKeys: Map>, + airbyteTypeConverter: Function + ): AirbyteCatalog { + val tableInfoFieldList = + tableInfos + .stream() + .map { t: TableInfo> -> + // some databases return multiple copies of the same record for a column (e.g. + // redshift) because + // they have at least once delivery guarantees. we want to dedupe these, but + // first we check that the + // records are actually the same and provide a good error message if they are + // not. + assertColumnsWithSameNameAreSame(t.nameSpace, t.name, t.fields) + val fields = + t.fields + .stream() + .map { commonField: CommonField -> + toField(commonField, airbyteTypeConverter) + } + .distinct() + .collect(Collectors.toList()) + val fullyQualifiedTableName = getFullyQualifiedTableName(t.nameSpace, t.name) + val primaryKeys = + fullyQualifiedTableNameToPrimaryKeys.getOrDefault( + fullyQualifiedTableName, + emptyList() + ) + TableInfo( + nameSpace = t.nameSpace, + name = t.name, + fields = fields, + primaryKeys = primaryKeys, + cursorFields = t.cursorFields + ) + } + .collect(Collectors.toList()) + + val streams = + tableInfoFieldList + .stream() + .map { tableInfo: TableInfo -> + val primaryKeys = + tableInfo.primaryKeys + .stream() + .filter { obj: String? -> Objects.nonNull(obj) } + .map { listOf(it) } + .toList() + CatalogHelpers.createAirbyteStream( + tableInfo.name, + tableInfo.nameSpace, + tableInfo.fields + ) + .withSupportedSyncModes( + if (tableInfo.cursorFields != null && tableInfo.cursorFields.isEmpty()) + Lists.newArrayList(SyncMode.FULL_REFRESH) + else Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(primaryKeys) + } + .collect(Collectors.toList()) + return AirbyteCatalog().withStreams(streams) + } + + fun getFullyQualifiedTableName(nameSpace: String?, tableName: String): String { + return if (nameSpace != null) "$nameSpace.$tableName" else tableName + } + + private fun toField( + commonField: CommonField, + airbyteTypeConverter: Function + ): Field { + if ( + airbyteTypeConverter.apply(commonField.type) === JsonSchemaType.OBJECT && + commonField.properties != null && + !commonField.properties.isEmpty() + ) { + val properties = + commonField.properties + .stream() + .map { commField: CommonField -> + toField(commField, airbyteTypeConverter) + } + .toList() + return Field.of( + commonField.name, + airbyteTypeConverter.apply(commonField.type), + properties + ) + } else { + return Field.of(commonField.name, airbyteTypeConverter.apply(commonField.type)) + } + } + + private fun assertColumnsWithSameNameAreSame( + nameSpace: String, + tableName: String, + columns: List> + ) { + columns + .stream() + .collect(Collectors.groupingBy(Function { obj: CommonField -> obj.name })) + .values + .forEach( + Consumer { columnsWithSameName: List> -> + val comparisonColumn = columnsWithSameName[0] + columnsWithSameName.forEach( + Consumer { column: CommonField -> + if (column != comparisonColumn) { + throw RuntimeException( + String.format( + "Found multiple columns with same name: %s in table: %s.%s but the columns are not the same. columns: %s", + comparisonColumn.name, + nameSpace, + tableName, + columns + ) + ) + } + } + ) + } + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.kt new file mode 100644 index 000000000000..d2c8e2b5ee01 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.kt @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import java.util.stream.Collectors + +object InvalidCursorInfoUtil { + fun getInvalidCursorConfigMessage(tablesWithInvalidCursor: List): String { + return ("The following tables have invalid columns selected as cursor, please select a column with a well-defined ordering with no null values as a cursor. " + + tablesWithInvalidCursor + .stream() + .map { obj: InvalidCursorInfo -> obj.toString() } + .collect(Collectors.joining(","))) + } + + class InvalidCursorInfo( + tableName: String?, + cursorColumnName: String, + cursorSqlType: String, + cause: String + ) { + override fun toString(): String { + return "{" + + "tableName='" + + tableName + + '\'' + + ", cursorColumnName='" + + cursorColumnName + + '\'' + + ", cursorSqlType=" + + cursorSqlType + + ", cause=" + + cause + + '}' + } + + val tableName: String? + val cursorColumnName: String + val cursorSqlType: String + val cause: String + + init { + this.tableName = tableName + this.cursorColumnName = cursorColumnName + this.cursorSqlType = cursorSqlType + this.cause = cause + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.kt new file mode 100644 index 000000000000..bd164a44486a --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.kt @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.SqlDatabase +import io.airbyte.commons.stream.AirbyteStreamUtils +import io.airbyte.commons.util.AutoCloseableIterator +import io.airbyte.commons.util.AutoCloseableIterators +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import java.util.* +import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** Utility class for methods to query a relational db. */ +object RelationalDbQueryUtils { + private val LOGGER: Logger = LoggerFactory.getLogger(RelationalDbQueryUtils::class.java) + + fun getIdentifierWithQuoting(identifier: String, quoteString: String): String { + // double-quoted values within a database name or column name should be wrapped with extra + // quoteString + return if (identifier.startsWith(quoteString) && identifier.endsWith(quoteString)) { + quoteString + quoteString + identifier + quoteString + quoteString + } else { + quoteString + identifier + quoteString + } + } + + fun enquoteIdentifierList(identifiers: List, quoteString: String): String { + val joiner = StringJoiner(",") + for (identifier in identifiers) { + joiner.add(getIdentifierWithQuoting(identifier, quoteString)) + } + return joiner.toString() + } + + /** @return fully qualified table name with the schema (if a schema exists) in quotes. */ + fun getFullyQualifiedTableNameWithQuoting( + nameSpace: String?, + tableName: String, + quoteString: String + ): String { + return (if (nameSpace == null || nameSpace.isEmpty()) + getIdentifierWithQuoting(tableName, quoteString) + else + getIdentifierWithQuoting(nameSpace, quoteString) + + "." + + getIdentifierWithQuoting(tableName, quoteString)) + } + + /** @return fully qualified table name with the schema (if a schema exists) without quotes. */ + fun getFullyQualifiedTableName(schemaName: String?, tableName: String): String { + return if (schemaName != null) "$schemaName.$tableName" else tableName + } + + /** @return the input identifier with quotes. */ + fun enquoteIdentifier(identifier: String?, quoteString: String?): String { + return quoteString + identifier + quoteString + } + + fun queryTable( + database: Database, + sqlQuery: String?, + tableName: String?, + schemaName: String? + ): AutoCloseableIterator { + val airbyteStreamNameNamespacePair = + AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName) + return AutoCloseableIterators.lazyIterator( + { + try { + LOGGER.info("Queueing query: {}", sqlQuery) + val stream = database!!.unsafeQuery(sqlQuery) + return@lazyIterator AutoCloseableIterators.fromStream( + stream, + airbyteStreamNameNamespacePair + ) + } catch (e: Exception) { + throw RuntimeException(e) + } + }, + airbyteStreamNameNamespacePair + ) + } + + fun logStreamSyncStatus(streams: List, syncType: String?) { + if (streams.isEmpty()) { + LOGGER.info("No Streams will be synced via {}.", syncType) + } else { + LOGGER.info("Streams to be synced via {} : {}", syncType, streams.size) + LOGGER.info("Streams: {}", prettyPrintConfiguredAirbyteStreamList(streams)) + } + } + + fun prettyPrintConfiguredAirbyteStreamList(streamList: List): String { + return streamList + .stream() + .map { s: ConfiguredAirbyteStream -> + "%s.%s".formatted(s.stream.namespace, s.stream.name) + } + .collect(Collectors.joining(", ")) + } + + class TableSizeInfo(tableSize: Long, avgRowLength: Long) { + val tableSize: Long + val avgRowLength: Long + + init { + this.tableSize = tableSize + this.avgRowLength = avgRowLength + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.kt new file mode 100644 index 000000000000..492878a63851 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.kt @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import com.google.common.collect.Sets +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import io.airbyte.protocol.models.v0.SyncMode +import java.util.stream.Collectors + +object RelationalDbReadUtil { + fun identifyStreamsToSnapshot( + catalog: ConfiguredAirbyteCatalog, + alreadySyncedStreams: Set + ): List { + val allStreams = AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog) + val newlyAddedStreams: Set = + HashSet(Sets.difference(allStreams, alreadySyncedStreams)) + return catalog.streams + .stream() + .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } + .filter { stream: ConfiguredAirbyteStream -> + newlyAddedStreams.contains( + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) + ) + } + .map { `object`: ConfiguredAirbyteStream -> Jsons.clone(`object`) } + .collect(Collectors.toList()) + } + + fun identifyStreamsForCursorBased( + catalog: ConfiguredAirbyteCatalog, + streamsForInitialLoad: List + ): List { + val initialLoadStreamsNamespacePairs = + streamsForInitialLoad + .stream() + .map { stream: ConfiguredAirbyteStream -> + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) + } + .collect(Collectors.toSet()) + return catalog.streams + .stream() + .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } + .filter { stream: ConfiguredAirbyteStream -> + !initialLoadStreamsNamespacePairs.contains( + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) + ) + } + .map { `object`: ConfiguredAirbyteStream -> Jsons.clone(`object`) } + .collect(Collectors.toList()) + } + + fun convertNameNamespacePairFromV0( + v1NameNamespacePair: io.airbyte.protocol.models.AirbyteStreamNameNamespacePair + ): AirbyteStreamNameNamespacePair { + return AirbyteStreamNameNamespacePair( + v1NameNamespacePair.name, + v1NameNamespacePair.namespace + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.kt new file mode 100644 index 000000000000..7d7bc4498cde --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.kt @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import com.google.common.collect.AbstractIterator +import io.airbyte.cdk.db.IncrementalUtils.compareCursors +import io.airbyte.cdk.integrations.source.relationaldb.state.StateManager +import io.airbyte.protocol.models.JsonSchemaPrimitiveUtil +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteStateStats +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import java.util.* +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +@Deprecated("") +class StateDecoratingIterator( + private val messageIterator: Iterator, + private val stateManager: StateManager, + private val pair: AirbyteStreamNameNamespacePair, + private val cursorField: String, + private val initialCursor: String, + private val cursorType: JsonSchemaPrimitiveUtil.JsonSchemaPrimitive, + stateEmissionFrequency: Int +) : AbstractIterator(), MutableIterator { + private var currentMaxCursor: String? + private var currentMaxCursorRecordCount = 0L + private var hasEmittedFinalState = false + + /** + * These parameters are for intermediate state message emission. We can emit an intermediate + * state when the following two conditions are met. + * + * 1. The records are sorted by the cursor field. This is true when `stateEmissionFrequency` > + * 0. This logic is guaranteed in `AbstractJdbcSource#queryTableIncremental`, in which an "ORDER + * BY" clause is appended to the SQL query if `stateEmissionFrequency` > 0. + * + * 2. There is a cursor value that is ready for emission. A cursor value is "ready" if there is + * no more record with the same value. We cannot emit a cursor at will, because there may be + * multiple records with the same cursor value. If we emit a cursor ignoring this condition, + * should the sync fail right after the emission, the next sync may skip some records with the + * same cursor value due to "WHERE cursor_field > cursor" in + * `AbstractJdbcSource#queryTableIncremental`. + * + * The `intermediateStateMessage` is set to the latest state message that is ready for emission. + * For every `stateEmissionFrequency` messages, `emitIntermediateState` is set to true and the + * latest "ready" state will be emitted in the next `computeNext` call. + */ + private val stateEmissionFrequency: Int + private var totalRecordCount = 0 + + // In between each state message, recordCountInStateMessage will be reset to 0. + private var recordCountInStateMessage = 0 + private var emitIntermediateState = false + private var intermediateStateMessage: AirbyteMessage? = null + private var hasCaughtException = false + + /** + * @param stateManager Manager that maintains connector state + * @param pair Stream Name and Namespace (e.g. public.users) + * @param cursorField Path to the comparator field used to track the records read so far + * @param initialCursor name of the initial cursor column + * @param cursorType ENUM type of primitive values that can be used as a cursor for + * checkpointing + * @param stateEmissionFrequency If larger than 0, the records are sorted by the cursor field, + * and intermediate states will be emitted for every `stateEmissionFrequency` records. The order + * of the records is guaranteed in `AbstractJdbcSource#queryTableIncremental`, in which an + * "ORDER BY" clause is appended to the SQL query if `stateEmissionFrequency` > 0. + */ + init { + this.currentMaxCursor = initialCursor + this.stateEmissionFrequency = stateEmissionFrequency + } + + private fun getCursorCandidate(message: AirbyteMessage): String? { + val cursorCandidate = message.record.data[cursorField].asText() + return (if (cursorCandidate != null) replaceNull(cursorCandidate) else null) + } + + private fun replaceNull(cursorCandidate: String): String { + if (cursorCandidate.contains("\u0000")) { + return cursorCandidate.replace("\u0000".toRegex(), "") + } + return cursorCandidate + } + + /** + * Computes the next record retrieved from Source stream. Emits StateMessage containing data of + * the record that has been read so far + * + * If this method throws an exception, it will propagate outward to the `hasNext` or `next` + * invocation that invoked this method. Any further attempts to use the iterator will result in + * an [IllegalStateException]. + * + * @return [AirbyteStateMessage] containing information of the records read so far + */ + override fun computeNext(): AirbyteMessage? { + if (hasCaughtException) { + // Mark iterator as done since the next call to messageIterator will result in an + // IllegalArgumentException and resets exception caught state. + // This occurs when the previous iteration emitted state so this iteration cycle will + // indicate + // iteration is complete + hasCaughtException = false + return endOfData() + } + + if (messageIterator.hasNext()) { + var optionalIntermediateMessage = intermediateMessage + if (optionalIntermediateMessage.isPresent) { + return optionalIntermediateMessage.get() + } + + totalRecordCount++ + recordCountInStateMessage++ + // Use try-catch to catch Exception that could occur when connection to the database + // fails + try { + val message = messageIterator.next() + if (message.record.data.hasNonNull(cursorField)) { + val cursorCandidate = getCursorCandidate(message) + val cursorComparison = + compareCursors(currentMaxCursor, cursorCandidate, cursorType) + if (cursorComparison < 0) { + // Update the current max cursor only when current max cursor < cursor + // candidate from the message + if ( + stateEmissionFrequency > 0 && + currentMaxCursor != initialCursor && + messageIterator.hasNext() + ) { + // Only create an intermediate state when it is not the first or last + // record message. + // The last state message will be processed seperately. + intermediateStateMessage = + createStateMessage(false, recordCountInStateMessage) + } + currentMaxCursor = cursorCandidate + currentMaxCursorRecordCount = 1L + } else if (cursorComparison == 0) { + currentMaxCursorRecordCount++ + } else if (cursorComparison > 0 && stateEmissionFrequency > 0) { + LOGGER.warn( + "Intermediate state emission feature requires records to be processed in order according to the cursor value. Otherwise, " + + "data loss can occur." + ) + } + } + + if (stateEmissionFrequency > 0 && totalRecordCount % stateEmissionFrequency == 0) { + emitIntermediateState = true + } + + return message + } catch (e: Exception) { + emitIntermediateState = true + hasCaughtException = true + LOGGER.error("Message iterator failed to read next record.", e) + optionalIntermediateMessage = intermediateMessage + return optionalIntermediateMessage.orElse(endOfData()) + } + } else if (!hasEmittedFinalState) { + return createStateMessage(true, recordCountInStateMessage) + } else { + return endOfData() + } + } + + protected val intermediateMessage: Optional + /** + * Returns AirbyteStateMessage when in a ready state, a ready state means that it has + * satifies the conditions of: + * + * cursorField has changed (e.g. 08-22-2022 -> 08-23-2022) and there have been at least + * stateEmissionFrequency number of records since the last emission + * + * @return AirbyteStateMessage if one exists, otherwise Optional indicating state was not + * ready to be emitted + */ + get() { + val message: AirbyteMessage? = intermediateStateMessage + if (emitIntermediateState && message != null) { + if (message.state != null) { + message.state.sourceStats = + AirbyteStateStats().withRecordCount(recordCountInStateMessage.toDouble()) + } + + intermediateStateMessage = null + recordCountInStateMessage = 0 + emitIntermediateState = false + return Optional.of(message) + } + return Optional.empty() + } + + /** + * Creates AirbyteStateMessage while updating the cursor used to checkpoint the state of records + * read up so far + * + * @param isFinalState marker for if the final state of the iterator has been reached + * @param recordCount count of read messages + * @return AirbyteMessage which includes information on state of records read so far + */ + fun createStateMessage(isFinalState: Boolean, recordCount: Int): AirbyteMessage { + val stateMessage = + stateManager.updateAndEmit(pair, currentMaxCursor, currentMaxCursorRecordCount) + val cursorInfo = stateManager.getCursorInfo(pair) + + // logging once every 100 messages to reduce log verbosity + if (recordCount % 100 == 0) { + LOGGER.info( + "State report for stream {} - original: {} = {} (count {}) -> latest: {} = {} (count {})", + pair, + cursorInfo.map { obj: CursorInfo -> obj.originalCursorField }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.originalCursor }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.originalCursorRecordCount }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.cursorField }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.cursor }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.cursorRecordCount }.orElse(null) + ) + } + + stateMessage?.withSourceStats(AirbyteStateStats().withRecordCount(recordCount.toDouble())) + if (isFinalState) { + hasEmittedFinalState = true + if (stateManager.getCursor(pair).isEmpty) { + LOGGER.warn( + "Cursor for stream {} was null. This stream will replicate all records on the next run", + pair + ) + } + } + + return AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(StateDecoratingIterator::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.kt new file mode 100644 index 000000000000..46ebe3bd96d8 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.kt @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +/** This class encapsulates all externally relevant Table information. */ +data class TableInfo( + val nameSpace: String, + val name: String, + val fields: List, + val primaryKeys: List = emptyList(), + val cursorFields: List +) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.kt new file mode 100644 index 000000000000..935f8c6d008d --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.kt @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import java.util.* +import java.util.function.Function +import java.util.function.Supplier + +/** + * Abstract implementation of the [StateManager] interface that provides common functionality for + * state manager implementations. + * + * @param The type associated with the state object managed by this manager. + * @param The type associated with the state object stored in the state managed by this manager. + * + */ +abstract class AbstractStateManager +@JvmOverloads +constructor( + catalog: ConfiguredAirbyteCatalog, + streamSupplier: Supplier>, + cursorFunction: Function?, + cursorFieldFunction: Function>?, + cursorRecordCountFunction: Function?, + namespacePairFunction: Function?, + onlyIncludeIncrementalStreams: Boolean = false +) : StateManager { + /** + * The [CursorManager] responsible for keeping track of the current cursor value for each stream + * managed by this state manager. + */ + private val cursorManager: CursorManager<*> = + CursorManager( + catalog, + streamSupplier, + cursorFunction, + cursorFieldFunction, + cursorRecordCountFunction, + namespacePairFunction, + onlyIncludeIncrementalStreams + ) + + override val pairToCursorInfoMap: Map + get() = cursorManager.pairToCursorInfo + + abstract override fun toState( + pair: Optional + ): AirbyteStateMessage +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.kt new file mode 100644 index 000000000000..657e9437c603 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.kt @@ -0,0 +1,311 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import com.google.common.annotations.VisibleForTesting +import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo +import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.concurrent.* +import java.util.function.Function +import java.util.function.Supplier +import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Manages the map of streams to current cursor values for state management. + * + * @param The type that represents the stream object which holds the current cursor information + * in the state. + */ +class CursorManager( + catalog: ConfiguredAirbyteCatalog, + streamSupplier: Supplier>, + cursorFunction: Function?, + cursorFieldFunction: Function>?, + cursorRecordCountFunction: Function?, + namespacePairFunction: Function?, + onlyIncludeIncrementalStreams: Boolean +) { + /** + * Map of streams (name/namespace tuple) to the current cursor information stored in the state. + */ + val pairToCursorInfo: Map + + /** + * Constructs a new [CursorManager] based on the configured connector and current state + * information. + * + * @param catalog The connector's configured catalog. + * @param streamSupplier A [Supplier] that provides the cursor manager with the collection of + * streams tracked by the connector's state. + * @param cursorFunction A [Function] that extracts the current cursor from a stream stored in + * the connector's state. + * @param cursorFieldFunction A [Function] that extracts the cursor field name from a stream + * stored in the connector's state. + * @param cursorRecordCountFunction A [Function] that extracts the cursor record count for a + * stream stored in the connector's state. + * @param namespacePairFunction A [Function] that generates a [AirbyteStreamNameNamespacePair] + * that identifies each stream in the connector's state. + */ + init { + pairToCursorInfo = + createCursorInfoMap( + catalog, + streamSupplier, + cursorFunction, + cursorFieldFunction, + cursorRecordCountFunction, + namespacePairFunction, + onlyIncludeIncrementalStreams + ) + } + + /** + * Creates the cursor information map that associates stream name/namespace tuples with the + * current cursor information for that stream as stored in the connector's state. + * + * @param catalog The connector's configured catalog. + * @param streamSupplier A [Supplier] that provides the cursor manager with the collection of + * streams tracked by the connector's state. + * @param cursorFunction A [Function] that extracts the current cursor from a stream stored in + * the connector's state. + * @param cursorFieldFunction A [Function] that extracts the cursor field name from a stream + * stored in the connector's state. + * @param cursorRecordCountFunction A [Function] that extracts the cursor record count for a + * stream stored in the connector's state. + * @param namespacePairFunction A [Function] that generates a [AirbyteStreamNameNamespacePair] + * that identifies each stream in the connector's state. + * @return A map of streams to current cursor information for the stream. + */ + @VisibleForTesting + protected fun createCursorInfoMap( + catalog: ConfiguredAirbyteCatalog, + streamSupplier: Supplier>, + cursorFunction: Function?, + cursorFieldFunction: Function>?, + cursorRecordCountFunction: Function?, + namespacePairFunction: Function?, + onlyIncludeIncrementalStreams: Boolean + ): Map { + val allStreamNames = + catalog.streams + .stream() + .filter { c: ConfiguredAirbyteStream -> + if (onlyIncludeIncrementalStreams) { + return@filter c.syncMode == SyncMode.INCREMENTAL + } + true + } + .map { obj: ConfiguredAirbyteStream -> obj.stream } + .map { stream: AirbyteStream? -> + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream) + } + .collect(Collectors.toSet()) + allStreamNames.addAll( + streamSupplier + .get() + .stream() + .map(namespacePairFunction) + .filter { obj: AirbyteStreamNameNamespacePair? -> Objects.nonNull(obj) } + .collect(Collectors.toSet()) + ) + + val localMap: MutableMap = ConcurrentHashMap() + val pairToState = + streamSupplier + .get() + .stream() + .collect(Collectors.toMap(namespacePairFunction, Function.identity())) + val pairToConfiguredAirbyteStream = + catalog.streams + .stream() + .collect( + Collectors.toMap( + Function { stream: ConfiguredAirbyteStream? -> + AirbyteStreamNameNamespacePair.fromConfiguredAirbyteSteam(stream) + }, + Function.identity() + ) + ) + + for (pair in allStreamNames) { + val stateOptional: Optional = Optional.ofNullable(pairToState[pair]) + val streamOptional = Optional.ofNullable(pairToConfiguredAirbyteStream[pair]) + localMap[pair] = + createCursorInfoForStream( + pair, + stateOptional, + streamOptional, + cursorFunction, + cursorFieldFunction, + cursorRecordCountFunction + ) + } + + return localMap.toMap() + } + + /** + * Generates a [CursorInfo] object based on the data currently stored in the connector's state + * for the given stream. + * + * @param pair A [AirbyteStreamNameNamespacePair] that identifies a specific stream managed by + * the connector. + * @param stateOptional [Optional] containing the current state associated with the stream. + * @param streamOptional [Optional] containing the [ConfiguredAirbyteStream] associated with the + * stream. + * @param cursorFunction A [Function] that provides the current cursor from the state associated + * with the stream. + * @param cursorFieldFunction A [Function] that provides the cursor field name for the cursor + * stored in the state associated with the stream. + * @param cursorRecordCountFunction A [Function] that extracts the cursor record count for a + * stream stored in the connector's state. + * @return A [CursorInfo] object based on the data currently stored in the connector's state for + * the given stream. + */ + internal fun createCursorInfoForStream( + pair: AirbyteStreamNameNamespacePair?, + stateOptional: Optional, + streamOptional: Optional, + cursorFunction: Function?, + cursorFieldFunction: Function>?, + cursorRecordCountFunction: Function? + ): CursorInfo { + val originalCursorField = + stateOptional + .map(cursorFieldFunction) + .flatMap { f: List -> + if (f.isNotEmpty()) Optional.of(f[0]) else Optional.empty() + } + .orElse(null) + val originalCursor = stateOptional.map(cursorFunction).orElse(null) + val originalCursorRecordCount = stateOptional.map(cursorRecordCountFunction).orElse(0L) + + val cursor: String? + val cursorField: String? + val cursorRecordCount: Long + + // if cursor field is set in catalog. + if ( + streamOptional + .map> { obj: ConfiguredAirbyteStream -> obj.cursorField } + .isPresent + ) { + cursorField = + streamOptional + .map { obj: ConfiguredAirbyteStream -> obj.cursorField } + .flatMap { f: List -> + if (f.size > 0) Optional.of(f[0]) else Optional.empty() + } + .orElse(null) + // if cursor field is set in state. + if (stateOptional.map?>(cursorFieldFunction).isPresent) { + // if cursor field in catalog and state are the same. + if ( + stateOptional.map?>(cursorFieldFunction) == + streamOptional.map> { obj: ConfiguredAirbyteStream -> + obj.cursorField + } + ) { + cursor = stateOptional.map(cursorFunction).orElse(null) + cursorRecordCount = stateOptional.map(cursorRecordCountFunction).orElse(0L) + // If a matching cursor is found in the state, and it's value is null - this + // indicates a CDC stream + // and we shouldn't log anything. + if (cursor != null) { + LOGGER.info( + "Found matching cursor in state. Stream: {}. Cursor Field: {} Value: {} Count: {}", + pair, + cursorField, + cursor, + cursorRecordCount + ) + } + // if cursor field in catalog and state are different. + } else { + cursor = null + cursorRecordCount = 0L + LOGGER.info( + "Found cursor field. Does not match previous cursor field. Stream: {}. Original Cursor Field: {} (count {}). New Cursor Field: {}. Resetting cursor value.", + pair, + originalCursorField, + originalCursorRecordCount, + cursorField + ) + } + // if cursor field is not set in state but is set in catalog. + } else { + LOGGER.info( + "No cursor field set in catalog but not present in state. Stream: {}, New Cursor Field: {}. Resetting cursor value", + pair, + cursorField + ) + cursor = null + cursorRecordCount = 0L + } + // if cursor field is not set in catalog. + } else { + LOGGER.info( + "Cursor field set in state but not present in catalog. Stream: {}. Original Cursor Field: {}. Original value: {}. Resetting cursor.", + pair, + originalCursorField, + originalCursor + ) + cursorField = null + cursor = null + cursorRecordCount = 0L + } + + return CursorInfo( + originalCursorField, + originalCursor, + originalCursorRecordCount, + cursorField, + cursor, + cursorRecordCount + ) + } + + /** + * Retrieves an [Optional] possibly containing the current [CursorInfo] associated with the + * provided stream name/namespace tuple. + * + * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. + * @return An [Optional] possibly containing the current [CursorInfo] associated with the + * provided stream name/namespace tuple. + */ + fun getCursorInfo(pair: AirbyteStreamNameNamespacePair?): Optional { + return Optional.ofNullable(pairToCursorInfo[pair]) + } + + /** + * Retrieves an [Optional] possibly containing the cursor field name associated with the cursor + * tracked in the state associated with the provided stream name/namespace tuple. + * + * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. + * @return An [Optional] possibly containing the cursor field name associated with the cursor + * tracked in the state associated with the provided stream name/namespace tuple. + */ + fun getCursorField(pair: AirbyteStreamNameNamespacePair?): Optional { + return getCursorInfo(pair).map { obj: CursorInfo -> obj.cursorField } + } + + /** + * Retrieves an [Optional] possibly containing the cursor value tracked in the state associated + * with the provided stream name/namespace tuple. + * + * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. + * @return An [Optional] possibly containing the cursor value tracked in the state associated + * with the provided stream name/namespace tuple. + */ + fun getCursor(pair: AirbyteStreamNameNamespacePair?): Optional { + return getCursorInfo(pair).map { obj: CursorInfo -> obj.cursor } + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(CursorManager::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.kt new file mode 100644 index 000000000000..9f006f8f053d --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.kt @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import io.airbyte.cdk.db.IncrementalUtils.compareCursors +import io.airbyte.cdk.db.IncrementalUtils.getCursorField +import io.airbyte.cdk.db.IncrementalUtils.getCursorType +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import java.util.* +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +class CursorStateMessageProducer( + private val stateManager: StateManager?, + private val initialCursor: Optional +) : SourceStateMessageProducer { + private var currentMaxCursor: Optional + + // We keep this field to mark `cursor_record_count` and also to control logging frequency. + private var currentCursorRecordCount = 0 + private var intermediateStateMessage: AirbyteStateMessage? = null + + private var cursorOutOfOrderDetected = false + + init { + this.currentMaxCursor = initialCursor + } + + override fun generateStateMessageAtCheckpoint( + stream: ConfiguredAirbyteStream? + ): AirbyteStateMessage? { + // At this stage intermediate state message should never be null; otherwise it would have + // been + // blocked by shouldEmitStateMessage check. + val message = intermediateStateMessage + intermediateStateMessage = null + if (cursorOutOfOrderDetected) { + LOGGER.warn( + "Intermediate state emission feature requires records to be processed in order according to the cursor value. Otherwise, " + + "data loss can occur." + ) + } + return message + } + + /** + * Note: We do not try to catch exception here. If error/exception happens, we should fail the + * sync, and since we have saved state message before, we should be able to resume it in next + * sync if we have fixed the underlying issue, of if the issue is transient. + */ + @SuppressFBWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + override fun processRecordMessage( + stream: ConfiguredAirbyteStream?, + message: AirbyteMessage + ): AirbyteMessage { + val cursorField = getCursorField(stream!!) + if (message.record.data.hasNonNull(cursorField)) { + val cursorCandidate = getCursorCandidate(cursorField, message) + val cursorType = getCursorType(stream, cursorField) + val cursorComparison = + compareCursors(currentMaxCursor.orElse(null), cursorCandidate, cursorType) + if (cursorComparison < 0) { + // Reset cursor but include current record message. This value will be used to + // create state message. + // Update the current max cursor only when current max cursor < cursor candidate + // from the message + if (currentMaxCursor != initialCursor) { + // Only create an intermediate state when it is not the first record. + intermediateStateMessage = createStateMessage(stream) + } + currentMaxCursor = Optional.of(cursorCandidate!!) + currentCursorRecordCount = 1 + } else if (cursorComparison > 0) { + cursorOutOfOrderDetected = true + } else { + currentCursorRecordCount++ + } + } + println("processed a record message. count: $currentCursorRecordCount") + return message + } + + @SuppressFBWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + override fun createFinalStateMessage(stream: ConfiguredAirbyteStream?): AirbyteStateMessage? { + return createStateMessage(stream!!) + } + + /** Only sends out state message when there is a state message to be sent out. */ + override fun shouldEmitStateMessage(stream: ConfiguredAirbyteStream?): Boolean { + return intermediateStateMessage != null + } + + /** + * Creates AirbyteStateMessage while updating the cursor used to checkpoint the state of records + * read up so far + * + * @return AirbyteMessage which includes information on state of records read so far + */ + private fun createStateMessage(stream: ConfiguredAirbyteStream): AirbyteStateMessage? { + val pair = AirbyteStreamNameNamespacePair(stream.stream.name, stream.stream.namespace) + println( + "state message creation: " + + pair + + " " + + currentMaxCursor.orElse(null) + + " " + + currentCursorRecordCount + ) + val stateMessage = + stateManager!!.updateAndEmit( + pair, + currentMaxCursor.orElse(null), + currentCursorRecordCount.toLong() + ) + val cursorInfo = stateManager.getCursorInfo(pair) + + // logging once every 100 messages to reduce log verbosity + if (currentCursorRecordCount % LOG_FREQUENCY == 0) { + LOGGER.info("State report for stream {}: {}", pair, cursorInfo) + } + + return stateMessage + } + + private fun getCursorCandidate(cursorField: String, message: AirbyteMessage): String? { + val cursorCandidate = message.record.data[cursorField].asText() + return (if (cursorCandidate != null) replaceNull(cursorCandidate) else null) + } + + private fun replaceNull(cursorCandidate: String): String { + if (cursorCandidate.contains("\u0000")) { + return cursorCandidate.replace("\u0000".toRegex(), "") + } + return cursorCandidate + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(CursorStateMessageProducer::class.java) + private const val LOG_FREQUENCY = 100 + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/FailedRecordIteratorException.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/FailedRecordIteratorException.kt new file mode 100644 index 000000000000..33434b23d9b7 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/FailedRecordIteratorException.kt @@ -0,0 +1,6 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +class FailedRecordIteratorException(cause: Throwable?) : RuntimeException(cause) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.kt new file mode 100644 index 000000000000..9329d6d66554 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.kt @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager +import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.function.Supplier +import java.util.stream.Collectors + +/** + * Global implementation of the [StateManager] interface. + * + * This implementation generates a single, global state object for the state tracked by this + * manager. + */ +class GlobalStateManager( + airbyteStateMessage: AirbyteStateMessage, + catalog: ConfiguredAirbyteCatalog +) : + AbstractStateManager( + catalog, + getStreamsSupplier(airbyteStateMessage), + StateGeneratorUtils.CURSOR_FUNCTION, + StateGeneratorUtils.CURSOR_FIELD_FUNCTION, + StateGeneratorUtils.CURSOR_RECORD_COUNT_FUNCTION, + StateGeneratorUtils.NAME_NAMESPACE_PAIR_FUNCTION, + true + ) { + /** + * Legacy [CdcStateManager] used to manage state for connectors that support Change Data Capture + * (CDC). + */ + override val cdcStateManager: CdcStateManager + + /** + * Constructs a new [GlobalStateManager] that is seeded with the provided [AirbyteStateMessage]. + * + * @param airbyteStateMessage The initial state represented as an [AirbyteStateMessage]. + * @param catalog The [ConfiguredAirbyteCatalog] for the connector associated with this state + * manager. + */ + init { + this.cdcStateManager = + CdcStateManager( + extractCdcState(airbyteStateMessage), + extractStreams(airbyteStateMessage), + airbyteStateMessage + ) + } + + override val rawStateMessages: List? + get() { + throw UnsupportedOperationException( + "Raw state retrieval not supported by global state manager." + ) + } + + override fun toState(pair: Optional): AirbyteStateMessage { + // Populate global state + val globalState = AirbyteGlobalState() + globalState.sharedState = Jsons.jsonNode(cdcStateManager.cdcState) + globalState.streamStates = StateGeneratorUtils.generateStreamStateList(pairToCursorInfoMap) + + // Generate the legacy state for backwards compatibility + val dbState = + StateGeneratorUtils.generateDbState(pairToCursorInfoMap) + .withCdc(true) + .withCdcState(cdcStateManager.cdcState) + + return AirbyteStateMessage() + .withType( + AirbyteStateMessage.AirbyteStateType.GLOBAL + ) // Temporarily include legacy state for backwards compatibility with the platform + .withData(Jsons.jsonNode(dbState)) + .withGlobal(globalState) + } + + /** + * Extracts the Change Data Capture (CDC) state stored in the initial state provided to this + * state manager. + * + * @param airbyteStateMessage The [AirbyteStateMessage] that contains the initial state provided + * to the state manager. + * @return The [CdcState] stored in the state, if any. Note that this will not be `null` but may + * be empty. + */ + private fun extractCdcState(airbyteStateMessage: AirbyteStateMessage?): CdcState? { + if (airbyteStateMessage!!.type == AirbyteStateMessage.AirbyteStateType.GLOBAL) { + return Jsons.`object`(airbyteStateMessage.global.sharedState, CdcState::class.java) + } else { + val legacyState = Jsons.`object`(airbyteStateMessage.data, DbState::class.java) + return legacyState?.cdcState + } + } + + private fun extractStreams( + airbyteStateMessage: AirbyteStateMessage? + ): Set { + if (airbyteStateMessage!!.type == AirbyteStateMessage.AirbyteStateType.GLOBAL) { + return airbyteStateMessage.global.streamStates + .stream() + .map { streamState: AirbyteStreamState -> + val cloned = Jsons.clone(streamState) + AirbyteStreamNameNamespacePair( + cloned.streamDescriptor.name, + cloned.streamDescriptor.namespace + ) + } + .collect(Collectors.toSet()) + } else { + val legacyState = Jsons.`object`(airbyteStateMessage.data, DbState::class.java) + return if (legacyState != null) + extractNamespacePairsFromDbStreamState(legacyState.streams) + else emptySet() + } + } + + private fun extractNamespacePairsFromDbStreamState( + streams: List + ): Set { + return streams + .stream() + .map { stream: DbStreamState -> + val cloned = Jsons.clone(stream) + AirbyteStreamNameNamespacePair(cloned.streamName, cloned.streamNamespace) + } + .collect(Collectors.toSet()) + } + + companion object { + /** + * Generates the [Supplier] that will be used to extract the streams from the incoming + * [AirbyteStateMessage]. + * + * @param airbyteStateMessage The [AirbyteStateMessage] supplied to this state manager with + * the initial state. + * @return A [Supplier] that will be used to fetch the streams present in the initial state. + */ + private fun getStreamsSupplier( + airbyteStateMessage: AirbyteStateMessage? + ): Supplier> { + /* + * If the incoming message has the state type set to GLOBAL, it is using the new format. Therefore, + * we can look for streams in the "global" field of the message. Otherwise, the message is still + * storing state in the legacy "data" field. + */ + return Supplier { + if (airbyteStateMessage!!.type == AirbyteStateMessage.AirbyteStateType.GLOBAL) { + return@Supplier airbyteStateMessage.global.streamStates + } else if (airbyteStateMessage.data != null) { + return@Supplier Jsons.`object`( + airbyteStateMessage.data, + DbState::class.java + ) + .streams + .stream() + .map { s: DbStreamState -> + AirbyteStreamState() + .withStreamState(Jsons.jsonNode(s)) + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(s.streamNamespace) + .withName(s.streamName) + ) + } + .collect(Collectors.toList()) + } else { + return@Supplier listOf() + } + } + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.kt new file mode 100644 index 000000000000..c379f25a9d1e --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.kt @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import java.util.* +import java.util.function.Function +import java.util.function.Supplier +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Legacy implementation (pre-per-stream state support) of the [StateManager] interface. + * + * This implementation assumes that the state matches the [DbState] object and effectively tracks + * state as global across the streams managed by a connector. + */ +@Deprecated( + """This manager may be removed in the future if/once all connectors support per-stream + state management.""" +) +class LegacyStateManager(dbState: DbState, catalog: ConfiguredAirbyteCatalog) : + AbstractStateManager( + catalog, + Supplier { dbState.streams }, + CURSOR_FUNCTION, + CURSOR_FIELD_FUNCTION, + CURSOR_RECORD_COUNT_FUNCTION, + NAME_NAMESPACE_PAIR_FUNCTION + ) { + /** Tracks whether the connector associated with this state manager supports CDC. */ + private var isCdc: Boolean + + /** [CdcStateManager] used to manage state for connectors that support CDC. */ + override val cdcStateManager: CdcStateManager = + CdcStateManager( + dbState.cdcState, + AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog), + null + ) + + /** + * Constructs a new [LegacyStateManager] that is seeded with the provided [DbState] instance. + * + * @param dbState The initial state represented as an [DbState] instance. + * @param catalog The [ConfiguredAirbyteCatalog] for the connector associated with this state + * manager. + */ + init { + this.isCdc = dbState.cdc ?: false + } + + override val rawStateMessages: List? + get() { + throw UnsupportedOperationException( + "Raw state retrieval not supported by global state manager." + ) + } + + override fun toState(pair: Optional): AirbyteStateMessage { + val dbState = + StateGeneratorUtils.generateDbState(pairToCursorInfoMap) + .withCdc(isCdc) + .withCdcState(cdcStateManager.cdcState) + + LOGGER.debug("Generated legacy state for {} streams", dbState.streams.size) + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(dbState)) + } + + override fun updateAndEmit( + pair: AirbyteStreamNameNamespacePair, + cursor: String? + ): AirbyteStateMessage? { + return updateAndEmit(pair, cursor, 0L) + } + + override fun updateAndEmit( + pair: AirbyteStreamNameNamespacePair, + cursor: String?, + cursorRecordCount: Long + ): AirbyteStateMessage? { + // cdc file gets updated by debezium so the "update" part is a no op. + if (!isCdc) { + return super.updateAndEmit(pair, cursor, cursorRecordCount) + } + + return toState(Optional.ofNullable(pair)) + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(LegacyStateManager::class.java) + + /** [Function] that extracts the cursor from the stream state. */ + private val CURSOR_FUNCTION = DbStreamState::getCursor + + /** [Function] that extracts the cursor field(s) from the stream state. */ + private val CURSOR_FIELD_FUNCTION = DbStreamState::getCursorField + + private val CURSOR_RECORD_COUNT_FUNCTION = Function { stream: DbStreamState -> + Objects.requireNonNullElse(stream.cursorRecordCount, 0L) + } + + /** [Function] that creates an [AirbyteStreamNameNamespacePair] from the stream state. */ + private val NAME_NAMESPACE_PAIR_FUNCTION = + Function { s: DbStreamState -> + AirbyteStreamNameNamespacePair(s!!.streamName, s.streamNamespace) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.kt new file mode 100644 index 000000000000..2aba6bd2ae30 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.kt @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import com.google.common.collect.AbstractIterator +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteStateStats +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import java.time.Duration +import java.time.Instant +import java.time.OffsetDateTime +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +open class SourceStateIterator( + private val messageIterator: Iterator, + private val stream: ConfiguredAirbyteStream?, + private val sourceStateMessageProducer: SourceStateMessageProducer, + private val stateEmitFrequency: StateEmitFrequency +) : AbstractIterator(), MutableIterator { + private var hasEmittedFinalState = false + private var recordCount = 0L + private var lastCheckpoint: Instant = Instant.now() + + override fun computeNext(): AirbyteMessage? { + var iteratorHasNextValue = false + try { + iteratorHasNextValue = messageIterator.hasNext() + } catch (ex: Exception) { + // If the underlying iterator throws an exception, we want to fail the sync, expecting + // sync/attempt + // will be restarted and + // sync will resume from the last state message. + throw FailedRecordIteratorException(ex) + } + if (iteratorHasNextValue) { + if ( + shouldEmitStateMessage() && + sourceStateMessageProducer.shouldEmitStateMessage(stream) + ) { + val stateMessage = + sourceStateMessageProducer.generateStateMessageAtCheckpoint(stream) + stateMessage!!.withSourceStats( + AirbyteStateStats().withRecordCount(recordCount.toDouble()) + ) + + recordCount = 0L + lastCheckpoint = Instant.now() + return AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) + } + // Use try-catch to catch Exception that could occur when connection to the database + // fails + try { + val message = messageIterator.next() + val processedMessage = + sourceStateMessageProducer.processRecordMessage(stream!!, message) + recordCount++ + return processedMessage + } catch (e: Exception) { + throw FailedRecordIteratorException(e) + } + } else if (!hasEmittedFinalState) { + hasEmittedFinalState = true + val finalStateMessageForStream = + sourceStateMessageProducer.createFinalStateMessage(stream!!) + finalStateMessageForStream!!.withSourceStats( + AirbyteStateStats().withRecordCount(recordCount.toDouble()) + ) + recordCount = 0L + return AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState(finalStateMessageForStream) + } else { + return endOfData() + } + } + + // This method is used to check if we should emit a state message. If the record count is set to + // 0, + // we should not emit a state message. + // If the frequency is set to be zero, we should not use it. + private fun shouldEmitStateMessage(): Boolean { + if (stateEmitFrequency.syncCheckpointRecords == 0L) { + return false + } + if (recordCount >= stateEmitFrequency.syncCheckpointRecords) { + return true + } + if (!stateEmitFrequency.syncCheckpointDuration.isZero) { + return Duration.between(lastCheckpoint, OffsetDateTime.now()) + .compareTo(stateEmitFrequency.syncCheckpointDuration) > 0 + } + return false + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(SourceStateIterator::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.kt new file mode 100644 index 000000000000..7c2fd5bc7c44 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.kt @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream + +/** + * To be used with SourceStateIterator. SourceStateIterator will iterate over the records and + * generate state messages when needed. This interface defines how would those state messages be + * generated, and how the incoming record messages will be processed. + * + * @param + */ +interface SourceStateMessageProducer { + /** Returns a state message that should be emitted at checkpoint. */ + fun generateStateMessageAtCheckpoint(stream: ConfiguredAirbyteStream?): AirbyteStateMessage? + + /** For the incoming record message, this method defines how the connector will consume it. */ + fun processRecordMessage(stream: ConfiguredAirbyteStream?, message: T): AirbyteMessage + + /** + * At the end of the iteration, this method will be called and it will generate the final state + * message. + * + * @return + */ + fun createFinalStateMessage(stream: ConfiguredAirbyteStream?): AirbyteStateMessage? + + /** + * Determines if the iterator has reached checkpoint or not per connector's definition. By + * default iterator will check if the number of records processed is greater than the checkpoint + * interval or last state message has already passed syncCheckpointDuration. + */ + fun shouldEmitStateMessage(stream: ConfiguredAirbyteStream?): Boolean +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.kt new file mode 100644 index 000000000000..6c2d0120cc6f --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.kt @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import java.time.Duration + +class StateEmitFrequency(syncCheckpointRecords: Long, syncCheckpointDuration: Duration) { + val syncCheckpointRecords: Long + val syncCheckpointDuration: Duration + + init { + this.syncCheckpointRecords = syncCheckpointRecords + this.syncCheckpointDuration = syncCheckpointDuration + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.kt new file mode 100644 index 000000000000..a9b61c9da642 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.kt @@ -0,0 +1,323 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.collect.Lists +import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo +import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.commons.json.Jsons +import io.airbyte.configoss.StateType +import io.airbyte.configoss.StateWrapper +import io.airbyte.configoss.helpers.StateMessageHelper +import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.function.Function +import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** Collection of utilities that facilitate the generation of state objects. */ +object StateGeneratorUtils { + private val LOGGER: Logger = LoggerFactory.getLogger(StateGeneratorUtils::class.java) + + /** [Function] that extracts the cursor from the stream state. */ + val CURSOR_FUNCTION: Function = + Function { stream: AirbyteStreamState -> + val dbStreamState = extractState(stream) + dbStreamState.map { obj: DbStreamState -> obj.cursor }.orElse(null) + } + + /** [Function] that extracts the cursor field(s) from the stream state. */ + val CURSOR_FIELD_FUNCTION: Function> = + Function { stream: AirbyteStreamState -> + val dbStreamState = extractState(stream) + if (dbStreamState.isPresent) { + return@Function dbStreamState.get().cursorField + } else { + return@Function listOf() + } + } + + val CURSOR_RECORD_COUNT_FUNCTION: Function = + Function { stream: AirbyteStreamState -> + val dbStreamState = extractState(stream) + dbStreamState.map { obj: DbStreamState -> obj.cursorRecordCount }.orElse(0L) + } + + /** [Function] that creates an [AirbyteStreamNameNamespacePair] from the stream state. */ + val NAME_NAMESPACE_PAIR_FUNCTION: + Function = + Function { s: AirbyteStreamState -> + if (isValidStreamDescriptor(s.streamDescriptor)) + AirbyteStreamNameNamespacePair( + s.streamDescriptor.name, + s.streamDescriptor.namespace + ) + else null + } + + /** + * Generates the stream state for the given stream and cursor information. + * + * @param airbyteStreamNameNamespacePair The stream. + * @param cursorInfo The current cursor. + * @return The [AirbyteStreamState] representing the current state of the stream. + */ + fun generateStreamState( + airbyteStreamNameNamespacePair: AirbyteStreamNameNamespacePair, + cursorInfo: CursorInfo + ): AirbyteStreamState { + return AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(airbyteStreamNameNamespacePair.name) + .withNamespace(airbyteStreamNameNamespacePair.namespace) + ) + .withStreamState( + Jsons.jsonNode(generateDbStreamState(airbyteStreamNameNamespacePair, cursorInfo)) + ) + } + + /** + * Generates a list of valid stream states from the provided stream and cursor information. A + * stream state is considered to be valid if the stream has a valid descriptor (see + * [.isValidStreamDescriptor] for more details). + * + * @param pairToCursorInfoMap The map of stream name/namespace tuple to the current cursor + * information for that stream + * @return The list of stream states derived from the state information extracted from the + * provided map. + */ + fun generateStreamStateList( + pairToCursorInfoMap: Map + ): List { + return pairToCursorInfoMap.entries + .stream() + .sorted(java.util.Map.Entry.comparingByKey()) + .map { e: Map.Entry -> + generateStreamState(e.key, e.value) + } + .filter { s: AirbyteStreamState -> isValidStreamDescriptor(s.streamDescriptor) } + .collect(Collectors.toList()) + } + + /** + * Generates the legacy global state for backwards compatibility. + * + * @param pairToCursorInfoMap The map of stream name/namespace tuple to the current cursor + * information for that stream + * @return The legacy [DbState]. + */ + fun generateDbState( + pairToCursorInfoMap: Map + ): DbState { + return DbState() + .withCdc(false) + .withStreams( + pairToCursorInfoMap.entries + .stream() + .sorted( + java.util.Map.Entry.comparingByKey() + ) // sort by stream name then namespace for sanity. + .map { e: Map.Entry -> + generateDbStreamState(e.key, e.value) + } + .collect(Collectors.toList()) + ) + } + + /** + * Generates the [DbStreamState] for the given stream and cursor. + * + * @param airbyteStreamNameNamespacePair The stream. + * @param cursorInfo The current cursor. + * @return The [DbStreamState]. + */ + fun generateDbStreamState( + airbyteStreamNameNamespacePair: AirbyteStreamNameNamespacePair, + cursorInfo: CursorInfo + ): DbStreamState { + val state = + DbStreamState() + .withStreamName(airbyteStreamNameNamespacePair.name) + .withStreamNamespace(airbyteStreamNameNamespacePair.namespace) + .withCursorField( + if (cursorInfo.cursorField == null) emptyList() + else Lists.newArrayList(cursorInfo.cursorField) + ) + .withCursor(cursorInfo.cursor) + if (cursorInfo.cursorRecordCount > 0L) { + state.cursorRecordCount = cursorInfo.cursorRecordCount + } + return state + } + + /** + * Extracts the actual state from the [AirbyteStreamState] object. + * + * @param state The [AirbyteStreamState] that contains the actual stream state as JSON. + * @return An [Optional] possibly containing the deserialized representation of the stream state + * or an empty [Optional] if the state is not present or could not be deserialized. + */ + fun extractState(state: AirbyteStreamState): Optional { + try { + return Optional.ofNullable(Jsons.`object`(state.streamState, DbStreamState::class.java)) + } catch (e: IllegalArgumentException) { + LOGGER.error("Unable to extract state.", e) + return Optional.empty() + } + } + + /** + * Tests whether the provided [StreamDescriptor] is valid. A valid descriptor is defined as one + * that has a non-`null` name. + * + * See + * https://github.com/airbytehq/airbyte/blob/e63458fabb067978beb5eaa74d2bc130919b419f/docs/understanding-airbyte/airbyte-protocol.md + * for more details + * + * @param streamDescriptor A [StreamDescriptor] to be validated. + * @return `true` if the provided [StreamDescriptor] is valid or `false` if it is invalid. + */ + fun isValidStreamDescriptor(streamDescriptor: StreamDescriptor?): Boolean { + return if (streamDescriptor != null) { + streamDescriptor.name != null + } else { + false + } + } + + /** + * Converts a [AirbyteStateType.LEGACY] state message into a [AirbyteStateType.GLOBAL] message. + * + * @param airbyteStateMessage A [AirbyteStateType.LEGACY] state message. + * @return A [AirbyteStateType.GLOBAL] state message. + */ + fun convertLegacyStateToGlobalState( + airbyteStateMessage: AirbyteStateMessage + ): AirbyteStateMessage { + val dbState = Jsons.`object`(airbyteStateMessage.data, DbState::class.java) + val globalState = + AirbyteGlobalState() + .withSharedState(Jsons.jsonNode(dbState.cdcState)) + .withStreamStates( + dbState.streams + .stream() + .map { s: DbStreamState -> + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(s.streamName) + .withNamespace(s.streamNamespace) + ) + .withStreamState(Jsons.jsonNode(s)) + } + .collect(Collectors.toList()) + ) + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) + } + + /** + * Converts a [AirbyteStateType.LEGACY] state message into a list of [AirbyteStateType.STREAM] + * messages. + * + * @param airbyteStateMessage A [AirbyteStateType.LEGACY] state message. + * @return A list [AirbyteStateType.STREAM] state messages. + */ + fun convertLegacyStateToStreamState( + airbyteStateMessage: AirbyteStateMessage + ): List { + return Jsons.`object`(airbyteStateMessage.data, DbState::class.java) + .streams + .stream() + .map { s: DbStreamState -> + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(s.streamNamespace) + .withName(s.streamName) + ) + .withStreamState(Jsons.jsonNode(s)) + ) + } + .collect(Collectors.toList()) + } + + fun convertStateMessage( + state: io.airbyte.protocol.models.AirbyteStateMessage + ): AirbyteStateMessage { + return Jsons.`object`(Jsons.jsonNode(state), AirbyteStateMessage::class.java) + } + + /** + * Deserializes the state represented as JSON into an object representation. + * + * @param initialStateJson The state as JSON. + * @Param supportedStateType the [AirbyteStateType] supported by this connector. + * @return The deserialized object representation of the state. + */ + fun deserializeInitialState( + initialStateJson: JsonNode?, + supportedStateType: AirbyteStateMessage.AirbyteStateType + ): List { + val typedState = StateMessageHelper.getTypedState(initialStateJson) + return typedState + .map { state: StateWrapper -> + when (state.stateType) { + StateType.GLOBAL -> java.util.List.of(convertStateMessage(state.global)) + StateType.STREAM -> state.stateMessages.map { convertStateMessage(it) } + else -> + java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(state.legacyState) + ) + } + } + .orElse(generateEmptyInitialState(supportedStateType)) + } + + /** + * Generates an empty, initial state for use by the connector. + * + * @Param supportedStateType the [AirbyteStateType] supported by this connector. + * @return The empty, initial state. + */ + private fun generateEmptyInitialState( + supportedStateType: AirbyteStateMessage.AirbyteStateType + ): List { + // For backwards compatibility with existing connectors + if (supportedStateType == AirbyteStateMessage.AirbyteStateType.LEGACY) { + return java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(DbState())) + ) + } else if (supportedStateType == AirbyteStateMessage.AirbyteStateType.GLOBAL) { + val globalState = + AirbyteGlobalState() + .withSharedState(Jsons.jsonNode(CdcState())) + .withStreamStates(listOf()) + return java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) + ) + } else { + return java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) + ) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.kt new file mode 100644 index 000000000000..9588478c6ac5 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.kt @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import com.google.common.base.Preconditions +import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager +import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import java.util.* +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Defines a manager that manages connector state. Connector state is used to keep track of the data + * synced by the connector. + * + * @param The type of the state maintained by the manager. + * @param The type of the stream(s) stored within the state maintained by the manager. + */ +interface StateManager { + /** + * Retrieves the [CdcStateManager] associated with the state manager. + * + * @return The [CdcStateManager] + * @throws UnsupportedOperationException if the state manager does not support tracking change + * data capture (CDC) state. + */ + val cdcStateManager: CdcStateManager + + /** + * Retries the raw state messages associated with the state manager. This is required for + * database-specific sync modes (e.g. Xmin) that would want to handle and parse their own state + * + * @return the list of airbyte state messages + * @throws UnsupportedOperationException if the state manager does not support retrieving raw + * state. + */ + val rawStateMessages: List? + + /** + * Retrieves the map of stream name/namespace tuple to the current cursor information for that + * stream. + * + * @return The map of stream name/namespace tuple to the current cursor information for that + * stream as maintained by this state manager. + */ + val pairToCursorInfoMap: Map + + /** + * Generates an [AirbyteStateMessage] that represents the current state contained in the state + * manager. + * + * @param pair The [AirbyteStreamNameNamespacePair] that represents a stream managed by the + * state manager. + * @return The [AirbyteStateMessage] that represents the current state contained in the state + * manager. + */ + fun toState(pair: Optional): AirbyteStateMessage + + /** + * Retrieves an [Optional] possibly containing the cursor value tracked in the state associated + * with the provided stream name/namespace tuple. + * + * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. + * @return An [Optional] possibly containing the cursor value tracked in the state associated + * with the provided stream name/namespace tuple. + */ + fun getCursor(pair: AirbyteStreamNameNamespacePair?): Optional { + return getCursorInfo(pair).map { obj: CursorInfo -> obj.cursor } + } + + /** + * Retrieves an [Optional] possibly containing the cursor field name associated with the cursor + * tracked in the state associated with the provided stream name/namespace tuple. + * + * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. + * @return An [Optional] possibly containing the cursor field name associated with the cursor + * tracked in the state associated with the provided stream name/namespace tuple. + */ + fun getCursorField(pair: AirbyteStreamNameNamespacePair?): Optional? { + return getCursorInfo(pair).map { obj: CursorInfo -> obj.cursorField } + } + + /** + * Retrieves an [Optional] possibly containing the original cursor value tracked in the state + * associated with the provided stream name/namespace tuple. + * + * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. + * @return An [Optional] possibly containing the original cursor value tracked in the state + * associated with the provided stream name/namespace tuple. + */ + fun getOriginalCursor(pair: AirbyteStreamNameNamespacePair?): Optional? { + return getCursorInfo(pair).map { obj: CursorInfo -> obj.originalCursor } + } + + /** + * Retrieves an [Optional] possibly containing the original cursor field name associated with + * the cursor tracked in the state associated with the provided stream name/namespace tuple. + * + * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. + * @return An [Optional] possibly containing the original cursor field name associated with the + * cursor tracked in the state associated with the provided stream name/namespace tuple. + */ + fun getOriginalCursorField(pair: AirbyteStreamNameNamespacePair?): Optional? { + return getCursorInfo(pair).map { obj: CursorInfo -> obj.originalCursorField } + } + + /** + * Retrieves the current cursor information stored in the state manager for the steam + * name/namespace tuple. + * + * @param pair The [AirbyteStreamNameNamespacePair] that represents a stream managed by the + * state manager. + * @return [Optional] that potentially contains the current cursor information for the given + * stream name/namespace tuple. + */ + fun getCursorInfo(pair: AirbyteStreamNameNamespacePair?): Optional { + return Optional.ofNullable(pairToCursorInfoMap!![pair]) + } + + /** + * Emits the current state maintained by the manager as an [AirbyteStateMessage]. + * + * @param pair The [AirbyteStreamNameNamespacePair] that represents a stream managed by the + * state manager. + * @return An [AirbyteStateMessage] that represents the current state maintained by the state + * manager. + */ + fun emit(pair: Optional): AirbyteStateMessage? { + return toState(pair) + } + + /** + * Updates the cursor associated with the provided stream name/namespace pair and emits the + * current state maintained by the state manager. + * + * @param pair The [AirbyteStreamNameNamespacePair] that represents a stream managed by the + * state manager. + * @param cursor The new value for the cursor associated with the + * [AirbyteStreamNameNamespacePair] that represents a stream managed by the state manager. + * @return An [AirbyteStateMessage] that represents the current state maintained by the state + * manager. + */ + fun updateAndEmit(pair: AirbyteStreamNameNamespacePair, cursor: String?): AirbyteStateMessage? { + return updateAndEmit(pair, cursor, 0L) + } + + fun updateAndEmit( + pair: AirbyteStreamNameNamespacePair, + cursor: String?, + cursorRecordCount: Long + ): AirbyteStateMessage? { + val cursorInfo = getCursorInfo(pair) + Preconditions.checkState( + cursorInfo.isPresent, + "Could not find cursor information for stream: $pair" + ) + cursorInfo.get().setCursor(cursor) + if (cursorRecordCount > 0L) { + cursorInfo.get().setCursorRecordCount(cursorRecordCount) + } + LOGGER.debug( + "Updating cursor value for {} to {} (count {})...", + pair, + cursor, + cursorRecordCount + ) + return emit(Optional.ofNullable(pair)) + } + + companion object { + val LOGGER: Logger = LoggerFactory.getLogger(StateManager::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.kt new file mode 100644 index 000000000000..2d34be63c3b8 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.kt @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** Factory class that creates [StateManager] instances based on the provided state. */ +object StateManagerFactory { + private val LOGGER: Logger = LoggerFactory.getLogger(StateManagerFactory::class.java) + + /** + * Creates a [StateManager] based on the provided state object and catalog. This method will + * handle the conversion of the provided state to match the requested state manager based on the + * provided [AirbyteStateType]. + * + * @param supportedStateType The type of state supported by the connector. + * @param initialState The deserialized initial state that will be provided to the selected + * [StateManager]. + * @param catalog The [ConfiguredAirbyteCatalog] for the connector that will utilize the state + * manager. + * @return A newly created [StateManager] implementation based on the provided state. + */ + fun createStateManager( + supportedStateType: AirbyteStateMessage.AirbyteStateType?, + initialState: List?, + catalog: ConfiguredAirbyteCatalog + ): StateManager { + if (initialState != null && !initialState.isEmpty()) { + val airbyteStateMessage = initialState[0] + when (supportedStateType) { + AirbyteStateMessage.AirbyteStateType.LEGACY -> { + LOGGER.info( + "Legacy state manager selected to manage state object with type {}.", + airbyteStateMessage!!.type + ) + @Suppress("deprecation") + val retVal: StateManager = + LegacyStateManager( + Jsons.`object`(airbyteStateMessage.data, DbState::class.java), + catalog + ) + return retVal + } + AirbyteStateMessage.AirbyteStateType.GLOBAL -> { + LOGGER.info( + "Global state manager selected to manage state object with type {}.", + airbyteStateMessage!!.type + ) + return GlobalStateManager(generateGlobalState(airbyteStateMessage), catalog) + } + AirbyteStateMessage.AirbyteStateType.STREAM -> { + LOGGER.info( + "Stream state manager selected to manage state object with type {}.", + airbyteStateMessage!!.type + ) + return StreamStateManager(generateStreamState(initialState), catalog) + } + else -> { + LOGGER.info( + "Stream state manager selected to manage state object with type {}.", + airbyteStateMessage!!.type + ) + return StreamStateManager(generateStreamState(initialState), catalog) + } + } + } else { + throw IllegalArgumentException( + "Failed to create state manager due to empty state list." + ) + } + } + + /** + * Handles the conversion between a different state type and the global state. This method + * handles the following transitions: + * + * * Stream -> Global (not supported, results in [IllegalArgumentException] + * * Legacy -> Global (supported) + * * Global -> Global (supported/no conversion required) + * + * @param airbyteStateMessage The current state that is to be converted to global state. + * @return The converted state message. + * @throws IllegalArgumentException if unable to convert between the given state type and + * global. + */ + private fun generateGlobalState(airbyteStateMessage: AirbyteStateMessage): AirbyteStateMessage { + var globalStateMessage = airbyteStateMessage + + when (airbyteStateMessage!!.type) { + AirbyteStateMessage.AirbyteStateType.STREAM -> + throw IllegalArgumentException( + "Unable to convert connector state from stream to global. Please reset the connection to continue." + ) + AirbyteStateMessage.AirbyteStateType.LEGACY -> { + globalStateMessage = + StateGeneratorUtils.convertLegacyStateToGlobalState(airbyteStateMessage) + LOGGER.info("Legacy state converted to global state.", airbyteStateMessage.type) + } + AirbyteStateMessage.AirbyteStateType.GLOBAL -> {} + else -> {} + } + return globalStateMessage + } + + /** + * Handles the conversion between a different state type and the stream state. This method + * handles the following transitions: + * + * * Global -> Stream (not supported, results in [IllegalArgumentException] + * * Legacy -> Stream (supported) + * * Stream -> Stream (supported/no conversion required) + * + * @param states The list of current states. + * @return The converted state messages. + * @throws IllegalArgumentException if unable to convert between the given state type and + * stream. + */ + private fun generateStreamState(states: List): List { + val airbyteStateMessage = states[0] + val streamStates: MutableList = ArrayList() + when (airbyteStateMessage!!.type) { + AirbyteStateMessage.AirbyteStateType.GLOBAL -> + throw IllegalArgumentException( + "Unable to convert connector state from global to stream. Please reset the connection to continue." + ) + AirbyteStateMessage.AirbyteStateType.LEGACY -> + streamStates.addAll( + StateGeneratorUtils.convertLegacyStateToStreamState(airbyteStateMessage) + ) + AirbyteStateMessage.AirbyteStateType.STREAM -> streamStates.addAll(states) + else -> streamStates.addAll(states) + } + return streamStates + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.kt new file mode 100644 index 000000000000..e09c7d90d03b --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.kt @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import io.airbyte.protocol.models.v0.AirbyteStreamState +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import java.util.* +import java.util.function.Supplier +import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Per-stream implementation of the [StateManager] interface. + * + * This implementation generates a state object for each stream detected in catalog/map of known + * streams to cursor information stored in this manager. + */ +class StreamStateManager +/** + * Constructs a new [StreamStateManager] that is seeded with the provided [AirbyteStateMessage]. + * + * @param airbyteStateMessages The initial state represented as a list of [AirbyteStateMessage]s. + * @param catalog The [ConfiguredAirbyteCatalog] for the connector associated with this state + * manager. + */ +( + private val rawAirbyteStateMessages: List, + catalog: ConfiguredAirbyteCatalog +) : + AbstractStateManager( + catalog, + Supplier { + rawAirbyteStateMessages.stream().map { it.stream }.collect(Collectors.toList()) + }, + StateGeneratorUtils.CURSOR_FUNCTION, + StateGeneratorUtils.CURSOR_FIELD_FUNCTION, + StateGeneratorUtils.CURSOR_RECORD_COUNT_FUNCTION, + StateGeneratorUtils.NAME_NAMESPACE_PAIR_FUNCTION + ) { + override val cdcStateManager: CdcStateManager + get() { + throw UnsupportedOperationException( + "CDC state management not supported by stream state manager." + ) + } + + override val rawStateMessages: List? + get() = rawAirbyteStateMessages + + override fun toState(pair: Optional): AirbyteStateMessage { + if (pair.isPresent) { + val pairToCursorInfoMap = pairToCursorInfoMap + val cursorInfo = Optional.ofNullable(pairToCursorInfoMap!![pair.get()]) + + if (cursorInfo.isPresent) { + LOGGER.debug("Generating state message for {}...", pair) + return AirbyteStateMessage() + .withType( + AirbyteStateMessage.AirbyteStateType.STREAM + ) // Temporarily include legacy state for backwards compatibility with the + // platform + .withData( + Jsons.jsonNode(StateGeneratorUtils.generateDbState(pairToCursorInfoMap)) + ) + .withStream( + StateGeneratorUtils.generateStreamState(pair.get(), cursorInfo.get()) + ) + } else { + LOGGER.warn( + "Cursor information could not be located in state for stream {}. Returning a new, empty state message...", + pair + ) + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) + } + } else { + LOGGER.warn("Stream not provided. Returning a new, empty state message...") + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) + } + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(StreamStateManager::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.java deleted file mode 100644 index 95b8e5e26d96..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium; - -import com.google.common.collect.Lists; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteCatalog; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.List; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class AirbyteDebeziumHandlerTest { - - @Test - public void shouldUseCdcTestShouldReturnTrue() { - final AirbyteCatalog catalog = new AirbyteCatalog().withStreams(List.of( - CatalogHelpers.createAirbyteStream( - "MODELS_STREAM_NAME", - "MODELS_SCHEMA", - Field.of("COL_ID", JsonSchemaType.NUMBER), - Field.of("COL_MAKE_ID", JsonSchemaType.NUMBER), - Field.of("COL_MODEL", JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(List.of("COL_ID"))))); - final ConfiguredAirbyteCatalog configuredCatalog = CatalogHelpers - .toDefaultConfiguredCatalog(catalog); - // set all streams to incremental. - configuredCatalog.getStreams().forEach(s -> s.setSyncMode(SyncMode.INCREMENTAL)); - - Assertions.assertTrue(AirbyteDebeziumHandler.isAnyStreamIncrementalSyncMode(configuredCatalog)); - } - - @Test - public void shouldUseCdcTestShouldReturnFalse() { - final AirbyteCatalog catalog = new AirbyteCatalog().withStreams(List.of( - CatalogHelpers.createAirbyteStream( - "MODELS_STREAM_NAME", - "MODELS_SCHEMA", - Field.of("COL_ID", JsonSchemaType.NUMBER), - Field.of("COL_MAKE_ID", JsonSchemaType.NUMBER), - Field.of("COL_MODEL", JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(List.of("COL_ID"))))); - final ConfiguredAirbyteCatalog configuredCatalog = CatalogHelpers - .toDefaultConfiguredCatalog(catalog); - - Assertions.assertFalse(AirbyteDebeziumHandler.isAnyStreamIncrementalSyncMode(configuredCatalog)); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.java deleted file mode 100644 index 70fdefe0dd9e..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.airbyte.cdk.integrations.debezium.internals.AirbyteFileOffsetBackingStore; -import io.airbyte.commons.io.IOs; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.resources.MoreResources; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Map; -import java.util.Optional; -import org.junit.jupiter.api.Test; - -class AirbyteFileOffsetBackingStoreTest { - - @SuppressWarnings("UnstableApiUsage") - @Test - void test() throws IOException { - final Path testRoot = Files.createTempDirectory(Path.of("/tmp"), "offset-store-test"); - - final byte[] bytes = MoreResources.readBytes("test_debezium_offset.dat"); - final Path templateFilePath = testRoot.resolve("template_offset.dat"); - IOs.writeFile(templateFilePath, bytes); - - final Path writeFilePath = testRoot.resolve("offset.dat"); - final Path secondWriteFilePath = testRoot.resolve("offset_2.dat"); - - final AirbyteFileOffsetBackingStore offsetStore = new AirbyteFileOffsetBackingStore(templateFilePath, Optional.empty()); - final Map offset = offsetStore.read(); - - final AirbyteFileOffsetBackingStore offsetStore2 = new AirbyteFileOffsetBackingStore(writeFilePath, Optional.empty()); - offsetStore2.persist(Jsons.jsonNode(offset)); - final Map stateFromOffsetStore2 = offsetStore2.read(); - - final AirbyteFileOffsetBackingStore offsetStore3 = new AirbyteFileOffsetBackingStore(secondWriteFilePath, Optional.empty()); - offsetStore3.persist(Jsons.jsonNode(stateFromOffsetStore2)); - final Map stateFromOffsetStore3 = offsetStore3.read(); - - // verify that, after a round trip through the offset store, we get back the same data. - assertEquals(stateFromOffsetStore2, stateFromOffsetStore3); - // verify that the file written by the offset store is identical to the template file. - assertTrue(com.google.common.io.Files.equal(secondWriteFilePath.toFile(), writeFilePath.toFile())); - } - - @Test - void test2() throws IOException { - final Path testRoot = Files.createTempDirectory(Path.of("/tmp"), "offset-store-test"); - - final byte[] bytes = MoreResources.readBytes("test_debezium_offset.dat"); - final Path templateFilePath = testRoot.resolve("template_offset.dat"); - IOs.writeFile(templateFilePath, bytes); - - final Path writeFilePath = testRoot.resolve("offset.dat"); - final Path secondWriteFilePath = testRoot.resolve("offset_2.dat"); - - final AirbyteFileOffsetBackingStore offsetStore = new AirbyteFileOffsetBackingStore(templateFilePath, Optional.of("orders")); - final Map offset = offsetStore.read(); - - final AirbyteFileOffsetBackingStore offsetStore2 = new AirbyteFileOffsetBackingStore(writeFilePath, Optional.of("orders")); - offsetStore2.persist(Jsons.jsonNode(offset)); - final Map stateFromOffsetStore2 = offsetStore2.read(); - - final AirbyteFileOffsetBackingStore offsetStore3 = new AirbyteFileOffsetBackingStore(secondWriteFilePath, Optional.of("orders")); - offsetStore3.persist(Jsons.jsonNode(stateFromOffsetStore2)); - final Map stateFromOffsetStore3 = offsetStore3.read(); - - // verify that, after a round trip through the offset store, we get back the same data. - assertEquals(stateFromOffsetStore2, stateFromOffsetStore3); - // verify that the file written by the offset store is identical to the template file. - assertTrue(com.google.common.io.Files.equal(secondWriteFilePath.toFile(), writeFilePath.toFile())); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.java deleted file mode 100644 index be906557f431..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.google.common.collect.ImmutableList; -import io.airbyte.cdk.integrations.debezium.internals.RelationalDbDebeziumPropertiesManager; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.regex.Pattern; -import org.junit.jupiter.api.Test; - -class DebeziumRecordPublisherTest { - - @Test - public void testTableIncludelistCreation() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog().withStreams(ImmutableList.of( - CatalogHelpers.createConfiguredAirbyteStream("id_and_name", "public").withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("id_,something", "public").withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("n\"aMéS", "public").withSyncMode(SyncMode.INCREMENTAL))); - - final String expectedWhitelist = "\\Qpublic.id_and_name\\E,\\Qpublic.id_\\,something\\E,\\Qpublic.n\"aMéS\\E"; - final String actualWhitelist = RelationalDbDebeziumPropertiesManager.getTableIncludelist(catalog); - - assertEquals(expectedWhitelist, actualWhitelist); - } - - @Test - public void testTableIncludelistFiltersFullRefresh() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog().withStreams(ImmutableList.of( - CatalogHelpers.createConfiguredAirbyteStream("id_and_name", "public").withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("id_and_name2", "public").withSyncMode(SyncMode.FULL_REFRESH))); - - final String expectedWhitelist = "\\Qpublic.id_and_name\\E"; - final String actualWhitelist = RelationalDbDebeziumPropertiesManager.getTableIncludelist(catalog); - - assertEquals(expectedWhitelist, actualWhitelist); - } - - @Test - public void testColumnIncludelistFiltersFullRefresh() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog().withStreams(ImmutableList.of( - CatalogHelpers.createConfiguredAirbyteStream( - "id_and_name", - "public", - Field.of("fld1", JsonSchemaType.NUMBER), Field.of("fld2", JsonSchemaType.STRING)).withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("id_,something", "public").withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("id_and_name2", "public").withSyncMode(SyncMode.FULL_REFRESH), - CatalogHelpers.createConfiguredAirbyteStream("n\"aMéS", "public").withSyncMode(SyncMode.INCREMENTAL))); - - final String expectedWhitelist = "\\Qpublic.id_and_name\\E\\.(\\Qfld2\\E|\\Qfld1\\E),\\Qpublic.id_\\,something\\E,\\Qpublic.n\"aMéS\\E"; - final String actualWhitelist = RelationalDbDebeziumPropertiesManager.getColumnIncludeList(catalog); - - assertEquals(expectedWhitelist, actualWhitelist); - } - - @Test - public void testColumnIncludeListEscaping() { - // final String a = "public\\.products\\*\\^\\$\\+-\\\\"; - // final String b = "public.products*^$+-\\"; - // final Pattern p = Pattern.compile(a, Pattern.UNIX_LINES); - // assertTrue(p.matcher(b).find()); - // assertTrue(Pattern.compile(Pattern.quote(b)).matcher(b).find()); - - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog().withStreams(ImmutableList.of( - CatalogHelpers.createConfiguredAirbyteStream( - "id_and_name", - "public", - Field.of("fld1", JsonSchemaType.NUMBER), Field.of("fld2", JsonSchemaType.STRING)).withSyncMode(SyncMode.INCREMENTAL))); - - final String anchored = "^" + RelationalDbDebeziumPropertiesManager.getColumnIncludeList(catalog) + "$"; - final Pattern pattern = Pattern.compile(anchored); - - assertTrue(pattern.matcher("public.id_and_name.fld1").find()); - assertTrue(pattern.matcher("public.id_and_name.fld2").find()); - assertFalse(pattern.matcher("ic.id_and_name.fl").find()); - assertFalse(pattern.matcher("ppppublic.id_and_name.fld2333").find()); - assertFalse(pattern.matcher("public.id_and_name.fld_wrong_wrong").find()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.java deleted file mode 100644 index 482936bd54aa..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.airbyte.cdk.integrations.debezium.internals.AirbyteSchemaHistoryStorage.SchemaHistory; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.resources.MoreResources; -import java.io.IOException; -import java.util.Optional; -import org.junit.jupiter.api.Test; - -public class AirbyteSchemaHistoryStorageTest { - - @Test - public void testForContentBiggerThan1MBLimit() throws IOException { - final String contentReadDirectlyFromFile = MoreResources.readResource("dbhistory_greater_than_1_mb.dat"); - - final AirbyteSchemaHistoryStorage schemaHistoryStorageFromUncompressedContent = AirbyteSchemaHistoryStorage.initializeDBHistory( - new SchemaHistory<>(Optional.of(Jsons.jsonNode(contentReadDirectlyFromFile)), - false), - true); - final SchemaHistory schemaHistoryFromUncompressedContent = schemaHistoryStorageFromUncompressedContent.read(); - - assertTrue(schemaHistoryFromUncompressedContent.isCompressed()); - assertNotNull(schemaHistoryFromUncompressedContent.schema()); - assertEquals(contentReadDirectlyFromFile, schemaHistoryStorageFromUncompressedContent.readUncompressed()); - - final AirbyteSchemaHistoryStorage schemaHistoryStorageFromCompressedContent = AirbyteSchemaHistoryStorage.initializeDBHistory( - new SchemaHistory<>(Optional.of(Jsons.jsonNode(schemaHistoryFromUncompressedContent.schema())), - true), - true); - final SchemaHistory schemaHistoryFromCompressedContent = schemaHistoryStorageFromCompressedContent.read(); - - assertTrue(schemaHistoryFromCompressedContent.isCompressed()); - assertNotNull(schemaHistoryFromCompressedContent.schema()); - assertEquals(schemaHistoryFromUncompressedContent.schema(), schemaHistoryFromCompressedContent.schema()); - } - - @Test - public void sizeTest() throws IOException { - assertEquals(5.881045341491699, - AirbyteSchemaHistoryStorage.calculateSizeOfStringInMB(MoreResources.readResource("dbhistory_greater_than_1_mb.dat"))); - assertEquals(0.0038671493530273438, - AirbyteSchemaHistoryStorage.calculateSizeOfStringInMB(MoreResources.readResource("dbhistory_less_than_1_mb.dat"))); - } - - @Test - public void testForContentLessThan1MBLimit() throws IOException { - final String contentReadDirectlyFromFile = MoreResources.readResource("dbhistory_less_than_1_mb.dat"); - - final AirbyteSchemaHistoryStorage schemaHistoryStorageFromUncompressedContent = AirbyteSchemaHistoryStorage.initializeDBHistory( - new SchemaHistory<>(Optional.of(Jsons.jsonNode(contentReadDirectlyFromFile)), - false), - true); - final SchemaHistory schemaHistoryFromUncompressedContent = schemaHistoryStorageFromUncompressedContent.read(); - - assertFalse(schemaHistoryFromUncompressedContent.isCompressed()); - assertNotNull(schemaHistoryFromUncompressedContent.schema()); - assertEquals(contentReadDirectlyFromFile, schemaHistoryFromUncompressedContent.schema()); - - final AirbyteSchemaHistoryStorage schemaHistoryStorageFromCompressedContent = AirbyteSchemaHistoryStorage.initializeDBHistory( - new SchemaHistory<>(Optional.of(Jsons.jsonNode(schemaHistoryFromUncompressedContent.schema())), - false), - true); - final SchemaHistory schemaHistoryFromCompressedContent = schemaHistoryStorageFromCompressedContent.read(); - - assertFalse(schemaHistoryFromCompressedContent.isCompressed()); - assertNotNull(schemaHistoryFromCompressedContent.schema()); - assertEquals(schemaHistoryFromUncompressedContent.schema(), schemaHistoryFromCompressedContent.schema()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.java deleted file mode 100644 index 59312c888703..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.debezium.spi.converter.RelationalColumn; -import java.sql.Timestamp; -import java.time.Duration; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -class DebeziumConverterUtilsTest { - - @Test - public void convertDefaultValueTest() { - - final RelationalColumn relationalColumn = mock(RelationalColumn.class); - - when(relationalColumn.isOptional()).thenReturn(true); - Object actualColumnDefaultValue = DebeziumConverterUtils.convertDefaultValue(relationalColumn); - Assertions.assertNull(actualColumnDefaultValue, "Default value for optional relational column should be null"); - - when(relationalColumn.isOptional()).thenReturn(false); - when(relationalColumn.hasDefaultValue()).thenReturn(false); - actualColumnDefaultValue = DebeziumConverterUtils.convertDefaultValue(relationalColumn); - Assertions.assertNull(actualColumnDefaultValue); - - when(relationalColumn.isOptional()).thenReturn(false); - when(relationalColumn.hasDefaultValue()).thenReturn(true); - final String expectedColumnDefaultValue = "default value"; - when(relationalColumn.defaultValue()).thenReturn(expectedColumnDefaultValue); - actualColumnDefaultValue = DebeziumConverterUtils.convertDefaultValue(relationalColumn); - Assertions.assertEquals(actualColumnDefaultValue, expectedColumnDefaultValue); - } - - @Test - public void convertLocalDate() { - final LocalDate localDate = LocalDate.of(2021, 1, 1); - - final String actual = DebeziumConverterUtils.convertDate(localDate); - Assertions.assertEquals("2021-01-01T00:00:00Z", actual); - } - - @Test - public void convertTLocalTime() { - final LocalTime localTime = LocalTime.of(8, 1, 1); - final String actual = DebeziumConverterUtils.convertDate(localTime); - Assertions.assertEquals("08:01:01", actual); - } - - @Test - public void convertLocalDateTime() { - final LocalDateTime localDateTime = LocalDateTime.of(2021, 1, 1, 8, 1, 1); - - final String actual = DebeziumConverterUtils.convertDate(localDateTime); - Assertions.assertEquals("2021-01-01T08:01:01Z", actual); - } - - @Test - @Disabled - public void convertDuration() { - final Duration duration = Duration.ofHours(100_000); - - final String actual = DebeziumConverterUtils.convertDate(duration); - Assertions.assertEquals("1981-05-29T20:00:00Z", actual); - } - - @Test - public void convertTimestamp() { - final LocalDateTime localDateTime = LocalDateTime.of(2021, 1, 1, 8, 1, 1); - final Timestamp timestamp = Timestamp.valueOf(localDateTime); - - final String actual = DebeziumConverterUtils.convertDate(timestamp); - Assertions.assertEquals("2021-01-01T08:01:01.000000Z", actual); - } - - @Test - @Disabled - public void convertNumber() { - final Number number = 100_000; - - final String actual = DebeziumConverterUtils.convertDate(number); - Assertions.assertEquals("1970-01-01T03:01:40Z", actual); - } - - @Test - public void convertStringDateFormat() { - final String stringValue = "2021-01-01T00:00:00Z"; - - final String actual = DebeziumConverterUtils.convertDate(stringValue); - Assertions.assertEquals("2021-01-01T00:00:00Z", actual); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducerTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducerTest.java deleted file mode 100644 index 5a9a4b9a9f84..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducerTest.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) 2024 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.airbyte.cdk.integrations.debezium.CdcStateHandler; -import io.airbyte.cdk.integrations.debezium.CdcTargetPosition; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -public class DebeziumMessageProducerTest { - - private DebeziumMessageProducer producer; - - CdcStateHandler cdcStateHandler; - CdcTargetPosition targetPosition; - DebeziumEventConverter eventConverter; - AirbyteFileOffsetBackingStore offsetManager; - AirbyteSchemaHistoryStorage schemaHistoryManager; - - private static Map OFFSET_MANAGER_READ = new HashMap<>(Map.of("key", "value")); - private static Map OFFSET_MANAGER_READ2 = new HashMap<>(Map.of("key2", "value2")); - - private static AirbyteSchemaHistoryStorage.SchemaHistory SCHEMA = new AirbyteSchemaHistoryStorage.SchemaHistory("schema", false); - - private static AirbyteStateMessage STATE_MESSAGE = new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL); - - @BeforeEach - void setUp() { - cdcStateHandler = mock(CdcStateHandler.class); - when(cdcStateHandler.isCdcCheckpointEnabled()).thenReturn(true); - targetPosition = mock(CdcTargetPosition.class); - eventConverter = mock(DebeziumEventConverter.class); - offsetManager = mock(AirbyteFileOffsetBackingStore.class); - when(offsetManager.read()).thenReturn(OFFSET_MANAGER_READ); - schemaHistoryManager = mock(AirbyteSchemaHistoryStorage.class); - when(schemaHistoryManager.read()).thenReturn(SCHEMA); - producer = new DebeziumMessageProducer(cdcStateHandler, targetPosition, eventConverter, offsetManager, Optional.of(schemaHistoryManager)); - } - - @Test - void testProcessRecordMessage() { - ChangeEventWithMetadata message = mock(ChangeEventWithMetadata.class); - - when(targetPosition.isSameOffset(any(), any())).thenReturn(true); - producer.processRecordMessage(null, message); - verify(eventConverter).toAirbyteMessage(message); - assertFalse(producer.shouldEmitStateMessage(null)); - } - - @Test - void testProcessRecordMessageWithStateMessage() { - ChangeEventWithMetadata message = mock(ChangeEventWithMetadata.class); - - when(targetPosition.isSameOffset(any(), any())).thenReturn(false); - when(targetPosition.isEventAheadOffset(OFFSET_MANAGER_READ, message)).thenReturn(true); - producer.processRecordMessage(null, message); - verify(eventConverter).toAirbyteMessage(message); - assertTrue(producer.shouldEmitStateMessage(null)); - - when(cdcStateHandler.isCdcCheckpointEnabled()).thenReturn(false); - when(cdcStateHandler.saveState(eq(OFFSET_MANAGER_READ), eq(SCHEMA))).thenReturn(new AirbyteMessage().withState(STATE_MESSAGE)); - - assertEquals(producer.generateStateMessageAtCheckpoint(null), STATE_MESSAGE); - } - - @Test - void testGenerateFinalMessageNoProgress() { - when(cdcStateHandler.saveState(eq(OFFSET_MANAGER_READ), eq(SCHEMA))).thenReturn(new AirbyteMessage().withState(STATE_MESSAGE)); - - // initialOffset will be OFFSET_MANAGER_READ, final state would be OFFSET_MANAGER_READ2. - // Mock CDC handler will only accept OFFSET_MANAGER_READ. - when(offsetManager.read()).thenReturn(OFFSET_MANAGER_READ2); - - when(targetPosition.isSameOffset(OFFSET_MANAGER_READ, OFFSET_MANAGER_READ2)).thenReturn(true); - - assertEquals(producer.createFinalStateMessage(null), STATE_MESSAGE); - } - - @Test - void testGenerateFinalMessageWithProgress() { - when(cdcStateHandler.saveState(eq(OFFSET_MANAGER_READ2), eq(SCHEMA))).thenReturn(new AirbyteMessage().withState(STATE_MESSAGE)); - - // initialOffset will be OFFSET_MANAGER_READ, final state would be OFFSET_MANAGER_READ2. - // Mock CDC handler will only accept OFFSET_MANAGER_READ2. - when(offsetManager.read()).thenReturn(OFFSET_MANAGER_READ2); - when(targetPosition.isSameOffset(OFFSET_MANAGER_READ, OFFSET_MANAGER_READ2)).thenReturn(false); - - assertEquals(producer.createFinalStateMessage(null), STATE_MESSAGE); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.java deleted file mode 100644 index e386b100c647..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.Mockito.mock; - -import io.airbyte.cdk.integrations.debezium.CdcTargetPosition; -import io.debezium.engine.ChangeEvent; -import java.time.Duration; -import java.util.Collections; -import java.util.Map; -import java.util.concurrent.LinkedBlockingQueue; -import org.apache.kafka.connect.source.SourceRecord; -import org.junit.jupiter.api.Test; - -public class DebeziumRecordIteratorTest { - - @Test - public void getHeartbeatPositionTest() { - final DebeziumRecordIterator debeziumRecordIterator = new DebeziumRecordIterator<>(mock(LinkedBlockingQueue.class), - new CdcTargetPosition<>() { - - @Override - public boolean reachedTargetPosition(final ChangeEventWithMetadata changeEventWithMetadata) { - return false; - } - - @Override - public Long extractPositionFromHeartbeatOffset(final Map sourceOffset) { - return (long) sourceOffset.get("lsn"); - } - - }, - () -> false, - mock(DebeziumShutdownProcedure.class), - Duration.ZERO, - Duration.ZERO); - final Long lsn = debeziumRecordIterator.getHeartbeatPosition(new ChangeEvent() { - - private final SourceRecord sourceRecord = new SourceRecord(null, Collections.singletonMap("lsn", 358824993496L), null, null, null); - - @Override - public String key() { - return null; - } - - @Override - public String value() { - return "{\"ts_ms\":1667616934701}"; - } - - @Override - public String destination() { - return null; - } - - public SourceRecord sourceRecord() { - return sourceRecord; - } - - }); - - assertEquals(lsn, 358824993496L); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.java deleted file mode 100644 index 335d157ed271..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicInteger; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class DebeziumShutdownProcedureTest { - - @Test - public void test() throws InterruptedException { - final LinkedBlockingQueue sourceQueue = new LinkedBlockingQueue<>(10); - final AtomicInteger recordsInserted = new AtomicInteger(); - final ExecutorService executorService = Executors.newSingleThreadExecutor(); - final DebeziumShutdownProcedure debeziumShutdownProcedure = new DebeziumShutdownProcedure<>(sourceQueue, - executorService::shutdown, () -> recordsInserted.get() >= 99); - executorService.execute(() -> { - for (int i = 0; i < 100; i++) { - try { - sourceQueue.put(i); - recordsInserted.set(i); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - }); - - Thread.sleep(1000); - debeziumShutdownProcedure.initiateShutdownProcedure(); - - Assertions.assertEquals(99, recordsInserted.get()); - Assertions.assertEquals(0, sourceQueue.size()); - Assertions.assertEquals(100, debeziumShutdownProcedure.getRecordsRemainingAfterShutdown().size()); - - for (int i = 0; i < 100; i++) { - Assertions.assertEquals(i, debeziumShutdownProcedure.getRecordsRemainingAfterShutdown().poll()); - } - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.java deleted file mode 100644 index 64701dd40668..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium.internals; - -import static io.airbyte.cdk.integrations.debezium.internals.RecordWaitTimeUtil.MAX_FIRST_RECORD_WAIT_TIME; -import static io.airbyte.cdk.integrations.debezium.internals.RecordWaitTimeUtil.MIN_FIRST_RECORD_WAIT_TIME; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.commons.json.Jsons; -import java.time.Duration; -import java.util.Collections; -import java.util.Map; -import java.util.Optional; -import org.junit.jupiter.api.Test; - -public class RecordWaitTimeUtilTest { - - @Test - void testGetFirstRecordWaitTime() { - final JsonNode emptyConfig = Jsons.jsonNode(Collections.emptyMap()); - assertDoesNotThrow(() -> RecordWaitTimeUtil.checkFirstRecordWaitTime(emptyConfig)); - assertEquals(Optional.empty(), RecordWaitTimeUtil.getFirstRecordWaitSeconds(emptyConfig)); - assertEquals(RecordWaitTimeUtil.DEFAULT_FIRST_RECORD_WAIT_TIME, RecordWaitTimeUtil.getFirstRecordWaitTime(emptyConfig)); - - final JsonNode normalConfig = Jsons.jsonNode(Map.of("replication_method", - Map.of("method", "CDC", "initial_waiting_seconds", 500))); - assertDoesNotThrow(() -> RecordWaitTimeUtil.checkFirstRecordWaitTime(normalConfig)); - assertEquals(Optional.of(500), RecordWaitTimeUtil.getFirstRecordWaitSeconds(normalConfig)); - assertEquals(Duration.ofSeconds(500), RecordWaitTimeUtil.getFirstRecordWaitTime(normalConfig)); - - final int tooShortTimeout = (int) MIN_FIRST_RECORD_WAIT_TIME.getSeconds() - 1; - final JsonNode tooShortConfig = Jsons.jsonNode(Map.of("replication_method", - Map.of("method", "CDC", "initial_waiting_seconds", tooShortTimeout))); - assertThrows(IllegalArgumentException.class, () -> RecordWaitTimeUtil.checkFirstRecordWaitTime(tooShortConfig)); - assertEquals(Optional.of(tooShortTimeout), RecordWaitTimeUtil.getFirstRecordWaitSeconds(tooShortConfig)); - assertEquals(MIN_FIRST_RECORD_WAIT_TIME, RecordWaitTimeUtil.getFirstRecordWaitTime(tooShortConfig)); - - final int tooLongTimeout = (int) MAX_FIRST_RECORD_WAIT_TIME.getSeconds() + 1; - final JsonNode tooLongConfig = Jsons.jsonNode(Map.of("replication_method", - Map.of("method", "CDC", "initial_waiting_seconds", tooLongTimeout))); - assertThrows(IllegalArgumentException.class, () -> RecordWaitTimeUtil.checkFirstRecordWaitTime(tooLongConfig)); - assertEquals(Optional.of(tooLongTimeout), RecordWaitTimeUtil.getFirstRecordWaitSeconds(tooLongConfig)); - assertEquals(MAX_FIRST_RECORD_WAIT_TIME, RecordWaitTimeUtil.getFirstRecordWaitTime(tooLongConfig)); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.java deleted file mode 100644 index 196a38c0d3ad..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.java +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc; - -import static io.airbyte.cdk.integrations.source.jdbc.JdbcDataSourceUtils.assertCustomParametersDontOverwriteDefaultParameters; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.ImmutableMap; -import io.airbyte.cdk.db.factory.DatabaseDriver; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig; -import io.airbyte.cdk.integrations.base.IntegrationRunner; -import io.airbyte.cdk.integrations.base.Source; -import io.airbyte.cdk.integrations.source.jdbc.test.JdbcSourceAcceptanceTest; -import io.airbyte.cdk.integrations.util.HostPortResolver; -import io.airbyte.cdk.testutils.TestDatabase; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import java.sql.JDBCType; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Stream; -import org.jooq.SQLDialect; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.testcontainers.containers.PostgreSQLContainer; - -/** - * Runs the acceptance tests in the source-jdbc test module. We want this module to run these tests - * itself as a sanity check. The trade off here is that this class is duplicated from the one used - * in source-postgres. - */ -class DefaultJdbcSourceAcceptanceTest - extends JdbcSourceAcceptanceTest { - - private static PostgreSQLContainer PSQL_CONTAINER; - - @BeforeAll - static void init() { - PSQL_CONTAINER = new PostgreSQLContainer<>("postgres:13-alpine"); - PSQL_CONTAINER.start(); - CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY = "CREATE TABLE %s (%s BIT(3) NOT NULL);"; - INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY = "INSERT INTO %s VALUES(B'101');"; - } - - @Override - protected JsonNode config() { - return testdb.testConfigBuilder().build(); - } - - @Override - protected PostgresTestSource source() { - return new PostgresTestSource(); - } - - @Override - protected BareBonesTestDatabase createTestDatabase() { - return new BareBonesTestDatabase(PSQL_CONTAINER).initialized(); - } - - @Override - public boolean supportsSchemas() { - return true; - } - - public JsonNode getConfigWithConnectionProperties(final PostgreSQLContainer psqlDb, final String dbName, final String additionalParameters) { - return Jsons.jsonNode(ImmutableMap.builder() - .put(JdbcUtils.HOST_KEY, HostPortResolver.resolveHost(psqlDb)) - .put(JdbcUtils.PORT_KEY, HostPortResolver.resolvePort(psqlDb)) - .put(JdbcUtils.DATABASE_KEY, dbName) - .put(JdbcUtils.SCHEMAS_KEY, List.of(SCHEMA_NAME)) - .put(JdbcUtils.USERNAME_KEY, psqlDb.getUsername()) - .put(JdbcUtils.PASSWORD_KEY, psqlDb.getPassword()) - .put(JdbcUtils.CONNECTION_PROPERTIES_KEY, additionalParameters) - .build()); - } - - @AfterAll - static void cleanUp() { - PSQL_CONTAINER.close(); - } - - public static class PostgresTestSource extends AbstractJdbcSource implements Source { - - private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class); - - static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName; - - public PostgresTestSource() { - super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.defaultSourceOperations); - } - - @Override - public JsonNode toDatabaseConfig(final JsonNode config) { - final ImmutableMap.Builder configBuilder = ImmutableMap.builder() - .put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText()) - .put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString, - config.get(JdbcUtils.HOST_KEY).asText(), - config.get(JdbcUtils.PORT_KEY).asInt(), - config.get(JdbcUtils.DATABASE_KEY).asText())); - - if (config.has(JdbcUtils.PASSWORD_KEY)) { - configBuilder.put(JdbcUtils.PASSWORD_KEY, config.get(JdbcUtils.PASSWORD_KEY).asText()); - } - - return Jsons.jsonNode(configBuilder.build()); - } - - @Override - public Set getExcludedInternalNameSpaces() { - return Set.of("information_schema", "pg_catalog", "pg_internal", "catalog_history"); - } - - @Override - protected AirbyteStateType getSupportedStateType(final JsonNode config) { - return AirbyteStateType.STREAM; - } - - public static void main(final String[] args) throws Exception { - final Source source = new PostgresTestSource(); - LOGGER.info("starting source: {}", PostgresTestSource.class); - new IntegrationRunner(source).run(args); - LOGGER.info("completed source: {}", PostgresTestSource.class); - } - - } - - static protected class BareBonesTestDatabase - extends TestDatabase, BareBonesTestDatabase, BareBonesTestDatabase.BareBonesConfigBuilder> { - - public BareBonesTestDatabase(PostgreSQLContainer container) { - super(container); - } - - @Override - protected Stream> inContainerBootstrapCmd() { - final var sql = Stream.of( - String.format("CREATE DATABASE %s", getDatabaseName()), - String.format("CREATE USER %s PASSWORD '%s'", getUserName(), getPassword()), - String.format("GRANT ALL PRIVILEGES ON DATABASE %s TO %s", getDatabaseName(), getUserName()), - String.format("ALTER USER %s WITH SUPERUSER", getUserName())); - return Stream.of(Stream.concat( - Stream.of("psql", - "-d", container.getDatabaseName(), - "-U", container.getUsername(), - "-v", "ON_ERROR_STOP=1", - "-a"), - sql.flatMap(stmt -> Stream.of("-c", stmt)))); - } - - @Override - protected Stream inContainerUndoBootstrapCmd() { - return Stream.empty(); - } - - @Override - public DatabaseDriver getDatabaseDriver() { - return DatabaseDriver.POSTGRESQL; - } - - @Override - public SQLDialect getSqlDialect() { - return SQLDialect.POSTGRES; - } - - @Override - public BareBonesConfigBuilder configBuilder() { - return new BareBonesConfigBuilder(this); - } - - static protected class BareBonesConfigBuilder extends TestDatabase.ConfigBuilder { - - private BareBonesConfigBuilder(BareBonesTestDatabase testDatabase) { - super(testDatabase); - } - - } - - } - - @Test - void testCustomParametersOverwriteDefaultParametersExpectException() { - final String connectionPropertiesUrl = "ssl=false"; - final JsonNode config = getConfigWithConnectionProperties(PSQL_CONTAINER, testdb.getDatabaseName(), connectionPropertiesUrl); - final Map customParameters = JdbcUtils.parseJdbcParameters(config, JdbcUtils.CONNECTION_PROPERTIES_KEY, "&"); - final Map defaultParameters = Map.of( - "ssl", "true", - "sslmode", "require"); - assertThrows(IllegalArgumentException.class, () -> { - assertCustomParametersDontOverwriteDefaultParameters(customParameters, defaultParameters); - }); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.java deleted file mode 100644 index 2a536228abac..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.ImmutableMap; -import io.airbyte.cdk.db.factory.DatabaseDriver; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig; -import io.airbyte.cdk.integrations.base.IntegrationRunner; -import io.airbyte.cdk.integrations.base.Source; -import io.airbyte.cdk.integrations.source.jdbc.test.JdbcStressTest; -import io.airbyte.cdk.testutils.PostgreSQLContainerHelper; -import io.airbyte.commons.io.IOs; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.string.Strings; -import java.sql.JDBCType; -import java.util.Optional; -import java.util.Set; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.utility.MountableFile; - -/** - * Runs the stress tests in the source-jdbc test module. We want this module to run these tests - * itself as a sanity check. The trade off here is that this class is duplicated from the one used - * in source-postgres. - */ -@Disabled -class DefaultJdbcStressTest extends JdbcStressTest { - - private static PostgreSQLContainer PSQL_DB; - - private JsonNode config; - - @BeforeAll - static void init() { - PSQL_DB = new PostgreSQLContainer<>("postgres:13-alpine"); - PSQL_DB.start(); - } - - @BeforeEach - public void setup() throws Exception { - final String dbName = Strings.addRandomSuffix("db", "_", 10); - - config = Jsons.jsonNode(ImmutableMap.of(JdbcUtils.HOST_KEY, "localhost", - JdbcUtils.PORT_KEY, 5432, - JdbcUtils.DATABASE_KEY, "charles", - JdbcUtils.USERNAME_KEY, "postgres", - JdbcUtils.PASSWORD_KEY, "")); - - config = Jsons.jsonNode(ImmutableMap.builder() - .put(JdbcUtils.HOST_KEY, PSQL_DB.getHost()) - .put(JdbcUtils.PORT_KEY, PSQL_DB.getFirstMappedPort()) - .put(JdbcUtils.DATABASE_KEY, dbName) - .put(JdbcUtils.USERNAME_KEY, PSQL_DB.getUsername()) - .put(JdbcUtils.PASSWORD_KEY, PSQL_DB.getPassword()) - .build()); - - final String initScriptName = "init_" + dbName.concat(".sql"); - final String tmpFilePath = IOs.writeFileToRandomTmpDir(initScriptName, "CREATE DATABASE " + dbName + ";"); - PostgreSQLContainerHelper.runSqlScript(MountableFile.forHostPath(tmpFilePath), PSQL_DB); - - super.setup(); - } - - @Override - public Optional getDefaultSchemaName() { - return Optional.of("public"); - } - - @Override - public AbstractJdbcSource getSource() { - return new PostgresTestSource(); - } - - @Override - public JsonNode getConfig() { - return config; - } - - @Override - public String getDriverClass() { - return PostgresTestSource.DRIVER_CLASS; - } - - @AfterAll - static void cleanUp() { - PSQL_DB.close(); - } - - private static class PostgresTestSource extends AbstractJdbcSource implements Source { - - private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class); - - static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName; - - public PostgresTestSource() { - super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.defaultSourceOperations); - } - - @Override - public JsonNode toDatabaseConfig(final JsonNode config) { - final ImmutableMap.Builder configBuilder = ImmutableMap.builder() - .put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText()) - .put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString, - config.get(JdbcUtils.HOST_KEY).asText(), - config.get(JdbcUtils.PORT_KEY).asInt(), - config.get(JdbcUtils.DATABASE_KEY).asText())); - - if (config.has(JdbcUtils.PASSWORD_KEY)) { - configBuilder.put(JdbcUtils.PASSWORD_KEY, config.get(JdbcUtils.PASSWORD_KEY).asText()); - } - - return Jsons.jsonNode(configBuilder.build()); - } - - @Override - public Set getExcludedInternalNameSpaces() { - return Set.of("information_schema", "pg_catalog", "pg_internal", "catalog_history"); - } - - public static void main(final String[] args) throws Exception { - final Source source = new PostgresTestSource(); - LOGGER.info("starting source: {}", PostgresTestSource.class); - new IntegrationRunner(source).run(args); - LOGGER.info("completed source: {}", PostgresTestSource.class); - } - - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.java deleted file mode 100644 index 116d122d7d31..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc; - -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.junit.Assert.assertTrue; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.commons.json.Jsons; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.Test; - -public class JdbcDataSourceUtilsTest { - - @Test - void test() { - final String validConfigString = "{\"jdbc_url_params\":\"key1=val1&key3=key3\",\"connection_properties\":\"key1=val1&key2=val2\"}"; - final JsonNode validConfig = Jsons.deserialize(validConfigString); - final Map connectionProperties = JdbcDataSourceUtils.getConnectionProperties(validConfig); - final List validKeys = List.of("key1", "key2", "key3"); - validKeys.forEach(key -> assertTrue(connectionProperties.containsKey(key))); - - // For an invalid config, there is a conflict betweeen the values of keys in jdbc_url_params and - // connection_properties - final String invalidConfigString = "{\"jdbc_url_params\":\"key1=val2&key3=key3\",\"connection_properties\":\"key1=val1&key2=val2\"}"; - final JsonNode invalidConfig = Jsons.deserialize(invalidConfigString); - final Exception exception = assertThrows(IllegalArgumentException.class, () -> { - JdbcDataSourceUtils.getConnectionProperties(invalidConfig); - }); - - final String expectedMessage = "Cannot overwrite default JDBC parameter key1"; - assertThat(expectedMessage.equals(exception.getMessage())); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.java deleted file mode 100644 index 000a18062bc7..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.ImmutableMap; -import io.airbyte.cdk.db.factory.DatabaseDriver; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig; -import io.airbyte.cdk.integrations.base.IntegrationRunner; -import io.airbyte.cdk.integrations.base.Source; -import io.airbyte.cdk.integrations.source.jdbc.test.JdbcStressTest; -import io.airbyte.cdk.testutils.PostgreSQLContainerHelper; -import io.airbyte.commons.io.IOs; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.string.Strings; -import java.sql.JDBCType; -import java.util.Optional; -import java.util.Set; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.utility.MountableFile; - -/** - * Runs the stress tests in the source-jdbc test module. We want this module to run these tests - * itself as a sanity check. The trade off here is that this class is duplicated from the one used - * in source-postgres. - */ -@Disabled -class JdbcSourceStressTest extends JdbcStressTest { - - private static PostgreSQLContainer PSQL_DB; - - private JsonNode config; - - @BeforeAll - static void init() { - PSQL_DB = new PostgreSQLContainer<>("postgres:13-alpine"); - PSQL_DB.start(); - } - - @BeforeEach - public void setup() throws Exception { - final String schemaName = Strings.addRandomSuffix("db", "_", 10);; - - config = Jsons.jsonNode(ImmutableMap.builder() - .put(JdbcUtils.HOST_KEY, PSQL_DB.getHost()) - .put(JdbcUtils.PORT_KEY, PSQL_DB.getFirstMappedPort()) - .put(JdbcUtils.DATABASE_KEY, schemaName) - .put(JdbcUtils.USERNAME_KEY, PSQL_DB.getUsername()) - .put(JdbcUtils.PASSWORD_KEY, PSQL_DB.getPassword()) - .build()); - - final String initScriptName = "init_" + schemaName.concat(".sql"); - final String tmpFilePath = IOs.writeFileToRandomTmpDir(initScriptName, "CREATE DATABASE " + schemaName + ";"); - PostgreSQLContainerHelper.runSqlScript(MountableFile.forHostPath(tmpFilePath), PSQL_DB); - - super.setup(); - } - - @Override - public Optional getDefaultSchemaName() { - return Optional.of("public"); - } - - @Override - public AbstractJdbcSource getSource() { - return new PostgresTestSource(); - } - - @Override - public JsonNode getConfig() { - return config; - } - - @Override - public String getDriverClass() { - return PostgresTestSource.DRIVER_CLASS; - } - - @AfterAll - static void cleanUp() { - PSQL_DB.close(); - } - - private static class PostgresTestSource extends AbstractJdbcSource implements Source { - - private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class); - - static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName; - - public PostgresTestSource() { - super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.defaultSourceOperations); - } - - @Override - public JsonNode toDatabaseConfig(final JsonNode config) { - final ImmutableMap.Builder configBuilder = ImmutableMap.builder() - .put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText()) - .put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString, - config.get(JdbcUtils.HOST_KEY).asText(), - config.get(JdbcUtils.PORT_KEY).asInt(), - config.get(JdbcUtils.DATABASE_KEY).asText())); - - if (config.has(JdbcUtils.PASSWORD_KEY)) { - configBuilder.put(JdbcUtils.PASSWORD_KEY, config.get(JdbcUtils.PASSWORD_KEY).asText()); - } - - return Jsons.jsonNode(configBuilder.build()); - } - - @Override - public Set getExcludedInternalNameSpaces() { - return Set.of("information_schema", "pg_catalog", "pg_internal", "catalog_history"); - } - - public static void main(final String[] args) throws Exception { - final Source source = new PostgresTestSource(); - LOGGER.info("starting source: {}", PostgresTestSource.class); - new IntegrationRunner(source).run(args); - LOGGER.info("completed source: {}", PostgresTestSource.class); - } - - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.java deleted file mode 100644 index 9e7bab7177f2..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.Mockito.CALLS_REAL_METHODS; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.withSettings; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.integrations.source.relationaldb.state.StateGeneratorUtils; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.resources.MoreResources; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import java.io.IOException; -import java.util.List; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; -import uk.org.webcompere.systemstubs.jupiter.SystemStub; -import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension; - -/** - * Test suite for the {@link AbstractDbSource} class. - */ -@ExtendWith(SystemStubsExtension.class) -public class AbstractDbSourceTest { - - @SystemStub - private EnvironmentVariables environmentVariables; - - @Test - void testDeserializationOfLegacyState() throws IOException { - final AbstractDbSource dbSource = mock(AbstractDbSource.class, withSettings().useConstructor("").defaultAnswer(CALLS_REAL_METHODS)); - final JsonNode config = mock(JsonNode.class); - - final String legacyStateJson = MoreResources.readResource("states/legacy.json"); - final JsonNode legacyState = Jsons.deserialize(legacyStateJson); - - final List result = StateGeneratorUtils.deserializeInitialState(legacyState, - dbSource.getSupportedStateType(config)); - assertEquals(1, result.size()); - assertEquals(AirbyteStateType.LEGACY, result.get(0).getType()); - } - - @Test - void testDeserializationOfGlobalState() throws IOException { - final AbstractDbSource dbSource = mock(AbstractDbSource.class, withSettings().useConstructor("").defaultAnswer(CALLS_REAL_METHODS)); - final JsonNode config = mock(JsonNode.class); - - final String globalStateJson = MoreResources.readResource("states/global.json"); - final JsonNode globalState = Jsons.deserialize(globalStateJson); - - final List result = - StateGeneratorUtils.deserializeInitialState(globalState, dbSource.getSupportedStateType(config)); - assertEquals(1, result.size()); - assertEquals(AirbyteStateType.GLOBAL, result.get(0).getType()); - } - - @Test - void testDeserializationOfStreamState() throws IOException { - final AbstractDbSource dbSource = mock(AbstractDbSource.class, withSettings().useConstructor("").defaultAnswer(CALLS_REAL_METHODS)); - final JsonNode config = mock(JsonNode.class); - - final String streamStateJson = MoreResources.readResource("states/per_stream.json"); - final JsonNode streamState = Jsons.deserialize(streamStateJson); - - final List result = - StateGeneratorUtils.deserializeInitialState(streamState, dbSource.getSupportedStateType(config)); - assertEquals(2, result.size()); - assertEquals(AirbyteStateType.STREAM, result.get(0).getType()); - } - - @Test - void testDeserializationOfNullState() throws IOException { - final AbstractDbSource dbSource = mock(AbstractDbSource.class, withSettings().useConstructor("").defaultAnswer(CALLS_REAL_METHODS)); - final JsonNode config = mock(JsonNode.class); - - final List result = StateGeneratorUtils.deserializeInitialState(null, dbSource.getSupportedStateType(config)); - assertEquals(1, result.size()); - assertEquals(dbSource.getSupportedStateType(config), result.get(0).getType()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.java deleted file mode 100644 index 9f5dccbed7fc..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_FIELD1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_FIELD2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_RECORD_COUNT; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAME_NAMESPACE_PAIR1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAME_NAMESPACE_PAIR2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.getCatalog; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.getState; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.getStream; -import static org.junit.jupiter.api.Assertions.assertEquals; - -import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import java.util.Collections; -import java.util.Optional; -import java.util.function.Function; -import org.junit.jupiter.api.Test; - -/** - * Test suite for the {@link CursorManager} class. - */ -public class CursorManagerTest { - - private static final Function CURSOR_RECORD_COUNT_FUNCTION = stream -> { - if (stream.getCursorRecordCount() != null) { - return stream.getCursorRecordCount(); - } else { - return 0L; - } - }; - - @Test - void testCreateCursorInfoCatalogAndStateSameCursorField() { - final CursorManager cursorManager = createCursorManager(CURSOR_FIELD1, CURSOR, NAME_NAMESPACE_PAIR1); - final CursorInfo actual = cursorManager.createCursorInfoForStream( - NAME_NAMESPACE_PAIR1, - getState(CURSOR_FIELD1, CURSOR, CURSOR_RECORD_COUNT), - getStream(CURSOR_FIELD1), - DbStreamState::getCursor, - DbStreamState::getCursorField, - CURSOR_RECORD_COUNT_FUNCTION); - assertEquals(new CursorInfo(CURSOR_FIELD1, CURSOR, CURSOR_RECORD_COUNT, CURSOR_FIELD1, CURSOR, CURSOR_RECORD_COUNT), actual); - } - - @Test - void testCreateCursorInfoCatalogAndStateSameCursorFieldButNoCursor() { - final CursorManager cursorManager = createCursorManager(CURSOR_FIELD1, null, NAME_NAMESPACE_PAIR1); - final CursorInfo actual = cursorManager.createCursorInfoForStream( - NAME_NAMESPACE_PAIR1, - getState(CURSOR_FIELD1, null), - getStream(CURSOR_FIELD1), - DbStreamState::getCursor, - DbStreamState::getCursorField, - CURSOR_RECORD_COUNT_FUNCTION); - assertEquals(new CursorInfo(CURSOR_FIELD1, null, CURSOR_FIELD1, null), actual); - } - - @Test - void testCreateCursorInfoCatalogAndStateChangeInCursorFieldName() { - final CursorManager cursorManager = createCursorManager(CURSOR_FIELD1, CURSOR, NAME_NAMESPACE_PAIR1); - final CursorInfo actual = cursorManager.createCursorInfoForStream( - NAME_NAMESPACE_PAIR1, - getState(CURSOR_FIELD1, CURSOR), - getStream(CURSOR_FIELD2), - DbStreamState::getCursor, - DbStreamState::getCursorField, - CURSOR_RECORD_COUNT_FUNCTION); - assertEquals(new CursorInfo(CURSOR_FIELD1, CURSOR, CURSOR_FIELD2, null), actual); - } - - @Test - void testCreateCursorInfoCatalogAndNoState() { - final CursorManager cursorManager = createCursorManager(CURSOR_FIELD1, CURSOR, NAME_NAMESPACE_PAIR1); - final CursorInfo actual = cursorManager.createCursorInfoForStream( - NAME_NAMESPACE_PAIR1, - Optional.empty(), - getStream(CURSOR_FIELD1), - DbStreamState::getCursor, - DbStreamState::getCursorField, - CURSOR_RECORD_COUNT_FUNCTION); - assertEquals(new CursorInfo(null, null, CURSOR_FIELD1, null), actual); - } - - @Test - void testCreateCursorInfoStateAndNoCatalog() { - final CursorManager cursorManager = createCursorManager(CURSOR_FIELD1, CURSOR, NAME_NAMESPACE_PAIR1); - final CursorInfo actual = cursorManager.createCursorInfoForStream( - NAME_NAMESPACE_PAIR1, - getState(CURSOR_FIELD1, CURSOR), - Optional.empty(), - DbStreamState::getCursor, - DbStreamState::getCursorField, - CURSOR_RECORD_COUNT_FUNCTION); - assertEquals(new CursorInfo(CURSOR_FIELD1, CURSOR, null, null), actual); - } - - // this is what full refresh looks like. - @Test - void testCreateCursorInfoNoCatalogAndNoState() { - final CursorManager cursorManager = createCursorManager(CURSOR_FIELD1, CURSOR, NAME_NAMESPACE_PAIR1); - final CursorInfo actual = cursorManager.createCursorInfoForStream( - NAME_NAMESPACE_PAIR1, - Optional.empty(), - Optional.empty(), - DbStreamState::getCursor, - DbStreamState::getCursorField, - CURSOR_RECORD_COUNT_FUNCTION); - assertEquals(new CursorInfo(null, null, null, null), actual); - } - - @Test - void testCreateCursorInfoStateAndCatalogButNoCursorField() { - final CursorManager cursorManager = createCursorManager(CURSOR_FIELD1, CURSOR, NAME_NAMESPACE_PAIR1); - final CursorInfo actual = cursorManager.createCursorInfoForStream( - NAME_NAMESPACE_PAIR1, - getState(CURSOR_FIELD1, CURSOR), - getStream(null), - DbStreamState::getCursor, - DbStreamState::getCursorField, - CURSOR_RECORD_COUNT_FUNCTION); - assertEquals(new CursorInfo(CURSOR_FIELD1, CURSOR, null, null), actual); - } - - @Test - void testGetters() { - final CursorManager cursorManager = createCursorManager(CURSOR_FIELD1, CURSOR, NAME_NAMESPACE_PAIR1); - final CursorInfo actualCursorInfo = new CursorInfo(CURSOR_FIELD1, CURSOR, null, null); - - assertEquals(Optional.of(actualCursorInfo), cursorManager.getCursorInfo(NAME_NAMESPACE_PAIR1)); - assertEquals(Optional.empty(), cursorManager.getCursorField(NAME_NAMESPACE_PAIR1)); - assertEquals(Optional.empty(), cursorManager.getCursor(NAME_NAMESPACE_PAIR1)); - - assertEquals(Optional.empty(), cursorManager.getCursorInfo(NAME_NAMESPACE_PAIR2)); - assertEquals(Optional.empty(), cursorManager.getCursorField(NAME_NAMESPACE_PAIR2)); - assertEquals(Optional.empty(), cursorManager.getCursor(NAME_NAMESPACE_PAIR2)); - } - - private CursorManager createCursorManager(final String cursorField, - final String cursor, - final AirbyteStreamNameNamespacePair nameNamespacePair) { - final DbStreamState dbStreamState = getState(cursorField, cursor).get(); - return new CursorManager<>( - getCatalog(cursorField).orElse(null), - () -> Collections.singleton(dbStreamState), - DbStreamState::getCursor, - DbStreamState::getCursorField, - CURSOR_RECORD_COUNT_FUNCTION, - s -> nameNamespacePair, - false); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.java deleted file mode 100644 index dc1ae098aee4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.java +++ /dev/null @@ -1,449 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import com.fasterxml.jackson.databind.node.ObjectNode; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.util.MoreIterators; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateStats; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.StreamDescriptor; -import java.sql.SQLException; -import java.time.Duration; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Optional; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; - -class CursorStateMessageProducerTest { - - private static final String NAMESPACE = "public"; - private static final String STREAM_NAME = "shoes"; - private static final String UUID_FIELD_NAME = "ascending_inventory_uuid"; - - private static final ConfiguredAirbyteStream STREAM = CatalogHelpers.createConfiguredAirbyteStream( - STREAM_NAME, - NAMESPACE, - Field.of(UUID_FIELD_NAME, JsonSchemaType.STRING)) - .withCursorField(List.of(UUID_FIELD_NAME)); - - private static final AirbyteMessage EMPTY_STATE_MESSAGE = createEmptyStateMessage(0.0); - - private static final String RECORD_VALUE_1 = "abc"; - private static final AirbyteMessage RECORD_MESSAGE_1 = createRecordMessage(RECORD_VALUE_1); - - private static final String RECORD_VALUE_2 = "def"; - private static final AirbyteMessage RECORD_MESSAGE_2 = createRecordMessage(RECORD_VALUE_2); - - private static final String RECORD_VALUE_3 = "ghi"; - private static final AirbyteMessage RECORD_MESSAGE_3 = createRecordMessage(RECORD_VALUE_3); - - private static final String RECORD_VALUE_4 = "jkl"; - private static final AirbyteMessage RECORD_MESSAGE_4 = createRecordMessage(RECORD_VALUE_4); - - private static final String RECORD_VALUE_5 = "xyz"; - private static final AirbyteMessage RECORD_MESSAGE_5 = createRecordMessage(RECORD_VALUE_5); - - private static AirbyteMessage createRecordMessage(final String recordValue) { - return new AirbyteMessage() - .withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage() - .withData(Jsons.jsonNode(ImmutableMap.of(UUID_FIELD_NAME, recordValue)))); - } - - private static AirbyteMessage createStateMessage(final String recordValue, final long cursorRecordCount, final double statsRecordCount) { - final DbStreamState dbStreamState = new DbStreamState() - .withCursorField(Collections.singletonList(UUID_FIELD_NAME)) - .withCursor(recordValue) - .withStreamName(STREAM_NAME) - .withStreamNamespace(NAMESPACE); - if (cursorRecordCount > 0) { - dbStreamState.withCursorRecordCount(cursorRecordCount); - } - final DbState dbState = new DbState().withCdc(false).withStreams(Collections.singletonList(dbStreamState)); - return new AirbyteMessage() - .withType(Type.STATE) - .withState(new AirbyteStateMessage() - .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(dbStreamState))) - .withData(Jsons.jsonNode(dbState)) - .withSourceStats(new AirbyteStateStats().withRecordCount(statsRecordCount))); - } - - private static AirbyteMessage createEmptyStateMessage(final double statsRecordCount) { - final DbStreamState dbStreamState = new DbStreamState() - .withCursorField(Collections.singletonList(UUID_FIELD_NAME)) - .withStreamName(STREAM_NAME) - .withStreamNamespace(NAMESPACE); - - final DbState dbState = new DbState().withCdc(false).withStreams(Collections.singletonList(dbStreamState)); - return new AirbyteMessage() - .withType(Type.STATE) - .withState(new AirbyteStateMessage() - .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(dbStreamState))) - .withData(Jsons.jsonNode(dbState)) - .withSourceStats(new AirbyteStateStats().withRecordCount(statsRecordCount))); - } - - private Iterator createExceptionIterator() { - return new Iterator<>() { - - final Iterator internalMessageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2, - RECORD_MESSAGE_2, RECORD_MESSAGE_3); - - @Override - public boolean hasNext() { - return true; - } - - @Override - public AirbyteMessage next() { - if (internalMessageIterator.hasNext()) { - return internalMessageIterator.next(); - } else { - // this line throws a RunTimeException wrapped around a SQLException to mimic the flow of when a - // SQLException is thrown and wrapped in - // StreamingJdbcDatabase#tryAdvance - throw new RuntimeException(new SQLException("Connection marked broken because of SQLSTATE(080006)", "08006")); - } - } - - }; - } - - private static Iterator messageIterator; - private StateManager stateManager; - - @BeforeEach - void setup() { - final AirbyteStream airbyteStream = new AirbyteStream().withNamespace(NAMESPACE).withName(STREAM_NAME); - final ConfiguredAirbyteStream configuredAirbyteStream = new ConfiguredAirbyteStream() - .withStream(airbyteStream) - .withCursorField(Collections.singletonList(UUID_FIELD_NAME)); - - stateManager = new StreamStateManager(Collections.emptyList(), - new ConfiguredAirbyteCatalog().withStreams(Collections.singletonList(configuredAirbyteStream))); - } - - @Test - void testWithoutInitialCursor() { - messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.empty()); - - final SourceStateIterator iterator = new SourceStateIterator(messageIterator, STREAM, producer, new StateEmitFrequency(0, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_1, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_2, 1, 2.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - - @Test - void testWithInitialCursor() { - // record 1 and 2 has smaller cursor value, so at the end, the initial cursor is emitted with 0 - // record count - - messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_5)); - - final SourceStateIterator iterator = new SourceStateIterator(messageIterator, STREAM, producer, new StateEmitFrequency(0, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_1, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_5, 0, 2.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - - @Test - void testCursorFieldIsEmpty() { - final AirbyteMessage recordMessage = Jsons.clone(RECORD_MESSAGE_1); - ((ObjectNode) recordMessage.getRecord().getData()).remove(UUID_FIELD_NAME); - final Iterator messageStream = MoreIterators.of(recordMessage); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.empty()); - - final SourceStateIterator iterator = new SourceStateIterator(messageStream, STREAM, producer, new StateEmitFrequency(0, Duration.ZERO)); - - assertEquals(recordMessage, iterator.next()); - // null because no records with a cursor field were replicated for the stream. - assertEquals(createEmptyStateMessage(1.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - - @Test - void testIteratorCatchesExceptionWhenEmissionFrequencyNonZero() { - final Iterator exceptionIterator = createExceptionIterator(); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)); - - final SourceStateIterator iterator = new SourceStateIterator(exceptionIterator, STREAM, producer, new StateEmitFrequency(1, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_1, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - // continues to emit RECORD_MESSAGE_2 since cursorField has not changed thus not satisfying the - // condition of "ready" - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - // emits the first state message since the iterator has changed cursorFields (2 -> 3) and met the - // frequency minimum of 1 record - assertEquals(createStateMessage(RECORD_VALUE_2, 2, 4.0), iterator.next()); - // no further records to read since Exception was caught above and marked iterator as endOfData() - assertThrows(FailedRecordIteratorException.class, () -> iterator.hasNext()); - } - - @Test - void testIteratorCatchesExceptionWhenEmissionFrequencyZero() { - final Iterator exceptionIterator = createExceptionIterator(); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)); - - final SourceStateIterator iterator = new SourceStateIterator(exceptionIterator, STREAM, producer, new StateEmitFrequency(0, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_1, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - - assertThrows(RuntimeException.class, () -> iterator.hasNext()); - } - - @Test - void testEmptyStream() { - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.empty()); - - final SourceStateIterator iterator = - new SourceStateIterator(Collections.emptyIterator(), STREAM, producer, new StateEmitFrequency(1, Duration.ZERO)); - - assertEquals(EMPTY_STATE_MESSAGE, iterator.next()); - assertFalse(iterator.hasNext()); - } - - @Test - void testUnicodeNull() { - final String recordValueWithNull = "abc\u0000"; - final AirbyteMessage recordMessageWithNull = createRecordMessage(recordValueWithNull); - - // UTF8 null \u0000 is removed from the cursor value in the state message - - messageIterator = MoreIterators.of(recordMessageWithNull); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.empty()); - - final SourceStateIterator iterator = new SourceStateIterator(messageIterator, STREAM, producer, new StateEmitFrequency(0, Duration.ZERO)); - - assertEquals(recordMessageWithNull, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_1, 1, 1.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - - @Test - // When initial cursor is null, and emit state for every record - void testStateEmissionFrequency1() { - messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2, RECORD_MESSAGE_3, RECORD_MESSAGE_4, RECORD_MESSAGE_5); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.empty()); - - final SourceStateIterator iterator = new SourceStateIterator(messageIterator, STREAM, producer, new StateEmitFrequency(1, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_1, iterator.next()); - // should emit state 1, but it is unclear whether there will be more - // records with the same cursor value, so no state is ready for emission - assertEquals(RECORD_MESSAGE_2, iterator.next()); - // emit state 1 because it is the latest state ready for emission - assertEquals(createStateMessage(RECORD_VALUE_1, 1, 2.0), iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_2, 1, 1.0), iterator.next()); - assertEquals(RECORD_MESSAGE_4, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_3, 1, 1.0), iterator.next()); - assertEquals(RECORD_MESSAGE_5, iterator.next()); - // state 4 is not emitted because there is no more record and only - // the final state should be emitted at this point; also the final - // state should only be emitted once - assertEquals(createStateMessage(RECORD_VALUE_5, 1, 1.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - - @Test - // When initial cursor is null, and emit state for every 2 records - void testStateEmissionFrequency2() { - messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2, RECORD_MESSAGE_3, RECORD_MESSAGE_4, RECORD_MESSAGE_5); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.empty()); - - final SourceStateIterator iterator = new SourceStateIterator(messageIterator, STREAM, producer, new StateEmitFrequency(2, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_1, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - // emit state 1 because it is the latest state ready for emission - assertEquals(createStateMessage(RECORD_VALUE_1, 1, 2.0), iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(RECORD_MESSAGE_4, iterator.next()); - // emit state 3 because it is the latest state ready for emission - assertEquals(createStateMessage(RECORD_VALUE_3, 1, 2.0), iterator.next()); - assertEquals(RECORD_MESSAGE_5, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_5, 1, 1.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - - @Test - // When initial cursor is not null - void testStateEmissionWhenInitialCursorIsNotNull() { - messageIterator = MoreIterators.of(RECORD_MESSAGE_2, RECORD_MESSAGE_3, RECORD_MESSAGE_4, RECORD_MESSAGE_5); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)); - - final SourceStateIterator iterator = new SourceStateIterator(messageIterator, STREAM, producer, new StateEmitFrequency(1, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_2, 1, 2.0), iterator.next()); - assertEquals(RECORD_MESSAGE_4, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_3, 1, 1.0), iterator.next()); - assertEquals(RECORD_MESSAGE_5, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_5, 1, 1.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - - /** - * Incremental syncs will sort the table with the cursor field, and emit the max cursor for every N - * records. The purpose is to emit the states frequently, so that if any transient failure occurs - * during a long sync, the next run does not need to start from the beginning, but can resume from - * the last successful intermediate state committed on the destination. The next run will start with - * `cursorField > cursor`. However, it is possible that there are multiple records with the same - * cursor value. If the intermediate state is emitted before all these records have been synced to - * the destination, some of these records may be lost. - *

- * Here is an example: - * - *

-   * | Record ID | Cursor Field | Other Field | Note                          |
-   * | --------- | ------------ | ----------- | ----------------------------- |
-   * | 1         | F1=16        | F2="abc"    |                               |
-   * | 2         | F1=16        | F2="def"    | <- state emission and failure |
-   * | 3         | F1=16        | F2="ghi"    |                               |
-   * 
- * - * If the intermediate state is emitted for record 2 and the sync fails immediately such that the - * cursor value `16` is committed, but only record 1 and 2 are actually synced, the next run will - * start with `F1 > 16` and skip record 3. - *

- * So intermediate state emission should only happen when all records with the same cursor value has - * been synced to destination. Reference: - * link - */ - @Test - // When there are multiple records with the same cursor value - void testStateEmissionForRecordsSharingSameCursorValue() { - - messageIterator = MoreIterators.of( - RECORD_MESSAGE_2, RECORD_MESSAGE_2, - RECORD_MESSAGE_3, RECORD_MESSAGE_3, RECORD_MESSAGE_3, - RECORD_MESSAGE_4, - RECORD_MESSAGE_5, RECORD_MESSAGE_5); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)); - - final SourceStateIterator iterator = new SourceStateIterator(messageIterator, STREAM, producer, new StateEmitFrequency(1, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - // state 2 is the latest state ready for emission because - // all records with the same cursor value have been emitted - assertEquals(createStateMessage(RECORD_VALUE_2, 2, 3.0), iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(RECORD_MESSAGE_4, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_3, 3, 3.0), iterator.next()); - assertEquals(RECORD_MESSAGE_5, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_4, 1, 1.0), iterator.next()); - assertEquals(RECORD_MESSAGE_5, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_5, 2, 1.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - - @Test - void testStateEmissionForRecordsSharingSameCursorValueButDifferentStatsCount() { - messageIterator = MoreIterators.of( - RECORD_MESSAGE_2, RECORD_MESSAGE_2, - RECORD_MESSAGE_2, RECORD_MESSAGE_2, - RECORD_MESSAGE_3, RECORD_MESSAGE_3, RECORD_MESSAGE_3, - RECORD_MESSAGE_3, - RECORD_MESSAGE_3, RECORD_MESSAGE_3, RECORD_MESSAGE_3); - - final CursorStateMessageProducer producer = new CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)); - - final SourceStateIterator iterator = new SourceStateIterator(messageIterator, STREAM, producer, new StateEmitFrequency(10, Duration.ZERO)); - - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_2, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - // state 2 is the latest state ready for emission because - // all records with the same cursor value have been emitted - assertEquals(createStateMessage(RECORD_VALUE_2, 4, 10.0), iterator.next()); - assertEquals(RECORD_MESSAGE_3, iterator.next()); - assertEquals(createStateMessage(RECORD_VALUE_3, 7, 1.0), iterator.next()); - assertFalse(iterator.hasNext()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.java deleted file mode 100644 index beee9c73aa89..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.java +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_FIELD1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_FIELD2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAMESPACE; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAME_NAMESPACE_PAIR1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME3; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.mock; - -import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteGlobalState; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.StreamDescriptor; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -/** - * Test suite for the {@link GlobalStateManager} class. - */ -public class GlobalStateManagerTest { - - @Test - void testCdcStateManager() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final CdcState cdcState = new CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))); - final AirbyteGlobalState globalState = new AirbyteGlobalState().withSharedState(Jsons.jsonNode(cdcState)) - .withStreamStates(List.of(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withNamespace("namespace").withName("name")) - .withStreamState(Jsons.jsonNode(new DbStreamState())))); - final StateManager stateManager = - new GlobalStateManager(new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(globalState), catalog); - assertNotNull(stateManager.getCdcStateManager()); - assertEquals(cdcState, stateManager.getCdcStateManager().getCdcState()); - assertEquals(1, stateManager.getCdcStateManager().getInitialStreamsSynced().size()); - assertTrue(stateManager.getCdcStateManager().getInitialStreamsSynced().contains(new AirbyteStreamNameNamespacePair("name", "namespace"))); - } - - @Test - void testToStateFromLegacyState() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD2)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME3).withNamespace(NAMESPACE)))); - - final CdcState cdcState = new CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))); - final DbState dbState = new DbState() - .withCdc(true) - .withCdcState(cdcState) - .withStreams(List.of( - new DbStreamState() - .withStreamName(STREAM_NAME1) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a"), - new DbStreamState() - .withStreamName(STREAM_NAME2) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD2)), - new DbStreamState() - .withStreamName(STREAM_NAME3) - .withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())); - final StateManager stateManager = new GlobalStateManager(new AirbyteStateMessage().withData(Jsons.jsonNode(dbState)), catalog); - - final long expectedRecordCount = 19L; - final DbState expectedDbState = new DbState() - .withCdc(true) - .withCdcState(cdcState) - .withStreams(List.of( - new DbStreamState() - .withStreamName(STREAM_NAME1) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a") - .withCursorRecordCount(expectedRecordCount), - new DbStreamState() - .withStreamName(STREAM_NAME2) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD2)), - new DbStreamState() - .withStreamName(STREAM_NAME3) - .withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())); - - final AirbyteGlobalState expectedGlobalState = new AirbyteGlobalState() - .withSharedState(Jsons.jsonNode(cdcState)) - .withStreamStates(List.of( - new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(new DbStreamState() - .withStreamName(STREAM_NAME1) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a") - .withCursorRecordCount(expectedRecordCount))), - new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME2).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(new DbStreamState() - .withStreamName(STREAM_NAME2) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD2)))), - new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME3).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(new DbStreamState() - .withStreamName(STREAM_NAME3) - .withStreamNamespace(NAMESPACE)))) - .stream().sorted(Comparator.comparing(o -> o.getStreamDescriptor().getName())).collect(Collectors.toList())); - final AirbyteStateMessage expected = new AirbyteStateMessage() - .withData(Jsons.jsonNode(expectedDbState)) - .withGlobal(expectedGlobalState) - .withType(AirbyteStateType.GLOBAL); - - final AirbyteStateMessage actualFirstEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR1, "a", expectedRecordCount); - assertEquals(expected, actualFirstEmission); - } - - // Discovered during CDK migration. - // Failure is: Could not find cursor information for stream: public_cars - @Disabled("Failing test.") - @Test - void testToState() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD2)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME3).withNamespace(NAMESPACE)))); - - final CdcState cdcState = new CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))); - final AirbyteGlobalState globalState = new AirbyteGlobalState().withSharedState(Jsons.jsonNode(new DbState())).withStreamStates( - List.of(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor()).withStreamState(Jsons.jsonNode(new DbStreamState())))); - final StateManager stateManager = - new GlobalStateManager(new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(globalState), catalog); - stateManager.getCdcStateManager().setCdcState(cdcState); - - final DbState expectedDbState = new DbState() - .withCdc(true) - .withCdcState(cdcState) - .withStreams(List.of( - new DbStreamState() - .withStreamName(STREAM_NAME1) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a") - .withCursorRecordCount(1L), - new DbStreamState() - .withStreamName(STREAM_NAME2) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD2)), - new DbStreamState() - .withStreamName(STREAM_NAME3) - .withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())); - - final AirbyteGlobalState expectedGlobalState = new AirbyteGlobalState() - .withSharedState(Jsons.jsonNode(cdcState)) - .withStreamStates(List.of( - new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(new DbStreamState() - .withStreamName(STREAM_NAME1) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a") - .withCursorRecordCount(1L))), - new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME2).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(new DbStreamState() - .withStreamName(STREAM_NAME2) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD2)))), - new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME3).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(new DbStreamState() - .withStreamName(STREAM_NAME3) - .withStreamNamespace(NAMESPACE)))) - .stream().sorted(Comparator.comparing(o -> o.getStreamDescriptor().getName())).collect(Collectors.toList())); - final AirbyteStateMessage expected = new AirbyteStateMessage() - .withData(Jsons.jsonNode(expectedDbState)) - .withGlobal(expectedGlobalState) - .withType(AirbyteStateType.GLOBAL); - - final AirbyteStateMessage actualFirstEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR1, "a", 1L); - assertEquals(expected, actualFirstEmission); - } - - @Test - void testToStateWithNoState() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog(); - final StateManager stateManager = - new GlobalStateManager(new AirbyteStateMessage(), catalog); - - final AirbyteStateMessage airbyteStateMessage = stateManager.toState(Optional.empty()); - assertNotNull(airbyteStateMessage); - assertEquals(AirbyteStateType.GLOBAL, airbyteStateMessage.getType()); - assertEquals(0, airbyteStateMessage.getGlobal().getStreamStates().size()); - } - - @Test - void testCdcStateManagerLegacyState() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final CdcState cdcState = new CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))); - final DbState dbState = new DbState().withCdcState(new CdcState().withState(Jsons.jsonNode(cdcState))) - .withStreams(List - .of(new DbStreamState().withStreamName("name").withStreamNamespace("namespace").withCursor("").withCursorField(Collections.emptyList()))) - .withCdc(true); - final StateManager stateManager = - new GlobalStateManager(new AirbyteStateMessage().withType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState)), catalog); - assertNotNull(stateManager.getCdcStateManager()); - assertEquals(1, stateManager.getCdcStateManager().getInitialStreamsSynced().size()); - assertTrue(stateManager.getCdcStateManager().getInitialStreamsSynced().contains(new AirbyteStreamNameNamespacePair("name", "namespace"))); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.java deleted file mode 100644 index 25214d1c7701..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.java +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_FIELD1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_FIELD2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAMESPACE; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAME_NAMESPACE_PAIR1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAME_NAMESPACE_PAIR2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME3; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.mockito.Mockito.mock; - -import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import org.junit.jupiter.api.Test; - -/** - * Test suite for the {@link LegacyStateManager} class. - */ -public class LegacyStateManagerTest { - - @Test - void testGetters() { - final DbState state = new DbState().withStreams(List.of( - new DbStreamState().withStreamName(STREAM_NAME1).withStreamNamespace(NAMESPACE).withCursorField(List.of(CURSOR_FIELD1)) - .withCursor(CURSOR), - new DbStreamState().withStreamName(STREAM_NAME2).withStreamNamespace(NAMESPACE))); - - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE)))); - - final StateManager stateManager = new LegacyStateManager(state, catalog); - - assertEquals(Optional.of(CURSOR_FIELD1), stateManager.getOriginalCursorField(NAME_NAMESPACE_PAIR1)); - assertEquals(Optional.of(CURSOR), stateManager.getOriginalCursor(NAME_NAMESPACE_PAIR1)); - assertEquals(Optional.of(CURSOR_FIELD1), stateManager.getCursorField(NAME_NAMESPACE_PAIR1)); - assertEquals(Optional.of(CURSOR), stateManager.getCursor(NAME_NAMESPACE_PAIR1)); - - assertEquals(Optional.empty(), stateManager.getOriginalCursorField(NAME_NAMESPACE_PAIR2)); - assertEquals(Optional.empty(), stateManager.getOriginalCursor(NAME_NAMESPACE_PAIR2)); - assertEquals(Optional.empty(), stateManager.getCursorField(NAME_NAMESPACE_PAIR2)); - assertEquals(Optional.empty(), stateManager.getCursor(NAME_NAMESPACE_PAIR2)); - } - - @Test - void testToState() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD2)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME3).withNamespace(NAMESPACE)))); - - final StateManager stateManager = new LegacyStateManager(new DbState(), catalog); - - final AirbyteStateMessage expectedFirstEmission = new AirbyteStateMessage() - .withType(AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(new DbState().withStreams(List.of( - new DbStreamState().withStreamName(STREAM_NAME1).withStreamNamespace(NAMESPACE).withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a"), - new DbStreamState().withStreamName(STREAM_NAME2).withStreamNamespace(NAMESPACE).withCursorField(List.of(CURSOR_FIELD2)), - new DbStreamState().withStreamName(STREAM_NAME3).withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())) - .withCdc(false))); - final AirbyteStateMessage actualFirstEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR1, "a"); - assertEquals(expectedFirstEmission, actualFirstEmission); - final AirbyteStateMessage expectedSecondEmission = new AirbyteStateMessage() - .withType(AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(new DbState().withStreams(List.of( - new DbStreamState().withStreamName(STREAM_NAME1).withStreamNamespace(NAMESPACE).withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a"), - new DbStreamState().withStreamName(STREAM_NAME2).withStreamNamespace(NAMESPACE).withCursorField(List.of(CURSOR_FIELD2)) - .withCursor("b"), - new DbStreamState().withStreamName(STREAM_NAME3).withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())) - .withCdc(false))); - final AirbyteStateMessage actualSecondEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR2, "b"); - assertEquals(expectedSecondEmission, actualSecondEmission); - } - - @Test - void testToStateNullCursorField() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE)))); - final StateManager stateManager = new LegacyStateManager(new DbState(), catalog); - - final AirbyteStateMessage expectedFirstEmission = new AirbyteStateMessage() - .withType(AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(new DbState().withStreams(List.of( - new DbStreamState().withStreamName(STREAM_NAME1).withStreamNamespace(NAMESPACE).withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a"), - new DbStreamState().withStreamName(STREAM_NAME2).withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())) - .withCdc(false))); - - final AirbyteStateMessage actualFirstEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR1, "a"); - assertEquals(expectedFirstEmission, actualFirstEmission); - } - - @Test - void testCursorNotUpdatedForCdc() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE)))); - - final DbState state = new DbState(); - state.setCdc(true); - final StateManager stateManager = new LegacyStateManager(state, catalog); - - final AirbyteStateMessage expectedFirstEmission = new AirbyteStateMessage() - .withType(AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(new DbState().withStreams(List.of( - new DbStreamState().withStreamName(STREAM_NAME1).withStreamNamespace(NAMESPACE).withCursorField(List.of(CURSOR_FIELD1)) - .withCursor(null), - new DbStreamState().withStreamName(STREAM_NAME2).withStreamNamespace(NAMESPACE).withCursorField(List.of())) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())) - .withCdc(true))); - final AirbyteStateMessage actualFirstEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR1, "a"); - assertEquals(expectedFirstEmission, actualFirstEmission); - final AirbyteStateMessage expectedSecondEmission = new AirbyteStateMessage() - .withType(AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(new DbState().withStreams(List.of( - new DbStreamState().withStreamName(STREAM_NAME1).withStreamNamespace(NAMESPACE).withCursorField(List.of(CURSOR_FIELD1)) - .withCursor(null), - new DbStreamState().withStreamName(STREAM_NAME2).withStreamNamespace(NAMESPACE).withCursorField(List.of()) - .withCursor(null)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())) - .withCdc(true))); - final AirbyteStateMessage actualSecondEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR2, "b"); - assertEquals(expectedSecondEmission, actualSecondEmission); - } - - @Test - void testCdcStateManager() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final CdcState cdcState = new CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))); - final DbState dbState = new DbState().withCdcState(cdcState).withStreams(List.of( - new DbStreamState().withStreamNamespace(NAMESPACE).withStreamName(STREAM_NAME1))); - final StateManager stateManager = new LegacyStateManager(dbState, catalog); - assertNotNull(stateManager.getCdcStateManager()); - assertEquals(cdcState, stateManager.getCdcStateManager().getCdcState()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.java deleted file mode 100644 index 626cd52545a4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; - -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateStats; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import java.time.Duration; -import java.util.Iterator; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -public class SourceStateIteratorTest { - - SourceStateMessageProducer mockProducer; - Iterator messageIterator; - ConfiguredAirbyteStream stream; - - SourceStateIterator sourceStateIterator; - - @BeforeEach - void setup() { - mockProducer = mock(SourceStateMessageProducer.class); - stream = mock(ConfiguredAirbyteStream.class); - messageIterator = mock(Iterator.class); - StateEmitFrequency stateEmitFrequency = new StateEmitFrequency(1L, Duration.ofSeconds(100L)); - sourceStateIterator = new SourceStateIterator(messageIterator, stream, mockProducer, stateEmitFrequency); - } - - // Provides a way to generate a record message and will verify corresponding spied functions have - // been called. - void processRecordMessage() { - doReturn(true).when(messageIterator).hasNext(); - doReturn(false).when(mockProducer).shouldEmitStateMessage(eq(stream)); - AirbyteMessage message = new AirbyteMessage().withType(Type.RECORD).withRecord(new AirbyteRecordMessage()); - doReturn(message).when(mockProducer).processRecordMessage(eq(stream), any()); - doReturn(message).when(messageIterator).next(); - - assertEquals(message, sourceStateIterator.computeNext()); - verify(mockProducer, atLeastOnce()).processRecordMessage(eq(stream), eq(message)); - } - - @Test - void testShouldProcessRecordMessage() { - processRecordMessage(); - } - - @Test - void testShouldEmitStateMessage() { - processRecordMessage(); - doReturn(true).when(mockProducer).shouldEmitStateMessage(eq(stream)); - final AirbyteStateMessage stateMessage = new AirbyteStateMessage(); - doReturn(stateMessage).when(mockProducer).generateStateMessageAtCheckpoint(stream); - AirbyteMessage expectedMessage = new AirbyteMessage().withType(Type.STATE).withState(stateMessage); - expectedMessage.getState().withSourceStats(new AirbyteStateStats().withRecordCount(1.0)); - assertEquals(expectedMessage, sourceStateIterator.computeNext()); - } - - @Test - void testShouldEmitFinalStateMessage() { - processRecordMessage(); - processRecordMessage(); - doReturn(false).when(messageIterator).hasNext(); - final AirbyteStateMessage stateMessage = new AirbyteStateMessage(); - doReturn(stateMessage).when(mockProducer).createFinalStateMessage(stream); - AirbyteMessage expectedMessage = new AirbyteMessage().withType(Type.STATE).withState(stateMessage); - expectedMessage.getState().withSourceStats(new AirbyteStateStats().withRecordCount(2.0)); - assertEquals(expectedMessage, sourceStateIterator.computeNext()); - } - - @Test - void testShouldSendEndOfData() { - processRecordMessage(); - doReturn(false).when(messageIterator).hasNext(); - doReturn(new AirbyteStateMessage()).when(mockProducer).createFinalStateMessage(stream); - sourceStateIterator.computeNext(); - - // After sending the final state, if iterator was called again, we will return null. - assertEquals(null, sourceStateIterator.computeNext()); - } - - @Test - void testShouldRethrowExceptions() { - processRecordMessage(); - doThrow(new ArrayIndexOutOfBoundsException("unexpected error")).when(messageIterator).hasNext(); - assertThrows(RuntimeException.class, () -> sourceStateIterator.computeNext()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.java deleted file mode 100644 index 0f65df39d292..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.airbyte.protocol.models.v0.StreamDescriptor; -import org.junit.jupiter.api.Test; - -/** - * Test suite for the {@link StateGeneratorUtils} class. - */ -public class StateGeneratorUtilsTest { - - @Test - void testValidStreamDescriptor() { - final StreamDescriptor streamDescriptor1 = null; - final StreamDescriptor streamDescriptor2 = new StreamDescriptor(); - final StreamDescriptor streamDescriptor3 = new StreamDescriptor().withName("name"); - final StreamDescriptor streamDescriptor4 = new StreamDescriptor().withNamespace("namespace"); - final StreamDescriptor streamDescriptor5 = new StreamDescriptor().withName("name").withNamespace("namespace"); - final StreamDescriptor streamDescriptor6 = new StreamDescriptor().withName("name").withNamespace(""); - final StreamDescriptor streamDescriptor7 = new StreamDescriptor().withName("").withNamespace("namespace"); - final StreamDescriptor streamDescriptor8 = new StreamDescriptor().withName("").withNamespace(""); - - assertFalse(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor1)); - assertFalse(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor2)); - assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor3)); - assertFalse(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor4)); - assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor5)); - assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor6)); - assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor7)); - assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor8)); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.java deleted file mode 100644 index 702429adc999..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteGlobalState; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.StreamDescriptor; -import java.util.List; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -/** - * Test suite for the {@link StateManagerFactory} class. - */ -public class StateManagerFactoryTest { - - private static final String NAMESPACE = "namespace"; - private static final String NAME = "name"; - - @Test - void testNullOrEmptyState() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - - Assertions.assertThrows(IllegalArgumentException.class, () -> { - StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, null, catalog); - }); - - Assertions.assertThrows(IllegalArgumentException.class, () -> { - StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(), catalog); - }); - - Assertions.assertThrows(IllegalArgumentException.class, () -> { - StateManagerFactory.createStateManager(AirbyteStateType.LEGACY, null, catalog); - }); - - Assertions.assertThrows(IllegalArgumentException.class, () -> { - StateManagerFactory.createStateManager(AirbyteStateType.LEGACY, List.of(), catalog); - }); - - Assertions.assertThrows(IllegalArgumentException.class, () -> { - StateManagerFactory.createStateManager(AirbyteStateType.STREAM, null, catalog); - }); - - Assertions.assertThrows(IllegalArgumentException.class, () -> { - StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(), catalog); - }); - } - - @Test - void testLegacyStateManagerCreationFromAirbyteStateMessage() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final AirbyteStateMessage airbyteStateMessage = mock(AirbyteStateMessage.class); - when(airbyteStateMessage.getData()).thenReturn(Jsons.jsonNode(new DbState())); - - final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.LEGACY, List.of(airbyteStateMessage), catalog); - - Assertions.assertNotNull(stateManager); - Assertions.assertEquals(LegacyStateManager.class, stateManager.getClass()); - } - - @Test - void testGlobalStateManagerCreation() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final AirbyteGlobalState globalState = - new AirbyteGlobalState().withSharedState(Jsons.jsonNode(new DbState().withCdcState(new CdcState().withState(Jsons.jsonNode(new DbState()))))) - .withStreamStates(List.of(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withNamespace(NAMESPACE).withName(NAME)) - .withStreamState(Jsons.jsonNode(new DbStreamState())))); - final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(globalState); - - final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog); - - Assertions.assertNotNull(stateManager); - Assertions.assertEquals(GlobalStateManager.class, stateManager.getClass()); - } - - @Test - void testGlobalStateManagerCreationFromLegacyState() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final CdcState cdcState = new CdcState(); - final DbState dbState = new DbState() - .withCdcState(cdcState) - .withStreams(List.of(new DbStreamState().withStreamName(NAME).withStreamNamespace(NAMESPACE))); - final AirbyteStateMessage airbyteStateMessage = - new AirbyteStateMessage().withType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState)); - - final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog); - - Assertions.assertNotNull(stateManager); - Assertions.assertEquals(GlobalStateManager.class, stateManager.getClass()); - } - - @Test - void testGlobalStateManagerCreationFromStreamState() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage().withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withName(NAME).withNamespace( - NAMESPACE)).withStreamState(Jsons.jsonNode(new DbStreamState()))); - - Assertions.assertThrows(IllegalArgumentException.class, - () -> StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog)); - } - - @Test - void testGlobalStateManagerCreationWithLegacyDataPresent() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final AirbyteGlobalState globalState = - new AirbyteGlobalState().withSharedState(Jsons.jsonNode(new DbState().withCdcState(new CdcState().withState(Jsons.jsonNode(new DbState()))))) - .withStreamStates(List.of(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withNamespace(NAMESPACE).withName(NAME)) - .withStreamState(Jsons.jsonNode(new DbStreamState())))); - final AirbyteStateMessage airbyteStateMessage = - new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(globalState).withData(Jsons.jsonNode(new DbState())); - - final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog); - - Assertions.assertNotNull(stateManager); - Assertions.assertEquals(GlobalStateManager.class, stateManager.getClass()); - } - - @Test - void testStreamStateManagerCreation() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage().withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withName(NAME).withNamespace( - NAMESPACE)).withStreamState(Jsons.jsonNode(new DbStreamState()))); - - final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog); - - Assertions.assertNotNull(stateManager); - Assertions.assertEquals(StreamStateManager.class, stateManager.getClass()); - } - - @Test - void testStreamStateManagerCreationFromLegacy() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final CdcState cdcState = new CdcState(); - final DbState dbState = new DbState() - .withCdcState(cdcState) - .withStreams(List.of(new DbStreamState().withStreamName(NAME).withStreamNamespace(NAMESPACE))); - final AirbyteStateMessage airbyteStateMessage = - new AirbyteStateMessage().withType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState)); - - final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog); - - Assertions.assertNotNull(stateManager); - Assertions.assertEquals(StreamStateManager.class, stateManager.getClass()); - } - - @Test - void testStreamStateManagerCreationFromGlobal() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final AirbyteGlobalState globalState = - new AirbyteGlobalState().withSharedState(Jsons.jsonNode(new DbState().withCdcState(new CdcState().withState(Jsons.jsonNode(new DbState()))))) - .withStreamStates(List.of(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withNamespace(NAMESPACE).withName(NAME)) - .withStreamState(Jsons.jsonNode(new DbStreamState())))); - final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(globalState); - - Assertions.assertThrows(IllegalArgumentException.class, - () -> StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog)); - } - - @Test - void testStreamStateManagerCreationWithLegacyDataPresent() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage().withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withName(NAME).withNamespace( - NAMESPACE)).withStreamState(Jsons.jsonNode(new DbStreamState()))) - .withData(Jsons.jsonNode(new DbState())); - - final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog); - - Assertions.assertNotNull(stateManager); - Assertions.assertEquals(StreamStateManager.class, stateManager.getClass()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.java deleted file mode 100644 index 0b6d0c4632d4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import org.testcontainers.shaded.com.google.common.collect.Lists; - -/** - * Collection of constants for use in state management-related tests. - */ -public final class StateTestConstants { - - public static final String NAMESPACE = "public"; - public static final String STREAM_NAME1 = "cars"; - public static final AirbyteStreamNameNamespacePair NAME_NAMESPACE_PAIR1 = new AirbyteStreamNameNamespacePair(STREAM_NAME1, NAMESPACE); - public static final String STREAM_NAME2 = "bicycles"; - public static final AirbyteStreamNameNamespacePair NAME_NAMESPACE_PAIR2 = new AirbyteStreamNameNamespacePair(STREAM_NAME2, NAMESPACE); - public static final String STREAM_NAME3 = "stationary_bicycles"; - public static final String CURSOR_FIELD1 = "year"; - public static final String CURSOR_FIELD2 = "generation"; - public static final String CURSOR = "2000"; - public static final long CURSOR_RECORD_COUNT = 19L; - - private StateTestConstants() {} - - public static Optional getState(final String cursorField, final String cursor) { - return Optional.of(new DbStreamState() - .withStreamName(STREAM_NAME1) - .withCursorField(Lists.newArrayList(cursorField)) - .withCursor(cursor)); - } - - public static Optional getState(final String cursorField, final String cursor, final long cursorRecordCount) { - return Optional.of(new DbStreamState() - .withStreamName(STREAM_NAME1) - .withCursorField(Lists.newArrayList(cursorField)) - .withCursor(cursor) - .withCursorRecordCount(cursorRecordCount)); - } - - public static Optional getCatalog(final String cursorField) { - return Optional.of(new ConfiguredAirbyteCatalog() - .withStreams(List.of(getStream(cursorField).orElse(null)))); - } - - public static Optional getStream(final String cursorField) { - return Optional.of(new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1)) - .withCursorField(cursorField == null ? Collections.emptyList() : Lists.newArrayList(cursorField))); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.java deleted file mode 100644 index 3ed37ec42308..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.java +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.relationaldb.state; - -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_FIELD1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.CURSOR_FIELD2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAMESPACE; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAME_NAMESPACE_PAIR1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.NAME_NAMESPACE_PAIR2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME1; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME2; -import static io.airbyte.cdk.integrations.source.relationaldb.state.StateTestConstants.STREAM_NAME3; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.mockito.Mockito.mock; - -import com.google.common.collect.Lists; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.StreamDescriptor; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -/** - * Test suite for the {@link StreamStateManager} class. - */ -public class StreamStateManagerTest { - - @Test - void testCreationFromInvalidState() { - final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage() - .withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(STREAM_NAME1).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode("Not a state object"))); - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - - Assertions.assertDoesNotThrow(() -> { - final StateManager stateManager = new StreamStateManager(List.of(airbyteStateMessage), catalog); - assertNotNull(stateManager); - }); - } - - @Test - void testGetters() { - final List state = new ArrayList<>(); - state.add(createStreamState(STREAM_NAME1, NAMESPACE, List.of(CURSOR_FIELD1), CURSOR, 0L)); - state.add(createStreamState(STREAM_NAME2, NAMESPACE, List.of(), null, 0L)); - - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))); - - final StateManager stateManager = new StreamStateManager(state, catalog); - - assertEquals(Optional.of(CURSOR_FIELD1), stateManager.getOriginalCursorField(NAME_NAMESPACE_PAIR1)); - assertEquals(Optional.of(CURSOR), stateManager.getOriginalCursor(NAME_NAMESPACE_PAIR1)); - assertEquals(Optional.of(CURSOR_FIELD1), stateManager.getCursorField(NAME_NAMESPACE_PAIR1)); - assertEquals(Optional.of(CURSOR), stateManager.getCursor(NAME_NAMESPACE_PAIR1)); - - assertEquals(Optional.empty(), stateManager.getOriginalCursorField(NAME_NAMESPACE_PAIR2)); - assertEquals(Optional.empty(), stateManager.getOriginalCursor(NAME_NAMESPACE_PAIR2)); - assertEquals(Optional.empty(), stateManager.getCursorField(NAME_NAMESPACE_PAIR2)); - assertEquals(Optional.empty(), stateManager.getCursor(NAME_NAMESPACE_PAIR2)); - } - - @Test - void testToState() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(List.of(CURSOR_FIELD2)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME3).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))); - - final StateManager stateManager = new StreamStateManager(createDefaultState(), catalog); - - final DbState expectedFirstDbState = new DbState() - .withCdc(false) - .withStreams(List.of( - new DbStreamState() - .withStreamName(STREAM_NAME1) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a"), - new DbStreamState() - .withStreamName(STREAM_NAME2) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD2)), - new DbStreamState() - .withStreamName(STREAM_NAME3) - .withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())); - final AirbyteStateMessage expectedFirstEmission = - createStreamState(STREAM_NAME1, NAMESPACE, List.of(CURSOR_FIELD1), "a", 0L).withData(Jsons.jsonNode(expectedFirstDbState)); - - final AirbyteStateMessage actualFirstEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR1, "a"); - assertEquals(expectedFirstEmission, actualFirstEmission); - - final long expectedRecordCount = 17L; - final DbState expectedSecondDbState = new DbState() - .withCdc(false) - .withStreams(List.of( - new DbStreamState() - .withStreamName(STREAM_NAME1) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a"), - new DbStreamState() - .withStreamName(STREAM_NAME2) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD2)) - .withCursor("b") - .withCursorRecordCount(expectedRecordCount), - new DbStreamState() - .withStreamName(STREAM_NAME3) - .withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())); - final AirbyteStateMessage expectedSecondEmission = - createStreamState(STREAM_NAME2, NAMESPACE, List.of(CURSOR_FIELD2), "b", expectedRecordCount).withData(Jsons.jsonNode(expectedSecondDbState)); - - final AirbyteStateMessage actualSecondEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR2, "b", expectedRecordCount); - assertEquals(expectedSecondEmission, actualSecondEmission); - } - - @Test - void testToStateWithoutCursorInfo() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(List.of(CURSOR_FIELD2)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME3).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))); - final AirbyteStreamNameNamespacePair airbyteStreamNameNamespacePair = new AirbyteStreamNameNamespacePair("other", "other"); - - final StateManager stateManager = new StreamStateManager(createDefaultState(), catalog); - final AirbyteStateMessage airbyteStateMessage = stateManager.toState(Optional.of(airbyteStreamNameNamespacePair)); - assertNotNull(airbyteStateMessage); - assertEquals(AirbyteStateType.STREAM, airbyteStateMessage.getType()); - assertNotNull(airbyteStateMessage.getStream()); - } - - @Test - void testToStateWithoutStreamPair() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(List.of(CURSOR_FIELD2)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME3).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))); - - final StateManager stateManager = new StreamStateManager(createDefaultState(), catalog); - final AirbyteStateMessage airbyteStateMessage = stateManager.toState(Optional.empty()); - assertNotNull(airbyteStateMessage); - assertEquals(AirbyteStateType.STREAM, airbyteStateMessage.getType()); - assertNotNull(airbyteStateMessage.getStream()); - assertNull(airbyteStateMessage.getStream().getStreamState()); - } - - @Test - void testToStateNullCursorField() { - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME1).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(List.of(CURSOR_FIELD1)), - new ConfiguredAirbyteStream() - .withStream(new AirbyteStream().withName(STREAM_NAME2).withNamespace(NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))); - final StateManager stateManager = new StreamStateManager(createDefaultState(), catalog); - - final DbState expectedFirstDbState = new DbState() - .withCdc(false) - .withStreams(List.of( - new DbStreamState() - .withStreamName(STREAM_NAME1) - .withStreamNamespace(NAMESPACE) - .withCursorField(List.of(CURSOR_FIELD1)) - .withCursor("a"), - new DbStreamState() - .withStreamName(STREAM_NAME2) - .withStreamNamespace(NAMESPACE)) - .stream().sorted(Comparator.comparing(DbStreamState::getStreamName)).collect(Collectors.toList())); - - final AirbyteStateMessage expectedFirstEmission = - createStreamState(STREAM_NAME1, NAMESPACE, List.of(CURSOR_FIELD1), "a", 0L).withData(Jsons.jsonNode(expectedFirstDbState)); - final AirbyteStateMessage actualFirstEmission = stateManager.updateAndEmit(NAME_NAMESPACE_PAIR1, "a"); - assertEquals(expectedFirstEmission, actualFirstEmission); - } - - @Test - void testCdcStateManager() { - final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class); - final StateManager stateManager = new StreamStateManager( - List.of(new AirbyteStateMessage().withType(AirbyteStateType.STREAM).withStream(new AirbyteStreamState())), catalog); - Assertions.assertThrows(UnsupportedOperationException.class, () -> stateManager.getCdcStateManager()); - } - - private List createDefaultState() { - return List.of(new AirbyteStateMessage().withType(AirbyteStateType.STREAM).withStream(new AirbyteStreamState())); - } - - private AirbyteStateMessage createStreamState(final String name, - final String namespace, - final List cursorFields, - final String cursorValue, - final long cursorRecordCount) { - final DbStreamState dbStreamState = new DbStreamState() - .withStreamName(name) - .withStreamNamespace(namespace); - - if (cursorFields != null && !cursorFields.isEmpty()) { - dbStreamState.withCursorField(cursorFields); - } - - if (cursorValue != null) { - dbStreamState.withCursor(cursorValue); - } - - if (cursorRecordCount > 0L) { - dbStreamState.withCursorRecordCount(cursorRecordCount); - } - - return new AirbyteStateMessage() - .withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(name).withNamespace(namespace)) - .withStreamState(Jsons.jsonNode(dbStreamState))); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.java deleted file mode 100644 index 9f7008f5f6c9..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/java/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.testutils; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; - -import com.zaxxer.hikari.HikariDataSource; -import javax.sql.DataSource; -import org.jooq.DSLContext; -import org.jooq.SQLDialect; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.testcontainers.containers.PostgreSQLContainer; - -class DatabaseConnectionHelperTest { - - private static final String DATABASE_NAME = "airbyte_test_database"; - - protected static PostgreSQLContainer container; - - @BeforeAll - static void dbSetup() { - container = new PostgreSQLContainer<>("postgres:13-alpine") - .withDatabaseName(DATABASE_NAME) - .withUsername("docker") - .withPassword("docker"); - container.start(); - } - - @AfterAll - static void dbDown() { - container.close(); - } - - @Test - void testCreatingFromATestContainer() { - final DataSource dataSource = DatabaseConnectionHelper.createDataSource(container); - assertNotNull(dataSource); - assertEquals(HikariDataSource.class, dataSource.getClass()); - assertEquals(10, ((HikariDataSource) dataSource).getHikariConfigMXBean().getMaximumPoolSize()); - } - - @Test - void testCreatingADslContextFromATestContainer() { - final SQLDialect dialect = SQLDialect.POSTGRES; - final DSLContext dslContext = DatabaseConnectionHelper.createDslContext(container, dialect); - assertNotNull(dslContext); - assertEquals(dialect, dslContext.configuration().dialect()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.kt new file mode 100644 index 000000000000..8732a0a6546e --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.kt @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import com.google.common.collect.Lists +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.AirbyteCatalog +import io.airbyte.protocol.models.v0.CatalogHelpers +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import io.airbyte.protocol.models.v0.SyncMode +import java.util.List +import java.util.function.Consumer +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +class AirbyteDebeziumHandlerTest { + @Test + fun shouldUseCdcTestShouldReturnTrue() { + val catalog = + AirbyteCatalog() + .withStreams( + List.of( + CatalogHelpers.createAirbyteStream( + "MODELS_STREAM_NAME", + "MODELS_SCHEMA", + Field.of("COL_ID", JsonSchemaType.NUMBER), + Field.of("COL_MAKE_ID", JsonSchemaType.NUMBER), + Field.of("COL_MODEL", JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(List.of(listOf("COL_ID"))) + ) + ) + val configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog(catalog) + // set all streams to incremental. + configuredCatalog.streams.forEach( + Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL } + ) + + Assertions.assertTrue( + AirbyteDebeziumHandler.isAnyStreamIncrementalSyncMode(configuredCatalog) + ) + } + + @Test + fun shouldUseCdcTestShouldReturnFalse() { + val catalog = + AirbyteCatalog() + .withStreams( + List.of( + CatalogHelpers.createAirbyteStream( + "MODELS_STREAM_NAME", + "MODELS_SCHEMA", + Field.of("COL_ID", JsonSchemaType.NUMBER), + Field.of("COL_MAKE_ID", JsonSchemaType.NUMBER), + Field.of("COL_MODEL", JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(List.of(listOf("COL_ID"))) + ) + ) + val configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog(catalog) + + Assertions.assertFalse( + AirbyteDebeziumHandler.isAnyStreamIncrementalSyncMode(configuredCatalog) + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.kt new file mode 100644 index 000000000000..aeba71586adb --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.kt @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import io.airbyte.cdk.integrations.debezium.internals.AirbyteFileOffsetBackingStore +import io.airbyte.commons.io.IOs +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.resources.MoreResources +import java.io.IOException +import java.nio.file.Files +import java.nio.file.Path +import java.util.* +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +internal class AirbyteFileOffsetBackingStoreTest { + @Test + @Throws(IOException::class) + fun test() { + val testRoot = Files.createTempDirectory(Path.of("/tmp"), "offset-store-test") + + val bytes = MoreResources.readBytes("test_debezium_offset.dat") + val templateFilePath = testRoot.resolve("template_offset.dat") + IOs.writeFile(templateFilePath, bytes) + + val writeFilePath = testRoot.resolve("offset.dat") + val secondWriteFilePath = testRoot.resolve("offset_2.dat") + + val offsetStore = AirbyteFileOffsetBackingStore(templateFilePath, Optional.empty()) + val offset = offsetStore.read() + + val offsetStore2 = AirbyteFileOffsetBackingStore(writeFilePath, Optional.empty()) + offsetStore2.persist(Jsons.jsonNode(offset)) + val stateFromOffsetStore2 = offsetStore2.read() + + val offsetStore3 = AirbyteFileOffsetBackingStore(secondWriteFilePath, Optional.empty()) + offsetStore3.persist(Jsons.jsonNode(stateFromOffsetStore2)) + val stateFromOffsetStore3 = offsetStore3.read() + + // verify that, after a round trip through the offset store, we get back the same data. + Assertions.assertEquals(stateFromOffsetStore2, stateFromOffsetStore3) + // verify that the file written by the offset store is identical to the template file. + Assertions.assertTrue( + com.google.common.io.Files.equal(secondWriteFilePath.toFile(), writeFilePath.toFile()) + ) + } + + @Test + @Throws(IOException::class) + fun test2() { + val testRoot = Files.createTempDirectory(Path.of("/tmp"), "offset-store-test") + + val bytes = MoreResources.readBytes("test_debezium_offset.dat") + val templateFilePath = testRoot.resolve("template_offset.dat") + IOs.writeFile(templateFilePath, bytes) + + val writeFilePath = testRoot.resolve("offset.dat") + val secondWriteFilePath = testRoot.resolve("offset_2.dat") + + val offsetStore = AirbyteFileOffsetBackingStore(templateFilePath, Optional.of("orders")) + val offset = offsetStore.read() + + val offsetStore2 = AirbyteFileOffsetBackingStore(writeFilePath, Optional.of("orders")) + offsetStore2.persist(Jsons.jsonNode(offset)) + val stateFromOffsetStore2 = offsetStore2.read() + + val offsetStore3 = AirbyteFileOffsetBackingStore(secondWriteFilePath, Optional.of("orders")) + offsetStore3.persist(Jsons.jsonNode(stateFromOffsetStore2)) + val stateFromOffsetStore3 = offsetStore3.read() + + // verify that, after a round trip through the offset store, we get back the same data. + Assertions.assertEquals(stateFromOffsetStore2, stateFromOffsetStore3) + // verify that the file written by the offset store is identical to the template file. + Assertions.assertTrue( + com.google.common.io.Files.equal(secondWriteFilePath.toFile(), writeFilePath.toFile()) + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.kt new file mode 100644 index 000000000000..8a23f58e748b --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.kt @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import com.google.common.collect.ImmutableList +import io.airbyte.cdk.integrations.debezium.internals.RelationalDbDebeziumPropertiesManager +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.CatalogHelpers +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.SyncMode +import java.util.regex.Pattern +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +internal class DebeziumRecordPublisherTest { + @Test + fun testTableIncludelistCreation() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + ImmutableList.of( + CatalogHelpers.createConfiguredAirbyteStream("id_and_name", "public") + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("id_,something", "public") + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("n\"aMéS", "public") + .withSyncMode(SyncMode.INCREMENTAL) + ) + ) + + val expectedWhitelist = + "\\Qpublic.id_and_name\\E,\\Qpublic.id_\\,something\\E,\\Qpublic.n\"aMéS\\E" + val actualWhitelist = RelationalDbDebeziumPropertiesManager.getTableIncludelist(catalog) + + Assertions.assertEquals(expectedWhitelist, actualWhitelist) + } + + @Test + fun testTableIncludelistFiltersFullRefresh() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + ImmutableList.of( + CatalogHelpers.createConfiguredAirbyteStream("id_and_name", "public") + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("id_and_name2", "public") + .withSyncMode(SyncMode.FULL_REFRESH) + ) + ) + + val expectedWhitelist = "\\Qpublic.id_and_name\\E" + val actualWhitelist = RelationalDbDebeziumPropertiesManager.getTableIncludelist(catalog) + + Assertions.assertEquals(expectedWhitelist, actualWhitelist) + } + + @Test + fun testColumnIncludelistFiltersFullRefresh() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + ImmutableList.of( + CatalogHelpers.createConfiguredAirbyteStream( + "id_and_name", + "public", + Field.of("fld1", JsonSchemaType.NUMBER), + Field.of("fld2", JsonSchemaType.STRING) + ) + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("id_,something", "public") + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("id_and_name2", "public") + .withSyncMode(SyncMode.FULL_REFRESH), + CatalogHelpers.createConfiguredAirbyteStream("n\"aMéS", "public") + .withSyncMode(SyncMode.INCREMENTAL) + ) + ) + + val expectedWhitelist = + "\\Qpublic.id_and_name\\E\\.(\\Qfld2\\E|\\Qfld1\\E),\\Qpublic.id_\\,something\\E,\\Qpublic.n\"aMéS\\E" + val actualWhitelist = RelationalDbDebeziumPropertiesManager.getColumnIncludeList(catalog) + + Assertions.assertEquals(expectedWhitelist, actualWhitelist) + } + + @Test + fun testColumnIncludeListEscaping() { + // final String a = "public\\.products\\*\\^\\$\\+-\\\\"; + // final String b = "public.products*^$+-\\"; + // final Pattern p = Pattern.compile(a, Pattern.UNIX_LINES); + // assertTrue(p.matcher(b).find()); + // assertTrue(Pattern.compile(Pattern.quote(b)).matcher(b).find()); + + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + ImmutableList.of( + CatalogHelpers.createConfiguredAirbyteStream( + "id_and_name", + "public", + Field.of("fld1", JsonSchemaType.NUMBER), + Field.of("fld2", JsonSchemaType.STRING) + ) + .withSyncMode(SyncMode.INCREMENTAL) + ) + ) + + val anchored = + "^" + RelationalDbDebeziumPropertiesManager.getColumnIncludeList(catalog) + "$" + val pattern = Pattern.compile(anchored) + + Assertions.assertTrue(pattern.matcher("public.id_and_name.fld1").find()) + Assertions.assertTrue(pattern.matcher("public.id_and_name.fld2").find()) + Assertions.assertFalse(pattern.matcher("ic.id_and_name.fl").find()) + Assertions.assertFalse(pattern.matcher("ppppublic.id_and_name.fld2333").find()) + Assertions.assertFalse(pattern.matcher("public.id_and_name.fld_wrong_wrong").find()) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.kt new file mode 100644 index 000000000000..217f4d0dffca --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.kt @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.resources.MoreResources +import java.io.IOException +import java.util.* +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +class AirbyteSchemaHistoryStorageTest { + @Test + @Throws(IOException::class) + fun testForContentBiggerThan1MBLimit() { + val contentReadDirectlyFromFile = + MoreResources.readResource("dbhistory_greater_than_1_mb.dat") + + val schemaHistoryStorageFromUncompressedContent = + AirbyteSchemaHistoryStorage.initializeDBHistory( + AirbyteSchemaHistoryStorage.SchemaHistory( + Optional.of(Jsons.jsonNode(contentReadDirectlyFromFile)), + false + ), + true + ) + val schemaHistoryFromUncompressedContent = + schemaHistoryStorageFromUncompressedContent.read() + + Assertions.assertTrue(schemaHistoryFromUncompressedContent.isCompressed) + Assertions.assertNotNull(schemaHistoryFromUncompressedContent.schema) + Assertions.assertEquals( + contentReadDirectlyFromFile, + schemaHistoryStorageFromUncompressedContent.readUncompressed() + ) + + val schemaHistoryStorageFromCompressedContent = + AirbyteSchemaHistoryStorage.initializeDBHistory( + AirbyteSchemaHistoryStorage.SchemaHistory( + Optional.of(Jsons.jsonNode(schemaHistoryFromUncompressedContent.schema)), + true + ), + true + ) + val schemaHistoryFromCompressedContent = schemaHistoryStorageFromCompressedContent.read() + + Assertions.assertTrue(schemaHistoryFromCompressedContent.isCompressed) + Assertions.assertNotNull(schemaHistoryFromCompressedContent.schema) + Assertions.assertEquals( + schemaHistoryFromUncompressedContent.schema, + schemaHistoryFromCompressedContent.schema + ) + } + + @Test + @Throws(IOException::class) + fun sizeTest() { + Assertions.assertEquals( + 5.881045341491699, + AirbyteSchemaHistoryStorage.calculateSizeOfStringInMB( + MoreResources.readResource("dbhistory_greater_than_1_mb.dat") + ) + ) + Assertions.assertEquals( + 0.0038671493530273438, + AirbyteSchemaHistoryStorage.calculateSizeOfStringInMB( + MoreResources.readResource("dbhistory_less_than_1_mb.dat") + ) + ) + } + + @Test + @Throws(IOException::class) + fun testForContentLessThan1MBLimit() { + val contentReadDirectlyFromFile = MoreResources.readResource("dbhistory_less_than_1_mb.dat") + + val schemaHistoryStorageFromUncompressedContent = + AirbyteSchemaHistoryStorage.initializeDBHistory( + AirbyteSchemaHistoryStorage.SchemaHistory( + Optional.of(Jsons.jsonNode(contentReadDirectlyFromFile)), + false + ), + true + ) + val schemaHistoryFromUncompressedContent = + schemaHistoryStorageFromUncompressedContent.read() + + Assertions.assertFalse(schemaHistoryFromUncompressedContent.isCompressed) + Assertions.assertNotNull(schemaHistoryFromUncompressedContent.schema) + Assertions.assertEquals( + contentReadDirectlyFromFile, + schemaHistoryFromUncompressedContent.schema + ) + + val schemaHistoryStorageFromCompressedContent = + AirbyteSchemaHistoryStorage.initializeDBHistory( + AirbyteSchemaHistoryStorage.SchemaHistory( + Optional.of(Jsons.jsonNode(schemaHistoryFromUncompressedContent.schema)), + false + ), + true + ) + val schemaHistoryFromCompressedContent = schemaHistoryStorageFromCompressedContent.read() + + Assertions.assertFalse(schemaHistoryFromCompressedContent.isCompressed) + Assertions.assertNotNull(schemaHistoryFromCompressedContent.schema) + Assertions.assertEquals( + schemaHistoryFromUncompressedContent.schema, + schemaHistoryFromCompressedContent.schema + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.kt new file mode 100644 index 000000000000..0b288c96d8f5 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.kt @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.debezium.spi.converter.RelationalColumn +import java.sql.Timestamp +import java.time.Duration +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.LocalTime +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test +import org.mockito.Mockito + +internal class DebeziumConverterUtilsTest { + @Test + fun convertDefaultValueTest() { + val relationalColumn = Mockito.mock(RelationalColumn::class.java) + + Mockito.`when`(relationalColumn.isOptional).thenReturn(true) + var actualColumnDefaultValue = DebeziumConverterUtils.convertDefaultValue(relationalColumn) + Assertions.assertNull( + actualColumnDefaultValue, + "Default value for optional relational column should be null" + ) + + Mockito.`when`(relationalColumn.isOptional).thenReturn(false) + Mockito.`when`(relationalColumn.hasDefaultValue()).thenReturn(false) + actualColumnDefaultValue = DebeziumConverterUtils.convertDefaultValue(relationalColumn) + Assertions.assertNull(actualColumnDefaultValue) + + Mockito.`when`(relationalColumn.isOptional).thenReturn(false) + Mockito.`when`(relationalColumn.hasDefaultValue()).thenReturn(true) + val expectedColumnDefaultValue = "default value" + Mockito.`when`(relationalColumn.defaultValue()).thenReturn(expectedColumnDefaultValue) + actualColumnDefaultValue = DebeziumConverterUtils.convertDefaultValue(relationalColumn) + Assertions.assertEquals(actualColumnDefaultValue, expectedColumnDefaultValue) + } + + @Test + fun convertLocalDate() { + val localDate = LocalDate.of(2021, 1, 1) + + val actual = DebeziumConverterUtils.convertDate(localDate) + Assertions.assertEquals("2021-01-01T00:00:00Z", actual) + } + + @Test + fun convertTLocalTime() { + val localTime = LocalTime.of(8, 1, 1) + val actual = DebeziumConverterUtils.convertDate(localTime) + Assertions.assertEquals("08:01:01", actual) + } + + @Test + fun convertLocalDateTime() { + val localDateTime = LocalDateTime.of(2021, 1, 1, 8, 1, 1) + + val actual = DebeziumConverterUtils.convertDate(localDateTime) + Assertions.assertEquals("2021-01-01T08:01:01Z", actual) + } + + @Test + @Disabled + fun convertDuration() { + val duration = Duration.ofHours(100000) + + val actual = DebeziumConverterUtils.convertDate(duration) + Assertions.assertEquals("1981-05-29T20:00:00Z", actual) + } + + @Test + fun convertTimestamp() { + val localDateTime = LocalDateTime.of(2021, 1, 1, 8, 1, 1) + val timestamp = Timestamp.valueOf(localDateTime) + + val actual = DebeziumConverterUtils.convertDate(timestamp) + Assertions.assertEquals("2021-01-01T08:01:01.000000Z", actual) + } + + @Test + @Disabled + fun convertNumber() { + val number: Number = 100000 + + val actual = DebeziumConverterUtils.convertDate(number) + Assertions.assertEquals("1970-01-01T03:01:40Z", actual) + } + + @Test + fun convertStringDateFormat() { + val stringValue = "2021-01-01T00:00:00Z" + + val actual = DebeziumConverterUtils.convertDate(stringValue) + Assertions.assertEquals("2021-01-01T00:00:00Z", actual) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducerTest.kt new file mode 100644 index 000000000000..083dc6048381 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumMessageProducerTest.kt @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.airbyte.cdk.integrations.debezium.CdcStateHandler +import io.airbyte.cdk.integrations.debezium.CdcTargetPosition +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import java.util.* +import org.junit.Assert +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.Mockito +import org.mockito.kotlin.any +import org.mockito.kotlin.eq +import org.mockito.kotlin.mock + +class DebeziumMessageProducerTest { + private var producer: DebeziumMessageProducer<*>? = null + + lateinit var cdcStateHandler: CdcStateHandler + lateinit var targetPosition: CdcTargetPosition + lateinit var eventConverter: DebeziumEventConverter + lateinit var offsetManager: AirbyteFileOffsetBackingStore + lateinit var schemaHistoryManager: AirbyteSchemaHistoryStorage + + @BeforeEach + fun setUp() { + cdcStateHandler = Mockito.mock(CdcStateHandler::class.java) + Mockito.`when`(cdcStateHandler.isCdcCheckpointEnabled).thenReturn(true) + targetPosition = mock() + eventConverter = Mockito.mock(DebeziumEventConverter::class.java) + offsetManager = Mockito.mock(AirbyteFileOffsetBackingStore::class.java) + Mockito.`when`>(offsetManager.read()).thenReturn(OFFSET_MANAGER_READ) + schemaHistoryManager = Mockito.mock(AirbyteSchemaHistoryStorage::class.java) + Mockito.`when`(schemaHistoryManager.read()).thenReturn(SCHEMA) + producer = + DebeziumMessageProducer( + cdcStateHandler, + targetPosition, + eventConverter!!, + offsetManager, + Optional.of(schemaHistoryManager) + ) + } + + @Test + fun testProcessRecordMessage() { + val message = Mockito.mock(ChangeEventWithMetadata::class.java) + + Mockito.`when`(targetPosition!!.isSameOffset(any(), any())).thenReturn(true) + producer!!.processRecordMessage(null, message) + Mockito.verify(eventConverter).toAirbyteMessage(message) + Assert.assertFalse(producer!!.shouldEmitStateMessage(null)) + } + + @Test + fun testProcessRecordMessageWithStateMessage() { + val message = Mockito.mock(ChangeEventWithMetadata::class.java) + + Mockito.`when`(targetPosition!!.isSameOffset(any(), any())).thenReturn(false) + Mockito.`when`(targetPosition!!.isEventAheadOffset(OFFSET_MANAGER_READ, message)) + .thenReturn(true) + producer!!.processRecordMessage(null, message) + Mockito.verify(eventConverter!!).toAirbyteMessage(message) + Assert.assertTrue(producer!!.shouldEmitStateMessage(null)) + + Mockito.`when`(cdcStateHandler!!.isCdcCheckpointEnabled).thenReturn(false) + Mockito.`when`(cdcStateHandler!!.saveState(eq(OFFSET_MANAGER_READ), eq(SCHEMA))) + .thenReturn(AirbyteMessage().withState(STATE_MESSAGE)) + + Assert.assertEquals(producer!!.generateStateMessageAtCheckpoint(null), STATE_MESSAGE) + } + + @Test + fun testGenerateFinalMessageNoProgress() { + Mockito.`when`(cdcStateHandler!!.saveState(eq(OFFSET_MANAGER_READ), eq(SCHEMA))) + .thenReturn(AirbyteMessage().withState(STATE_MESSAGE)) + + // initialOffset will be OFFSET_MANAGER_READ, final state would be OFFSET_MANAGER_READ2. + // Mock CDC handler will only accept OFFSET_MANAGER_READ. + Mockito.`when`>(offsetManager!!.read()).thenReturn(OFFSET_MANAGER_READ2) + + Mockito.`when`(targetPosition!!.isSameOffset(OFFSET_MANAGER_READ, OFFSET_MANAGER_READ2)) + .thenReturn(true) + + Assert.assertEquals(producer!!.createFinalStateMessage(null), STATE_MESSAGE) + } + + @Test + fun testGenerateFinalMessageWithProgress() { + Mockito.`when`(cdcStateHandler!!.saveState(eq(OFFSET_MANAGER_READ2), eq(SCHEMA))) + .thenReturn(AirbyteMessage().withState(STATE_MESSAGE)) + + // initialOffset will be OFFSET_MANAGER_READ, final state would be OFFSET_MANAGER_READ2. + // Mock CDC handler will only accept OFFSET_MANAGER_READ2. + Mockito.`when`>(offsetManager!!.read()).thenReturn(OFFSET_MANAGER_READ2) + Mockito.`when`(targetPosition!!.isSameOffset(OFFSET_MANAGER_READ, OFFSET_MANAGER_READ2)) + .thenReturn(false) + + Assert.assertEquals(producer!!.createFinalStateMessage(null), STATE_MESSAGE) + } + + companion object { + private val OFFSET_MANAGER_READ: Map = + HashMap(java.util.Map.of("key", "value")) + private val OFFSET_MANAGER_READ2: Map = + HashMap(java.util.Map.of("key2", "value2")) + + private val SCHEMA: AirbyteSchemaHistoryStorage.SchemaHistory = + AirbyteSchemaHistoryStorage.SchemaHistory("schema", false) + + private val STATE_MESSAGE: AirbyteStateMessage = + AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.kt new file mode 100644 index 000000000000..7e55c7241b17 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.kt @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.airbyte.cdk.integrations.debezium.CdcTargetPosition +import io.debezium.engine.ChangeEvent +import java.time.Duration +import java.util.* +import org.apache.kafka.connect.source.SourceRecord +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.mockito.Mockito.mock + +class DebeziumRecordIteratorTest { + @Test + fun getHeartbeatPositionTest() { + val debeziumRecordIterator = + DebeziumRecordIterator( + mock(), + object : CdcTargetPosition { + override fun reachedTargetPosition( + changeEventWithMetadata: ChangeEventWithMetadata? + ): Boolean { + return false + } + + override fun extractPositionFromHeartbeatOffset( + sourceOffset: Map? + ): Long { + return sourceOffset!!["lsn"] as Long + } + }, + { false }, + mock(), + Duration.ZERO, + Duration.ZERO + ) + val lsn = + debeziumRecordIterator.getHeartbeatPosition( + object : ChangeEvent { + private val sourceRecord = + SourceRecord( + null, + Collections.singletonMap("lsn", 358824993496L), + null, + null, + null + ) + + override fun key(): String? { + return null + } + + override fun value(): String { + return "{\"ts_ms\":1667616934701}" + } + + override fun destination(): String? { + return null + } + + fun sourceRecord(): SourceRecord { + return sourceRecord + } + } + ) + + Assertions.assertEquals(lsn, 358824993496L) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.kt new file mode 100644 index 000000000000..df7eb675bcc8 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.kt @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import java.util.concurrent.Executors +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.AtomicInteger +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +class DebeziumShutdownProcedureTest { + @Test + @Throws(InterruptedException::class) + fun test() { + val sourceQueue = LinkedBlockingQueue(10) + val recordsInserted = AtomicInteger() + val executorService = Executors.newSingleThreadExecutor() + val debeziumShutdownProcedure = + DebeziumShutdownProcedure( + sourceQueue, + { executorService.shutdown() }, + { recordsInserted.get() >= 99 } + ) + executorService.execute { + for (i in 0..99) { + try { + sourceQueue.put(i) + recordsInserted.set(i) + } catch (e: InterruptedException) { + throw RuntimeException(e) + } + } + } + + Thread.sleep(1000) + debeziumShutdownProcedure.initiateShutdownProcedure() + + Assertions.assertEquals(99, recordsInserted.get()) + Assertions.assertEquals(0, sourceQueue.size) + Assertions.assertEquals(100, debeziumShutdownProcedure.recordsRemainingAfterShutdown.size) + + for (i in 0..99) { + Assertions.assertEquals( + i, + debeziumShutdownProcedure.recordsRemainingAfterShutdown.poll() + ) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.kt new file mode 100644 index 000000000000..19aa9ece08af --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.kt @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium.internals + +import io.airbyte.commons.json.Jsons +import java.time.Duration +import java.util.* +import java.util.Map +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +class RecordWaitTimeUtilTest { + @Test + fun testGetFirstRecordWaitTime() { + val emptyConfig = Jsons.jsonNode(emptyMap()) + Assertions.assertDoesNotThrow { RecordWaitTimeUtil.checkFirstRecordWaitTime(emptyConfig) } + Assertions.assertEquals( + Optional.empty(), + RecordWaitTimeUtil.getFirstRecordWaitSeconds(emptyConfig) + ) + Assertions.assertEquals( + RecordWaitTimeUtil.DEFAULT_FIRST_RECORD_WAIT_TIME, + RecordWaitTimeUtil.getFirstRecordWaitTime(emptyConfig) + ) + + val normalConfig = + Jsons.jsonNode( + Map.of( + "replication_method", + Map.of("method", "CDC", "initial_waiting_seconds", 500) + ) + ) + Assertions.assertDoesNotThrow { RecordWaitTimeUtil.checkFirstRecordWaitTime(normalConfig) } + Assertions.assertEquals( + Optional.of(500), + RecordWaitTimeUtil.getFirstRecordWaitSeconds(normalConfig) + ) + Assertions.assertEquals( + Duration.ofSeconds(500), + RecordWaitTimeUtil.getFirstRecordWaitTime(normalConfig) + ) + + val tooShortTimeout = RecordWaitTimeUtil.MIN_FIRST_RECORD_WAIT_TIME.seconds.toInt() - 1 + val tooShortConfig = + Jsons.jsonNode( + Map.of( + "replication_method", + Map.of("method", "CDC", "initial_waiting_seconds", tooShortTimeout) + ) + ) + Assertions.assertThrows(IllegalArgumentException::class.java) { + RecordWaitTimeUtil.checkFirstRecordWaitTime(tooShortConfig) + } + Assertions.assertEquals( + Optional.of(tooShortTimeout), + RecordWaitTimeUtil.getFirstRecordWaitSeconds(tooShortConfig) + ) + Assertions.assertEquals( + RecordWaitTimeUtil.MIN_FIRST_RECORD_WAIT_TIME, + RecordWaitTimeUtil.getFirstRecordWaitTime(tooShortConfig) + ) + + val tooLongTimeout = RecordWaitTimeUtil.MAX_FIRST_RECORD_WAIT_TIME.seconds.toInt() + 1 + val tooLongConfig = + Jsons.jsonNode( + Map.of( + "replication_method", + Map.of("method", "CDC", "initial_waiting_seconds", tooLongTimeout) + ) + ) + Assertions.assertThrows(IllegalArgumentException::class.java) { + RecordWaitTimeUtil.checkFirstRecordWaitTime(tooLongConfig) + } + Assertions.assertEquals( + Optional.of(tooLongTimeout), + RecordWaitTimeUtil.getFirstRecordWaitSeconds(tooLongConfig) + ) + Assertions.assertEquals( + RecordWaitTimeUtil.MAX_FIRST_RECORD_WAIT_TIME, + RecordWaitTimeUtil.getFirstRecordWaitTime(tooLongConfig) + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractDbSourceForTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractDbSourceForTest.kt new file mode 100644 index 000000000000..4a049fd570c3 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractDbSourceForTest.kt @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.AbstractDatabase +import io.airbyte.cdk.integrations.source.relationaldb.AbstractDbSource +import io.airbyte.protocol.models.v0.AirbyteStateMessage + +abstract class AbstractDbSourceForTest( + driverClassName: String +) : AbstractDbSource(driverClassName) { + public override fun getSupportedStateType( + config: JsonNode? + ): AirbyteStateMessage.AirbyteStateType { + return super.getSupportedStateType(config) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.kt new file mode 100644 index 000000000000..dd104bc6c110 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.kt @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.collect.ImmutableMap +import io.airbyte.cdk.db.factory.DatabaseDriver +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.db.jdbc.JdbcUtils.parseJdbcParameters +import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig +import io.airbyte.cdk.integrations.base.IntegrationRunner +import io.airbyte.cdk.integrations.base.Source +import io.airbyte.cdk.integrations.source.jdbc.DefaultJdbcSourceAcceptanceTest.BareBonesTestDatabase +import io.airbyte.cdk.integrations.source.jdbc.DefaultJdbcSourceAcceptanceTest.BareBonesTestDatabase.BareBonesConfigBuilder +import io.airbyte.cdk.integrations.source.jdbc.test.JdbcSourceAcceptanceTest +import io.airbyte.cdk.integrations.util.HostPortResolver.resolveHost +import io.airbyte.cdk.integrations.util.HostPortResolver.resolvePort +import io.airbyte.cdk.testutils.TestDatabase +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import java.sql.JDBCType +import java.util.List +import java.util.Map +import java.util.function.Supplier +import java.util.stream.Stream +import org.jooq.SQLDialect +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Test +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import org.testcontainers.containers.PostgreSQLContainer + +/** + * Runs the acceptance tests in the source-jdbc test module. We want this module to run these tests + * itself as a sanity check. The trade off here is that this class is duplicated from the one used + * in source-postgres. + */ +internal class DefaultJdbcSourceAcceptanceTest : + JdbcSourceAcceptanceTest< + DefaultJdbcSourceAcceptanceTest.PostgresTestSource, BareBonesTestDatabase>() { + override fun config(): JsonNode { + return testdb!!.testConfigBuilder()!!.build() + } + + override fun source(): PostgresTestSource { + return PostgresTestSource() + } + + override fun createTestDatabase(): BareBonesTestDatabase { + return BareBonesTestDatabase(PSQL_CONTAINER).initialized()!! + } + + public override fun supportsSchemas(): Boolean { + return true + } + + fun getConfigWithConnectionProperties( + psqlDb: PostgreSQLContainer<*>?, + dbName: String?, + additionalParameters: String? + ): JsonNode { + return Jsons.jsonNode( + ImmutableMap.builder() + .put(JdbcUtils.HOST_KEY, resolveHost(psqlDb)) + .put(JdbcUtils.PORT_KEY, resolvePort(psqlDb)) + .put(JdbcUtils.DATABASE_KEY, dbName) + .put(JdbcUtils.SCHEMAS_KEY, List.of(SCHEMA_NAME)) + .put(JdbcUtils.USERNAME_KEY, psqlDb!!.username) + .put(JdbcUtils.PASSWORD_KEY, psqlDb.password) + .put(JdbcUtils.CONNECTION_PROPERTIES_KEY, additionalParameters) + .build() + ) + } + + class PostgresTestSource : + AbstractJdbcSource( + DRIVER_CLASS, + Supplier { AdaptiveStreamingQueryConfig() }, + JdbcUtils.defaultSourceOperations + ), + Source { + override fun toDatabaseConfig(config: JsonNode): JsonNode { + val configBuilder = + ImmutableMap.builder() + .put(JdbcUtils.USERNAME_KEY, config[JdbcUtils.USERNAME_KEY].asText()) + .put( + JdbcUtils.JDBC_URL_KEY, + String.format( + DatabaseDriver.POSTGRESQL.urlFormatString, + config[JdbcUtils.HOST_KEY].asText(), + config[JdbcUtils.PORT_KEY].asInt(), + config[JdbcUtils.DATABASE_KEY].asText() + ) + ) + + if (config.has(JdbcUtils.PASSWORD_KEY)) { + configBuilder.put(JdbcUtils.PASSWORD_KEY, config[JdbcUtils.PASSWORD_KEY].asText()) + } + + return Jsons.jsonNode(configBuilder.build()) + } + + override val excludedInternalNameSpaces = + setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") + + override fun getSupportedStateType( + config: JsonNode? + ): AirbyteStateMessage.AirbyteStateType { + return AirbyteStateMessage.AirbyteStateType.STREAM + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(PostgresTestSource::class.java) + + val DRIVER_CLASS: String = DatabaseDriver.POSTGRESQL.driverClassName + + @Throws(Exception::class) + @JvmStatic + fun main(args: Array) { + val source: Source = PostgresTestSource() + LOGGER.info("starting source: {}", PostgresTestSource::class.java) + IntegrationRunner(source).run(args) + LOGGER.info("completed source: {}", PostgresTestSource::class.java) + } + } + } + + class BareBonesTestDatabase(container: PostgreSQLContainer<*>) : + TestDatabase, BareBonesTestDatabase, BareBonesConfigBuilder>( + container + ) { + override fun inContainerBootstrapCmd(): Stream?>? { + val sql = + Stream.of( + String.format("CREATE DATABASE %s", databaseName), + String.format("CREATE USER %s PASSWORD '%s'", userName, password), + String.format( + "GRANT ALL PRIVILEGES ON DATABASE %s TO %s", + databaseName, + userName + ), + String.format("ALTER USER %s WITH SUPERUSER", userName) + ) + return Stream.of( + Stream.concat( + Stream.of( + "psql", + "-d", + container!!.databaseName, + "-U", + container.username, + "-v", + "ON_ERROR_STOP=1", + "-a" + ), + sql.flatMap { stmt: String? -> Stream.of("-c", stmt) } + ) + ) + } + + override fun inContainerUndoBootstrapCmd(): Stream? { + return Stream.empty() + } + + override val databaseDriver: DatabaseDriver + get() = DatabaseDriver.POSTGRESQL + + override val sqlDialect: SQLDialect + get() = SQLDialect.POSTGRES + + override fun configBuilder(): BareBonesConfigBuilder { + return BareBonesConfigBuilder(this) + } + + class BareBonesConfigBuilder(testDatabase: BareBonesTestDatabase) : + ConfigBuilder(testDatabase) + } + + @Test + fun testCustomParametersOverwriteDefaultParametersExpectException() { + val connectionPropertiesUrl = "ssl=false" + val config = + getConfigWithConnectionProperties( + PSQL_CONTAINER, + testdb!!.databaseName, + connectionPropertiesUrl + ) + val customParameters = parseJdbcParameters(config, JdbcUtils.CONNECTION_PROPERTIES_KEY, "&") + val defaultParameters = Map.of("ssl", "true", "sslmode", "require") + Assertions.assertThrows(IllegalArgumentException::class.java) { + JdbcDataSourceUtils.assertCustomParametersDontOverwriteDefaultParameters( + customParameters, + defaultParameters + ) + } + } + + companion object { + private lateinit var PSQL_CONTAINER: PostgreSQLContainer<*> + + @JvmStatic + @BeforeAll + fun init(): Unit { + PSQL_CONTAINER = PostgreSQLContainer("postgres:13-alpine") + PSQL_CONTAINER!!.start() + CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY = "CREATE TABLE %s (%s BIT(3) NOT NULL);" + INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY = "INSERT INTO %s VALUES(B'101');" + } + + @JvmStatic + @AfterAll + fun cleanUp(): Unit { + PSQL_CONTAINER!!.close() + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.kt new file mode 100644 index 000000000000..85902d6915ad --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.kt @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.collect.ImmutableMap +import io.airbyte.cdk.db.factory.DatabaseDriver +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig +import io.airbyte.cdk.integrations.base.IntegrationRunner +import io.airbyte.cdk.integrations.base.Source +import io.airbyte.cdk.integrations.source.jdbc.test.JdbcStressTest +import io.airbyte.cdk.testutils.PostgreSQLContainerHelper.runSqlScript +import io.airbyte.commons.io.IOs +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.string.Strings +import java.sql.JDBCType +import java.util.* +import java.util.function.Supplier +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Disabled +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import org.testcontainers.containers.PostgreSQLContainer +import org.testcontainers.utility.MountableFile + +/** + * Runs the stress tests in the source-jdbc test module. We want this module to run these tests + * itself as a sanity check. The trade off here is that this class is duplicated from the one used + * in source-postgres. + */ +@Disabled +internal class DefaultJdbcStressTest : JdbcStressTest() { + private var config: JsonNode? = null + + @BeforeEach + @Throws(Exception::class) + override fun setup() { + val dbName = Strings.addRandomSuffix("db", "_", 10) + + config = + Jsons.jsonNode( + ImmutableMap.of( + JdbcUtils.HOST_KEY, + "localhost", + JdbcUtils.PORT_KEY, + 5432, + JdbcUtils.DATABASE_KEY, + "charles", + JdbcUtils.USERNAME_KEY, + "postgres", + JdbcUtils.PASSWORD_KEY, + "" + ) + ) + + config = + Jsons.jsonNode( + ImmutableMap.builder() + .put(JdbcUtils.HOST_KEY, PSQL_DB!!.host) + .put(JdbcUtils.PORT_KEY, PSQL_DB!!.firstMappedPort) + .put(JdbcUtils.DATABASE_KEY, dbName) + .put(JdbcUtils.USERNAME_KEY, PSQL_DB!!.username) + .put(JdbcUtils.PASSWORD_KEY, PSQL_DB!!.password) + .build() + ) + + val initScriptName = "init_$dbName.sql" + val tmpFilePath = IOs.writeFileToRandomTmpDir(initScriptName, "CREATE DATABASE $dbName;") + runSqlScript(MountableFile.forHostPath(tmpFilePath), PSQL_DB!!) + + super.setup() + } + + override val defaultSchemaName = Optional.of("public") + + override fun getSource(): AbstractJdbcSource { + return PostgresTestSource() + } + + override fun getConfig(): JsonNode { + return config!! + } + + override val driverClass = PostgresTestSource.DRIVER_CLASS + + private class PostgresTestSource : + AbstractJdbcSource( + DRIVER_CLASS, + Supplier { AdaptiveStreamingQueryConfig() }, + JdbcUtils.defaultSourceOperations + ), + Source { + override fun toDatabaseConfig(config: JsonNode): JsonNode { + val configBuilder = + ImmutableMap.builder() + .put(JdbcUtils.USERNAME_KEY, config[JdbcUtils.USERNAME_KEY].asText()) + .put( + JdbcUtils.JDBC_URL_KEY, + String.format( + DatabaseDriver.POSTGRESQL.urlFormatString, + config[JdbcUtils.HOST_KEY].asText(), + config[JdbcUtils.PORT_KEY].asInt(), + config[JdbcUtils.DATABASE_KEY].asText() + ) + ) + + if (config.has(JdbcUtils.PASSWORD_KEY)) { + configBuilder.put(JdbcUtils.PASSWORD_KEY, config[JdbcUtils.PASSWORD_KEY].asText()) + } + + return Jsons.jsonNode(configBuilder.build()) + } + + public override val excludedInternalNameSpaces = + setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(PostgresTestSource::class.java) + + val DRIVER_CLASS: String = DatabaseDriver.POSTGRESQL.driverClassName + + @Throws(Exception::class) + @JvmStatic + fun main(args: Array) { + val source: Source = PostgresTestSource() + LOGGER.info("starting source: {}", PostgresTestSource::class.java) + IntegrationRunner(source).run(args) + LOGGER.info("completed source: {}", PostgresTestSource::class.java) + } + } + } + + companion object { + private var PSQL_DB: PostgreSQLContainer? = null + + @BeforeAll + @JvmStatic + fun init() { + PSQL_DB = PostgreSQLContainer("postgres:13-alpine") + PSQL_DB!!.start() + } + + @AfterAll + @JvmStatic + fun cleanUp() { + PSQL_DB!!.close() + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.kt new file mode 100644 index 000000000000..6a8dc1ab3d8b --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.kt @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc + +import io.airbyte.commons.json.Jsons +import java.util.function.Consumer +import org.assertj.core.api.AssertionsForClassTypes +import org.junit.Assert +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +class JdbcDataSourceUtilsTest { + @Test + fun test() { + val validConfigString = + "{\"jdbc_url_params\":\"key1=val1&key3=key3\",\"connection_properties\":\"key1=val1&key2=val2\"}" + val validConfig = Jsons.deserialize(validConfigString) + val connectionProperties = JdbcDataSourceUtils.getConnectionProperties(validConfig) + val validKeys = listOf("key1", "key2", "key3") + validKeys.forEach( + Consumer { key: String -> Assert.assertTrue(connectionProperties.containsKey(key)) } + ) + + // For an invalid config, there is a conflict betweeen the values of keys in jdbc_url_params + // and + // connection_properties + val invalidConfigString = + "{\"jdbc_url_params\":\"key1=val2&key3=key3\",\"connection_properties\":\"key1=val1&key2=val2\"}" + val invalidConfig = Jsons.deserialize(invalidConfigString) + val exception: Exception = + Assertions.assertThrows(IllegalArgumentException::class.java) { + JdbcDataSourceUtils.getConnectionProperties(invalidConfig) + } + + val expectedMessage = "Cannot overwrite default JDBC parameter key1" + AssertionsForClassTypes.assertThat(expectedMessage == exception.message) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.kt new file mode 100644 index 000000000000..a9a5b87afb2c --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.kt @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.collect.ImmutableMap +import io.airbyte.cdk.db.factory.DatabaseDriver +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig +import io.airbyte.cdk.integrations.base.IntegrationRunner +import io.airbyte.cdk.integrations.base.Source +import io.airbyte.cdk.integrations.source.jdbc.test.JdbcStressTest +import io.airbyte.cdk.testutils.PostgreSQLContainerHelper.runSqlScript +import io.airbyte.commons.io.IOs +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.string.Strings +import java.sql.JDBCType +import java.util.* +import java.util.function.Supplier +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Disabled +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import org.testcontainers.containers.PostgreSQLContainer +import org.testcontainers.utility.MountableFile + +/** + * Runs the stress tests in the source-jdbc test module. We want this module to run these tests + * itself as a sanity check. The trade off here is that this class is duplicated from the one used + * in source-postgres. + */ +@Disabled +internal class JdbcSourceStressTest : JdbcStressTest() { + private var config: JsonNode? = null + + @BeforeEach + @Throws(Exception::class) + override fun setup() { + val schemaName = Strings.addRandomSuffix("db", "_", 10) + + config = + Jsons.jsonNode( + ImmutableMap.builder() + .put(JdbcUtils.HOST_KEY, PSQL_DB!!.host) + .put(JdbcUtils.PORT_KEY, PSQL_DB!!.firstMappedPort) + .put(JdbcUtils.DATABASE_KEY, schemaName) + .put(JdbcUtils.USERNAME_KEY, PSQL_DB!!.username) + .put(JdbcUtils.PASSWORD_KEY, PSQL_DB!!.password) + .build() + ) + + val initScriptName = "init_$schemaName.sql" + val tmpFilePath = + IOs.writeFileToRandomTmpDir(initScriptName, "CREATE DATABASE $schemaName;") + runSqlScript(MountableFile.forHostPath(tmpFilePath), PSQL_DB!!) + + super.setup() + } + + override val defaultSchemaName = Optional.of("public") + + override fun getSource(): AbstractJdbcSource { + return PostgresTestSource() + } + + override fun getConfig(): JsonNode { + return config!! + } + + override val driverClass = PostgresTestSource.DRIVER_CLASS + + private class PostgresTestSource : + AbstractJdbcSource( + DRIVER_CLASS, + Supplier { AdaptiveStreamingQueryConfig() }, + JdbcUtils.defaultSourceOperations + ), + Source { + override fun toDatabaseConfig(config: JsonNode): JsonNode { + val configBuilder = + ImmutableMap.builder() + .put(JdbcUtils.USERNAME_KEY, config[JdbcUtils.USERNAME_KEY].asText()) + .put( + JdbcUtils.JDBC_URL_KEY, + String.format( + DatabaseDriver.POSTGRESQL.urlFormatString, + config[JdbcUtils.HOST_KEY].asText(), + config[JdbcUtils.PORT_KEY].asInt(), + config[JdbcUtils.DATABASE_KEY].asText() + ) + ) + + if (config.has(JdbcUtils.PASSWORD_KEY)) { + configBuilder.put(JdbcUtils.PASSWORD_KEY, config[JdbcUtils.PASSWORD_KEY].asText()) + } + + return Jsons.jsonNode(configBuilder.build()) + } + + override val excludedInternalNameSpaces = + setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(PostgresTestSource::class.java) + + val DRIVER_CLASS: String = DatabaseDriver.POSTGRESQL.driverClassName + + @Throws(Exception::class) + @JvmStatic + fun main(args: Array) { + val source: Source = PostgresTestSource() + LOGGER.info("starting source: {}", PostgresTestSource::class.java) + IntegrationRunner(source).run(args) + LOGGER.info("completed source: {}", PostgresTestSource::class.java) + } + } + } + + companion object { + private lateinit var PSQL_DB: PostgreSQLContainer + + @BeforeAll + @JvmStatic + fun init() { + PSQL_DB = PostgreSQLContainer("postgres:13-alpine") + PSQL_DB!!.start() + } + + @AfterAll + @JvmStatic + fun cleanUp() { + PSQL_DB!!.close() + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.kt new file mode 100644 index 000000000000..a292255725f0 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.kt @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.integrations.source.jdbc.AbstractDbSourceForTest +import io.airbyte.cdk.integrations.source.relationaldb.state.* +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.resources.MoreResources +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import java.io.IOException +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.mockito.Mockito +import uk.org.webcompere.systemstubs.environment.EnvironmentVariables +import uk.org.webcompere.systemstubs.jupiter.SystemStub +import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension + +/** Test suite for the [AbstractDbSource] class. */ +@ExtendWith(SystemStubsExtension::class) +class AbstractDbSourceTest { + @SystemStub private val environmentVariables: EnvironmentVariables? = null + + @Test + @Throws(IOException::class) + fun testDeserializationOfLegacyState() { + val dbSource = + Mockito.mock( + AbstractDbSourceForTest::class.java, + Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS) + ) + val config = Mockito.mock(JsonNode::class.java) + + val legacyStateJson = MoreResources.readResource("states/legacy.json") + val legacyState = Jsons.deserialize(legacyStateJson) + + val result = + StateGeneratorUtils.deserializeInitialState( + legacyState, + dbSource.getSupportedStateType(config) + ) + Assertions.assertEquals(1, result.size) + Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.LEGACY, result[0].type) + } + + @Test + @Throws(IOException::class) + fun testDeserializationOfGlobalState() { + val dbSource = + Mockito.mock( + AbstractDbSourceForTest::class.java, + Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS) + ) + val config = Mockito.mock(JsonNode::class.java) + + val globalStateJson = MoreResources.readResource("states/global.json") + val globalState = Jsons.deserialize(globalStateJson) + + val result = + StateGeneratorUtils.deserializeInitialState( + globalState, + dbSource.getSupportedStateType(config) + ) + Assertions.assertEquals(1, result.size) + Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, result[0].type) + } + + @Test + @Throws(IOException::class) + fun testDeserializationOfStreamState() { + val dbSource = + Mockito.mock( + AbstractDbSourceForTest::class.java, + Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS) + ) + val config = Mockito.mock(JsonNode::class.java) + + val streamStateJson = MoreResources.readResource("states/per_stream.json") + val streamState = Jsons.deserialize(streamStateJson) + + val result = + StateGeneratorUtils.deserializeInitialState( + streamState, + dbSource.getSupportedStateType(config) + ) + Assertions.assertEquals(2, result.size) + Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.STREAM, result[0].type) + } + + @Test + @Throws(IOException::class) + fun testDeserializationOfNullState() { + val dbSource = + Mockito.mock( + AbstractDbSourceForTest::class.java, + Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS) + ) + val config = Mockito.mock(JsonNode::class.java) + + val result = + StateGeneratorUtils.deserializeInitialState( + null, + dbSource.getSupportedStateType(config) + ) + Assertions.assertEquals(1, result.size) + Assertions.assertEquals(dbSource.getSupportedStateType(config), result[0].type) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.kt new file mode 100644 index 000000000000..c3905e5043ea --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.kt @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import java.util.* +import java.util.function.Function +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +/** Test suite for the [CursorManager] class. */ +class CursorManagerTest { + @Test + fun testCreateCursorInfoCatalogAndStateSameCursorField() { + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( + StateTestConstants.NAME_NAMESPACE_PAIR1, + StateTestConstants.getState( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.CURSOR_RECORD_COUNT + ), + StateTestConstants.getStream(StateTestConstants.CURSOR_FIELD1), + { obj: DbStreamState? -> obj!!.cursor }, + { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.CURSOR_RECORD_COUNT, + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.CURSOR_RECORD_COUNT + ), + actual + ) + } + + @Test + fun testCreateCursorInfoCatalogAndStateSameCursorFieldButNoCursor() { + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + null, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( + StateTestConstants.NAME_NAMESPACE_PAIR1, + StateTestConstants.getState(StateTestConstants.CURSOR_FIELD1, null), + StateTestConstants.getStream(StateTestConstants.CURSOR_FIELD1), + { obj: DbStreamState? -> obj!!.cursor }, + { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo( + StateTestConstants.CURSOR_FIELD1, + null, + StateTestConstants.CURSOR_FIELD1, + null + ), + actual + ) + } + + @Test + fun testCreateCursorInfoCatalogAndStateChangeInCursorFieldName() { + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( + StateTestConstants.NAME_NAMESPACE_PAIR1, + StateTestConstants.getState( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR + ), + StateTestConstants.getStream(StateTestConstants.CURSOR_FIELD2), + { obj: DbStreamState? -> obj!!.cursor }, + { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.CURSOR_FIELD2, + null + ), + actual + ) + } + + @Test + fun testCreateCursorInfoCatalogAndNoState() { + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( + StateTestConstants.NAME_NAMESPACE_PAIR1, + Optional.empty(), + StateTestConstants.getStream(StateTestConstants.CURSOR_FIELD1), + Function { obj: DbStreamState? -> obj!!.cursor }, + Function { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo(null, null, StateTestConstants.CURSOR_FIELD1, null), + actual + ) + } + + @Test + fun testCreateCursorInfoStateAndNoCatalog() { + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( + StateTestConstants.NAME_NAMESPACE_PAIR1, + StateTestConstants.getState( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR + ), + Optional.empty(), + { obj: DbStreamState? -> obj!!.cursor }, + { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null), + actual + ) + } + + // this is what full refresh looks like. + @Test + fun testCreateCursorInfoNoCatalogAndNoState() { + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( + StateTestConstants.NAME_NAMESPACE_PAIR1, + Optional.empty(), + Optional.empty(), + Function { obj: DbStreamState? -> obj!!.cursor }, + Function { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals(CursorInfo(null, null, null, null), actual) + } + + @Test + fun testCreateCursorInfoStateAndCatalogButNoCursorField() { + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( + StateTestConstants.NAME_NAMESPACE_PAIR1, + StateTestConstants.getState( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR + ), + StateTestConstants.getStream(null), + { obj: DbStreamState? -> obj!!.cursor }, + { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null), + actual + ) + } + + @Test + fun testGetters() { + val cursorManager: CursorManager<*> = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actualCursorInfo = + CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null) + + Assertions.assertEquals( + Optional.of(actualCursorInfo), + cursorManager.getCursorInfo(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursorInfo(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + } + + private fun createCursorManager( + cursorField: String?, + cursor: String?, + nameNamespacePair: AirbyteStreamNameNamespacePair? + ): CursorManager { + val dbStreamState = StateTestConstants.getState(cursorField, cursor).get() + return CursorManager( + StateTestConstants.getCatalog(cursorField).orElse(null), + { setOf(dbStreamState) }, + { obj: DbStreamState? -> obj!!.cursor }, + { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION, + { s: DbStreamState? -> nameNamespacePair }, + false + ) + } + + companion object { + private val CURSOR_RECORD_COUNT_FUNCTION = Function { stream: DbStreamState -> + if (stream!!.cursorRecordCount != null) { + return@Function stream.cursorRecordCount + } else { + return@Function 0L + } + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.kt new file mode 100644 index 000000000000..996b5e02c519 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.kt @@ -0,0 +1,540 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import com.fasterxml.jackson.databind.node.ObjectNode +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.util.MoreIterators +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.* +import java.sql.SQLException +import java.time.Duration +import java.util.* +import java.util.List +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.testcontainers.shaded.com.google.common.collect.ImmutableMap + +internal class CursorStateMessageProducerTest { + private fun createExceptionIterator(): Iterator { + return object : Iterator { + val internalMessageIterator: Iterator = + MoreIterators.of( + RECORD_MESSAGE_1, + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3 + ) + + override fun hasNext(): Boolean { + return true + } + + override fun next(): AirbyteMessage { + if (internalMessageIterator.hasNext()) { + return internalMessageIterator.next() + } else { + // this line throws a RunTimeException wrapped around a SQLException to mimic + // the flow of when a + // SQLException is thrown and wrapped in + // StreamingJdbcDatabase#tryAdvance + throw RuntimeException( + SQLException( + "Connection marked broken because of SQLSTATE(080006)", + "08006" + ) + ) + } + } + } + } + + private var stateManager: StateManager? = null + + @BeforeEach + fun setup() { + val airbyteStream = AirbyteStream().withNamespace(NAMESPACE).withName(STREAM_NAME) + val configuredAirbyteStream = + ConfiguredAirbyteStream() + .withStream(airbyteStream) + .withCursorField(listOf(UUID_FIELD_NAME)) + + stateManager = + StreamStateManager( + emptyList(), + ConfiguredAirbyteCatalog().withStreams(listOf(configuredAirbyteStream)) + ) + } + + @Test + fun testWithoutInitialCursor() { + messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2) + + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_2, 1, 2.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + @Test + fun testWithInitialCursor() { + // record 1 and 2 has smaller cursor value, so at the end, the initial cursor is emitted + // with 0 + // record count + + messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2) + + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_5)) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_5, 0, 2.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + @Test + fun testCursorFieldIsEmpty() { + val recordMessage = Jsons.clone(RECORD_MESSAGE_1) + (recordMessage.record.data as ObjectNode).remove(UUID_FIELD_NAME) + val messageStream = MoreIterators.of(recordMessage) + + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageStream, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) + + Assertions.assertEquals(recordMessage, iterator.next()) + // null because no records with a cursor field were replicated for the stream. + Assertions.assertEquals(createEmptyStateMessage(1.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + @Test + fun testIteratorCatchesExceptionWhenEmissionFrequencyNonZero() { + val exceptionIterator = createExceptionIterator() + + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + exceptionIterator, + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + // continues to emit RECORD_MESSAGE_2 since cursorField has not changed thus not satisfying + // the + // condition of "ready" + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + // emits the first state message since the iterator has changed cursorFields (2 -> 3) and + // met the + // frequency minimum of 1 record + Assertions.assertEquals(createStateMessage(RECORD_VALUE_2, 2, 4.0), iterator.next()) + // no further records to read since Exception was caught above and marked iterator as + // endOfData() + Assertions.assertThrows(FailedRecordIteratorException::class.java) { iterator.hasNext() } + } + + @Test + fun testIteratorCatchesExceptionWhenEmissionFrequencyZero() { + val exceptionIterator = createExceptionIterator() + + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + exceptionIterator, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + + Assertions.assertThrows(RuntimeException::class.java) { iterator.hasNext() } + } + + @Test + fun testEmptyStream() { + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + Collections.emptyIterator(), + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) + + Assertions.assertEquals(EMPTY_STATE_MESSAGE, iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + @Test + fun testUnicodeNull() { + val recordValueWithNull = "abc\u0000" + val recordMessageWithNull = createRecordMessage(recordValueWithNull) + + // UTF8 null \u0000 is removed from the cursor value in the state message + messageIterator = MoreIterators.of(recordMessageWithNull) + + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) + + Assertions.assertEquals(recordMessageWithNull, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_1, 1, 1.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + @Test + fun testStateEmissionFrequency1() { + messageIterator = + MoreIterators.of( + RECORD_MESSAGE_1, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3, + RECORD_MESSAGE_4, + RECORD_MESSAGE_5 + ) + + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) + // should emit state 1, but it is unclear whether there will be more + // records with the same cursor value, so no state is ready for emission + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + // emit state 1 because it is the latest state ready for emission + Assertions.assertEquals(createStateMessage(RECORD_VALUE_1, 1, 2.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_2, 1, 1.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_4, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_3, 1, 1.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_5, iterator.next()) + // state 4 is not emitted because there is no more record and only + // the final state should be emitted at this point; also the final + // state should only be emitted once + Assertions.assertEquals(createStateMessage(RECORD_VALUE_5, 1, 1.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + @Test + fun testStateEmissionFrequency2() { + messageIterator = + MoreIterators.of( + RECORD_MESSAGE_1, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3, + RECORD_MESSAGE_4, + RECORD_MESSAGE_5 + ) + + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(2, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + // emit state 1 because it is the latest state ready for emission + Assertions.assertEquals(createStateMessage(RECORD_VALUE_1, 1, 2.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_4, iterator.next()) + // emit state 3 because it is the latest state ready for emission + Assertions.assertEquals(createStateMessage(RECORD_VALUE_3, 1, 2.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_5, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_5, 1, 1.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + @Test + fun testStateEmissionWhenInitialCursorIsNotNull() { + messageIterator = + MoreIterators.of(RECORD_MESSAGE_2, RECORD_MESSAGE_3, RECORD_MESSAGE_4, RECORD_MESSAGE_5) + + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_2, 1, 2.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_4, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_3, 1, 1.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_5, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_5, 1, 1.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + /** + * Incremental syncs will sort the table with the cursor field, and emit the max cursor for + * every N records. The purpose is to emit the states frequently, so that if any transient + * failure occurs during a long sync, the next run does not need to start from the beginning, + * but can resume from the last successful intermediate state committed on the destination. The + * next run will start with `cursorField > cursor`. However, it is possible that there are + * multiple records with the same cursor value. If the intermediate state is emitted before all + * these records have been synced to the destination, some of these records may be lost. + * + * Here is an example: + * + *

 | Record ID | Cursor Field | Other Field | Note | | --------- | ------------ |
+     * ----------- | ----------------------------- | | 1 | F1=16 | F2="abc" | | | 2 | F1=16 |
+     * F2="def" | <- state emission and failure | | 3 | F1=16 | F2="ghi" | | 
* + * + * If the intermediate state is emitted for record 2 and the sync fails immediately such that + * the cursor value `16` is committed, but only record 1 and 2 are actually synced, the next run + * will start with `F1 > 16` and skip record 3. + * + * So intermediate state emission should only happen when all records with the same cursor value + * has been synced to destination. Reference: + * [link](https://github.com/airbytehq/airbyte/issues/15427) + */ + @Test + fun testStateEmissionForRecordsSharingSameCursorValue() { + messageIterator = + MoreIterators.of( + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_4, + RECORD_MESSAGE_5, + RECORD_MESSAGE_5 + ) + + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + // state 2 is the latest state ready for emission because + // all records with the same cursor value have been emitted + Assertions.assertEquals(createStateMessage(RECORD_VALUE_2, 2, 3.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_4, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_3, 3, 3.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_5, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_4, 1, 1.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_5, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_5, 2, 1.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + @Test + fun testStateEmissionForRecordsSharingSameCursorValueButDifferentStatsCount() { + messageIterator = + MoreIterators.of( + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3 + ) + + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) + + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(10, Duration.ZERO) + ) + + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + // state 2 is the latest state ready for emission because + // all records with the same cursor value have been emitted + Assertions.assertEquals(createStateMessage(RECORD_VALUE_2, 4, 10.0), iterator.next()) + Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) + Assertions.assertEquals(createStateMessage(RECORD_VALUE_3, 7, 1.0), iterator.next()) + Assertions.assertFalse(iterator.hasNext()) + } + + companion object { + private const val NAMESPACE = "public" + private const val STREAM_NAME = "shoes" + private const val UUID_FIELD_NAME = "ascending_inventory_uuid" + + private val STREAM: ConfiguredAirbyteStream = + CatalogHelpers.createConfiguredAirbyteStream( + STREAM_NAME, + NAMESPACE, + Field.of(UUID_FIELD_NAME, JsonSchemaType.STRING) + ) + .withCursorField(List.of(UUID_FIELD_NAME)) + + private val EMPTY_STATE_MESSAGE = createEmptyStateMessage(0.0) + + private const val RECORD_VALUE_1 = "abc" + private val RECORD_MESSAGE_1 = createRecordMessage(RECORD_VALUE_1) + + private const val RECORD_VALUE_2 = "def" + private val RECORD_MESSAGE_2 = createRecordMessage(RECORD_VALUE_2) + + private const val RECORD_VALUE_3 = "ghi" + private val RECORD_MESSAGE_3 = createRecordMessage(RECORD_VALUE_3) + + private const val RECORD_VALUE_4 = "jkl" + private val RECORD_MESSAGE_4 = createRecordMessage(RECORD_VALUE_4) + + private const val RECORD_VALUE_5 = "xyz" + private val RECORD_MESSAGE_5 = createRecordMessage(RECORD_VALUE_5) + + private fun createRecordMessage(recordValue: String): AirbyteMessage { + return AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withData(Jsons.jsonNode(ImmutableMap.of(UUID_FIELD_NAME, recordValue))) + ) + } + + private fun createStateMessage( + recordValue: String, + cursorRecordCount: Long, + statsRecordCount: Double + ): AirbyteMessage { + val dbStreamState = + DbStreamState() + .withCursorField(listOf(UUID_FIELD_NAME)) + .withCursor(recordValue) + .withStreamName(STREAM_NAME) + .withStreamNamespace(NAMESPACE) + if (cursorRecordCount > 0) { + dbStreamState.withCursorRecordCount(cursorRecordCount) + } + val dbState = DbState().withCdc(false).withStreams(listOf(dbStreamState)) + return AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(STREAM_NAME) + .withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(dbStreamState)) + ) + .withData(Jsons.jsonNode(dbState)) + .withSourceStats(AirbyteStateStats().withRecordCount(statsRecordCount)) + ) + } + + private fun createEmptyStateMessage(statsRecordCount: Double): AirbyteMessage { + val dbStreamState = + DbStreamState() + .withCursorField(listOf(UUID_FIELD_NAME)) + .withStreamName(STREAM_NAME) + .withStreamNamespace(NAMESPACE) + + val dbState = DbState().withCdc(false).withStreams(listOf(dbStreamState)) + return AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(STREAM_NAME) + .withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(dbStreamState)) + ) + .withData(Jsons.jsonNode(dbState)) + .withSourceStats(AirbyteStateStats().withRecordCount(statsRecordCount)) + ) + } + + private lateinit var messageIterator: Iterator + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.kt new file mode 100644 index 000000000000..ec7521360f37 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.kt @@ -0,0 +1,408 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.List +import java.util.Map +import java.util.stream.Collectors +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test +import org.mockito.Mockito + +/** Test suite for the [GlobalStateManager] class. */ +class GlobalStateManagerTest { + @Test + fun testCdcStateManager() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) + val globalState = + AirbyteGlobalState() + .withSharedState(Jsons.jsonNode(cdcState)) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withNamespace("namespace").withName("name") + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) + val stateManager: StateManager = + GlobalStateManager( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState), + catalog + ) + Assertions.assertNotNull(stateManager.cdcStateManager) + Assertions.assertEquals(cdcState, stateManager.cdcStateManager.cdcState) + Assertions.assertEquals(1, stateManager.cdcStateManager.initialStreamsSynced!!.size) + Assertions.assertTrue( + stateManager.cdcStateManager.initialStreamsSynced!!.contains( + AirbyteStreamNameNamespacePair("name", "namespace") + ) + ) + } + + @Test + fun testToStateFromLegacyState() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + + val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) + val dbState = + DbState() + .withCdc(true) + .withCdcState(cdcState) + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) + val stateManager: StateManager = + GlobalStateManager(AirbyteStateMessage().withData(Jsons.jsonNode(dbState)), catalog) + + val expectedRecordCount = 19L + val expectedDbState = + DbState() + .withCdc(true) + .withCdcState(cdcState) + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) + .withCursor("a") + .withCursorRecordCount(expectedRecordCount), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) + + val expectedGlobalState = + AirbyteGlobalState() + .withSharedState(Jsons.jsonNode(cdcState)) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a") + .withCursorRecordCount(expectedRecordCount) + ) + ), + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD2) + ) + ) + ), + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + .stream() + .sorted( + Comparator.comparing { o: AirbyteStreamState -> + o.streamDescriptor.name + } + ) + .collect(Collectors.toList()) + ) + val expected = + AirbyteStateMessage() + .withData(Jsons.jsonNode(expectedDbState)) + .withGlobal(expectedGlobalState) + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + + val actualFirstEmission = + stateManager.updateAndEmit( + StateTestConstants.NAME_NAMESPACE_PAIR1, + "a", + expectedRecordCount + ) + Assertions.assertEquals(expected, actualFirstEmission) + } + + // Discovered during CDK migration. + // Failure is: Could not find cursor information for stream: public_cars + @Disabled("Failing test.") + @Test + fun testToState() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + + val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) + val globalState = + AirbyteGlobalState() + .withSharedState(Jsons.jsonNode(DbState())) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor(StreamDescriptor()) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) + val stateManager: StateManager = + GlobalStateManager( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState), + catalog + ) + stateManager.cdcStateManager.cdcState = cdcState + + val expectedDbState = + DbState() + .withCdc(true) + .withCdcState(cdcState) + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) + .withCursor("a") + .withCursorRecordCount(1L), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) + + val expectedGlobalState = + AirbyteGlobalState() + .withSharedState(Jsons.jsonNode(cdcState)) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a") + .withCursorRecordCount(1L) + ) + ), + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD2) + ) + ) + ), + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + .stream() + .sorted( + Comparator.comparing { o: AirbyteStreamState -> + o.streamDescriptor.name + } + ) + .collect(Collectors.toList()) + ) + val expected = + AirbyteStateMessage() + .withData(Jsons.jsonNode(expectedDbState)) + .withGlobal(expectedGlobalState) + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a", 1L) + Assertions.assertEquals(expected, actualFirstEmission) + } + + @Test + fun testToStateWithNoState() { + val catalog = ConfiguredAirbyteCatalog() + val stateManager: StateManager = GlobalStateManager(AirbyteStateMessage(), catalog) + + val airbyteStateMessage = stateManager.toState(Optional.empty()) + Assertions.assertNotNull(airbyteStateMessage) + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + airbyteStateMessage!!.type + ) + Assertions.assertEquals(0, airbyteStateMessage.global.streamStates.size) + } + + @Test + fun testCdcStateManagerLegacyState() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) + val dbState = + DbState() + .withCdcState(CdcState().withState(Jsons.jsonNode(cdcState))) + .withStreams( + List.of( + DbStreamState() + .withStreamName("name") + .withStreamNamespace("namespace") + .withCursor("") + .withCursorField(emptyList()) + ) + ) + .withCdc(true) + val stateManager: StateManager = + GlobalStateManager( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(dbState)), + catalog + ) + Assertions.assertNotNull(stateManager.cdcStateManager) + Assertions.assertEquals(1, stateManager.cdcStateManager.initialStreamsSynced!!.size) + Assertions.assertTrue( + stateManager.cdcStateManager.initialStreamsSynced!!.contains( + AirbyteStreamNameNamespacePair("name", "namespace") + ) + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.kt new file mode 100644 index 000000000000..b6a585713b95 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.kt @@ -0,0 +1,384 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStream +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import java.util.* +import java.util.List +import java.util.Map +import java.util.stream.Collectors +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.mockito.Mockito + +/** Test suite for the [LegacyStateManager] class. */ +class LegacyStateManagerTest { + @Test + fun testGetters() { + val state = + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) + .withCursor(StateTestConstants.CURSOR), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + ) + + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + + val stateManager: StateManager = LegacyStateManager(state, catalog) + + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR_FIELD1), + stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR), + stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR_FIELD1), + stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR), + stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + + Assertions.assertEquals( + Optional.empty(), + stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + } + + @Test + fun testToState() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + + val stateManager: StateManager = LegacyStateManager(DbState(), catalog) + + val expectedFirstEmission = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD2) + ), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(false) + ) + ) + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) + val expectedSecondEmission = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD2) + ) + .withCursor("b"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(false) + ) + ) + val actualSecondEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR2, "b") + Assertions.assertEquals(expectedSecondEmission, actualSecondEmission) + } + + @Test + fun testToStateNullCursorField() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + val stateManager: StateManager = LegacyStateManager(DbState(), catalog) + + val expectedFirstEmission = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(false) + ) + ) + + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) + } + + @Test + fun testCursorNotUpdatedForCdc() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + + val state = DbState() + state.cdc = true + val stateManager: StateManager = LegacyStateManager(state, catalog) + + val expectedFirstEmission = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor(null), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(listOf()) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(true) + ) + ) + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) + val expectedSecondEmission = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor(null), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(listOf()) + .withCursor(null) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(true) + ) + ) + val actualSecondEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR2, "b") + Assertions.assertEquals(expectedSecondEmission, actualSecondEmission) + } + + @Test + fun testCdcStateManager() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) + val dbState = + DbState() + .withCdcState(cdcState) + .withStreams( + List.of( + DbStreamState() + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withStreamName(StateTestConstants.STREAM_NAME1) + ) + ) + val stateManager: StateManager = LegacyStateManager(dbState, catalog) + Assertions.assertNotNull(stateManager.cdcStateManager) + Assertions.assertEquals(cdcState, stateManager.cdcStateManager.cdcState) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorForTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorForTest.kt new file mode 100644 index 000000000000..209743976990 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorForTest.kt @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream + +class SourceStateIteratorForTest( + messageIterator: Iterator, + stream: ConfiguredAirbyteStream, + sourceStateMessageProducer: SourceStateMessageProducer, + stateEmitFrequency: StateEmitFrequency +) : + SourceStateIterator( + messageIterator, + stream, + sourceStateMessageProducer, + stateEmitFrequency + ) { + public override fun computeNext(): AirbyteMessage? = super.computeNext() +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.kt new file mode 100644 index 000000000000..fb34ea35822b --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.kt @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.protocol.models.v0.* +import java.time.Duration +import org.junit.Assert +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.ArgumentMatchers +import org.mockito.Mockito +import org.mockito.Mockito.mock +import org.mockito.kotlin.any +import org.mockito.kotlin.eq + +class SourceStateIteratorTest { + lateinit var mockProducer: SourceStateMessageProducer + lateinit var messageIterator: Iterator + lateinit var stream: ConfiguredAirbyteStream + + var sourceStateIterator: SourceStateIteratorForTest<*>? = null + + @BeforeEach + fun setup() { + mockProducer = mock() + stream = mock() + messageIterator = mock() + val stateEmitFrequency = StateEmitFrequency(1L, Duration.ofSeconds(100L)) + sourceStateIterator = + SourceStateIteratorForTest(messageIterator, stream, mockProducer, stateEmitFrequency) + } + + // Provides a way to generate a record message and will verify corresponding spied functions + // have + // been called. + fun processRecordMessage() { + Mockito.doReturn(true).`when`(messageIterator).hasNext() + Mockito.doReturn(false) + .`when`(mockProducer) + .shouldEmitStateMessage(ArgumentMatchers.eq(stream)) + val message = + AirbyteMessage().withType(AirbyteMessage.Type.RECORD).withRecord(AirbyteRecordMessage()) + Mockito.doReturn(message).`when`(mockProducer).processRecordMessage(eq(stream), any()) + Mockito.doReturn(message).`when`(messageIterator).next() + + Assert.assertEquals(message, sourceStateIterator!!.computeNext()) + Mockito.verify(mockProducer, Mockito.atLeastOnce()) + .processRecordMessage(eq(stream), eq(message)) + } + + @Test + fun testShouldProcessRecordMessage() { + processRecordMessage() + } + + @Test + fun testShouldEmitStateMessage() { + processRecordMessage() + Mockito.doReturn(true) + .`when`(mockProducer) + .shouldEmitStateMessage(ArgumentMatchers.eq(stream)) + val stateMessage = AirbyteStateMessage() + Mockito.doReturn(stateMessage).`when`(mockProducer).generateStateMessageAtCheckpoint(stream) + val expectedMessage = + AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) + expectedMessage.state.withSourceStats(AirbyteStateStats().withRecordCount(1.0)) + Assert.assertEquals(expectedMessage, sourceStateIterator!!.computeNext()) + } + + @Test + fun testShouldEmitFinalStateMessage() { + processRecordMessage() + processRecordMessage() + Mockito.doReturn(false).`when`(messageIterator).hasNext() + val stateMessage = AirbyteStateMessage() + Mockito.doReturn(stateMessage).`when`(mockProducer).createFinalStateMessage(stream) + val expectedMessage = + AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) + expectedMessage.state.withSourceStats(AirbyteStateStats().withRecordCount(2.0)) + Assert.assertEquals(expectedMessage, sourceStateIterator!!.computeNext()) + } + + @Test + fun testShouldSendEndOfData() { + processRecordMessage() + Mockito.doReturn(false).`when`(messageIterator).hasNext() + Mockito.doReturn(AirbyteStateMessage()).`when`(mockProducer).createFinalStateMessage(stream) + sourceStateIterator!!.computeNext() + + // After sending the final state, if iterator was called again, we will return null. + Assert.assertEquals(null, sourceStateIterator!!.computeNext()) + } + + @Test + fun testShouldRethrowExceptions() { + processRecordMessage() + Mockito.doThrow(ArrayIndexOutOfBoundsException("unexpected error")) + .`when`(messageIterator) + .hasNext() + Assert.assertThrows(RuntimeException::class.java) { sourceStateIterator!!.computeNext() } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.kt new file mode 100644 index 000000000000..e9334ff081f3 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.kt @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.protocol.models.v0.StreamDescriptor +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +/** Test suite for the [StateGeneratorUtils] class. */ +class StateGeneratorUtilsTest { + @Test + fun testValidStreamDescriptor() { + val streamDescriptor1: StreamDescriptor? = null + val streamDescriptor2 = StreamDescriptor() + val streamDescriptor3 = StreamDescriptor().withName("name") + val streamDescriptor4 = StreamDescriptor().withNamespace("namespace") + val streamDescriptor5 = StreamDescriptor().withName("name").withNamespace("namespace") + val streamDescriptor6 = StreamDescriptor().withName("name").withNamespace("") + val streamDescriptor7 = StreamDescriptor().withName("").withNamespace("namespace") + val streamDescriptor8 = StreamDescriptor().withName("").withNamespace("") + + Assertions.assertFalse(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor1)) + Assertions.assertFalse(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor2)) + Assertions.assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor3)) + Assertions.assertFalse(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor4)) + Assertions.assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor5)) + Assertions.assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor6)) + Assertions.assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor7)) + Assertions.assertTrue(StateGeneratorUtils.isValidStreamDescriptor(streamDescriptor8)) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.kt new file mode 100644 index 000000000000..ca8c76753b0c --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.kt @@ -0,0 +1,322 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.* +import java.util.List +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.mockito.Mockito + +/** Test suite for the [StateManagerFactory] class. */ +class StateManagerFactoryTest { + @Test + fun testNullOrEmptyState() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + null, + catalog + ) + } + + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + listOf(), + catalog + ) + } + + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.LEGACY, + null, + catalog + ) + } + + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.LEGACY, + listOf(), + catalog + ) + } + + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + null, + catalog + ) + } + + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + listOf(), + catalog + ) + } + } + + @Test + fun testLegacyStateManagerCreationFromAirbyteStateMessage() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val airbyteStateMessage = Mockito.mock(AirbyteStateMessage::class.java) + Mockito.`when`(airbyteStateMessage.data).thenReturn(Jsons.jsonNode(DbState())) + + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.LEGACY, + List.of(airbyteStateMessage), + catalog + ) + + Assertions.assertNotNull(stateManager) + Assertions.assertEquals(LegacyStateManager::class.java, stateManager.javaClass) + } + + @Test + fun testGlobalStateManagerCreation() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val globalState = + AirbyteGlobalState() + .withSharedState( + Jsons.jsonNode( + DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))) + ) + ) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withNamespace(NAMESPACE).withName(NAME) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) + + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + List.of(airbyteStateMessage), + catalog + ) + + Assertions.assertNotNull(stateManager) + Assertions.assertEquals(GlobalStateManager::class.java, stateManager.javaClass) + } + + @Test + fun testGlobalStateManagerCreationFromLegacyState() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val cdcState = CdcState() + val dbState = + DbState() + .withCdcState(cdcState) + .withStreams( + List.of(DbStreamState().withStreamName(NAME).withStreamNamespace(NAMESPACE)) + ) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(dbState)) + + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + List.of(airbyteStateMessage), + catalog + ) + + Assertions.assertNotNull(stateManager) + Assertions.assertEquals(GlobalStateManager::class.java, stateManager.javaClass) + } + + @Test + fun testGlobalStateManagerCreationFromStreamState() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(NAME).withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + List.of(airbyteStateMessage), + catalog + ) + } + } + + @Test + fun testGlobalStateManagerCreationWithLegacyDataPresent() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val globalState = + AirbyteGlobalState() + .withSharedState( + Jsons.jsonNode( + DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))) + ) + ) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withNamespace(NAMESPACE).withName(NAME) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) + .withData(Jsons.jsonNode(DbState())) + + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + List.of(airbyteStateMessage), + catalog + ) + + Assertions.assertNotNull(stateManager) + Assertions.assertEquals(GlobalStateManager::class.java, stateManager.javaClass) + } + + @Test + fun testStreamStateManagerCreation() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(NAME).withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + List.of(airbyteStateMessage), + catalog + ) + + Assertions.assertNotNull(stateManager) + Assertions.assertEquals(StreamStateManager::class.java, stateManager.javaClass) + } + + @Test + fun testStreamStateManagerCreationFromLegacy() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val cdcState = CdcState() + val dbState = + DbState() + .withCdcState(cdcState) + .withStreams( + List.of(DbStreamState().withStreamName(NAME).withStreamNamespace(NAMESPACE)) + ) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(dbState)) + + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + List.of(airbyteStateMessage), + catalog + ) + + Assertions.assertNotNull(stateManager) + Assertions.assertEquals(StreamStateManager::class.java, stateManager.javaClass) + } + + @Test + fun testStreamStateManagerCreationFromGlobal() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val globalState = + AirbyteGlobalState() + .withSharedState( + Jsons.jsonNode( + DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))) + ) + ) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withNamespace(NAMESPACE).withName(NAME) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) + + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + List.of(airbyteStateMessage), + catalog + ) + } + } + + @Test + fun testStreamStateManagerCreationWithLegacyDataPresent() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(NAME).withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + .withData(Jsons.jsonNode(DbState())) + + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + List.of(airbyteStateMessage), + catalog + ) + + Assertions.assertNotNull(stateManager) + Assertions.assertEquals(StreamStateManager::class.java, stateManager.javaClass) + } + + companion object { + private const val NAMESPACE = "namespace" + private const val NAME = "name" + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.kt new file mode 100644 index 000000000000..3ffd9781e760 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.kt @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.protocol.models.v0.AirbyteStream +import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import java.util.* +import java.util.List +import org.testcontainers.shaded.com.google.common.collect.Lists + +/** Collection of constants for use in state management-related tests. */ +object StateTestConstants { + const val NAMESPACE: String = "public" + const val STREAM_NAME1: String = "cars" + val NAME_NAMESPACE_PAIR1: AirbyteStreamNameNamespacePair = + AirbyteStreamNameNamespacePair(STREAM_NAME1, NAMESPACE) + const val STREAM_NAME2: String = "bicycles" + val NAME_NAMESPACE_PAIR2: AirbyteStreamNameNamespacePair = + AirbyteStreamNameNamespacePair(STREAM_NAME2, NAMESPACE) + const val STREAM_NAME3: String = "stationary_bicycles" + const val CURSOR_FIELD1: String = "year" + const val CURSOR_FIELD2: String = "generation" + const val CURSOR: String = "2000" + const val CURSOR_RECORD_COUNT: Long = 19L + + fun getState(cursorField: String?, cursor: String?): Optional { + return Optional.of( + DbStreamState() + .withStreamName(STREAM_NAME1) + .withCursorField(Lists.newArrayList(cursorField)) + .withCursor(cursor) + ) + } + + fun getState( + cursorField: String?, + cursor: String?, + cursorRecordCount: Long + ): Optional { + return Optional.of( + DbStreamState() + .withStreamName(STREAM_NAME1) + .withCursorField(Lists.newArrayList(cursorField)) + .withCursor(cursor) + .withCursorRecordCount(cursorRecordCount) + ) + } + + fun getCatalog(cursorField: String?): Optional { + return Optional.of( + ConfiguredAirbyteCatalog().withStreams(List.of(getStream(cursorField).orElse(null))) + ) + } + + fun getStream(cursorField: String?): Optional { + return Optional.of( + ConfiguredAirbyteStream() + .withStream(AirbyteStream().withName(STREAM_NAME1)) + .withCursorField( + if (cursorField == null) emptyList() else Lists.newArrayList(cursorField) + ) + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.kt new file mode 100644 index 000000000000..6fba4dda3a85 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.kt @@ -0,0 +1,473 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.relationaldb.state + +import com.google.common.collect.Lists +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.stream.Collectors +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.mockito.Mockito + +/** Test suite for the [StreamStateManager] class. */ +class StreamStateManagerTest { + @Test + fun testCreationFromInvalidState() { + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState(Jsons.jsonNode("Not a state object")) + ) + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + + Assertions.assertDoesNotThrow { + val stateManager: StateManager = + StreamStateManager(java.util.List.of(airbyteStateMessage), catalog) + Assertions.assertNotNull(stateManager) + } + } + + @Test + fun testGetters() { + val state: MutableList = ArrayList() + state.add( + createStreamState( + StateTestConstants.STREAM_NAME1, + StateTestConstants.NAMESPACE, + java.util.List.of(StateTestConstants.CURSOR_FIELD1), + StateTestConstants.CURSOR, + 0L + ) + ) + state.add( + createStreamState( + StateTestConstants.STREAM_NAME2, + StateTestConstants.NAMESPACE, + listOf(), + null, + 0L + ) + ) + + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) + + val stateManager: StateManager = StreamStateManager(state, catalog) + + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR_FIELD1), + stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR), + stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR_FIELD1), + stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR), + stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + + Assertions.assertEquals( + Optional.empty(), + stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + } + + @Test + fun testToState() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) + + val stateManager: StateManager = StreamStateManager(createDefaultState(), catalog) + + val expectedFirstDbState = + DbState() + .withCdc(false) + .withStreams( + java.util.List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD2) + ), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) + val expectedFirstEmission = + createStreamState( + StateTestConstants.STREAM_NAME1, + StateTestConstants.NAMESPACE, + java.util.List.of(StateTestConstants.CURSOR_FIELD1), + "a", + 0L + ) + .withData(Jsons.jsonNode(expectedFirstDbState)) + + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) + + val expectedRecordCount = 17L + val expectedSecondDbState = + DbState() + .withCdc(false) + .withStreams( + java.util.List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD2) + ) + .withCursor("b") + .withCursorRecordCount(expectedRecordCount), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) + val expectedSecondEmission = + createStreamState( + StateTestConstants.STREAM_NAME2, + StateTestConstants.NAMESPACE, + java.util.List.of(StateTestConstants.CURSOR_FIELD2), + "b", + expectedRecordCount + ) + .withData(Jsons.jsonNode(expectedSecondDbState)) + + val actualSecondEmission = + stateManager.updateAndEmit( + StateTestConstants.NAME_NAMESPACE_PAIR2, + "b", + expectedRecordCount + ) + Assertions.assertEquals(expectedSecondEmission, actualSecondEmission) + } + + @Test + fun testToStateWithoutCursorInfo() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) + val airbyteStreamNameNamespacePair = AirbyteStreamNameNamespacePair("other", "other") + + val stateManager: StateManager = StreamStateManager(createDefaultState(), catalog) + val airbyteStateMessage = stateManager.toState(Optional.of(airbyteStreamNameNamespacePair)) + Assertions.assertNotNull(airbyteStateMessage) + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.STREAM, + airbyteStateMessage.type + ) + Assertions.assertNotNull(airbyteStateMessage.stream) + } + + @Test + fun testToStateWithoutStreamPair() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) + + val stateManager: StateManager = StreamStateManager(createDefaultState(), catalog) + val airbyteStateMessage = stateManager.toState(Optional.empty()) + Assertions.assertNotNull(airbyteStateMessage) + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.STREAM, + airbyteStateMessage.type + ) + Assertions.assertNotNull(airbyteStateMessage.stream) + Assertions.assertNull(airbyteStateMessage.stream.streamState) + } + + @Test + fun testToStateNullCursorField() { + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) + val stateManager: StateManager = StreamStateManager(createDefaultState(), catalog) + + val expectedFirstDbState = + DbState() + .withCdc(false) + .withStreams( + java.util.List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) + + val expectedFirstEmission = + createStreamState( + StateTestConstants.STREAM_NAME1, + StateTestConstants.NAMESPACE, + java.util.List.of(StateTestConstants.CURSOR_FIELD1), + "a", + 0L + ) + .withData(Jsons.jsonNode(expectedFirstDbState)) + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) + } + + @Test + fun testCdcStateManager() { + val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) + val stateManager: StateManager = + StreamStateManager( + java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) + ), + catalog + ) + Assertions.assertThrows(UnsupportedOperationException::class.java) { + stateManager.cdcStateManager + } + } + + private fun createDefaultState(): List { + return java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) + ) + } + + private fun createStreamState( + name: String?, + namespace: String?, + cursorFields: List?, + cursorValue: String?, + cursorRecordCount: Long + ): AirbyteStateMessage { + val dbStreamState = DbStreamState().withStreamName(name).withStreamNamespace(namespace) + + if (cursorFields != null && !cursorFields.isEmpty()) { + dbStreamState.withCursorField(cursorFields) + } + + if (cursorValue != null) { + dbStreamState.withCursor(cursorValue) + } + + if (cursorRecordCount > 0L) { + dbStreamState.withCursorRecordCount(cursorRecordCount) + } + + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(name).withNamespace(namespace) + ) + .withStreamState(Jsons.jsonNode(dbStreamState)) + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.kt new file mode 100644 index 000000000000..e0759df75609 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.kt @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.testutils + +import com.zaxxer.hikari.HikariDataSource +import io.airbyte.cdk.testutils.DatabaseConnectionHelper.createDataSource +import io.airbyte.cdk.testutils.DatabaseConnectionHelper.createDslContext +import org.jooq.SQLDialect +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Test +import org.testcontainers.containers.PostgreSQLContainer + +internal class DatabaseConnectionHelperTest { + @Test + fun testCreatingFromATestContainer() { + val dataSource = createDataSource(container) + Assertions.assertNotNull(dataSource) + Assertions.assertEquals(HikariDataSource::class.java, dataSource!!.javaClass) + Assertions.assertEquals( + 10, + (dataSource as HikariDataSource?)!!.hikariConfigMXBean.maximumPoolSize + ) + } + + @Test + fun testCreatingADslContextFromATestContainer() { + val dialect = SQLDialect.POSTGRES + val dslContext = createDslContext(container, dialect) + Assertions.assertNotNull(dslContext) + Assertions.assertEquals(dialect, dslContext!!.configuration().dialect()) + } + + companion object { + private const val DATABASE_NAME = "airbyte_test_database" + + protected var container: PostgreSQLContainer<*>? = null + + @BeforeAll + @JvmStatic + fun dbSetup() { + container = + PostgreSQLContainer("postgres:13-alpine") + .withDatabaseName(DATABASE_NAME) + .withUsername("docker") + .withPassword("docker") + container!!.start() + } + + @AfterAll + @JvmStatic + fun dbDown() { + container!!.close() + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/debezium/CdcSourceTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/debezium/CdcSourceTest.java deleted file mode 100644 index 729d774f33a8..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/debezium/CdcSourceTest.java +++ /dev/null @@ -1,864 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debezium; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import com.google.common.collect.Streams; -import io.airbyte.cdk.integrations.base.Source; -import io.airbyte.cdk.testutils.TestDatabase; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.util.AutoCloseableIterator; -import io.airbyte.commons.util.AutoCloseableIterators; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteCatalog; -import io.airbyte.protocol.models.v0.AirbyteConnectionStatus; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.StreamDescriptor; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public abstract class CdcSourceTest> { - - static private final Logger LOGGER = LoggerFactory.getLogger(CdcSourceTest.class); - - static protected final String MODELS_STREAM_NAME = "models"; - static protected final Set STREAM_NAMES = Set.of(MODELS_STREAM_NAME); - static protected final String COL_ID = "id"; - static protected final String COL_MAKE_ID = "make_id"; - static protected final String COL_MODEL = "model"; - - static protected final List MODEL_RECORDS = ImmutableList.of( - Jsons.jsonNode(ImmutableMap.of(COL_ID, 11, COL_MAKE_ID, 1, COL_MODEL, "Fiesta")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 12, COL_MAKE_ID, 1, COL_MODEL, "Focus")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 13, COL_MAKE_ID, 1, COL_MODEL, "Ranger")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 14, COL_MAKE_ID, 2, COL_MODEL, "GLA")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 15, COL_MAKE_ID, 2, COL_MODEL, "A 220")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 16, COL_MAKE_ID, 2, COL_MODEL, "E 350"))); - - static protected final String RANDOM_TABLE_NAME = MODELS_STREAM_NAME + "_random"; - - static protected final List MODEL_RECORDS_RANDOM = MODEL_RECORDS.stream() - .map(r -> Jsons.jsonNode(ImmutableMap.of( - COL_ID + "_random", r.get(COL_ID).asInt() * 1000, - COL_MAKE_ID + "_random", r.get(COL_MAKE_ID), - COL_MODEL + "_random", r.get(COL_MODEL).asText() + "-random"))) - .toList(); - - protected T testdb; - - protected String createTableSqlFmt() { - return "CREATE TABLE %s.%s(%s);"; - } - - protected String createSchemaSqlFmt() { - return "CREATE SCHEMA %s;"; - } - - protected String modelsSchema() { - return "models_schema"; - } - - /** - * The schema of a random table which is used as a new table in snapshot test - */ - protected String randomSchema() { - return "models_schema_random"; - } - - protected AirbyteCatalog getCatalog() { - return new AirbyteCatalog().withStreams(List.of( - CatalogHelpers.createAirbyteStream( - MODELS_STREAM_NAME, - modelsSchema(), - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), - Field.of(COL_MODEL, JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(List.of(COL_ID))))); - } - - protected ConfiguredAirbyteCatalog getConfiguredCatalog() { - final var configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog(getCatalog()); - configuredCatalog.getStreams().forEach(s -> s.setSyncMode(SyncMode.INCREMENTAL)); - return configuredCatalog; - } - - protected abstract T createTestDatabase(); - - protected abstract S source(); - - protected abstract JsonNode config(); - - protected abstract CdcTargetPosition cdcLatestTargetPosition(); - - protected abstract CdcTargetPosition extractPosition(final JsonNode record); - - protected abstract void assertNullCdcMetaData(final JsonNode data); - - protected abstract void assertCdcMetaData(final JsonNode data, final boolean deletedAtNull); - - protected abstract void removeCDCColumns(final ObjectNode data); - - protected abstract void addCdcMetadataColumns(final AirbyteStream stream); - - protected abstract void addCdcDefaultCursorField(final AirbyteStream stream); - - protected abstract void assertExpectedStateMessages(final List stateMessages); - - // TODO: this assertion should be added into test cases in this class, we will need to implement - // corresponding iterator for other connectors before - // doing so. - protected void assertExpectedStateMessageCountMatches(final List stateMessages, long totalCount) { - // Do nothing. - } - - @BeforeEach - protected void setup() { - testdb = createTestDatabase(); - createTables(); - populateTables(); - } - - protected void createTables() { - // create and populate actual table - final var actualColumns = ImmutableMap.of( - COL_ID, "INTEGER", - COL_MAKE_ID, "INTEGER", - COL_MODEL, "VARCHAR(200)"); - testdb - .with(createSchemaSqlFmt(), modelsSchema()) - .with(createTableSqlFmt(), modelsSchema(), MODELS_STREAM_NAME, columnClause(actualColumns, Optional.of(COL_ID))); - - // Create random table. - // This table is not part of Airbyte sync. It is being created just to make sure the schemas not - // being synced by Airbyte are not causing issues with our debezium logic. - final var randomColumns = ImmutableMap.of( - COL_ID + "_random", "INTEGER", - COL_MAKE_ID + "_random", "INTEGER", - COL_MODEL + "_random", "VARCHAR(200)"); - if (!randomSchema().equals(modelsSchema())) { - testdb.with(createSchemaSqlFmt(), randomSchema()); - } - testdb.with(createTableSqlFmt(), randomSchema(), RANDOM_TABLE_NAME, columnClause(randomColumns, Optional.of(COL_ID + "_random"))); - } - - protected void populateTables() { - for (final JsonNode recordJson : MODEL_RECORDS) { - writeModelRecord(recordJson); - } - - for (final JsonNode recordJson : MODEL_RECORDS_RANDOM) { - writeRecords(recordJson, randomSchema(), RANDOM_TABLE_NAME, - COL_ID + "_random", COL_MAKE_ID + "_random", COL_MODEL + "_random"); - } - } - - @AfterEach - protected void tearDown() { - try { - testdb.close(); - } catch (Throwable e) { - LOGGER.error("exception during teardown", e); - } - } - - protected String columnClause(final Map columnsWithDataType, final Optional primaryKey) { - final StringBuilder columnClause = new StringBuilder(); - int i = 0; - for (final Map.Entry column : columnsWithDataType.entrySet()) { - columnClause.append(column.getKey()); - columnClause.append(" "); - columnClause.append(column.getValue()); - if (i < (columnsWithDataType.size() - 1)) { - columnClause.append(","); - columnClause.append(" "); - } - i++; - } - primaryKey.ifPresent(s -> columnClause.append(", PRIMARY KEY (").append(s).append(")")); - - return columnClause.toString(); - } - - protected void writeModelRecord(final JsonNode recordJson) { - writeRecords(recordJson, modelsSchema(), MODELS_STREAM_NAME, COL_ID, COL_MAKE_ID, COL_MODEL); - } - - protected void writeRecords( - final JsonNode recordJson, - final String dbName, - final String streamName, - final String idCol, - final String makeIdCol, - final String modelCol) { - testdb.with("INSERT INTO %s.%s (%s, %s, %s) VALUES (%s, %s, '%s');", dbName, streamName, - idCol, makeIdCol, modelCol, - recordJson.get(idCol).asInt(), recordJson.get(makeIdCol).asInt(), - recordJson.get(modelCol).asText()); - } - - protected void deleteMessageOnIdCol(final String streamName, final String idCol, final int idValue) { - testdb.with("DELETE FROM %s.%s WHERE %s = %s", modelsSchema(), streamName, idCol, idValue); - } - - protected void deleteCommand(final String streamName) { - testdb.with("DELETE FROM %s.%s", modelsSchema(), streamName); - } - - protected void updateCommand(final String streamName, final String modelCol, final String modelVal, final String idCol, final int idValue) { - testdb.with("UPDATE %s.%s SET %s = '%s' WHERE %s = %s", modelsSchema(), streamName, - modelCol, modelVal, COL_ID, 11); - } - - static protected Set removeDuplicates(final Set messages) { - final Set existingDataRecordsWithoutUpdated = new HashSet<>(); - final Set output = new HashSet<>(); - - for (final AirbyteRecordMessage message : messages) { - final ObjectNode node = message.getData().deepCopy(); - node.remove("_ab_cdc_updated_at"); - - if (existingDataRecordsWithoutUpdated.contains(node)) { - LOGGER.info("Removing duplicate node: " + node); - } else { - output.add(message); - existingDataRecordsWithoutUpdated.add(node); - } - } - - return output; - } - - protected Set extractRecordMessages(final List messages) { - final Map> recordsPerStream = extractRecordMessagesStreamWise(messages); - final Set consolidatedRecords = new HashSet<>(); - recordsPerStream.values().forEach(consolidatedRecords::addAll); - return consolidatedRecords; - } - - protected Map> extractRecordMessagesStreamWise(final List messages) { - final Map> recordsPerStream = new HashMap<>(); - for (final AirbyteMessage message : messages) { - if (message.getType() == Type.RECORD) { - AirbyteRecordMessage recordMessage = message.getRecord(); - recordsPerStream.computeIfAbsent(recordMessage.getStream(), (c) -> new ArrayList<>()).add(recordMessage); - } - } - - final Map> recordsPerStreamWithNoDuplicates = new HashMap<>(); - for (final Map.Entry> element : recordsPerStream.entrySet()) { - final String streamName = element.getKey(); - final List records = element.getValue(); - final Set recordMessageSet = new HashSet<>(records); - assertEquals(records.size(), recordMessageSet.size(), - "Expected no duplicates in airbyte record message output for a single sync."); - recordsPerStreamWithNoDuplicates.put(streamName, recordMessageSet); - } - - return recordsPerStreamWithNoDuplicates; - } - - protected List extractStateMessages(final List messages) { - return messages.stream().filter(r -> r.getType() == Type.STATE).map(AirbyteMessage::getState) - .collect(Collectors.toList()); - } - - protected void assertExpectedRecords(final Set expectedRecords, final Set actualRecords) { - // assume all streams are cdc. - assertExpectedRecords(expectedRecords, actualRecords, actualRecords.stream().map(AirbyteRecordMessage::getStream).collect(Collectors.toSet())); - } - - private void assertExpectedRecords(final Set expectedRecords, - final Set actualRecords, - final Set cdcStreams) { - assertExpectedRecords(expectedRecords, actualRecords, cdcStreams, STREAM_NAMES, modelsSchema()); - } - - protected void assertExpectedRecords(final Set expectedRecords, - final Set actualRecords, - final Set cdcStreams, - final Set streamNames, - final String namespace) { - final Set actualData = actualRecords - .stream() - .map(recordMessage -> { - assertTrue(streamNames.contains(recordMessage.getStream())); - assertNotNull(recordMessage.getEmittedAt()); - - assertEquals(namespace, recordMessage.getNamespace()); - - final JsonNode data = recordMessage.getData(); - - if (cdcStreams.contains(recordMessage.getStream())) { - assertCdcMetaData(data, true); - } else { - assertNullCdcMetaData(data); - } - - removeCDCColumns((ObjectNode) data); - - return data; - }) - .collect(Collectors.toSet()); - - assertEquals(expectedRecords, actualData); - } - - @Test - // On the first sync, produce returns records that exist in the database. - void testExistingData() throws Exception { - final CdcTargetPosition targetPosition = cdcLatestTargetPosition(); - final AutoCloseableIterator read = source().read(config(), getConfiguredCatalog(), null); - final List actualRecords = AutoCloseableIterators.toListAndClose(read); - - final Set recordMessages = extractRecordMessages(actualRecords); - final List stateMessages = extractStateMessages(actualRecords); - - assertNotNull(targetPosition); - recordMessages.forEach(record -> { - compareTargetPositionFromTheRecordsWithTargetPostionGeneratedBeforeSync(targetPosition, record); - }); - - assertExpectedRecords(new HashSet<>(MODEL_RECORDS), recordMessages); - assertExpectedStateMessages(stateMessages); - assertExpectedStateMessageCountMatches(stateMessages, MODEL_RECORDS.size()); - } - - protected void compareTargetPositionFromTheRecordsWithTargetPostionGeneratedBeforeSync(final CdcTargetPosition targetPosition, - final AirbyteRecordMessage record) { - assertEquals(extractPosition(record.getData()), targetPosition); - } - - @Test - // When a record is deleted, produces a deletion record. - public void testDelete() throws Exception { - final AutoCloseableIterator read1 = source() - .read(config(), getConfiguredCatalog(), null); - final List actualRecords1 = AutoCloseableIterators.toListAndClose(read1); - final List stateMessages1 = extractStateMessages(actualRecords1); - assertExpectedStateMessages(stateMessages1); - - deleteMessageOnIdCol(MODELS_STREAM_NAME, COL_ID, 11); - waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1); - - final JsonNode state = Jsons.jsonNode(Collections.singletonList(stateMessages1.get(stateMessages1.size() - 1))); - final AutoCloseableIterator read2 = source() - .read(config(), getConfiguredCatalog(), state); - final List actualRecords2 = AutoCloseableIterators.toListAndClose(read2); - final List recordMessages2 = new ArrayList<>( - extractRecordMessages(actualRecords2)); - final List stateMessages2 = extractStateMessages(actualRecords2); - assertExpectedStateMessagesFromIncrementalSync(stateMessages2); - assertExpectedStateMessageCountMatches(stateMessages2, 1); - assertEquals(1, recordMessages2.size()); - assertEquals(11, recordMessages2.get(0).getData().get(COL_ID).asInt()); - assertCdcMetaData(recordMessages2.get(0).getData(), false); - } - - protected void assertExpectedStateMessagesFromIncrementalSync(final List stateMessages) { - assertExpectedStateMessages(stateMessages); - } - - @Test - // When a record is updated, produces an update record. - public void testUpdate() throws Exception { - final String updatedModel = "Explorer"; - final AutoCloseableIterator read1 = source() - .read(config(), getConfiguredCatalog(), null); - final List actualRecords1 = AutoCloseableIterators.toListAndClose(read1); - final List stateMessages1 = extractStateMessages(actualRecords1); - assertExpectedStateMessages(stateMessages1); - - updateCommand(MODELS_STREAM_NAME, COL_MODEL, updatedModel, COL_ID, 11); - waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1); - - final JsonNode state = Jsons.jsonNode(Collections.singletonList(stateMessages1.get(stateMessages1.size() - 1))); - final AutoCloseableIterator read2 = source() - .read(config(), getConfiguredCatalog(), state); - final List actualRecords2 = AutoCloseableIterators.toListAndClose(read2); - final List recordMessages2 = new ArrayList<>( - extractRecordMessages(actualRecords2)); - final List stateMessages2 = extractStateMessages(actualRecords2); - assertExpectedStateMessagesFromIncrementalSync(stateMessages2); - assertEquals(1, recordMessages2.size()); - assertEquals(11, recordMessages2.get(0).getData().get(COL_ID).asInt()); - assertEquals(updatedModel, recordMessages2.get(0).getData().get(COL_MODEL).asText()); - assertCdcMetaData(recordMessages2.get(0).getData(), true); - assertExpectedStateMessageCountMatches(stateMessages2, 1); - } - - @SuppressWarnings({"BusyWait", "CodeBlock2Expr"}) - @Test - // Verify that when data is inserted into the database while a sync is happening and after the first - // sync, it all gets replicated. - protected void testRecordsProducedDuringAndAfterSync() throws Exception { - int recordsCreatedBeforeTestCount = MODEL_RECORDS.size(); - int expectedRecords = recordsCreatedBeforeTestCount; - int expectedRecordsInCdc = 0; - final int recordsToCreate = 20; - // first batch of records. 20 created here and 6 created in setup method. - for (int recordsCreated = 0; recordsCreated < recordsToCreate; recordsCreated++) { - final JsonNode record = - Jsons.jsonNode(ImmutableMap - .of(COL_ID, 100 + recordsCreated, COL_MAKE_ID, 1, COL_MODEL, - "F-" + recordsCreated)); - writeModelRecord(record); - expectedRecords++; - expectedRecordsInCdc++; - } - waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, expectedRecordsInCdc); - - final AutoCloseableIterator firstBatchIterator = source() - .read(config(), getConfiguredCatalog(), null); - final List dataFromFirstBatch = AutoCloseableIterators - .toListAndClose(firstBatchIterator); - final List stateAfterFirstBatch = extractStateMessages(dataFromFirstBatch); - assertExpectedStateMessagesForRecordsProducedDuringAndAfterSync(stateAfterFirstBatch); - final Set recordsFromFirstBatch = extractRecordMessages( - dataFromFirstBatch); - assertEquals(expectedRecords, recordsFromFirstBatch.size()); - - // second batch of records again 20 being created - for (int recordsCreated = 0; recordsCreated < recordsToCreate; recordsCreated++) { - final JsonNode record = - Jsons.jsonNode(ImmutableMap - .of(COL_ID, 200 + recordsCreated, COL_MAKE_ID, 1, COL_MODEL, - "F-" + recordsCreated)); - writeModelRecord(record); - expectedRecords++; - expectedRecordsInCdc++; - } - waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, expectedRecordsInCdc); - - final JsonNode state = Jsons.jsonNode(Collections.singletonList(stateAfterFirstBatch.get(stateAfterFirstBatch.size() - 1))); - final AutoCloseableIterator secondBatchIterator = source() - .read(config(), getConfiguredCatalog(), state); - final List dataFromSecondBatch = AutoCloseableIterators - .toListAndClose(secondBatchIterator); - - final List stateAfterSecondBatch = extractStateMessages(dataFromSecondBatch); - assertExpectedStateMessagesFromIncrementalSync(stateAfterSecondBatch); - - final Set recordsFromSecondBatch = extractRecordMessages( - dataFromSecondBatch); - assertEquals(recordsToCreate, recordsFromSecondBatch.size(), - "Expected 20 records to be replicated in the second sync."); - - // sometimes there can be more than one of these at the end of the snapshot and just before the - // first incremental. - final Set recordsFromFirstBatchWithoutDuplicates = removeDuplicates( - recordsFromFirstBatch); - final Set recordsFromSecondBatchWithoutDuplicates = removeDuplicates( - recordsFromSecondBatch); - - assertTrue(recordsCreatedBeforeTestCount < recordsFromFirstBatchWithoutDuplicates.size(), - "Expected first sync to include records created while the test was running."); - assertEquals(expectedRecords, - recordsFromFirstBatchWithoutDuplicates.size() + recordsFromSecondBatchWithoutDuplicates - .size()); - } - - protected void assertExpectedStateMessagesForRecordsProducedDuringAndAfterSync(final List stateAfterFirstBatch) { - assertExpectedStateMessages(stateAfterFirstBatch); - } - - @Test - // When both incremental CDC and full refresh are configured for different streams in a sync, the - // data is replicated as expected. - public void testCdcAndFullRefreshInSameSync() throws Exception { - final ConfiguredAirbyteCatalog configuredCatalog = Jsons.clone(getConfiguredCatalog()); - - final List MODEL_RECORDS_2 = ImmutableList.of( - Jsons.jsonNode(ImmutableMap.of(COL_ID, 110, COL_MAKE_ID, 1, COL_MODEL, "Fiesta-2")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 120, COL_MAKE_ID, 1, COL_MODEL, "Focus-2")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 130, COL_MAKE_ID, 1, COL_MODEL, "Ranger-2")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 140, COL_MAKE_ID, 2, COL_MODEL, "GLA-2")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 150, COL_MAKE_ID, 2, COL_MODEL, "A 220-2")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 160, COL_MAKE_ID, 2, COL_MODEL, "E 350-2"))); - - final var columns = ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)"); - testdb.with(createTableSqlFmt(), modelsSchema(), MODELS_STREAM_NAME + "_2", columnClause(columns, Optional.of(COL_ID))); - - for (final JsonNode recordJson : MODEL_RECORDS_2) { - writeRecords(recordJson, modelsSchema(), MODELS_STREAM_NAME + "_2", COL_ID, COL_MAKE_ID, COL_MODEL); - } - - final ConfiguredAirbyteStream airbyteStream = new ConfiguredAirbyteStream() - .withStream(CatalogHelpers.createAirbyteStream( - MODELS_STREAM_NAME + "_2", - modelsSchema(), - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), - Field.of(COL_MODEL, JsonSchemaType.STRING)) - .withSupportedSyncModes( - Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(List.of(COL_ID)))); - airbyteStream.setSyncMode(SyncMode.FULL_REFRESH); - - final List streams = configuredCatalog.getStreams(); - streams.add(airbyteStream); - configuredCatalog.withStreams(streams); - - final AutoCloseableIterator read1 = source() - .read(config(), configuredCatalog, null); - final List actualRecords1 = AutoCloseableIterators.toListAndClose(read1); - - final Set recordMessages1 = extractRecordMessages(actualRecords1); - final List stateMessages1 = extractStateMessages(actualRecords1); - final HashSet names = new HashSet<>(STREAM_NAMES); - names.add(MODELS_STREAM_NAME + "_2"); - assertExpectedStateMessages(stateMessages1); - // Full refresh does not get any state messages. - assertExpectedStateMessageCountMatches(stateMessages1, MODEL_RECORDS_2.size()); - assertExpectedRecords(Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream()) - .collect(Collectors.toSet()), - recordMessages1, - Collections.singleton(MODELS_STREAM_NAME), - names, - modelsSchema()); - - final JsonNode puntoRecord = Jsons - .jsonNode(ImmutableMap.of(COL_ID, 100, COL_MAKE_ID, 3, COL_MODEL, "Punto")); - writeModelRecord(puntoRecord); - waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1); - - final JsonNode state = Jsons.jsonNode(Collections.singletonList(stateMessages1.get(stateMessages1.size() - 1))); - final AutoCloseableIterator read2 = source() - .read(config(), configuredCatalog, state); - final List actualRecords2 = AutoCloseableIterators.toListAndClose(read2); - - final Set recordMessages2 = extractRecordMessages(actualRecords2); - final List stateMessages2 = extractStateMessages(actualRecords2); - assertExpectedStateMessagesFromIncrementalSync(stateMessages2); - assertExpectedStateMessageCountMatches(stateMessages2, 1); - assertExpectedRecords( - Streams.concat(MODEL_RECORDS_2.stream(), Stream.of(puntoRecord)) - .collect(Collectors.toSet()), - recordMessages2, - Collections.singleton(MODELS_STREAM_NAME), - names, - modelsSchema()); - } - - @Test - // When no records exist, no records are returned. - public void testNoData() throws Exception { - - deleteCommand(MODELS_STREAM_NAME); - waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, MODEL_RECORDS.size()); - final AutoCloseableIterator read = source().read(config(), getConfiguredCatalog(), null); - final List actualRecords = AutoCloseableIterators.toListAndClose(read); - - final Set recordMessages = extractRecordMessages(actualRecords); - final List stateMessages = extractStateMessages(actualRecords); - assertExpectedRecords(Collections.emptySet(), recordMessages); - assertExpectedStateMessagesForNoData(stateMessages); - assertExpectedStateMessageCountMatches(stateMessages, 0); - } - - protected void assertExpectedStateMessagesForNoData(final List stateMessages) { - assertExpectedStateMessages(stateMessages); - } - - @Test - // When no changes have been made to the database since the previous sync, no records are returned. - public void testNoDataOnSecondSync() throws Exception { - final AutoCloseableIterator read1 = source() - .read(config(), getConfiguredCatalog(), null); - final List actualRecords1 = AutoCloseableIterators.toListAndClose(read1); - final List stateMessagesFromFirstSync = extractStateMessages(actualRecords1); - final JsonNode state = Jsons.jsonNode(Collections.singletonList(stateMessagesFromFirstSync.get(stateMessagesFromFirstSync.size() - 1))); - - final AutoCloseableIterator read2 = source() - .read(config(), getConfiguredCatalog(), state); - final List actualRecords2 = AutoCloseableIterators.toListAndClose(read2); - - final Set recordMessages2 = extractRecordMessages(actualRecords2); - final List stateMessages2 = extractStateMessages(actualRecords2); - - assertExpectedRecords(Collections.emptySet(), recordMessages2); - assertExpectedStateMessagesFromIncrementalSync(stateMessages2); - assertExpectedStateMessageCountMatches(stateMessages2, 0); - } - - @Test - public void testCheck() throws Exception { - final AirbyteConnectionStatus status = source().check(config()); - assertEquals(status.getStatus(), AirbyteConnectionStatus.Status.SUCCEEDED); - } - - @Test - public void testDiscover() throws Exception { - final AirbyteCatalog expectedCatalog = expectedCatalogForDiscover(); - final AirbyteCatalog actualCatalog = source().discover(config()); - - assertEquals( - expectedCatalog.getStreams().stream().sorted(Comparator.comparing(AirbyteStream::getName)) - .collect(Collectors.toList()), - actualCatalog.getStreams().stream().sorted(Comparator.comparing(AirbyteStream::getName)) - .collect(Collectors.toList())); - } - - @Test - public void newTableSnapshotTest() throws Exception { - final AutoCloseableIterator firstBatchIterator = source() - .read(config(), getConfiguredCatalog(), null); - final List dataFromFirstBatch = AutoCloseableIterators - .toListAndClose(firstBatchIterator); - final Set recordsFromFirstBatch = extractRecordMessages( - dataFromFirstBatch); - final List stateAfterFirstBatch = extractStateMessages(dataFromFirstBatch); - assertExpectedStateMessages(stateAfterFirstBatch); - assertExpectedStateMessageCountMatches(stateAfterFirstBatch, MODEL_RECORDS.size()); - - final AirbyteStateMessage stateMessageEmittedAfterFirstSyncCompletion = stateAfterFirstBatch.get(stateAfterFirstBatch.size() - 1); - assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, stateMessageEmittedAfterFirstSyncCompletion.getType()); - assertNotNull(stateMessageEmittedAfterFirstSyncCompletion.getGlobal().getSharedState()); - final Set streamsInStateAfterFirstSyncCompletion = stateMessageEmittedAfterFirstSyncCompletion.getGlobal().getStreamStates() - .stream() - .map(AirbyteStreamState::getStreamDescriptor) - .collect(Collectors.toSet()); - assertEquals(1, streamsInStateAfterFirstSyncCompletion.size()); - assertTrue(streamsInStateAfterFirstSyncCompletion.contains(new StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()))); - assertNotNull(stateMessageEmittedAfterFirstSyncCompletion.getData()); - - assertEquals((MODEL_RECORDS.size()), recordsFromFirstBatch.size()); - assertExpectedRecords(new HashSet<>(MODEL_RECORDS), recordsFromFirstBatch); - - final JsonNode state = stateAfterFirstBatch.get(stateAfterFirstBatch.size() - 1).getData(); - - final ConfiguredAirbyteCatalog newTables = CatalogHelpers - .toDefaultConfiguredCatalog(new AirbyteCatalog().withStreams(List.of( - CatalogHelpers.createAirbyteStream( - RANDOM_TABLE_NAME, - randomSchema(), - Field.of(COL_ID + "_random", JsonSchemaType.NUMBER), - Field.of(COL_MAKE_ID + "_random", JsonSchemaType.NUMBER), - Field.of(COL_MODEL + "_random", JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(List.of(COL_ID + "_random")))))); - - newTables.getStreams().forEach(s -> s.setSyncMode(SyncMode.INCREMENTAL)); - final List combinedStreams = new ArrayList<>(); - combinedStreams.addAll(getConfiguredCatalog().getStreams()); - combinedStreams.addAll(newTables.getStreams()); - - final ConfiguredAirbyteCatalog updatedCatalog = new ConfiguredAirbyteCatalog().withStreams(combinedStreams); - - /* - * Write 20 records to the existing table - */ - final Set recordsWritten = new HashSet<>(); - for (int recordsCreated = 0; recordsCreated < 20; recordsCreated++) { - final JsonNode record = - Jsons.jsonNode(ImmutableMap - .of(COL_ID, 100 + recordsCreated, COL_MAKE_ID, 1, COL_MODEL, - "F-" + recordsCreated)); - recordsWritten.add(record); - writeModelRecord(record); - } - - final AutoCloseableIterator secondBatchIterator = source() - .read(config(), updatedCatalog, state); - final List dataFromSecondBatch = AutoCloseableIterators - .toListAndClose(secondBatchIterator); - - final List stateAfterSecondBatch = extractStateMessages(dataFromSecondBatch); - assertStateMessagesForNewTableSnapshotTest(stateAfterSecondBatch, stateMessageEmittedAfterFirstSyncCompletion); - - final Map> recordsStreamWise = extractRecordMessagesStreamWise(dataFromSecondBatch); - assertTrue(recordsStreamWise.containsKey(MODELS_STREAM_NAME)); - assertTrue(recordsStreamWise.containsKey(RANDOM_TABLE_NAME)); - - final Set recordsForModelsStreamFromSecondBatch = recordsStreamWise.get(MODELS_STREAM_NAME); - final Set recordsForModelsRandomStreamFromSecondBatch = recordsStreamWise.get(RANDOM_TABLE_NAME); - - assertEquals((MODEL_RECORDS_RANDOM.size()), recordsForModelsRandomStreamFromSecondBatch.size()); - assertEquals(20, recordsForModelsStreamFromSecondBatch.size()); - assertExpectedRecords(new HashSet<>(MODEL_RECORDS_RANDOM), recordsForModelsRandomStreamFromSecondBatch, - recordsForModelsRandomStreamFromSecondBatch.stream().map(AirbyteRecordMessage::getStream).collect( - Collectors.toSet()), - Sets - .newHashSet(RANDOM_TABLE_NAME), - randomSchema()); - assertExpectedRecords(recordsWritten, recordsForModelsStreamFromSecondBatch); - - /* - * Write 20 records to both the tables - */ - final Set recordsWrittenInRandomTable = new HashSet<>(); - recordsWritten.clear(); - for (int recordsCreated = 30; recordsCreated < 50; recordsCreated++) { - final JsonNode record = - Jsons.jsonNode(ImmutableMap - .of(COL_ID, 100 + recordsCreated, COL_MAKE_ID, 1, COL_MODEL, - "F-" + recordsCreated)); - writeModelRecord(record); - recordsWritten.add(record); - - final JsonNode record2 = Jsons - .jsonNode(ImmutableMap - .of(COL_ID + "_random", 11000 + recordsCreated, COL_MAKE_ID + "_random", 1 + recordsCreated, COL_MODEL + "_random", - "Fiesta-random" + recordsCreated)); - writeRecords(record2, randomSchema(), RANDOM_TABLE_NAME, - COL_ID + "_random", COL_MAKE_ID + "_random", COL_MODEL + "_random"); - recordsWrittenInRandomTable.add(record2); - } - - final JsonNode state2 = stateAfterSecondBatch.get(stateAfterSecondBatch.size() - 1).getData(); - final AutoCloseableIterator thirdBatchIterator = source() - .read(config(), updatedCatalog, state2); - final List dataFromThirdBatch = AutoCloseableIterators - .toListAndClose(thirdBatchIterator); - - final List stateAfterThirdBatch = extractStateMessages(dataFromThirdBatch); - assertTrue(stateAfterThirdBatch.size() >= 1); - - final AirbyteStateMessage stateMessageEmittedAfterThirdSyncCompletion = stateAfterThirdBatch.get(stateAfterThirdBatch.size() - 1); - assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, stateMessageEmittedAfterThirdSyncCompletion.getType()); - assertNotEquals(stateMessageEmittedAfterThirdSyncCompletion.getGlobal().getSharedState(), - stateAfterSecondBatch.get(stateAfterSecondBatch.size() - 1).getGlobal().getSharedState()); - final Set streamsInSyncCompletionStateAfterThirdSync = stateMessageEmittedAfterThirdSyncCompletion.getGlobal().getStreamStates() - .stream() - .map(AirbyteStreamState::getStreamDescriptor) - .collect(Collectors.toSet()); - assertTrue( - streamsInSyncCompletionStateAfterThirdSync.contains( - new StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()))); - assertTrue( - streamsInSyncCompletionStateAfterThirdSync.contains(new StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()))); - assertNotNull(stateMessageEmittedAfterThirdSyncCompletion.getData()); - - final Map> recordsStreamWiseFromThirdBatch = extractRecordMessagesStreamWise(dataFromThirdBatch); - assertTrue(recordsStreamWiseFromThirdBatch.containsKey(MODELS_STREAM_NAME)); - assertTrue(recordsStreamWiseFromThirdBatch.containsKey(RANDOM_TABLE_NAME)); - - final Set recordsForModelsStreamFromThirdBatch = recordsStreamWiseFromThirdBatch.get(MODELS_STREAM_NAME); - final Set recordsForModelsRandomStreamFromThirdBatch = recordsStreamWiseFromThirdBatch.get(RANDOM_TABLE_NAME); - - assertEquals(20, recordsForModelsStreamFromThirdBatch.size()); - assertEquals(20, recordsForModelsRandomStreamFromThirdBatch.size()); - assertExpectedRecords(recordsWritten, recordsForModelsStreamFromThirdBatch); - assertExpectedRecords(recordsWrittenInRandomTable, recordsForModelsRandomStreamFromThirdBatch, - recordsForModelsRandomStreamFromThirdBatch.stream().map(AirbyteRecordMessage::getStream).collect( - Collectors.toSet()), - Sets - .newHashSet(RANDOM_TABLE_NAME), - randomSchema()); - } - - protected void assertStateMessagesForNewTableSnapshotTest(final List stateMessages, - final AirbyteStateMessage stateMessageEmittedAfterFirstSyncCompletion) { - assertEquals(2, stateMessages.size()); - final AirbyteStateMessage stateMessageEmittedAfterSnapshotCompletionInSecondSync = stateMessages.get(0); - assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, stateMessageEmittedAfterSnapshotCompletionInSecondSync.getType()); - assertEquals(stateMessageEmittedAfterFirstSyncCompletion.getGlobal().getSharedState(), - stateMessageEmittedAfterSnapshotCompletionInSecondSync.getGlobal().getSharedState()); - final Set streamsInSnapshotState = stateMessageEmittedAfterSnapshotCompletionInSecondSync.getGlobal().getStreamStates() - .stream() - .map(AirbyteStreamState::getStreamDescriptor) - .collect(Collectors.toSet()); - assertEquals(2, streamsInSnapshotState.size()); - assertTrue( - streamsInSnapshotState.contains(new StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()))); - assertTrue(streamsInSnapshotState.contains(new StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()))); - assertNotNull(stateMessageEmittedAfterSnapshotCompletionInSecondSync.getData()); - - final AirbyteStateMessage stateMessageEmittedAfterSecondSyncCompletion = stateMessages.get(1); - assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, stateMessageEmittedAfterSecondSyncCompletion.getType()); - assertNotEquals(stateMessageEmittedAfterFirstSyncCompletion.getGlobal().getSharedState(), - stateMessageEmittedAfterSecondSyncCompletion.getGlobal().getSharedState()); - final Set streamsInSyncCompletionState = stateMessageEmittedAfterSecondSyncCompletion.getGlobal().getStreamStates() - .stream() - .map(AirbyteStreamState::getStreamDescriptor) - .collect(Collectors.toSet()); - assertEquals(2, streamsInSnapshotState.size()); - assertTrue( - streamsInSyncCompletionState.contains( - new StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()))); - assertTrue(streamsInSyncCompletionState.contains(new StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()))); - assertNotNull(stateMessageEmittedAfterSecondSyncCompletion.getData()); - } - - protected AirbyteCatalog expectedCatalogForDiscover() { - final AirbyteCatalog expectedCatalog = Jsons.clone(getCatalog()); - - final var columns = ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)"); - testdb.with(createTableSqlFmt(), modelsSchema(), MODELS_STREAM_NAME + "_2", columnClause(columns, Optional.empty())); - - final List streams = expectedCatalog.getStreams(); - // stream with PK - streams.get(0).setSourceDefinedCursor(true); - addCdcMetadataColumns(streams.get(0)); - addCdcDefaultCursorField(streams.get(0)); - - final AirbyteStream streamWithoutPK = CatalogHelpers.createAirbyteStream( - MODELS_STREAM_NAME + "_2", - modelsSchema(), - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), - Field.of(COL_MODEL, JsonSchemaType.STRING)); - streamWithoutPK.setSourceDefinedPrimaryKey(Collections.emptyList()); - streamWithoutPK.setSupportedSyncModes(List.of(SyncMode.FULL_REFRESH)); - addCdcDefaultCursorField(streamWithoutPK); - addCdcMetadataColumns(streamWithoutPK); - - final AirbyteStream randomStream = CatalogHelpers.createAirbyteStream( - RANDOM_TABLE_NAME, - randomSchema(), - Field.of(COL_ID + "_random", JsonSchemaType.INTEGER), - Field.of(COL_MAKE_ID + "_random", JsonSchemaType.INTEGER), - Field.of(COL_MODEL + "_random", JsonSchemaType.STRING)) - .withSourceDefinedCursor(true) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(List.of(COL_ID + "_random"))); - - addCdcDefaultCursorField(randomStream); - addCdcMetadataColumns(randomStream); - - streams.add(streamWithoutPK); - streams.add(randomStream); - expectedCatalog.withStreams(streams); - return expectedCatalog; - } - - protected void waitForCdcRecords(String schemaName, String tableName, int recordCount) - throws Exception {} - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/debug/DebugUtil.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/debug/DebugUtil.java deleted file mode 100644 index 836f6cf50347..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/debug/DebugUtil.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.debug; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ObjectNode; -import io.airbyte.cdk.integrations.base.Source; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.resources.MoreResources; -import io.airbyte.commons.util.AutoCloseableIterator; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import java.util.Collections; - -/** - * Utility class defined to debug a source. Copy over any relevant configurations, catalogs & state - * in the resources/debug_resources directory. - */ -public class DebugUtil { - - @SuppressWarnings({"unchecked", "deprecation", "resource"}) - public static void debug(final Source debugSource) throws Exception { - final JsonNode debugConfig = DebugUtil.getConfig(); - final ConfiguredAirbyteCatalog configuredAirbyteCatalog = DebugUtil.getCatalog(); - JsonNode state; - try { - state = DebugUtil.getState(); - } catch (final Exception e) { - state = null; - } - - debugSource.check(debugConfig); - debugSource.discover(debugConfig); - - final AutoCloseableIterator messageIterator = debugSource.read(debugConfig, configuredAirbyteCatalog, state); - messageIterator.forEachRemaining(message -> {}); - } - - private static JsonNode getConfig() throws Exception { - final JsonNode originalConfig = new ObjectMapper().readTree(MoreResources.readResource("debug_resources/config.json")); - final JsonNode debugConfig = ((ObjectNode) originalConfig.deepCopy()).put("debug_mode", true); - return debugConfig; - } - - private static ConfiguredAirbyteCatalog getCatalog() throws Exception { - final String catalog = MoreResources.readResource("debug_resources/configured_catalog.json"); - return Jsons.deserialize(catalog, ConfiguredAirbyteCatalog.class); - } - - private static JsonNode getState() throws Exception { - final AirbyteStateMessage message = Jsons.deserialize(MoreResources.readResource("debug_resources/state.json"), AirbyteStateMessage.class); - return Jsons.jsonNode(Collections.singletonList(message)); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.java deleted file mode 100644 index aac25c5d87b0..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.java +++ /dev/null @@ -1,1108 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc.test; - -import static io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils.enquoteIdentifier; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.doCallRealMethod; -import static org.mockito.Mockito.spy; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import io.airbyte.cdk.db.factory.DatabaseDriver; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.integrations.base.Source; -import io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbState; -import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState; -import io.airbyte.cdk.testutils.TestDatabase; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.resources.MoreResources; -import io.airbyte.commons.util.MoreIterators; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteCatalog; -import io.airbyte.protocol.models.v0.AirbyteConnectionStatus; -import io.airbyte.protocol.models.v0.AirbyteConnectionStatus.Status; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType; -import io.airbyte.protocol.models.v0.AirbyteStateStats; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.AirbyteStreamState; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.ConnectorSpecification; -import io.airbyte.protocol.models.v0.DestinationSyncMode; -import io.airbyte.protocol.models.v0.StreamDescriptor; -import io.airbyte.protocol.models.v0.SyncMode; -import java.math.BigDecimal; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; -import org.hamcrest.Matchers; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -/** - * Tests that should be run on all Sources that extend the AbstractJdbcSource. - */ -@SuppressFBWarnings( - value = {"MS_SHOULD_BE_FINAL"}, - justification = "The static variables are updated in subclasses for convenience, and cannot be final.") -abstract public class JdbcSourceAcceptanceTest> { - - static protected String SCHEMA_NAME = "jdbc_integration_test1"; - static protected String SCHEMA_NAME2 = "jdbc_integration_test2"; - static protected Set TEST_SCHEMAS = Set.of(SCHEMA_NAME, SCHEMA_NAME2); - - static protected String TABLE_NAME = "id_and_name"; - static protected String TABLE_NAME_WITH_SPACES = "id and name"; - static protected String TABLE_NAME_WITHOUT_PK = "id_and_name_without_pk"; - static protected String TABLE_NAME_COMPOSITE_PK = "full_name_composite_pk"; - static protected String TABLE_NAME_WITHOUT_CURSOR_TYPE = "table_without_cursor_type"; - static protected String TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE = "table_with_null_cursor_type"; - // this table is used in testing incremental sync with concurrent insertions - static protected String TABLE_NAME_AND_TIMESTAMP = "name_and_timestamp"; - - static protected String COL_ID = "id"; - static protected String COL_NAME = "name"; - static protected String COL_UPDATED_AT = "updated_at"; - static protected String COL_FIRST_NAME = "first_name"; - static protected String COL_LAST_NAME = "last_name"; - static protected String COL_LAST_NAME_WITH_SPACE = "last name"; - static protected String COL_CURSOR = "cursor_field"; - static protected String COL_TIMESTAMP = "timestamp"; - static protected String COL_TIMESTAMP_TYPE = "TIMESTAMP"; - static protected Number ID_VALUE_1 = 1; - static protected Number ID_VALUE_2 = 2; - static protected Number ID_VALUE_3 = 3; - static protected Number ID_VALUE_4 = 4; - static protected Number ID_VALUE_5 = 5; - - static protected String DROP_SCHEMA_QUERY = "DROP SCHEMA IF EXISTS %s CASCADE"; - static protected String COLUMN_CLAUSE_WITH_PK = "id INTEGER, name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL"; - static protected String COLUMN_CLAUSE_WITHOUT_PK = "id INTEGER, name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL"; - static protected String COLUMN_CLAUSE_WITH_COMPOSITE_PK = - "first_name VARCHAR(200) NOT NULL, last_name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL"; - - static protected String CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY = "CREATE TABLE %s (%s bit NOT NULL);"; - static protected String INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY = "INSERT INTO %s VALUES(0);"; - static protected String CREATE_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY = "CREATE TABLE %s (%s VARCHAR(20));"; - static protected String INSERT_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY = "INSERT INTO %s VALUES('Hello world :)');"; - static protected String INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY = "INSERT INTO %s (name, timestamp) VALUES ('%s', '%s')"; - - protected T testdb; - - protected String streamName() { - return TABLE_NAME; - } - - /** - * A valid configuration to connect to a test database. - * - * @return config - */ - abstract protected JsonNode config(); - - /** - * An instance of the source that should be tests. - * - * @return abstract jdbc source - */ - abstract protected S source(); - - /** - * Creates a TestDatabase instance to be used in {@link #setup()}. - * - * @return TestDatabase instance to use for test case. - */ - abstract protected T createTestDatabase(); - - /** - * These tests write records without specifying a namespace (schema name). They will be written into - * whatever the default schema is for the database. When they are discovered they will be namespaced - * by the schema name (e.g. .). Thus the source needs to tell the - * tests what that default schema name is. If the database does not support schemas, then database - * name should used instead. - * - * @return name that will be used to namespace the record. - */ - abstract protected boolean supportsSchemas(); - - protected String createTableQuery(final String tableName, final String columnClause, final String primaryKeyClause) { - return String.format("CREATE TABLE %s(%s %s %s)", - tableName, columnClause, primaryKeyClause.equals("") ? "" : ",", primaryKeyClause); - } - - protected String primaryKeyClause(final List columns) { - if (columns.isEmpty()) { - return ""; - } - - final StringBuilder clause = new StringBuilder(); - clause.append("PRIMARY KEY ("); - for (int i = 0; i < columns.size(); i++) { - clause.append(columns.get(i)); - if (i != (columns.size() - 1)) { - clause.append(","); - } - } - clause.append(")"); - return clause.toString(); - } - - @BeforeEach - public void setup() throws Exception { - testdb = createTestDatabase(); - if (supportsSchemas()) { - createSchemas(); - } - if (testdb.getDatabaseDriver().equals(DatabaseDriver.ORACLE)) { - testdb.with("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'"); - } - testdb - .with(createTableQuery(getFullyQualifiedTableName(TABLE_NAME), COLUMN_CLAUSE_WITH_PK, primaryKeyClause(Collections.singletonList("id")))) - .with("INSERT INTO %s(id, name, updated_at) VALUES (1, 'picard', '2004-10-19')", getFullyQualifiedTableName(TABLE_NAME)) - .with("INSERT INTO %s(id, name, updated_at) VALUES (2, 'crusher', '2005-10-19')", getFullyQualifiedTableName(TABLE_NAME)) - .with("INSERT INTO %s(id, name, updated_at) VALUES (3, 'vash', '2006-10-19')", getFullyQualifiedTableName(TABLE_NAME)) - .with(createTableQuery(getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK), COLUMN_CLAUSE_WITHOUT_PK, "")) - .with("INSERT INTO %s(id, name, updated_at) VALUES (1, 'picard', '2004-10-19')", getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK)) - .with("INSERT INTO %s(id, name, updated_at) VALUES (2, 'crusher', '2005-10-19')", getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK)) - .with("INSERT INTO %s(id, name, updated_at) VALUES (3, 'vash', '2006-10-19')", getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK)) - .with(createTableQuery(getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK), COLUMN_CLAUSE_WITH_COMPOSITE_PK, - primaryKeyClause(List.of("first_name", "last_name")))) - .with("INSERT INTO %s(first_name, last_name, updated_at) VALUES ('first', 'picard', '2004-10-19')", - getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK)) - .with("INSERT INTO %s(first_name, last_name, updated_at) VALUES ('second', 'crusher', '2005-10-19')", - getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK)) - .with("INSERT INTO %s(first_name, last_name, updated_at) VALUES ('third', 'vash', '2006-10-19')", - getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK)); - } - - protected void maybeSetShorterConnectionTimeout(final JsonNode config) { - // Optionally implement this to speed up test cases which will result in a connection timeout. - } - - @AfterEach - public void tearDown() { - testdb.close(); - } - - @Test - void testSpec() throws Exception { - final ConnectorSpecification actual = source().spec(); - final String resourceString = MoreResources.readResource("spec.json"); - final ConnectorSpecification expected = Jsons.deserialize(resourceString, ConnectorSpecification.class); - - assertEquals(expected, actual); - } - - @Test - void testCheckSuccess() throws Exception { - final AirbyteConnectionStatus actual = source().check(config()); - final AirbyteConnectionStatus expected = new AirbyteConnectionStatus().withStatus(Status.SUCCEEDED); - assertEquals(expected, actual); - } - - @Test - protected void testCheckFailure() throws Exception { - final var config = config(); - maybeSetShorterConnectionTimeout(config); - ((ObjectNode) config).put(JdbcUtils.PASSWORD_KEY, "fake"); - final AirbyteConnectionStatus actual = source().check(config); - assertEquals(Status.FAILED, actual.getStatus()); - } - - @Test - void testDiscover() throws Exception { - final AirbyteCatalog actual = filterOutOtherSchemas(source().discover(config())); - final AirbyteCatalog expected = getCatalog(getDefaultNamespace()); - assertEquals(expected.getStreams().size(), actual.getStreams().size()); - actual.getStreams().forEach(actualStream -> { - final Optional expectedStream = - expected.getStreams().stream() - .filter(stream -> stream.getNamespace().equals(actualStream.getNamespace()) && stream.getName().equals(actualStream.getName())) - .findAny(); - assertTrue(expectedStream.isPresent(), String.format("Unexpected stream %s", actualStream.getName())); - assertEquals(expectedStream.get(), actualStream); - }); - } - - @Test - protected void testDiscoverWithNonCursorFields() throws Exception { - testdb.with(CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY, getFullyQualifiedTableName(TABLE_NAME_WITHOUT_CURSOR_TYPE), COL_CURSOR) - .with(INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY, getFullyQualifiedTableName(TABLE_NAME_WITHOUT_CURSOR_TYPE)); - final AirbyteCatalog actual = filterOutOtherSchemas(source().discover(config())); - final AirbyteStream stream = - actual.getStreams().stream().filter(s -> s.getName().equalsIgnoreCase(TABLE_NAME_WITHOUT_CURSOR_TYPE)).findFirst().orElse(null); - assertNotNull(stream); - assertEquals(TABLE_NAME_WITHOUT_CURSOR_TYPE.toLowerCase(), stream.getName().toLowerCase()); - assertEquals(1, stream.getSupportedSyncModes().size()); - assertEquals(SyncMode.FULL_REFRESH, stream.getSupportedSyncModes().get(0)); - } - - @Test - protected void testDiscoverWithNullableCursorFields() throws Exception { - testdb.with(CREATE_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY, getFullyQualifiedTableName(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE), COL_CURSOR) - .with(INSERT_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY, getFullyQualifiedTableName(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE)); - final AirbyteCatalog actual = filterOutOtherSchemas(source().discover(config())); - final AirbyteStream stream = - actual.getStreams().stream().filter(s -> s.getName().equalsIgnoreCase(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE)).findFirst().orElse(null); - assertNotNull(stream); - assertEquals(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE.toLowerCase(), stream.getName().toLowerCase()); - assertEquals(2, stream.getSupportedSyncModes().size()); - assertTrue(stream.getSupportedSyncModes().contains(SyncMode.FULL_REFRESH)); - assertTrue(stream.getSupportedSyncModes().contains(SyncMode.INCREMENTAL)); - } - - protected AirbyteCatalog filterOutOtherSchemas(final AirbyteCatalog catalog) { - if (supportsSchemas()) { - final AirbyteCatalog filteredCatalog = Jsons.clone(catalog); - filteredCatalog.setStreams(filteredCatalog.getStreams() - .stream() - .filter(stream -> TEST_SCHEMAS.stream().anyMatch(schemaName -> stream.getNamespace().startsWith(schemaName))) - .collect(Collectors.toList())); - return filteredCatalog; - } else { - return catalog; - } - - } - - @Test - protected void testDiscoverWithMultipleSchemas() throws Exception { - // clickhouse and mysql do not have a concept of schemas, so this test does not make sense for them. - switch (testdb.getDatabaseDriver()) { - case MYSQL, CLICKHOUSE, TERADATA: - return; - } - - // add table and data to a separate schema. - testdb.with("CREATE TABLE %s(id VARCHAR(200) NOT NULL, name VARCHAR(200) NOT NULL)", - RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME)) - .with("INSERT INTO %s(id, name) VALUES ('1','picard')", - RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME)) - .with("INSERT INTO %s(id, name) VALUES ('2', 'crusher')", - RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME)) - .with("INSERT INTO %s(id, name) VALUES ('3', 'vash')", - RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME)); - - final AirbyteCatalog actual = source().discover(config()); - - final AirbyteCatalog expected = getCatalog(getDefaultNamespace()); - final List catalogStreams = new ArrayList<>(); - catalogStreams.addAll(expected.getStreams()); - catalogStreams.add(CatalogHelpers - .createAirbyteStream(TABLE_NAME, - SCHEMA_NAME2, - Field.of(COL_ID, JsonSchemaType.STRING), - Field.of(COL_NAME, JsonSchemaType.STRING)) - .withSupportedSyncModes(List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL))); - expected.setStreams(catalogStreams); - // sort streams by name so that we are comparing lists with the same order. - final Comparator schemaTableCompare = Comparator.comparing(stream -> stream.getNamespace() + "." + stream.getName()); - expected.getStreams().sort(schemaTableCompare); - actual.getStreams().sort(schemaTableCompare); - assertEquals(expected, filterOutOtherSchemas(actual)); - } - - @Test - void testReadSuccess() throws Exception { - final List actualMessages = - MoreIterators.toList( - source().read(config(), getConfiguredCatalogWithOneStream(getDefaultNamespace()), null)); - - setEmittedAtToNull(actualMessages); - final List expectedMessages = getTestMessages(); - assertThat(expectedMessages, Matchers.containsInAnyOrder(actualMessages.toArray())); - assertThat(actualMessages, Matchers.containsInAnyOrder(expectedMessages.toArray())); - } - - @Test - protected void testReadOneColumn() throws Exception { - final ConfiguredAirbyteCatalog catalog = CatalogHelpers - .createConfiguredAirbyteCatalog(streamName(), getDefaultNamespace(), Field.of(COL_ID, JsonSchemaType.NUMBER)); - final List actualMessages = MoreIterators - .toList(source().read(config(), catalog, null)); - - setEmittedAtToNull(actualMessages); - - final List expectedMessages = getAirbyteMessagesReadOneColumn(); - assertEquals(expectedMessages.size(), actualMessages.size()); - assertTrue(expectedMessages.containsAll(actualMessages)); - assertTrue(actualMessages.containsAll(expectedMessages)); - } - - protected List getAirbyteMessagesReadOneColumn() { - final List expectedMessages = getTestMessages().stream() - .map(Jsons::clone) - .peek(m -> { - ((ObjectNode) m.getRecord().getData()).remove(COL_NAME); - ((ObjectNode) m.getRecord().getData()).remove(COL_UPDATED_AT); - ((ObjectNode) m.getRecord().getData()).replace(COL_ID, - convertIdBasedOnDatabase(m.getRecord().getData().get(COL_ID).asInt())); - }) - .collect(Collectors.toList()); - return expectedMessages; - } - - @Test - protected void testReadMultipleTables() throws Exception { - final ConfiguredAirbyteCatalog catalog = getConfiguredCatalogWithOneStream( - getDefaultNamespace()); - final List expectedMessages = new ArrayList<>(getTestMessages()); - - for (int i = 2; i < 10; i++) { - final String streamName2 = streamName() + i; - final String tableName = getFullyQualifiedTableName(TABLE_NAME + i); - testdb.with(createTableQuery(tableName, "id INTEGER, name VARCHAR(200)", "")) - .with("INSERT INTO %s(id, name) VALUES (1,'picard')", tableName) - .with("INSERT INTO %s(id, name) VALUES (2, 'crusher')", tableName) - .with("INSERT INTO %s(id, name) VALUES (3, 'vash')", tableName); - catalog.getStreams().add(CatalogHelpers.createConfiguredAirbyteStream( - streamName2, - getDefaultNamespace(), - Field.of(COL_ID, JsonSchemaType.NUMBER), - Field.of(COL_NAME, JsonSchemaType.STRING))); - - expectedMessages.addAll(getAirbyteMessagesSecondSync(streamName2)); - } - - final List actualMessages = MoreIterators - .toList(source().read(config(), catalog, null)); - - setEmittedAtToNull(actualMessages); - - assertEquals(expectedMessages.size(), actualMessages.size()); - assertTrue(expectedMessages.containsAll(actualMessages)); - assertTrue(actualMessages.containsAll(expectedMessages)); - } - - protected List getAirbyteMessagesSecondSync(final String streamName) { - return getTestMessages() - .stream() - .map(Jsons::clone) - .peek(m -> { - m.getRecord().setStream(streamName); - m.getRecord().setNamespace(getDefaultNamespace()); - ((ObjectNode) m.getRecord().getData()).remove(COL_UPDATED_AT); - ((ObjectNode) m.getRecord().getData()).replace(COL_ID, - convertIdBasedOnDatabase(m.getRecord().getData().get(COL_ID).asInt())); - }) - .collect(Collectors.toList()); - - } - - @Test - protected void testTablesWithQuoting() throws Exception { - final ConfiguredAirbyteStream streamForTableWithSpaces = createTableWithSpaces(); - - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of( - getConfiguredCatalogWithOneStream(getDefaultNamespace()).getStreams().get(0), - streamForTableWithSpaces)); - final List actualMessages = MoreIterators - .toList(source().read(config(), catalog, null)); - - setEmittedAtToNull(actualMessages); - - final List expectedMessages = new ArrayList<>(getTestMessages()); - expectedMessages.addAll(getAirbyteMessagesForTablesWithQuoting(streamForTableWithSpaces)); - - assertEquals(expectedMessages.size(), actualMessages.size()); - assertTrue(expectedMessages.containsAll(actualMessages)); - assertTrue(actualMessages.containsAll(expectedMessages)); - } - - protected List getAirbyteMessagesForTablesWithQuoting(final ConfiguredAirbyteStream streamForTableWithSpaces) { - return getTestMessages() - .stream() - .map(Jsons::clone) - .peek(m -> { - m.getRecord().setStream(streamForTableWithSpaces.getStream().getName()); - ((ObjectNode) m.getRecord().getData()).set(COL_LAST_NAME_WITH_SPACE, - ((ObjectNode) m.getRecord().getData()).remove(COL_NAME)); - ((ObjectNode) m.getRecord().getData()).remove(COL_UPDATED_AT); - ((ObjectNode) m.getRecord().getData()).replace(COL_ID, - convertIdBasedOnDatabase(m.getRecord().getData().get(COL_ID).asInt())); - }) - .collect(Collectors.toList()); - } - - @SuppressWarnings("ResultOfMethodCallIgnored") - @Test - void testReadFailure() { - final ConfiguredAirbyteStream spiedAbStream = spy( - getConfiguredCatalogWithOneStream(getDefaultNamespace()).getStreams().get(0)); - final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of(spiedAbStream)); - doCallRealMethod().doThrow(new RuntimeException()).when(spiedAbStream).getStream(); - - assertThrows(RuntimeException.class, () -> source().read(config(), catalog, null)); - } - - @Test - void testIncrementalNoPreviousState() throws Exception { - incrementalCursorCheck( - COL_ID, - null, - "3", - getTestMessages()); - } - - @Test - void testIncrementalIntCheckCursor() throws Exception { - incrementalCursorCheck( - COL_ID, - "2", - "3", - List.of(getTestMessages().get(2))); - } - - @Test - void testIncrementalStringCheckCursor() throws Exception { - incrementalCursorCheck( - COL_NAME, - "patent", - "vash", - List.of(getTestMessages().get(0), getTestMessages().get(2))); - } - - @Test - void testIncrementalStringCheckCursorSpaceInColumnName() throws Exception { - final ConfiguredAirbyteStream streamWithSpaces = createTableWithSpaces(); - - final List expectedRecordMessages = getAirbyteMessagesCheckCursorSpaceInColumnName(streamWithSpaces); - incrementalCursorCheck( - COL_LAST_NAME_WITH_SPACE, - COL_LAST_NAME_WITH_SPACE, - "patent", - "vash", - expectedRecordMessages, - streamWithSpaces); - } - - protected List getAirbyteMessagesCheckCursorSpaceInColumnName(final ConfiguredAirbyteStream streamWithSpaces) { - final AirbyteMessage firstMessage = getTestMessages().get(0); - firstMessage.getRecord().setStream(streamWithSpaces.getStream().getName()); - ((ObjectNode) firstMessage.getRecord().getData()).remove(COL_UPDATED_AT); - ((ObjectNode) firstMessage.getRecord().getData()).set(COL_LAST_NAME_WITH_SPACE, - ((ObjectNode) firstMessage.getRecord().getData()).remove(COL_NAME)); - - final AirbyteMessage secondMessage = getTestMessages().get(2); - secondMessage.getRecord().setStream(streamWithSpaces.getStream().getName()); - ((ObjectNode) secondMessage.getRecord().getData()).remove(COL_UPDATED_AT); - ((ObjectNode) secondMessage.getRecord().getData()).set(COL_LAST_NAME_WITH_SPACE, - ((ObjectNode) secondMessage.getRecord().getData()).remove(COL_NAME)); - - return List.of(firstMessage, secondMessage); - } - - @Test - void testIncrementalDateCheckCursor() throws Exception { - incrementalDateCheck(); - } - - protected void incrementalDateCheck() throws Exception { - incrementalCursorCheck( - COL_UPDATED_AT, - "2005-10-18", - "2006-10-19", - List.of(getTestMessages().get(1), getTestMessages().get(2))); - } - - @Test - void testIncrementalCursorChanges() throws Exception { - incrementalCursorCheck( - COL_ID, - COL_NAME, - // cheesing this value a little bit. in the correct implementation this initial cursor value should - // be ignored because the cursor field changed. setting it to a value that if used, will cause - // records to (incorrectly) be filtered out. - "data", - "vash", - getTestMessages()); - } - - @Test - protected void testReadOneTableIncrementallyTwice() throws Exception { - final var config = config(); - final String namespace = getDefaultNamespace(); - final ConfiguredAirbyteCatalog configuredCatalog = getConfiguredCatalogWithOneStream(namespace); - configuredCatalog.getStreams().forEach(airbyteStream -> { - airbyteStream.setSyncMode(SyncMode.INCREMENTAL); - airbyteStream.setCursorField(List.of(COL_ID)); - airbyteStream.setDestinationSyncMode(DestinationSyncMode.APPEND); - }); - - final List actualMessagesFirstSync = MoreIterators - .toList(source().read(config, configuredCatalog, createEmptyState(streamName(), namespace))); - - final Optional stateAfterFirstSyncOptional = actualMessagesFirstSync.stream() - .filter(r -> r.getType() == Type.STATE).findFirst(); - assertTrue(stateAfterFirstSyncOptional.isPresent()); - - executeStatementReadIncrementallyTwice(); - - final List actualMessagesSecondSync = MoreIterators - .toList(source().read(config, configuredCatalog, extractState(stateAfterFirstSyncOptional.get()))); - - assertEquals(2, - (int) actualMessagesSecondSync.stream().filter(r -> r.getType() == Type.RECORD).count()); - final List expectedMessages = getExpectedAirbyteMessagesSecondSync(namespace); - - setEmittedAtToNull(actualMessagesSecondSync); - - assertEquals(expectedMessages.size(), actualMessagesSecondSync.size()); - assertTrue(expectedMessages.containsAll(actualMessagesSecondSync)); - assertTrue(actualMessagesSecondSync.containsAll(expectedMessages)); - } - - protected void executeStatementReadIncrementallyTwice() { - testdb - .with("INSERT INTO %s (id, name, updated_at) VALUES (4, 'riker', '2006-10-19')", getFullyQualifiedTableName(TABLE_NAME)) - .with("INSERT INTO %s (id, name, updated_at) VALUES (5, 'data', '2006-10-19')", getFullyQualifiedTableName(TABLE_NAME)); - } - - protected List getExpectedAirbyteMessagesSecondSync(final String namespace) { - final List expectedMessages = new ArrayList<>(); - expectedMessages.add(new AirbyteMessage().withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage().withStream(streamName()).withNamespace(namespace) - .withData(Jsons.jsonNode(Map - .of(COL_ID, ID_VALUE_4, - COL_NAME, "riker", - COL_UPDATED_AT, "2006-10-19"))))); - expectedMessages.add(new AirbyteMessage().withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage().withStream(streamName()).withNamespace(namespace) - .withData(Jsons.jsonNode(Map - .of(COL_ID, ID_VALUE_5, - COL_NAME, "data", - COL_UPDATED_AT, "2006-10-19"))))); - final DbStreamState state = new DbStreamState() - .withStreamName(streamName()) - .withStreamNamespace(namespace) - .withCursorField(List.of(COL_ID)) - .withCursor("5") - .withCursorRecordCount(1L); - expectedMessages.addAll(createExpectedTestMessages(List.of(state), 2L)); - return expectedMessages; - } - - @Test - protected void testReadMultipleTablesIncrementally() throws Exception { - final String tableName2 = TABLE_NAME + 2; - final String streamName2 = streamName() + 2; - final String fqTableName2 = getFullyQualifiedTableName(tableName2); - testdb.with(createTableQuery(fqTableName2, "id INTEGER, name VARCHAR(200)", "")) - .with("INSERT INTO %s(id, name) VALUES (1,'picard')", fqTableName2) - .with("INSERT INTO %s(id, name) VALUES (2, 'crusher')", fqTableName2) - .with("INSERT INTO %s(id, name) VALUES (3, 'vash')", fqTableName2); - - final String namespace = getDefaultNamespace(); - final ConfiguredAirbyteCatalog configuredCatalog = getConfiguredCatalogWithOneStream( - namespace); - configuredCatalog.getStreams().add(CatalogHelpers.createConfiguredAirbyteStream( - streamName2, - namespace, - Field.of(COL_ID, JsonSchemaType.NUMBER), - Field.of(COL_NAME, JsonSchemaType.STRING))); - configuredCatalog.getStreams().forEach(airbyteStream -> { - airbyteStream.setSyncMode(SyncMode.INCREMENTAL); - airbyteStream.setCursorField(List.of(COL_ID)); - airbyteStream.setDestinationSyncMode(DestinationSyncMode.APPEND); - }); - - final List actualMessagesFirstSync = MoreIterators - .toList(source().read(config(), configuredCatalog, createEmptyState(streamName(), namespace))); - - // get last state message. - final Optional stateAfterFirstSyncOptional = actualMessagesFirstSync.stream() - .filter(r -> r.getType() == Type.STATE) - .reduce((first, second) -> second); - assertTrue(stateAfterFirstSyncOptional.isPresent()); - - // we know the second streams messages are the same as the first minus the updated at column. so we - // cheat and generate the expected messages off of the first expected messages. - final List secondStreamExpectedMessages = getAirbyteMessagesSecondStreamWithNamespace(streamName2); - - // Represents the state after the first stream has been updated - final List expectedStateStreams1 = List.of( - new DbStreamState() - .withStreamName(streamName()) - .withStreamNamespace(namespace) - .withCursorField(List.of(COL_ID)) - .withCursor("3") - .withCursorRecordCount(1L), - new DbStreamState() - .withStreamName(streamName2) - .withStreamNamespace(namespace) - .withCursorField(List.of(COL_ID))); - - // Represents the state after both streams have been updated - final List expectedStateStreams2 = List.of( - new DbStreamState() - .withStreamName(streamName()) - .withStreamNamespace(namespace) - .withCursorField(List.of(COL_ID)) - .withCursor("3") - .withCursorRecordCount(1L), - new DbStreamState() - .withStreamName(streamName2) - .withStreamNamespace(namespace) - .withCursorField(List.of(COL_ID)) - .withCursor("3") - .withCursorRecordCount(1L)); - - final List expectedMessagesFirstSync = new ArrayList<>(getTestMessages()); - expectedMessagesFirstSync.add(createStateMessage(expectedStateStreams1.get(0), expectedStateStreams1, 3L)); - expectedMessagesFirstSync.addAll(secondStreamExpectedMessages); - expectedMessagesFirstSync.add(createStateMessage(expectedStateStreams2.get(1), expectedStateStreams2, 3L)); - - setEmittedAtToNull(actualMessagesFirstSync); - - assertEquals(expectedMessagesFirstSync.size(), actualMessagesFirstSync.size()); - assertTrue(expectedMessagesFirstSync.containsAll(actualMessagesFirstSync)); - assertTrue(actualMessagesFirstSync.containsAll(expectedMessagesFirstSync)); - } - - protected List getAirbyteMessagesSecondStreamWithNamespace(final String streamName2) { - return getTestMessages() - .stream() - .map(Jsons::clone) - .peek(m -> { - m.getRecord().setStream(streamName2); - ((ObjectNode) m.getRecord().getData()).remove(COL_UPDATED_AT); - ((ObjectNode) m.getRecord().getData()).replace(COL_ID, - convertIdBasedOnDatabase(m.getRecord().getData().get(COL_ID).asInt())); - }) - .collect(Collectors.toList()); - } - - // when initial and final cursor fields are the same. - protected void incrementalCursorCheck( - final String cursorField, - final String initialCursorValue, - final String endCursorValue, - final List expectedRecordMessages) - throws Exception { - incrementalCursorCheck(cursorField, cursorField, initialCursorValue, endCursorValue, - expectedRecordMessages); - } - - // See https://github.com/airbytehq/airbyte/issues/14732 for rationale and details. - @Test - public void testIncrementalWithConcurrentInsertion() throws Exception { - final String namespace = getDefaultNamespace(); - final String fullyQualifiedTableName = getFullyQualifiedTableName(TABLE_NAME_AND_TIMESTAMP); - final String columnDefinition = String.format("name VARCHAR(200) NOT NULL, %s %s NOT NULL", COL_TIMESTAMP, COL_TIMESTAMP_TYPE); - - // 1st sync - testdb.with(createTableQuery(fullyQualifiedTableName, columnDefinition, "")) - .with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "a", "2021-01-01 00:00:00") - .with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "b", "2021-01-01 00:00:00"); - - final ConfiguredAirbyteCatalog configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog( - new AirbyteCatalog().withStreams(List.of( - CatalogHelpers.createAirbyteStream( - TABLE_NAME_AND_TIMESTAMP, - namespace, - Field.of(COL_NAME, JsonSchemaType.STRING), - Field.of(COL_TIMESTAMP, JsonSchemaType.STRING_TIMESTAMP_WITHOUT_TIMEZONE))))); - - configuredCatalog.getStreams().forEach(airbyteStream -> { - airbyteStream.setSyncMode(SyncMode.INCREMENTAL); - airbyteStream.setCursorField(List.of(COL_TIMESTAMP)); - airbyteStream.setDestinationSyncMode(DestinationSyncMode.APPEND); - }); - - final List firstSyncActualMessages = MoreIterators.toList( - source().read(config(), configuredCatalog, createEmptyState(TABLE_NAME_AND_TIMESTAMP, namespace))); - - // cursor after 1st sync: 2021-01-01 00:00:00, count 2 - final Optional firstSyncStateOptional = firstSyncActualMessages.stream().filter(r -> r.getType() == Type.STATE).findFirst(); - assertTrue(firstSyncStateOptional.isPresent()); - final JsonNode firstSyncState = getStateData(firstSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP); - assertEquals(firstSyncState.get("cursor_field").elements().next().asText(), COL_TIMESTAMP); - assertTrue(firstSyncState.get("cursor").asText().contains("2021-01-01")); - assertTrue(firstSyncState.get("cursor").asText().contains("00:00:00")); - assertEquals(2L, firstSyncState.get("cursor_record_count").asLong()); - - final List firstSyncNames = firstSyncActualMessages.stream() - .filter(r -> r.getType() == Type.RECORD) - .map(r -> r.getRecord().getData().get(COL_NAME).asText()) - .toList(); - // some databases don't make insertion order guarantee when equal ordering value - if (testdb.getDatabaseDriver().equals(DatabaseDriver.TERADATA) || testdb.getDatabaseDriver().equals(DatabaseDriver.ORACLE)) { - assertThat(List.of("a", "b"), Matchers.containsInAnyOrder(firstSyncNames.toArray())); - } else { - assertEquals(List.of("a", "b"), firstSyncNames); - } - - // 2nd sync - testdb.with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "c", "2021-01-02 00:00:00"); - - final List secondSyncActualMessages = MoreIterators.toList( - source().read(config(), configuredCatalog, createState(TABLE_NAME_AND_TIMESTAMP, namespace, firstSyncState))); - - // cursor after 2nd sync: 2021-01-02 00:00:00, count 1 - final Optional secondSyncStateOptional = secondSyncActualMessages.stream().filter(r -> r.getType() == Type.STATE).findFirst(); - assertTrue(secondSyncStateOptional.isPresent()); - final JsonNode secondSyncState = getStateData(secondSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP); - assertEquals(secondSyncState.get("cursor_field").elements().next().asText(), COL_TIMESTAMP); - assertTrue(secondSyncState.get("cursor").asText().contains("2021-01-02")); - assertTrue(secondSyncState.get("cursor").asText().contains("00:00:00")); - assertEquals(1L, secondSyncState.get("cursor_record_count").asLong()); - - final List secondSyncNames = secondSyncActualMessages.stream() - .filter(r -> r.getType() == Type.RECORD) - .map(r -> r.getRecord().getData().get(COL_NAME).asText()) - .toList(); - assertEquals(List.of("c"), secondSyncNames); - - // 3rd sync has records with duplicated cursors - testdb.with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "d", "2021-01-02 00:00:00") - .with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "e", "2021-01-02 00:00:00") - .with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "f", "2021-01-03 00:00:00"); - - final List thirdSyncActualMessages = MoreIterators.toList( - source().read(config(), configuredCatalog, createState(TABLE_NAME_AND_TIMESTAMP, namespace, secondSyncState))); - - // Cursor after 3rd sync is: 2021-01-03 00:00:00, count 1. - final Optional thirdSyncStateOptional = thirdSyncActualMessages.stream().filter(r -> r.getType() == Type.STATE).findFirst(); - assertTrue(thirdSyncStateOptional.isPresent()); - final JsonNode thirdSyncState = getStateData(thirdSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP); - assertEquals(thirdSyncState.get("cursor_field").elements().next().asText(), COL_TIMESTAMP); - assertTrue(thirdSyncState.get("cursor").asText().contains("2021-01-03")); - assertTrue(thirdSyncState.get("cursor").asText().contains("00:00:00")); - assertEquals(1L, thirdSyncState.get("cursor_record_count").asLong()); - - // The c, d, e, f are duplicated records from this sync, because the cursor - // record count in the database is different from that in the state. - final List thirdSyncExpectedNames = thirdSyncActualMessages.stream() - .filter(r -> r.getType() == Type.RECORD) - .map(r -> r.getRecord().getData().get(COL_NAME).asText()) - .toList(); - - // teradata doesn't make insertion order guarantee when equal ordering value - if (testdb.getDatabaseDriver().equals(DatabaseDriver.TERADATA)) { - assertThat(List.of("c", "d", "e", "f"), Matchers.containsInAnyOrder(thirdSyncExpectedNames.toArray())); - } else { - assertEquals(List.of("c", "d", "e", "f"), thirdSyncExpectedNames); - } - } - - protected JsonNode getStateData(final AirbyteMessage airbyteMessage, final String streamName) { - for (final JsonNode stream : airbyteMessage.getState().getData().get("streams")) { - if (stream.get("stream_name").asText().equals(streamName)) { - return stream; - } - } - throw new IllegalArgumentException("Stream not found in state message: " + streamName); - } - - private void incrementalCursorCheck( - final String initialCursorField, - final String cursorField, - final String initialCursorValue, - final String endCursorValue, - final List expectedRecordMessages) - throws Exception { - incrementalCursorCheck(initialCursorField, cursorField, initialCursorValue, endCursorValue, - expectedRecordMessages, - getConfiguredCatalogWithOneStream(getDefaultNamespace()).getStreams().get(0)); - } - - protected void incrementalCursorCheck( - final String initialCursorField, - final String cursorField, - final String initialCursorValue, - final String endCursorValue, - final List expectedRecordMessages, - final ConfiguredAirbyteStream airbyteStream) - throws Exception { - airbyteStream.setSyncMode(SyncMode.INCREMENTAL); - airbyteStream.setCursorField(List.of(cursorField)); - airbyteStream.setDestinationSyncMode(DestinationSyncMode.APPEND); - - final ConfiguredAirbyteCatalog configuredCatalog = new ConfiguredAirbyteCatalog() - .withStreams(List.of(airbyteStream)); - - final DbStreamState dbStreamState = buildStreamState(airbyteStream, initialCursorField, initialCursorValue); - - final List actualMessages = MoreIterators - .toList(source().read(config(), configuredCatalog, Jsons.jsonNode(createState(List.of(dbStreamState))))); - - setEmittedAtToNull(actualMessages); - - final List expectedStreams = List.of(buildStreamState(airbyteStream, cursorField, endCursorValue)); - - final List expectedMessages = new ArrayList<>(expectedRecordMessages); - expectedMessages.addAll(createExpectedTestMessages(expectedStreams, expectedRecordMessages.size())); - - assertEquals(expectedMessages.size(), actualMessages.size()); - assertTrue(expectedMessages.containsAll(actualMessages)); - assertTrue(actualMessages.containsAll(expectedMessages)); - } - - protected DbStreamState buildStreamState(final ConfiguredAirbyteStream configuredAirbyteStream, - final String cursorField, - final String cursorValue) { - return new DbStreamState() - .withStreamName(configuredAirbyteStream.getStream().getName()) - .withStreamNamespace(configuredAirbyteStream.getStream().getNamespace()) - .withCursorField(List.of(cursorField)) - .withCursor(cursorValue) - .withCursorRecordCount(1L); - } - - // get catalog and perform a defensive copy. - protected ConfiguredAirbyteCatalog getConfiguredCatalogWithOneStream(final String defaultNamespace) { - final ConfiguredAirbyteCatalog catalog = CatalogHelpers.toDefaultConfiguredCatalog(getCatalog(defaultNamespace)); - // Filter to only keep the main stream name as configured stream - catalog.withStreams( - catalog.getStreams().stream().filter(s -> s.getStream().getName().equals(streamName())) - .collect(Collectors.toList())); - return catalog; - } - - protected AirbyteCatalog getCatalog(final String defaultNamespace) { - return new AirbyteCatalog().withStreams(List.of( - CatalogHelpers.createAirbyteStream( - TABLE_NAME, - defaultNamespace, - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_NAME, JsonSchemaType.STRING), - Field.of(COL_UPDATED_AT, JsonSchemaType.STRING)) - .withSupportedSyncModes(List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(List.of(COL_ID))), - CatalogHelpers.createAirbyteStream( - TABLE_NAME_WITHOUT_PK, - defaultNamespace, - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_NAME, JsonSchemaType.STRING), - Field.of(COL_UPDATED_AT, JsonSchemaType.STRING)) - .withSupportedSyncModes(List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(Collections.emptyList()), - CatalogHelpers.createAirbyteStream( - TABLE_NAME_COMPOSITE_PK, - defaultNamespace, - Field.of(COL_FIRST_NAME, JsonSchemaType.STRING), - Field.of(COL_LAST_NAME, JsonSchemaType.STRING), - Field.of(COL_UPDATED_AT, JsonSchemaType.STRING)) - .withSupportedSyncModes(List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey( - List.of(List.of(COL_FIRST_NAME), List.of(COL_LAST_NAME))))); - } - - protected List getTestMessages() { - return List.of( - new AirbyteMessage().withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage().withStream(streamName()).withNamespace(getDefaultNamespace()) - .withData(Jsons.jsonNode(Map - .of(COL_ID, ID_VALUE_1, - COL_NAME, "picard", - COL_UPDATED_AT, "2004-10-19")))), - new AirbyteMessage().withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage().withStream(streamName()).withNamespace(getDefaultNamespace()) - .withData(Jsons.jsonNode(Map - .of(COL_ID, ID_VALUE_2, - COL_NAME, "crusher", - COL_UPDATED_AT, - "2005-10-19")))), - new AirbyteMessage().withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage().withStream(streamName()).withNamespace(getDefaultNamespace()) - .withData(Jsons.jsonNode(Map - .of(COL_ID, ID_VALUE_3, - COL_NAME, "vash", - COL_UPDATED_AT, "2006-10-19"))))); - } - - protected List createExpectedTestMessages(final List states, final long numRecords) { - return states.stream() - .map(s -> new AirbyteMessage().withType(Type.STATE) - .withState( - new AirbyteStateMessage().withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withNamespace(s.getStreamNamespace()).withName(s.getStreamName())) - .withStreamState(Jsons.jsonNode(s))) - .withData(Jsons.jsonNode(new DbState().withCdc(false).withStreams(states))) - .withSourceStats(new AirbyteStateStats().withRecordCount((double) numRecords)))) - .collect( - Collectors.toList()); - } - - protected List createState(final List states) { - return states.stream() - .map(s -> new AirbyteStateMessage().withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withNamespace(s.getStreamNamespace()).withName(s.getStreamName())) - .withStreamState(Jsons.jsonNode(s)))) - .collect( - Collectors.toList()); - } - - protected ConfiguredAirbyteStream createTableWithSpaces() throws SQLException { - final String tableNameWithSpaces = TABLE_NAME_WITH_SPACES + "2"; - final String streamName2 = tableNameWithSpaces; - - try (final var connection = testdb.getDataSource().getConnection()) { - final String identifierQuoteString = connection.getMetaData().getIdentifierQuoteString(); - connection.createStatement() - .execute( - createTableQuery(getFullyQualifiedTableName( - enquoteIdentifier(tableNameWithSpaces, identifierQuoteString)), - "id INTEGER, " + enquoteIdentifier(COL_LAST_NAME_WITH_SPACE, identifierQuoteString) - + " VARCHAR(200)", - "")); - connection.createStatement() - .execute(String.format("INSERT INTO %s(id, %s) VALUES (1,'picard')", - getFullyQualifiedTableName( - enquoteIdentifier(tableNameWithSpaces, identifierQuoteString)), - enquoteIdentifier(COL_LAST_NAME_WITH_SPACE, identifierQuoteString))); - connection.createStatement() - .execute(String.format("INSERT INTO %s(id, %s) VALUES (2, 'crusher')", - getFullyQualifiedTableName( - enquoteIdentifier(tableNameWithSpaces, identifierQuoteString)), - enquoteIdentifier(COL_LAST_NAME_WITH_SPACE, identifierQuoteString))); - connection.createStatement() - .execute(String.format("INSERT INTO %s(id, %s) VALUES (3, 'vash')", - getFullyQualifiedTableName( - enquoteIdentifier(tableNameWithSpaces, identifierQuoteString)), - enquoteIdentifier(COL_LAST_NAME_WITH_SPACE, identifierQuoteString))); - } - - return CatalogHelpers.createConfiguredAirbyteStream( - streamName2, - getDefaultNamespace(), - Field.of(COL_ID, JsonSchemaType.NUMBER), - Field.of(COL_LAST_NAME_WITH_SPACE, JsonSchemaType.STRING)); - } - - public String getFullyQualifiedTableName(final String tableName) { - return RelationalDbQueryUtils.getFullyQualifiedTableName(getDefaultSchemaName(), tableName); - } - - protected void createSchemas() { - if (supportsSchemas()) { - for (final String schemaName : TEST_SCHEMAS) { - testdb.with("CREATE SCHEMA %s;", schemaName); - } - } - } - - private JsonNode convertIdBasedOnDatabase(final int idValue) { - return switch (testdb.getDatabaseDriver()) { - case ORACLE, SNOWFLAKE -> Jsons.jsonNode(BigDecimal.valueOf(idValue)); - default -> Jsons.jsonNode(idValue); - }; - } - - private String getDefaultSchemaName() { - return supportsSchemas() ? SCHEMA_NAME : null; - } - - protected String getDefaultNamespace() { - return switch (testdb.getDatabaseDriver()) { - // mysql does not support schemas, it namespaces using database names instead. - case MYSQL, CLICKHOUSE, TERADATA -> testdb.getDatabaseName(); - default -> SCHEMA_NAME; - }; - } - - protected static void setEmittedAtToNull(final Iterable messages) { - for (final AirbyteMessage actualMessage : messages) { - if (actualMessage.getRecord() != null) { - actualMessage.getRecord().setEmittedAt(null); - } - } - } - - /** - * Creates empty state with the provided stream name and namespace. - * - * @param streamName The stream name. - * @param streamNamespace The stream namespace. - * @return {@link JsonNode} representation of the generated empty state. - */ - protected JsonNode createEmptyState(final String streamName, final String streamNamespace) { - final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage() - .withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withName(streamName).withNamespace(streamNamespace))); - return Jsons.jsonNode(List.of(airbyteStateMessage)); - - } - - protected JsonNode createState(final String streamName, final String streamNamespace, final JsonNode stateData) { - final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage() - .withType(AirbyteStateType.STREAM) - .withStream( - new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withName(streamName).withNamespace(streamNamespace)) - .withStreamState(stateData)); - return Jsons.jsonNode(List.of(airbyteStateMessage)); - } - - protected JsonNode extractState(final AirbyteMessage airbyteMessage) { - return Jsons.jsonNode(List.of(airbyteMessage.getState())); - } - - protected AirbyteMessage createStateMessage(final DbStreamState dbStreamState, final List legacyStates, final long recordCount) { - return new AirbyteMessage().withType(Type.STATE) - .withState( - new AirbyteStateMessage().withType(AirbyteStateType.STREAM) - .withStream(new AirbyteStreamState() - .withStreamDescriptor(new StreamDescriptor().withNamespace(dbStreamState.getStreamNamespace()) - .withName(dbStreamState.getStreamName())) - .withStreamState(Jsons.jsonNode(dbStreamState))) - .withData(Jsons.jsonNode(new DbState().withCdc(false).withStreams(legacyStates))) - .withSourceStats(new AirbyteStateStats().withRecordCount((double) recordCount))); - } - - protected List extractSpecificFieldFromCombinedMessages(final List messages, - final String streamName, - final String field) { - return extractStateMessage(messages).stream() - .filter(s -> s.getStream().getStreamDescriptor().getName().equals(streamName)) - .map(s -> s.getStream().getStreamState().get(field) != null ? s.getStream().getStreamState().get(field).asText() : "").toList(); - } - - protected List filterRecords(final List messages) { - return messages.stream().filter(r -> r.getType() == Type.RECORD) - .collect(Collectors.toList()); - } - - protected List extractStateMessage(final List messages) { - return messages.stream().filter(r -> r.getType() == Type.STATE).map(AirbyteMessage::getState) - .collect(Collectors.toList()); - } - - protected List extractStateMessage(final List messages, final String streamName) { - return messages.stream().filter(r -> r.getType() == Type.STATE && - r.getState().getStream().getStreamDescriptor().getName().equals(streamName)).map(AirbyteMessage::getState) - .collect(Collectors.toList()); - } - - protected AirbyteMessage createRecord(final String stream, final String namespace, final Map data) { - return new AirbyteMessage().withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage().withData(Jsons.jsonNode(data)).withStream(stream).withNamespace(namespace)); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.java deleted file mode 100644 index 9c626a9ac911..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.java +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.source.jdbc.test; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import io.airbyte.cdk.db.factory.DataSourceFactory; -import io.airbyte.cdk.db.jdbc.DefaultJdbcDatabase; -import io.airbyte.cdk.db.jdbc.JdbcDatabase; -import io.airbyte.cdk.db.jdbc.JdbcUtils; -import io.airbyte.cdk.integrations.source.jdbc.AbstractJdbcSource; -import io.airbyte.commons.json.Jsons; -import io.airbyte.commons.stream.MoreStreams; -import io.airbyte.commons.string.Strings; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteCatalog; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.DestinationSyncMode; -import io.airbyte.protocol.models.v0.SyncMode; -import java.math.BigDecimal; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.BitSet; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Optional; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Runs a "large" amount of data through a JdbcSource to ensure that it streams / chunks records. - */ -// todo (cgardens) - this needs more love and thought. we should be able to test this without having -// to rewrite so much data. it is enough for now to sanity check that our JdbcSources can actually -// handle more data than fits in memory. -@SuppressFBWarnings( - value = {"MS_SHOULD_BE_FINAL"}, - justification = "The static variables are updated in sub classes for convenience, and cannot be final.") -public abstract class JdbcStressTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(JdbcStressTest.class); - - // this will get rounded down to the nearest 1000th. - private static final long TOTAL_RECORDS = 10_000_000L; - private static final int BATCH_SIZE = 1000; - public static String TABLE_NAME = "id_and_name"; - public static String COL_ID = "id"; - public static String COL_NAME = "name"; - public static String COL_ID_TYPE = "BIGINT"; - public static String INSERT_STATEMENT = "(%s,'picard-%s')"; - - private static String streamName; - - private BitSet bitSet; - private JsonNode config; - private AbstractJdbcSource source; - - /** - * These tests write records without specifying a namespace (schema name). They will be written into - * whatever the default schema is for the database. When they are discovered they will be namespaced - * by the schema name (e.g. .). Thus the source needs to tell the - * tests what that default schema name is. If the database does not support schemas, then database - * name should used instead. - * - * @return name that will be used to namespace the record. - */ - public abstract Optional getDefaultSchemaName(); - - /** - * A valid configuration to connect to a test database. - * - * @return config - */ - public abstract JsonNode getConfig(); - - /** - * Full qualified class name of the JDBC driver for the database. - * - * @return driver - */ - public abstract String getDriverClass(); - - /** - * An instance of the source that should be tests. - * - * @return source - */ - public abstract AbstractJdbcSource getSource(); - - protected String createTableQuery(final String tableName, final String columnClause) { - return String.format("CREATE TABLE %s(%s)", - tableName, columnClause); - } - - public void setup() throws Exception { - LOGGER.info("running for driver:" + getDriverClass()); - bitSet = new BitSet((int) TOTAL_RECORDS); - - source = getSource(); - streamName = getDefaultSchemaName().map(val -> val + "." + TABLE_NAME).orElse(TABLE_NAME); - config = getConfig(); - - final JsonNode jdbcConfig = source.toDatabaseConfig(config); - final JdbcDatabase database = new DefaultJdbcDatabase( - DataSourceFactory.create( - jdbcConfig.get(JdbcUtils.USERNAME_KEY).asText(), - jdbcConfig.has(JdbcUtils.PASSWORD_KEY) ? jdbcConfig.get(JdbcUtils.PASSWORD_KEY).asText() : null, - getDriverClass(), - jdbcConfig.get(JdbcUtils.JDBC_URL_KEY).asText())); - - database.execute(connection -> connection.createStatement().execute( - createTableQuery("id_and_name", String.format("id %s, name VARCHAR(200)", COL_ID_TYPE)))); - final long batchCount = TOTAL_RECORDS / BATCH_SIZE; - LOGGER.info("writing {} batches of {}", batchCount, BATCH_SIZE); - for (int i = 0; i < batchCount; i++) { - if (i % 1000 == 0) - LOGGER.info("writing batch: " + i); - final List insert = new ArrayList<>(); - for (int j = 0; j < BATCH_SIZE; j++) { - final int recordNumber = (i * BATCH_SIZE) + j; - insert.add(String.format(INSERT_STATEMENT, recordNumber, recordNumber)); - } - - final String sql = prepareInsertStatement(insert); - database.execute(connection -> connection.createStatement().execute(sql)); - } - - } - - // todo (cgardens) - restructure these tests so that testFullRefresh() and testIncremental() can be - // separate tests. current constrained by only wanting to setup the fixture in the database once, - // but it is not trivial to move them to @BeforeAll because it is static and we are doing - // inheritance. Not impossible, just needs to be done thoughtfully and for all JdbcSources. - @Test - public void stressTest() throws Exception { - testFullRefresh(); - testIncremental(); - } - - private void testFullRefresh() throws Exception { - runTest(getConfiguredCatalogFullRefresh(), "full_refresh"); - } - - private void testIncremental() throws Exception { - runTest(getConfiguredCatalogIncremental(), "incremental"); - } - - private void runTest(final ConfiguredAirbyteCatalog configuredCatalog, final String testName) throws Exception { - LOGGER.info("running stress test for: " + testName); - final Iterator read = source.read(config, configuredCatalog, Jsons.jsonNode(Collections.emptyMap())); - final long actualCount = MoreStreams.toStream(read) - .filter(m -> m.getType() == Type.RECORD) - .peek(m -> { - if (m.getRecord().getData().get(COL_ID).asLong() % 100000 == 0) { - LOGGER.info("reading batch: " + m.getRecord().getData().get(COL_ID).asLong() / 1000); - } - }) - .peek(m -> assertExpectedMessage(m)) - .count(); - ByteBuffer a; - final long expectedRoundedRecordsCount = TOTAL_RECORDS - TOTAL_RECORDS % 1000; - LOGGER.info("expected records count: " + TOTAL_RECORDS); - LOGGER.info("actual records count: " + actualCount); - assertEquals(expectedRoundedRecordsCount, actualCount, "testing: " + testName); - assertEquals(expectedRoundedRecordsCount, bitSet.cardinality(), "testing: " + testName); - } - - // each is roughly 106 bytes. - private void assertExpectedMessage(final AirbyteMessage actualMessage) { - final long recordNumber = actualMessage.getRecord().getData().get(COL_ID).asLong(); - bitSet.set((int) recordNumber); - actualMessage.getRecord().setEmittedAt(null); - - final Number expectedRecordNumber = - getDriverClass().toLowerCase().contains("oracle") ? new BigDecimal(recordNumber) - : recordNumber; - - final AirbyteMessage expectedMessage = new AirbyteMessage().withType(Type.RECORD) - .withRecord(new AirbyteRecordMessage().withStream(streamName) - .withData(Jsons.jsonNode( - ImmutableMap.of(COL_ID, expectedRecordNumber, COL_NAME, "picard-" + recordNumber)))); - assertEquals(expectedMessage, actualMessage); - } - - private static ConfiguredAirbyteCatalog getConfiguredCatalogFullRefresh() { - return CatalogHelpers.toDefaultConfiguredCatalog(getCatalog()); - } - - private static ConfiguredAirbyteCatalog getConfiguredCatalogIncremental() { - return new ConfiguredAirbyteCatalog() - .withStreams(Collections.singletonList(new ConfiguredAirbyteStream().withStream(getCatalog().getStreams().get(0)) - .withCursorField(Collections.singletonList(COL_ID)) - .withSyncMode(SyncMode.INCREMENTAL) - .withDestinationSyncMode(DestinationSyncMode.APPEND))); - } - - private static AirbyteCatalog getCatalog() { - return new AirbyteCatalog().withStreams(Lists.newArrayList(CatalogHelpers.createAirbyteStream( - streamName, - Field.of(COL_ID, JsonSchemaType.NUMBER), - Field.of(COL_NAME, JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)))); - } - - private String prepareInsertStatement(final List inserts) { - if (getDriverClass().toLowerCase().contains("oracle")) { - return String.format("INSERT ALL %s SELECT * FROM dual", Strings.join(inserts, " ")); - } - return String.format("INSERT INTO id_and_name (id, name) VALUES %s", Strings.join(inserts, ", ")); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.java deleted file mode 100644 index 2393c4dcc595..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.java +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.api.client.AirbyteApiClient; -import io.airbyte.api.client.generated.SourceApi; -import io.airbyte.api.client.model.generated.DiscoverCatalogResult; -import io.airbyte.api.client.model.generated.SourceDiscoverSchemaWriteRequestBody; -import io.airbyte.commons.features.EnvVariableFeatureFlags; -import io.airbyte.commons.features.FeatureFlags; -import io.airbyte.commons.json.Jsons; -import io.airbyte.configoss.JobGetSpecConfig; -import io.airbyte.configoss.StandardCheckConnectionInput; -import io.airbyte.configoss.StandardCheckConnectionOutput; -import io.airbyte.configoss.StandardDiscoverCatalogInput; -import io.airbyte.configoss.State; -import io.airbyte.configoss.WorkerSourceConfig; -import io.airbyte.protocol.models.v0.AirbyteCatalog; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConnectorSpecification; -import io.airbyte.workers.exception.TestHarnessException; -import io.airbyte.workers.general.DefaultCheckConnectionTestHarness; -import io.airbyte.workers.general.DefaultDiscoverCatalogTestHarness; -import io.airbyte.workers.general.DefaultGetSpecTestHarness; -import io.airbyte.workers.helper.CatalogClientConverters; -import io.airbyte.workers.helper.ConnectorConfigUpdater; -import io.airbyte.workers.helper.EntrypointEnvChecker; -import io.airbyte.workers.internal.AirbyteSource; -import io.airbyte.workers.internal.DefaultAirbyteSource; -import io.airbyte.workers.process.AirbyteIntegrationLauncher; -import io.airbyte.workers.process.DockerProcessFactory; -import io.airbyte.workers.process.ProcessFactory; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.mockito.ArgumentCaptor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This abstract class contains helpful functionality and boilerplate for testing a source - * connector. - */ -public abstract class AbstractSourceConnectorTest { - - protected static final Logger LOGGER = LoggerFactory.getLogger(AbstractSourceConnectorTest.class); - private TestDestinationEnv environment; - private Path jobRoot; - protected Path localRoot; - private ProcessFactory processFactory; - - private static final String JOB_ID = String.valueOf(0L); - private static final int JOB_ATTEMPT = 0; - - private static final UUID CATALOG_ID = UUID.randomUUID(); - - private static final UUID SOURCE_ID = UUID.randomUUID(); - - private static final String CPU_REQUEST_FIELD_NAME = "cpuRequest"; - private static final String CPU_LIMIT_FIELD_NAME = "cpuLimit"; - private static final String MEMORY_REQUEST_FIELD_NAME = "memoryRequest"; - private static final String MEMORY_LIMIT_FIELD_NAME = "memoryLimit"; - - /** - * Name of the docker image that the tests will run against. - * - * @return docker image name - */ - protected abstract String getImageName(); - - /** - * Configuration specific to the integration. Will be passed to integration where appropriate in - * each test. Should be valid. - * - * @return integration-specific configuration - */ - protected abstract JsonNode getConfig() throws Exception; - - /** - * Function that performs any setup of external resources required for the test. e.g. instantiate a - * postgres database. This function will be called before EACH test. - * - * @param environment - information about the test environment. - * @throws Exception - can throw any exception, test framework will handle. - */ - protected abstract void setupEnvironment(TestDestinationEnv environment) throws Exception; - - /** - * Function that performs any clean up of external resources required for the test. e.g. delete a - * postgres database. This function will be called after EACH test. It MUST remove all data in the - * destination so that there is no contamination across tests. - * - * @param testEnv - information about the test environment. - * @throws Exception - can throw any exception, test framework will handle. - */ - protected abstract void tearDown(TestDestinationEnv testEnv) throws Exception; - - private AirbyteApiClient mAirbyteApiClient; - - private SourceApi mSourceApi; - - private ConnectorConfigUpdater mConnectorConfigUpdater; - - protected AirbyteCatalog getLastPersistedCatalog() { - return convertProtocolObject( - CatalogClientConverters.toAirbyteProtocol(discoverWriteRequest.getValue().getCatalog()), AirbyteCatalog.class); - } - - private final ArgumentCaptor discoverWriteRequest = - ArgumentCaptor.forClass(SourceDiscoverSchemaWriteRequestBody.class); - - @BeforeEach - public void setUpInternal() throws Exception { - final Path testDir = Path.of("/tmp/airbyte_tests/"); - Files.createDirectories(testDir); - final Path workspaceRoot = Files.createTempDirectory(testDir, "test"); - jobRoot = Files.createDirectories(Path.of(workspaceRoot.toString(), "job")); - localRoot = Files.createTempDirectory(testDir, "output"); - environment = new TestDestinationEnv(localRoot); - setupEnvironment(environment); - mAirbyteApiClient = mock(AirbyteApiClient.class); - mSourceApi = mock(SourceApi.class); - when(mAirbyteApiClient.getSourceApi()).thenReturn(mSourceApi); - when(mSourceApi.writeDiscoverCatalogResult(any())) - .thenReturn(new DiscoverCatalogResult().catalogId(CATALOG_ID)); - mConnectorConfigUpdater = mock(ConnectorConfigUpdater.class); - var envMap = new HashMap<>(new TestEnvConfigs().getJobDefaultEnvMap()); - envMap.put(EnvVariableFeatureFlags.DEPLOYMENT_MODE, featureFlags().deploymentMode()); - processFactory = new DockerProcessFactory( - workspaceRoot, - workspaceRoot.toString(), - localRoot.toString(), - "host", - envMap); - - postSetup(); - } - - /** - * Override this method if you want to do any per-test setup that depends on being able to e.g. - * {@link #runRead(ConfiguredAirbyteCatalog)}. - */ - protected void postSetup() throws Exception {} - - @AfterEach - public void tearDownInternal() throws Exception { - tearDown(environment); - } - - protected FeatureFlags featureFlags() { - return new EnvVariableFeatureFlags(); - } - - protected ConnectorSpecification runSpec() throws TestHarnessException { - final io.airbyte.protocol.models.ConnectorSpecification spec = new DefaultGetSpecTestHarness( - new AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, getImageName(), processFactory, null, null, false, - featureFlags())) - .run(new JobGetSpecConfig().withDockerImage(getImageName()), jobRoot).getSpec(); - return convertProtocolObject(spec, ConnectorSpecification.class); - } - - protected StandardCheckConnectionOutput runCheck() throws Exception { - return new DefaultCheckConnectionTestHarness( - new AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, getImageName(), processFactory, null, null, false, - featureFlags()), - mConnectorConfigUpdater) - .run(new StandardCheckConnectionInput().withConnectionConfiguration(getConfig()), jobRoot).getCheckConnection(); - } - - protected String runCheckAndGetStatusAsString(final JsonNode config) throws Exception { - return new DefaultCheckConnectionTestHarness( - new AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, getImageName(), processFactory, null, null, false, - featureFlags()), - mConnectorConfigUpdater) - .run(new StandardCheckConnectionInput().withConnectionConfiguration(config), jobRoot).getCheckConnection().getStatus().toString(); - } - - protected UUID runDiscover() throws Exception { - final UUID toReturn = new DefaultDiscoverCatalogTestHarness( - mAirbyteApiClient, - new AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, getImageName(), processFactory, null, null, false, - featureFlags()), - mConnectorConfigUpdater) - .run(new StandardDiscoverCatalogInput().withSourceId(SOURCE_ID.toString()).withConnectionConfiguration(getConfig()), jobRoot) - .getDiscoverCatalogId(); - verify(mSourceApi).writeDiscoverCatalogResult(discoverWriteRequest.capture()); - return toReturn; - } - - protected void checkEntrypointEnvVariable() throws Exception { - final String entrypoint = EntrypointEnvChecker.getEntrypointEnvVariable( - processFactory, - JOB_ID, - JOB_ATTEMPT, - jobRoot, - getImageName()); - - assertNotNull(entrypoint); - assertFalse(entrypoint.isBlank()); - } - - protected List runRead(final ConfiguredAirbyteCatalog configuredCatalog) throws Exception { - return runRead(configuredCatalog, null); - } - - // todo (cgardens) - assume no state since we are all full refresh right now. - protected List runRead(final ConfiguredAirbyteCatalog catalog, final JsonNode state) throws Exception { - final WorkerSourceConfig sourceConfig = new WorkerSourceConfig() - .withSourceConnectionConfiguration(getConfig()) - .withState(state == null ? null : new State().withState(state)) - .withCatalog(convertProtocolObject(catalog, io.airbyte.protocol.models.ConfiguredAirbyteCatalog.class)); - - final AirbyteSource source = new DefaultAirbyteSource( - new AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, getImageName(), processFactory, null, null, false, - featureFlags()), - featureFlags()); - final List messages = new ArrayList<>(); - source.start(sourceConfig, jobRoot); - while (!source.isFinished()) { - source.attemptRead().ifPresent(m -> messages.add(convertProtocolObject(m, AirbyteMessage.class))); - } - source.close(); - - return messages; - } - - protected Map runReadVerifyNumberOfReceivedMsgs(final ConfiguredAirbyteCatalog catalog, - final JsonNode state, - final Map mapOfExpectedRecordsCount) - throws Exception { - - final WorkerSourceConfig sourceConfig = new WorkerSourceConfig() - .withSourceConnectionConfiguration(getConfig()) - .withState(state == null ? null : new State().withState(state)) - .withCatalog(convertProtocolObject(catalog, io.airbyte.protocol.models.ConfiguredAirbyteCatalog.class)); - - final AirbyteSource source = prepareAirbyteSource(); - source.start(sourceConfig, jobRoot); - - while (!source.isFinished()) { - final Optional airbyteMessageOptional = source.attemptRead().map(m -> convertProtocolObject(m, AirbyteMessage.class)); - if (airbyteMessageOptional.isPresent() && airbyteMessageOptional.get().getType().equals(Type.RECORD)) { - final AirbyteMessage airbyteMessage = airbyteMessageOptional.get(); - final AirbyteRecordMessage record = airbyteMessage.getRecord(); - - final String streamName = record.getStream(); - mapOfExpectedRecordsCount.put(streamName, mapOfExpectedRecordsCount.get(streamName) - 1); - } - } - source.close(); - return mapOfExpectedRecordsCount; - } - - private AirbyteSource prepareAirbyteSource() { - final var integrationLauncher = new AirbyteIntegrationLauncher( - JOB_ID, - JOB_ATTEMPT, - getImageName(), - processFactory, - null, - null, - false, - featureFlags()); - return new DefaultAirbyteSource(integrationLauncher, featureFlags()); - } - - private static V0 convertProtocolObject(final V1 v1, final Class klass) { - return Jsons.object(Jsons.jsonNode(v1), klass); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.java deleted file mode 100644 index 9dcb95773cdb..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.java +++ /dev/null @@ -1,374 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.Lists; -import io.airbyte.cdk.db.Database; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.DestinationSyncMode; -import io.airbyte.protocol.models.v0.SyncMode; -import java.io.IOException; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import org.apache.commons.lang3.StringUtils; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This abstract class contains common helpers and boilerplate for comprehensively testing that all - * data types in a source can be read and handled correctly by the connector and within Airbyte's - * type system. - */ -public abstract class AbstractSourceDatabaseTypeTest extends AbstractSourceConnectorTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractSourceDatabaseTypeTest.class); - - protected final List testDataHolders = new ArrayList<>(); - protected Database database; - - /** - * The column name will be used for a PK column in the test tables. Override it if default name is - * not valid for your source. - * - * @return Id column name - */ - protected String getIdColumnName() { - return "id"; - } - - /** - * The column name will be used for a test column in the test tables. Override it if default name is - * not valid for your source. - * - * @return Test column name - */ - protected String getTestColumnName() { - return "test_column"; - } - - /** - * Setup the test database. All tables and data described in the registered tests will be put there. - * - * @return configured test database - * @throws Exception - might throw any exception during initialization. - */ - protected abstract Database setupDatabase() throws Exception; - - /** - * Put all required tests here using method {@link #addDataTypeTestData(TestDataHolder)} - */ - protected abstract void initTests(); - - @Override - protected void setupEnvironment(final TestDestinationEnv environment) throws Exception { - database = setupDatabase(); - initTests(); - createTables(); - populateTables(); - } - - /** - * Provide a source namespace. It's allocated place for table creation. It also known ask "Database - * Schema" or "Dataset" - * - * @return source name space - */ - protected abstract String getNameSpace(); - - /** - * Test the 'discover' command. TODO (liren): Some existing databases may fail testDataTypes(), so - * it is turned off by default. It should be enabled for all databases eventually. - */ - protected boolean testCatalog() { - return false; - } - - /** - * The test checks that the types from the catalog matches the ones discovered from the source. This - * test is disabled by default. To enable it you need to overwrite testCatalog() function. - */ - @Test - @SuppressWarnings("unchecked") - public void testDataTypes() throws Exception { - if (testCatalog()) { - runDiscover(); - final Map streams = getLastPersistedCatalog().getStreams().stream() - .collect(Collectors.toMap(AirbyteStream::getName, s -> s)); - - // testDataHolders should be initialized using the `addDataTypeTestData` function - testDataHolders.forEach(testDataHolder -> { - final AirbyteStream airbyteStream = streams.get(testDataHolder.getNameWithTestPrefix()); - final Map jsonSchemaTypeMap = (Map) Jsons.deserialize( - airbyteStream.getJsonSchema().get("properties").get(getTestColumnName()).toString(), Map.class); - assertEquals(testDataHolder.getAirbyteType().getJsonSchemaTypeMap(), jsonSchemaTypeMap, - "Expected column type for " + testDataHolder.getNameWithTestPrefix()); - }); - } - } - - /** - * The test checks that connector can fetch prepared data without failure. It uses a prepared - * catalog and read the source using that catalog. Then makes sure that the expected values are the - * ones inserted in the source. - */ - @Test - public void testDataContent() throws Exception { - // Class used to make easier the error reporting - class MissedRecords { - - // Stream that is missing any value - public String streamName; - // Which are the values that has not being gathered from the source - public List missedValues; - - public MissedRecords(String streamName, List missedValues) { - this.streamName = streamName; - this.missedValues = missedValues; - } - - } - - class UnexpectedRecord { - - public final String streamName; - public final String unexpectedValue; - - public UnexpectedRecord(String streamName, String unexpectedValue) { - this.streamName = streamName; - this.unexpectedValue = unexpectedValue; - } - - } - - final ConfiguredAirbyteCatalog catalog = getConfiguredCatalog(); - final List allMessages = runRead(catalog); - - final List recordMessages = allMessages.stream().filter(m -> m.getType() == Type.RECORD).toList(); - final Map> expectedValues = new HashMap<>(); - final Map> missedValuesByStream = new HashMap<>(); - final Map> unexpectedValuesByStream = new HashMap<>(); - final Map testByName = new HashMap<>(); - - // If there is no expected value in the test set we don't include it in the list to be asserted - // (even if the table contains records) - testDataHolders.forEach(testDataHolder -> { - if (!testDataHolder.getExpectedValues().isEmpty()) { - expectedValues.put(testDataHolder.getNameWithTestPrefix(), testDataHolder.getExpectedValues()); - testByName.put(testDataHolder.getNameWithTestPrefix(), testDataHolder); - } else { - LOGGER.warn("Missing expected values for type: " + testDataHolder.getSourceType()); - } - }); - - for (final AirbyteMessage message : recordMessages) { - final String streamName = message.getRecord().getStream(); - final List expectedValuesForStream = expectedValues.get(streamName); - if (expectedValuesForStream != null) { - final String value = getValueFromJsonNode(message.getRecord().getData().get(getTestColumnName())); - if (!expectedValuesForStream.contains(value)) { - unexpectedValuesByStream.putIfAbsent(streamName, new ArrayList<>()); - unexpectedValuesByStream.get(streamName).add(new UnexpectedRecord(streamName, value)); - } else { - expectedValuesForStream.remove(value); - } - } - } - - // Gather all the missing values, so we don't stop the test in the first missed one - expectedValues.forEach((streamName, values) -> { - if (!values.isEmpty()) { - missedValuesByStream.putIfAbsent(streamName, new ArrayList<>()); - missedValuesByStream.get(streamName).add(new MissedRecords(streamName, values)); - } - }); - - Map> errorsByStream = new HashMap<>(); - for (String streamName : unexpectedValuesByStream.keySet()) { - errorsByStream.putIfAbsent(streamName, new ArrayList<>()); - TestDataHolder test = testByName.get(streamName); - List unexpectedValues = unexpectedValuesByStream.get(streamName); - for (UnexpectedRecord unexpectedValue : unexpectedValues) { - errorsByStream.get(streamName).add( - "The stream '%s' checking type '%s' initialized at %s got unexpected values: %s".formatted(streamName, test.getSourceType(), - test.getDeclarationLocation(), unexpectedValue)); - } - } - - for (String streamName : missedValuesByStream.keySet()) { - errorsByStream.putIfAbsent(streamName, new ArrayList<>()); - TestDataHolder test = testByName.get(streamName); - List missedValues = missedValuesByStream.get(streamName); - for (MissedRecords missedValue : missedValues) { - errorsByStream.get(streamName).add( - "The stream '%s' checking type '%s' initialized at %s is missing values: %s".formatted(streamName, test.getSourceType(), - test.getDeclarationLocation(), missedValue)); - } - } - - List errorStrings = new ArrayList<>(); - for (List errors : errorsByStream.values()) { - errorStrings.add(StringUtils.join(errors, "\n")); - } - - assertTrue(errorsByStream.isEmpty(), StringUtils.join(errorStrings, "\n")); - } - - protected String getValueFromJsonNode(final JsonNode jsonNode) throws IOException { - if (jsonNode != null) { - if (jsonNode.isArray()) { - return jsonNode.toString(); - } - - String value = (jsonNode.isBinary() ? Arrays.toString(jsonNode.binaryValue()) : jsonNode.asText()); - value = (value != null && value.equals("null") ? null : value); - return value; - } - return null; - } - - /** - * Creates all tables and insert data described in the registered data type tests. - * - * @throws Exception might raise exception if configuration goes wrong or tables creation/insert - * scripts failed. - */ - - protected void createTables() throws Exception { - for (final TestDataHolder test : testDataHolders) { - database.query(ctx -> { - ctx.fetch(test.getCreateSqlQuery()); - LOGGER.info("Table {} is created.", test.getNameWithTestPrefix()); - return null; - }); - } - } - - protected void populateTables() throws Exception { - for (final TestDataHolder test : testDataHolders) { - database.query(ctx -> { - test.getInsertSqlQueries().forEach(ctx::fetch); - LOGGER.info("Inserted {} rows in Ttable {}", test.getInsertSqlQueries().size(), test.getNameWithTestPrefix()); - - return null; - }); - } - } - - /** - * Configures streams for all registered data type tests. - * - * @return configured catalog - */ - protected ConfiguredAirbyteCatalog getConfiguredCatalog() { - return new ConfiguredAirbyteCatalog().withStreams( - testDataHolders - .stream() - .map(test -> new ConfiguredAirbyteStream() - .withSyncMode(SyncMode.INCREMENTAL) - .withCursorField(Lists.newArrayList(getIdColumnName())) - .withDestinationSyncMode(DestinationSyncMode.APPEND) - .withStream(CatalogHelpers.createAirbyteStream( - String.format("%s", test.getNameWithTestPrefix()), - String.format("%s", getNameSpace()), - Field.of(getIdColumnName(), JsonSchemaType.INTEGER), - Field.of(getTestColumnName(), test.getAirbyteType())) - .withSourceDefinedCursor(true) - .withSourceDefinedPrimaryKey(List.of(List.of(getIdColumnName()))) - .withSupportedSyncModes( - Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)))) - .collect(Collectors.toList())); - } - - /** - * Register your test in the run scope. For each test will be created a table with one column of - * specified type. Note! If you register more than one test with the same type name, they will be - * run as independent tests with own streams. - * - * @param test comprehensive data type test - */ - public void addDataTypeTestData(final TestDataHolder test) { - testDataHolders.add(test); - test.setTestNumber(testDataHolders.stream().filter(t -> t.getSourceType().equals(test.getSourceType())).count()); - test.setNameSpace(getNameSpace()); - test.setIdColumnName(getIdColumnName()); - test.setTestColumnName(getTestColumnName()); - test.setDeclarationLocation(Thread.currentThread().getStackTrace()); - } - - private String formatCollection(final Collection collection) { - return collection.stream().map(s -> "`" + s + "`").collect(Collectors.joining(", ")); - } - - /** - * Builds a table with all registered test cases with values using Markdown syntax (can be used in - * the github). - * - * @return formatted list of test cases - */ - public String getMarkdownTestTable() { - final StringBuilder table = new StringBuilder() - .append("|**Data Type**|**Insert values**|**Expected values**|**Comment**|**Common test result**|\n") - .append("|----|----|----|----|----|\n"); - - testDataHolders.forEach(test -> table.append(String.format("| %s | %s | %s | %s | %s |\n", - test.getSourceType(), - formatCollection(test.getValues()), - formatCollection(test.getExpectedValues()), - "", - "Ok"))); - return table.toString(); - } - - protected void printMarkdownTestTable() { - LOGGER.info(getMarkdownTestTable()); - } - - protected ConfiguredAirbyteStream createDummyTableWithData(final Database database) throws SQLException { - database.query(ctx -> { - ctx.fetch("CREATE TABLE " + getNameSpace() + ".random_dummy_table(id INTEGER PRIMARY KEY, test_column VARCHAR(63));"); - ctx.fetch("INSERT INTO " + getNameSpace() + ".random_dummy_table VALUES (2, 'Random Data');"); - return null; - }); - - return new ConfiguredAirbyteStream().withSyncMode(SyncMode.INCREMENTAL) - .withCursorField(Lists.newArrayList("id")) - .withDestinationSyncMode(DestinationSyncMode.APPEND) - .withStream(CatalogHelpers.createAirbyteStream( - "random_dummy_table", - getNameSpace(), - Field.of("id", JsonSchemaType.INTEGER), - Field.of("test_column", JsonSchemaType.STRING)) - .withSourceDefinedCursor(true) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(List.of("id")))); - - } - - protected List extractStateMessages(final List messages) { - return messages.stream().filter(r -> r.getType() == Type.STATE).map(AirbyteMessage::getState) - .collect(Collectors.toList()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.java deleted file mode 100644 index f5caa2ad9978..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.Lists; -import com.google.common.collect.Streams; -import io.airbyte.commons.io.IOs; -import io.airbyte.commons.io.LineGobbler; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConnectorSpecification; -import io.airbyte.workers.TestHarnessUtils; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.List; -import java.util.concurrent.TimeUnit; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Extends TestSource such that it can be called using resources pulled from the file system. Will - * also add the ability to execute arbitrary scripts in the next version. - */ -public class PythonSourceAcceptanceTest extends SourceAcceptanceTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(PythonSourceAcceptanceTest.class); - private static final String OUTPUT_FILENAME = "output.json"; - - public static String IMAGE_NAME; - public static String PYTHON_CONTAINER_NAME; - - private Path testRoot; - - @Override - protected String getImageName() { - return IMAGE_NAME; - } - - @Override - protected ConnectorSpecification getSpec() throws IOException { - return runExecutable(Command.GET_SPEC, ConnectorSpecification.class); - } - - @Override - protected JsonNode getConfig() throws IOException { - return runExecutable(Command.GET_CONFIG); - } - - @Override - protected ConfiguredAirbyteCatalog getConfiguredCatalog() throws IOException { - return runExecutable(Command.GET_CONFIGURED_CATALOG, ConfiguredAirbyteCatalog.class); - } - - @Override - protected JsonNode getState() throws IOException { - return runExecutable(Command.GET_STATE); - } - - @Override - protected void assertFullRefreshMessages(final List allMessages) throws IOException { - final List regexTests = Streams.stream(runExecutable(Command.GET_REGEX_TESTS).withArray("tests").elements()) - .map(JsonNode::textValue).toList(); - final List stringMessages = allMessages.stream().map(Jsons::serialize).toList(); - LOGGER.info("Running " + regexTests.size() + " regex tests..."); - regexTests.forEach(regex -> { - LOGGER.info("Looking for [" + regex + "]"); - assertTrue(stringMessages.stream().anyMatch(line -> line.matches(regex)), "Failed to find regex: " + regex); - }); - } - - @Override - protected void setupEnvironment(final TestDestinationEnv environment) throws Exception { - testRoot = Files.createTempDirectory(Files.createDirectories(Path.of("/tmp/standard_test")), "pytest"); - runExecutableVoid(Command.SETUP); - } - - @Override - protected void tearDown(final TestDestinationEnv testEnv) throws Exception { - runExecutableVoid(Command.TEARDOWN); - } - - private enum Command { - GET_SPEC, - GET_CONFIG, - GET_CONFIGURED_CATALOG, - GET_STATE, - GET_REGEX_TESTS, - SETUP, - TEARDOWN - } - - private T runExecutable(final Command cmd, final Class klass) throws IOException { - return Jsons.object(runExecutable(cmd), klass); - } - - private JsonNode runExecutable(final Command cmd) throws IOException { - return Jsons.deserialize(IOs.readFile(runExecutableInternal(cmd), OUTPUT_FILENAME)); - } - - private void runExecutableVoid(final Command cmd) throws IOException { - runExecutableInternal(cmd); - } - - private Path runExecutableInternal(final Command cmd) throws IOException { - LOGGER.info("testRoot = " + testRoot); - final List dockerCmd = - Lists.newArrayList( - "docker", - "run", - "--rm", - "-i", - "-v", - String.format("%s:%s", testRoot, "/test_root"), - "-w", - testRoot.toString(), - "--network", - "host", - PYTHON_CONTAINER_NAME, - cmd.toString().toLowerCase(), - "--out", - "/test_root"); - - final Process process = new ProcessBuilder(dockerCmd).start(); - LineGobbler.gobble(process.getErrorStream(), LOGGER::error); - LineGobbler.gobble(process.getInputStream(), LOGGER::info); - - TestHarnessUtils.gentleClose(process, 1, TimeUnit.MINUTES); - - final int exitCode = process.exitValue(); - if (exitCode != 0) { - throw new RuntimeException("python execution failed"); - } - - return testRoot; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.java deleted file mode 100644 index 9e77e0037d35..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.java +++ /dev/null @@ -1,393 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import static io.airbyte.protocol.models.v0.SyncMode.FULL_REFRESH; -import static io.airbyte.protocol.models.v0.SyncMode.INCREMENTAL; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.google.common.collect.Iterables; -import com.google.common.collect.Sets; -import io.airbyte.commons.json.Jsons; -import io.airbyte.configoss.StandardCheckConnectionOutput.Status; -import io.airbyte.protocol.models.v0.AirbyteCatalog; -import io.airbyte.protocol.models.v0.AirbyteMessage; -import io.airbyte.protocol.models.v0.AirbyteMessage.Type; -import io.airbyte.protocol.models.v0.AirbyteRecordMessage; -import io.airbyte.protocol.models.v0.AirbyteStateMessage; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.ConnectorSpecification; -import io.airbyte.protocol.models.v0.DestinationSyncMode; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Objects; -import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public abstract class SourceAcceptanceTest extends AbstractSourceConnectorTest { - - public static final String CDC_LSN = "_ab_cdc_lsn"; - public static final String CDC_UPDATED_AT = "_ab_cdc_updated_at"; - public static final String CDC_DELETED_AT = "_ab_cdc_deleted_at"; - public static final String CDC_LOG_FILE = "_ab_cdc_log_file"; - public static final String CDC_LOG_POS = "_ab_cdc_log_pos"; - public static final String CDC_DEFAULT_CURSOR = "_ab_cdc_cursor"; - public static final String CDC_EVENT_SERIAL_NO = "_ab_cdc_event_serial_no"; - - private static final Logger LOGGER = LoggerFactory.getLogger(SourceAcceptanceTest.class); - - /** - * TODO hack: Various Singer integrations use cursor fields inclusively i.e: they output records - * whose cursor field >= the provided cursor value. This leads to the last record in a sync to - * always be the first record in the next sync. This is a fine assumption from a product POV since - * we offer at-least-once delivery. But for simplicity, the incremental test suite currently assumes - * that the second incremental read should output no records when provided the state from the first - * sync. This works for many integrations but not some Singer ones, so we hardcode the list of - * integrations to skip over when performing those tests. - */ - private final Set IMAGES_TO_SKIP_SECOND_INCREMENTAL_READ = Sets.newHashSet( - "airbyte/source-intercom-singer", - "airbyte/source-exchangeratesapi-singer", - "airbyte/source-hubspot", - "airbyte/source-iterable", - "airbyte/source-marketo-singer", - "airbyte/source-twilio-singer", - "airbyte/source-mixpanel-singer", - "airbyte/source-twilio-singer", - "airbyte/source-braintree-singer", - "airbyte/source-stripe-singer", - "airbyte/source-exchange-rates", - "airbyte/source-stripe", - "airbyte/source-github-singer", - "airbyte/source-gitlab-singer", - "airbyte/source-google-workspace-admin-reports", - "airbyte/source-zendesk-talk", - "airbyte/source-zendesk-support-singer", - "airbyte/source-quickbooks-singer", - "airbyte/source-jira"); - - /** - * FIXME: Some sources can't guarantee that there will be no events between two sequential sync - */ - private final Set IMAGES_TO_SKIP_IDENTICAL_FULL_REFRESHES = Sets.newHashSet( - "airbyte/source-google-workspace-admin-reports", "airbyte/source-kafka"); - - /** - * Specification for integration. Will be passed to integration where appropriate in each test. - * Should be valid. - * - * @return integration-specific configuration - */ - protected abstract ConnectorSpecification getSpec() throws Exception; - - /** - * The catalog to use to validate the output of read operations. This will be used as follows: - *

- * Full Refresh syncs will be tested on all the input streams which support it Incremental syncs: - - * if the stream declares a source-defined cursor, it will be tested with an incremental sync using - * the default cursor. - if the stream requires a user-defined cursor, it will be tested with the - * input cursor in both cases, the input {@link #getState()} will be used as the input state. - * - * @return - * @throws Exception - */ - protected abstract ConfiguredAirbyteCatalog getConfiguredCatalog() throws Exception; - - /** - * @return a JSON file representing the state file to use when testing incremental syncs - */ - protected abstract JsonNode getState() throws Exception; - - /** - * Verify that a spec operation issued to the connector returns a valid spec. - */ - @Test - public void testGetSpec() throws Exception { - assertEquals(getSpec(), runSpec(), "Expected spec output by integration to be equal to spec provided by test runner"); - } - - /** - * Verify that a check operation issued to the connector with the input config file returns a - * success response. - */ - @Test - public void testCheckConnection() throws Exception { - assertEquals(Status.SUCCEEDED, runCheck().getStatus(), "Expected check connection operation to succeed"); - } - - // /** - // * Verify that when given invalid credentials, that check connection returns a failed response. - // * Assume that the {@link TestSource#getFailCheckConfig()} is invalid. - // */ - // @Test - // public void testCheckConnectionInvalidCredentials() throws Exception { - // final OutputAndStatus output = runCheck(); - // assertTrue(output.getOutput().isPresent()); - // assertEquals(Status.FAILED, output.getOutput().get().getStatus()); - // } - - /** - * Verifies when a discover operation is run on the connector using the given config file, a valid - * catalog is output by the connector. - */ - @Test - public void testDiscover() throws Exception { - final UUID discoverOutput = runDiscover(); - final AirbyteCatalog discoveredCatalog = getLastPersistedCatalog(); - assertNotNull(discoveredCatalog, "Expected discover to produce a catalog"); - verifyCatalog(discoveredCatalog); - } - - /** - * Override this method to check the actual catalog. - */ - protected void verifyCatalog(final AirbyteCatalog catalog) throws Exception { - // do nothing by default - } - - /** - * Configuring all streams in the input catalog to full refresh mode, verifies that a read operation - * produces some RECORD messages. - */ - @Test - public void testFullRefreshRead() throws Exception { - if (!sourceSupportsFullRefresh()) { - LOGGER.info("Test skipped. Source does not support full refresh."); - return; - } - - final ConfiguredAirbyteCatalog catalog = withFullRefreshSyncModes(getConfiguredCatalog()); - final List allMessages = runRead(catalog); - - assertFalse(filterRecords(allMessages).isEmpty(), "Expected a full refresh sync to produce records"); - assertFullRefreshMessages(allMessages); - } - - /** - * Override this method to perform more specific assertion on the messages. - */ - protected void assertFullRefreshMessages(final List allMessages) throws Exception { - // do nothing by default - } - - /** - * Configuring all streams in the input catalog to full refresh mode, performs two read operations - * on all streams which support full refresh syncs. It then verifies that the RECORD messages output - * from both were identical. - */ - @Test - public void testIdenticalFullRefreshes() throws Exception { - if (!sourceSupportsFullRefresh()) { - LOGGER.info("Test skipped. Source does not support full refresh."); - return; - } - - if (IMAGES_TO_SKIP_IDENTICAL_FULL_REFRESHES.contains(getImageName().split(":")[0])) { - return; - } - - final ConfiguredAirbyteCatalog configuredCatalog = withFullRefreshSyncModes(getConfiguredCatalog()); - final List recordMessagesFirstRun = filterRecords(runRead(configuredCatalog)); - final List recordMessagesSecondRun = filterRecords(runRead(configuredCatalog)); - // the worker validates the messages, so we just validate the message, so we do not need to validate - // again (as long as we use the worker, which we will not want to do long term). - assertFalse(recordMessagesFirstRun.isEmpty(), "Expected first full refresh to produce records"); - assertFalse(recordMessagesSecondRun.isEmpty(), "Expected second full refresh to produce records"); - - assertSameRecords(recordMessagesFirstRun, recordMessagesSecondRun, "Expected two full refresh syncs to produce the same records"); - } - - /** - * This test verifies that all streams in the input catalog which support incremental sync can do so - * correctly. It does this by running two read operations on the connector's Docker image: the first - * takes the configured catalog and config provided to this test as input. It then verifies that the - * sync produced a non-zero number of RECORD and STATE messages. - *

- * The second read takes the same catalog and config used in the first test, plus the last STATE - * message output by the first read operation as the input state file. It verifies that no records - * are produced (since we read all records in the first sync). - *

- * This test is performed only for streams which support incremental. Streams which do not support - * incremental sync are ignored. If no streams in the input catalog support incremental sync, this - * test is skipped. - */ - @Test - public void testIncrementalSyncWithState() throws Exception { - if (!sourceSupportsIncremental()) { - return; - } - - final ConfiguredAirbyteCatalog configuredCatalog = withSourceDefinedCursors(getConfiguredCatalog()); - // only sync incremental streams - configuredCatalog.setStreams( - configuredCatalog.getStreams().stream().filter(s -> s.getSyncMode() == INCREMENTAL).collect(Collectors.toList())); - - final List airbyteMessages = runRead(configuredCatalog, getState()); - final List recordMessages = filterRecords(airbyteMessages); - final List stateMessages = airbyteMessages - .stream() - .filter(m -> m.getType() == Type.STATE) - .map(AirbyteMessage::getState) - .collect(Collectors.toList()); - assertFalse(recordMessages.isEmpty(), "Expected the first incremental sync to produce records"); - assertFalse(stateMessages.isEmpty(), "Expected incremental sync to produce STATE messages"); - // TODO validate exact records - - if (IMAGES_TO_SKIP_SECOND_INCREMENTAL_READ.contains(getImageName().split(":")[0])) { - return; - } - - // when we run incremental sync again there should be no new records. Run a sync with the latest - // state message and assert no records were emitted. - JsonNode latestState = null; - for (final AirbyteStateMessage stateMessage : stateMessages) { - if (stateMessage.getType().equals(AirbyteStateMessage.AirbyteStateType.STREAM)) { - latestState = Jsons.jsonNode(stateMessages); - break; - } else if (stateMessage.getType().equals(AirbyteStateMessage.AirbyteStateType.GLOBAL)) { - latestState = Jsons.jsonNode(List.of(Iterables.getLast(stateMessages))); - break; - } else { - throw new RuntimeException("Unknown state type " + stateMessage.getType()); - } - } - - assert Objects.nonNull(latestState); - final List secondSyncRecords = filterRecords(runRead(configuredCatalog, latestState)); - assertTrue( - secondSyncRecords.isEmpty(), - "Expected the second incremental sync to produce no records when given the first sync's output state."); - } - - /** - * If the source does not support incremental sync, this test is skipped. - *

- * Otherwise, this test runs two syncs: one where all streams provided in the input catalog sync in - * full refresh mode, and another where all the streams which in the input catalog which support - * incremental, sync in incremental mode (streams which don't support incremental sync in full - * refresh mode). Then, the test asserts that the two syncs produced the same RECORD messages. Any - * other type of message is disregarded. - */ - @Test - public void testEmptyStateIncrementalIdenticalToFullRefresh() throws Exception { - if (!sourceSupportsIncremental()) { - return; - } - - if (!sourceSupportsFullRefresh()) { - LOGGER.info("Test skipped. Source does not support full refresh."); - return; - } - - final ConfiguredAirbyteCatalog configuredCatalog = getConfiguredCatalog(); - final ConfiguredAirbyteCatalog fullRefreshCatalog = withFullRefreshSyncModes(configuredCatalog); - - final List fullRefreshRecords = filterRecords(runRead(fullRefreshCatalog)); - final List emptyStateRecords = filterRecords(runRead(configuredCatalog, Jsons.jsonNode(new HashMap<>()))); - assertFalse(fullRefreshRecords.isEmpty(), "Expected a full refresh sync to produce records"); - assertFalse(emptyStateRecords.isEmpty(), "Expected state records to not be empty"); - assertSameRecords(fullRefreshRecords, emptyStateRecords, - "Expected a full refresh sync and incremental sync with no input state to produce identical records"); - } - - /** - * In order to launch a source on Kubernetes in a pod, we need to be able to wrap the entrypoint. - * The source connector must specify its entrypoint in the AIRBYTE_ENTRYPOINT variable. This test - * ensures that the entrypoint environment variable is set. - */ - @Test - public void testEntrypointEnvVar() throws Exception { - checkEntrypointEnvVariable(); - } - - protected static List filterRecords(final Collection messages) { - return messages.stream() - .filter(m -> m.getType() == Type.RECORD) - .map(AirbyteMessage::getRecord) - .collect(Collectors.toList()); - } - - protected ConfiguredAirbyteCatalog withSourceDefinedCursors(final ConfiguredAirbyteCatalog catalog) { - final ConfiguredAirbyteCatalog clone = Jsons.clone(catalog); - for (final ConfiguredAirbyteStream configuredStream : clone.getStreams()) { - if (configuredStream.getSyncMode() == INCREMENTAL - && configuredStream.getStream().getSourceDefinedCursor() != null - && configuredStream.getStream().getSourceDefinedCursor()) { - configuredStream.setCursorField(configuredStream.getStream().getDefaultCursorField()); - } - } - return clone; - } - - protected ConfiguredAirbyteCatalog withFullRefreshSyncModes(final ConfiguredAirbyteCatalog catalog) { - final ConfiguredAirbyteCatalog clone = Jsons.clone(catalog); - for (final ConfiguredAirbyteStream configuredStream : clone.getStreams()) { - if (configuredStream.getStream().getSupportedSyncModes().contains(FULL_REFRESH)) { - configuredStream.setSyncMode(FULL_REFRESH); - configuredStream.setDestinationSyncMode(DestinationSyncMode.OVERWRITE); - } - } - return clone; - } - - private boolean sourceSupportsIncremental() throws Exception { - return sourceSupports(INCREMENTAL); - } - - private boolean sourceSupportsFullRefresh() throws Exception { - return sourceSupports(FULL_REFRESH); - } - - private boolean sourceSupports(final SyncMode syncMode) throws Exception { - final ConfiguredAirbyteCatalog catalog = getConfiguredCatalog(); - for (final ConfiguredAirbyteStream stream : catalog.getStreams()) { - if (stream.getStream().getSupportedSyncModes().contains(syncMode)) { - return true; - } - } - return false; - } - - private void assertSameRecords(final List expected, final List actual, final String message) { - final List prunedExpected = expected.stream().map(this::pruneEmittedAt).collect(Collectors.toList()); - final List prunedActual = actual - .stream() - .map(this::pruneEmittedAt) - .map(this::pruneCdcMetadata) - .collect(Collectors.toList()); - assertEquals(prunedExpected.size(), prunedActual.size(), message); - assertTrue(prunedExpected.containsAll(prunedActual), message); - assertTrue(prunedActual.containsAll(prunedExpected), message); - } - - private AirbyteRecordMessage pruneEmittedAt(final AirbyteRecordMessage m) { - return Jsons.clone(m).withEmittedAt(null); - } - - private AirbyteRecordMessage pruneCdcMetadata(final AirbyteRecordMessage m) { - final AirbyteRecordMessage clone = Jsons.clone(m); - ((ObjectNode) clone.getData()).remove(CDC_LSN); - ((ObjectNode) clone.getData()).remove(CDC_LOG_FILE); - ((ObjectNode) clone.getData()).remove(CDC_LOG_POS); - ((ObjectNode) clone.getData()).remove(CDC_UPDATED_AT); - ((ObjectNode) clone.getData()).remove(CDC_DELETED_AT); - ((ObjectNode) clone.getData()).remove(CDC_EVENT_SERIAL_NO); - ((ObjectNode) clone.getData()).remove(CDC_DEFAULT_CURSOR); - return clone; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.java deleted file mode 100644 index 8c8e0b103b30..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.java +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class TestDataHolder { - - private static final String DEFAULT_CREATE_TABLE_SQL = "CREATE TABLE %1$s(%2$s INTEGER PRIMARY KEY, %3$s %4$s)"; - private static final String DEFAULT_INSERT_SQL = "INSERT INTO %1$s VALUES (%2$s, %3$s)"; - - private final String sourceType; - private final JsonSchemaType airbyteType; - private final List values; - private final List expectedValues; - private final String createTablePatternSql; - private final String insertPatternSql; - private final String fullSourceDataType; - private String nameSpace; - private long testNumber; - private String idColumnName; - private String testColumnName; - - private StackTraceElement[] declarationLocation; - - TestDataHolder(final String sourceType, - final JsonSchemaType airbyteType, - final List values, - final List expectedValues, - final String createTablePatternSql, - final String insertPatternSql, - final String fullSourceDataType) { - this.sourceType = sourceType; - this.airbyteType = airbyteType; - this.values = values; - this.expectedValues = expectedValues; - this.createTablePatternSql = createTablePatternSql; - this.insertPatternSql = insertPatternSql; - this.fullSourceDataType = fullSourceDataType; - } - - /** - * The builder allows to setup any comprehensive data type test. - * - * @return builder for setup comprehensive test - */ - public static TestDataHolderBuilder builder() { - return new TestDataHolderBuilder(); - } - - public static class TestDataHolderBuilder { - - private String sourceType; - private JsonSchemaType airbyteType; - private final List values = new ArrayList<>(); - private final List expectedValues = new ArrayList<>(); - private String createTablePatternSql; - private String insertPatternSql; - private String fullSourceDataType; - - TestDataHolderBuilder() { - this.createTablePatternSql = DEFAULT_CREATE_TABLE_SQL; - this.insertPatternSql = DEFAULT_INSERT_SQL; - } - - /** - * The name of the source data type. Duplicates by name will be tested independently from each - * others. Note that this name will be used for connector setup and table creation. If source syntax - * requires more details (E.g. "varchar" type requires length "varchar(50)"), you can additionally - * set custom data type syntax by {@link TestDataHolderBuilder#fullSourceDataType(String)} method. - * - * @param sourceType source data type name - * @return builder - */ - public TestDataHolderBuilder sourceType(final String sourceType) { - this.sourceType = sourceType; - if (fullSourceDataType == null) - fullSourceDataType = sourceType; - return this; - } - - /** - * corresponding Airbyte data type. It requires for proper configuration - * {@link ConfiguredAirbyteStream} - * - * @param airbyteType Airbyte data type - * @return builder - */ - public TestDataHolderBuilder airbyteType(final JsonSchemaType airbyteType) { - this.airbyteType = airbyteType; - return this; - } - - /** - * Set custom the create table script pattern. Use it if you source uses untypical table creation - * sql. Default patter described {@link #DEFAULT_CREATE_TABLE_SQL} Note! The patter should contain - * four String place holders for the: - namespace.table name (as one placeholder together) - id - * column name - test column name - test column data type - * - * @param createTablePatternSql creation table sql pattern - * @return builder - */ - public TestDataHolderBuilder createTablePatternSql(final String createTablePatternSql) { - this.createTablePatternSql = createTablePatternSql; - return this; - } - - /** - * Set custom the insert record script pattern. Use it if you source uses untypical insert record - * sql. Default patter described {@link #DEFAULT_INSERT_SQL} Note! The patter should contains two - * String place holders for the table name and value. - * - * @param insertPatternSql creation table sql pattern - * @return builder - */ - public TestDataHolderBuilder insertPatternSql(final String insertPatternSql) { - this.insertPatternSql = insertPatternSql; - return this; - } - - /** - * Allows to set extended data type for the table creation. E.g. The "varchar" type requires in - * MySQL requires length. In this case fullSourceDataType will be "varchar(50)". - * - * @param fullSourceDataType actual string for the column data type description - * @return builder - */ - public TestDataHolderBuilder fullSourceDataType(final String fullSourceDataType) { - this.fullSourceDataType = fullSourceDataType; - return this; - } - - /** - * Adds value(s) to the scope of a corresponding test. The values will be inserted into the created - * table. Note! The value will be inserted into the insert script without any transformations. Make - * sure that the value is in line with the source syntax. - * - * @param insertValue test value - * @return builder - */ - public TestDataHolderBuilder addInsertValues(final String... insertValue) { - this.values.addAll(Arrays.asList(insertValue)); - return this; - } - - /** - * Adds expected value(s) to the test scope. If you add at least one value, it will check that all - * values are provided by corresponding streamer. - * - * @param expectedValue value which should be provided by a streamer - * @return builder - */ - public TestDataHolderBuilder addExpectedValues(final String... expectedValue) { - this.expectedValues.addAll(Arrays.asList(expectedValue)); - return this; - } - - /** - * Add NULL value to the expected value list. If you need to add only one value and it's NULL, you - * have to use this method instead of {@link #addExpectedValues(String...)} - * - * @return builder - */ - public TestDataHolderBuilder addNullExpectedValue() { - this.expectedValues.add(null); - return this; - } - - public TestDataHolder build() { - return new TestDataHolder(sourceType, airbyteType, values, expectedValues, createTablePatternSql, insertPatternSql, fullSourceDataType); - } - - } - - void setNameSpace(final String nameSpace) { - this.nameSpace = nameSpace; - } - - void setTestNumber(final long testNumber) { - this.testNumber = testNumber; - } - - void setIdColumnName(final String idColumnName) { - this.idColumnName = idColumnName; - } - - void setTestColumnName(final String testColumnName) { - this.testColumnName = testColumnName; - } - - public String getSourceType() { - return sourceType; - } - - public JsonSchemaType getAirbyteType() { - return airbyteType; - } - - public List getExpectedValues() { - return expectedValues; - } - - public List getValues() { - return values; - } - - public String getNameSpace() { - return nameSpace; - } - - public String getNameWithTestPrefix() { - // source type may include space (e.g. "character varying") - return nameSpace + "_" + testNumber + "_" + sourceType.replaceAll("\\s", "_"); - } - - public String getCreateSqlQuery() { - return String.format(createTablePatternSql, (nameSpace != null ? nameSpace + "." : "") + getNameWithTestPrefix(), idColumnName, testColumnName, - fullSourceDataType); - } - - void setDeclarationLocation(StackTraceElement[] declarationLocation) { - this.declarationLocation = declarationLocation; - } - - public String getDeclarationLocation() { - return Arrays.asList(declarationLocation).subList(2, 3).toString(); - } - - public List getInsertSqlQueries() { - final List insertSqls = new ArrayList<>(); - int rowId = 1; - for (final String value : values) { - insertSqls.add(String.format(insertPatternSql, (nameSpace != null ? nameSpace + "." : "") + getNameWithTestPrefix(), rowId++, value)); - } - return insertSqls; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestDestinationEnv.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestDestinationEnv.java deleted file mode 100644 index 451cb4864b8c..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestDestinationEnv.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import java.nio.file.Path; - -public class TestDestinationEnv { - - private final Path localRoot; - - public TestDestinationEnv(final Path localRoot) { - this.localRoot = localRoot; - } - - public Path getLocalRoot() { - return localRoot; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.java deleted file mode 100644 index 88992d8da6c4..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import com.google.common.base.Preconditions; -import io.airbyte.commons.lang.Exceptions; -import io.airbyte.commons.map.MoreMaps; -import io.airbyte.commons.version.AirbyteVersion; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Objects; -import java.util.Set; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This class passes environment variable to the DockerProcessFactory that runs the source in the - * SourceAcceptanceTest. - */ -// todo (cgardens) - this cloud_deployment implicit interface is going to bite us. -public class TestEnvConfigs { - - private static final Logger LOGGER = LoggerFactory.getLogger(TestEnvConfigs.class); - - // env variable names - public static final String AIRBYTE_ROLE = "AIRBYTE_ROLE"; - public static final String AIRBYTE_VERSION = "AIRBYTE_VERSION"; - public static final String WORKER_ENVIRONMENT = "WORKER_ENVIRONMENT"; - public static final String DEPLOYMENT_MODE = "DEPLOYMENT_MODE"; - public static final String JOB_DEFAULT_ENV_PREFIX = "JOB_DEFAULT_ENV_"; - - public static final Map> JOB_SHARED_ENVS = Map.of( - AIRBYTE_VERSION, (instance) -> instance.getAirbyteVersion().serialize(), - AIRBYTE_ROLE, TestEnvConfigs::getAirbyteRole, - DEPLOYMENT_MODE, (instance) -> instance.getDeploymentMode().name(), - WORKER_ENVIRONMENT, (instance) -> instance.getWorkerEnvironment().name()); - - enum DeploymentMode { - OSS, - CLOUD - } - - enum WorkerEnvironment { - DOCKER, - KUBERNETES - } - - private final Function getEnv; - private final Supplier> getAllEnvKeys; - - public TestEnvConfigs() { - this(System.getenv()); - } - - private TestEnvConfigs(final Map envMap) { - getEnv = envMap::get; - getAllEnvKeys = envMap::keySet; - } - - // CORE - // General - public String getAirbyteRole() { - return getEnv(AIRBYTE_ROLE); - } - - public AirbyteVersion getAirbyteVersion() { - return new AirbyteVersion(getEnsureEnv(AIRBYTE_VERSION)); - } - - public DeploymentMode getDeploymentMode() { - return getEnvOrDefault(DEPLOYMENT_MODE, DeploymentMode.OSS, s -> { - try { - return DeploymentMode.valueOf(s); - } catch (final IllegalArgumentException e) { - LOGGER.info(s + " not recognized, defaulting to " + DeploymentMode.OSS); - return DeploymentMode.OSS; - } - }); - } - - public WorkerEnvironment getWorkerEnvironment() { - return getEnvOrDefault(WORKER_ENVIRONMENT, WorkerEnvironment.DOCKER, s -> WorkerEnvironment.valueOf(s.toUpperCase())); - } - - /** - * There are two types of environment variables available to the job container: - *

    - *
  • Exclusive variables prefixed with JOB_DEFAULT_ENV_PREFIX
  • - *
  • Shared variables defined in JOB_SHARED_ENVS
  • - *
- */ - public Map getJobDefaultEnvMap() { - final Map jobPrefixedEnvMap = getAllEnvKeys.get().stream() - .filter(key -> key.startsWith(JOB_DEFAULT_ENV_PREFIX)) - .collect(Collectors.toMap(key -> key.replace(JOB_DEFAULT_ENV_PREFIX, ""), getEnv)); - // This method assumes that these shared env variables are not critical to the execution - // of the jobs, and only serve as metadata. So any exception is swallowed and default to - // an empty string. Change this logic if this assumption no longer holds. - final Map jobSharedEnvMap = JOB_SHARED_ENVS.entrySet().stream().collect(Collectors.toMap( - Entry::getKey, - entry -> Exceptions.swallowWithDefault(() -> Objects.requireNonNullElse(entry.getValue().apply(this), ""), ""))); - return MoreMaps.merge(jobPrefixedEnvMap, jobSharedEnvMap); - } - - public T getEnvOrDefault(final String key, final T defaultValue, final Function parser) { - return getEnvOrDefault(key, defaultValue, parser, false); - } - - public T getEnvOrDefault(final String key, final T defaultValue, final Function parser, final boolean isSecret) { - final String value = getEnv.apply(key); - if (value != null && !value.isEmpty()) { - return parser.apply(value); - } else { - LOGGER.info("Using default value for environment variable {}: '{}'", key, isSecret ? "*****" : defaultValue); - return defaultValue; - } - } - - public String getEnv(final String name) { - return getEnv.apply(name); - } - - public String getEnsureEnv(final String name) { - final String value = getEnv(name); - Preconditions.checkArgument(value != null, "'%s' environment variable cannot be null", name); - - return value; - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.java deleted file mode 100644 index f00f0f2a7e19..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import net.sourceforge.argparse4j.ArgumentParsers; -import net.sourceforge.argparse4j.inf.ArgumentParser; -import net.sourceforge.argparse4j.inf.ArgumentParserException; -import net.sourceforge.argparse4j.inf.Namespace; - -/** - * Parse command line arguments and inject them into the test class before running the test. Then - * runs the tests. - */ -public class TestPythonSourceMain { - - public static void main(final String[] args) { - final ArgumentParser parser = ArgumentParsers.newFor(TestPythonSourceMain.class.getName()).build() - .defaultHelp(true) - .description("Run standard source tests"); - - parser.addArgument("--imageName") - .help("Name of the integration image"); - - parser.addArgument("--pythonContainerName") - .help("Name of the python integration image"); - - Namespace ns = null; - try { - ns = parser.parseArgs(args); - } catch (final ArgumentParserException e) { - parser.handleError(e); - System.exit(1); - } - - final String imageName = ns.getString("imageName"); - final String pythonContainerName = ns.getString("pythonContainerName"); - - PythonSourceAcceptanceTest.IMAGE_NAME = imageName; - PythonSourceAcceptanceTest.PYTHON_CONTAINER_NAME = pythonContainerName; - - TestRunner.runTestClass(PythonSourceAcceptanceTest.class); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestRunner.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestRunner.java deleted file mode 100644 index 1f27307421fc..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/TestRunner.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source; - -import static org.junit.platform.engine.discovery.DiscoverySelectors.selectClass; - -import java.io.PrintWriter; -import java.nio.charset.StandardCharsets; -import org.junit.platform.launcher.Launcher; -import org.junit.platform.launcher.LauncherDiscoveryRequest; -import org.junit.platform.launcher.TestPlan; -import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder; -import org.junit.platform.launcher.core.LauncherFactory; -import org.junit.platform.launcher.listeners.SummaryGeneratingListener; - -public class TestRunner { - - public static void runTestClass(final Class testClass) { - final LauncherDiscoveryRequest request = LauncherDiscoveryRequestBuilder.request() - .selectors(selectClass(testClass)) - .build(); - - final TestPlan plan = LauncherFactory.create().discover(request); - final Launcher launcher = LauncherFactory.create(); - - // Register a listener of your choice - final SummaryGeneratingListener listener = new SummaryGeneratingListener(); - - launcher.execute(plan, listener); - - listener.getSummary().printFailuresTo(new PrintWriter(System.out, false, StandardCharsets.UTF_8)); - listener.getSummary().printTo(new PrintWriter(System.out, false, StandardCharsets.UTF_8)); - - if (listener.getSummary().getTestsFailedCount() > 0) { - System.out.println( - "There are failing tests. See https://docs.airbyte.io/contributing-to-airbyte/building-new-connector/standard-source-tests " + - "for more information about the standard source test suite."); - System.exit(1); - } - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.java deleted file mode 100644 index 9df6e564d945..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source.fs; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.cdk.integrations.standardtest.source.SourceAcceptanceTest; -import io.airbyte.cdk.integrations.standardtest.source.TestDestinationEnv; -import io.airbyte.commons.io.IOs; -import io.airbyte.commons.json.Jsons; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConnectorSpecification; -import java.nio.file.Path; -import javax.annotation.Nullable; - -/** - * Extends TestSource such that it can be called using resources pulled from the file system. Will - * also add the ability to execute arbitrary scripts in the next version. - */ -public class ExecutableTestSource extends SourceAcceptanceTest { - - public static class TestConfig { - - private final String imageName; - private final Path specPath; - private final Path configPath; - private final Path catalogPath; - - private final Path statePath; - - public TestConfig(final String imageName, final Path specPath, final Path configPath, final Path catalogPath, final Path statePath) { - this.imageName = imageName; - this.specPath = specPath; - this.configPath = configPath; - this.catalogPath = catalogPath; - this.statePath = statePath; - } - - public String getImageName() { - return imageName; - } - - public Path getSpecPath() { - return specPath; - } - - public Path getConfigPath() { - return configPath; - } - - public Path getCatalogPath() { - return catalogPath; - } - - @Nullable - public Path getStatePath() { - return statePath; - } - - } - - public static TestConfig TEST_CONFIG; - - @Override - protected ConnectorSpecification getSpec() { - return Jsons.deserialize(IOs.readFile(TEST_CONFIG.getSpecPath()), ConnectorSpecification.class); - } - - @Override - protected String getImageName() { - return TEST_CONFIG.getImageName(); - } - - @Override - protected JsonNode getConfig() { - return Jsons.deserialize(IOs.readFile(TEST_CONFIG.getConfigPath())); - } - - @Override - protected ConfiguredAirbyteCatalog getConfiguredCatalog() { - return Jsons.deserialize(IOs.readFile(TEST_CONFIG.getCatalogPath()), ConfiguredAirbyteCatalog.class); - } - - @Override - protected JsonNode getState() { - if (TEST_CONFIG.getStatePath() != null) { - return Jsons.deserialize(IOs.readFile(TEST_CONFIG.getStatePath())); - } else { - return Jsons.deserialize("{}"); - } - - } - - @Override - protected void setupEnvironment(final TestDestinationEnv environment) throws Exception { - // no-op, for now - } - - @Override - protected void tearDown(final TestDestinationEnv testEnv) throws Exception { - // no-op, for now - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.java deleted file mode 100644 index 7eb5958b424e..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source.fs; - -import io.airbyte.cdk.integrations.standardtest.source.TestRunner; -import java.nio.file.Path; -import net.sourceforge.argparse4j.ArgumentParsers; -import net.sourceforge.argparse4j.inf.ArgumentParser; -import net.sourceforge.argparse4j.inf.ArgumentParserException; -import net.sourceforge.argparse4j.inf.Namespace; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Parse command line arguments and inject them into the test class before running the test. Then - * runs the tests. - */ -public class TestSourceMain { - - private static final Logger LOGGER = LoggerFactory.getLogger(TestSourceMain.class); - - public static void main(final String[] args) { - final ArgumentParser parser = ArgumentParsers.newFor(TestSourceMain.class.getName()).build() - .defaultHelp(true) - .description("Run standard source tests"); - - parser.addArgument("--imageName") - .required(true) - .help("Name of the source connector image e.g: airbyte/source-mailchimp"); - - parser.addArgument("--spec") - .required(true) - .help("Path to file that contains spec json"); - - parser.addArgument("--config") - .required(true) - .help("Path to file that contains config json"); - - parser.addArgument("--catalog") - .required(true) - .help("Path to file that contains catalog json"); - - parser.addArgument("--state") - .required(false) - .help("Path to the file containing state"); - - Namespace ns = null; - try { - ns = parser.parseArgs(args); - } catch (final ArgumentParserException e) { - parser.handleError(e); - System.exit(1); - } - - final String imageName = ns.getString("imageName"); - final String specFile = ns.getString("spec"); - final String configFile = ns.getString("config"); - final String catalogFile = ns.getString("catalog"); - final String stateFile = ns.getString("state"); - - ExecutableTestSource.TEST_CONFIG = new ExecutableTestSource.TestConfig( - imageName, - Path.of(specFile), - Path.of(configFile), - Path.of(catalogFile), - stateFile != null ? Path.of(stateFile) : null); - - TestRunner.runTestClass(ExecutableTestSource.class); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.java deleted file mode 100644 index c8a4ddaa52f9..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source.performancetest; - -import io.airbyte.cdk.integrations.standardtest.source.AbstractSourceConnectorTest; -import io.airbyte.cdk.integrations.standardtest.source.TestDestinationEnv; - -/** - * This abstract class contains common methods for both steams - Fill Db scripts and Performance - * tests. - */ -public abstract class AbstractSourceBasePerformanceTest extends AbstractSourceConnectorTest { - - private static final String TEST_COLUMN_NAME = "test_column"; - private static final String TEST_STREAM_NAME_TEMPLATE = "test_%S"; - - /** - * The column name will be used for a test column in the test tables. Override it if default name is - * not valid for your source. - * - * @return Test column name - */ - protected String getTestColumnName() { - return TEST_COLUMN_NAME; - } - - /** - * The stream name template will be used for a test tables. Override it if default name is not valid - * for your source. - * - * @return Test steam name template - */ - protected String getTestStreamNameTemplate() { - return TEST_STREAM_NAME_TEMPLATE; - } - - @Override - protected void setupEnvironment(final TestDestinationEnv environment) throws Exception { - // DO NOTHING. Mandatory to override. DB will be setup as part of each test - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.java deleted file mode 100644 index b8066a7aae8e..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source.performancetest; - -import io.airbyte.cdk.db.Database; -import java.util.StringJoiner; -import java.util.stream.Stream; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This abstract class contains common methods for Fill Db scripts. - */ -public abstract class AbstractSourceFillDbWithTestData extends AbstractSourceBasePerformanceTest { - - private static final String CREATE_DB_TABLE_TEMPLATE = "CREATE TABLE %s.%s(id INTEGER PRIMARY KEY, %s)"; - private static final String INSERT_INTO_DB_TABLE_QUERY_TEMPLATE = "INSERT INTO %s.%s (%s) VALUES %s"; - private static final String TEST_DB_FIELD_TYPE = "varchar(10)"; - - protected static final Logger c = LoggerFactory.getLogger(AbstractSourceFillDbWithTestData.class); - private static final String TEST_VALUE_TEMPLATE_POSTGRES = "\'Value id_placeholder\'"; - - /** - * Setup the test database. All tables and data described in the registered tests will be put there. - * - * @return configured test database - * @throws Exception - might throw any exception during initialization. - */ - protected abstract Database setupDatabase(String dbName) throws Exception; - - /** - * The test added test data to a new DB. 1. Set DB creds in static variables above 2. Set desired - * number for streams, coolumns and records 3. Run the test - */ - @Disabled - @ParameterizedTest - @MethodSource("provideParameters") - public void addTestData(final String dbName, - final String schemaName, - final int numberOfDummyRecords, - final int numberOfBatches, - final int numberOfColumns, - final int numberOfStreams) - throws Exception { - - final Database database = setupDatabase(dbName); - - database.query(ctx -> { - for (int currentSteamNumber = 0; currentSteamNumber < numberOfStreams; currentSteamNumber++) { - - final String currentTableName = String.format(getTestStreamNameTemplate(), currentSteamNumber); - - ctx.fetch(prepareCreateTableQuery(schemaName, numberOfColumns, currentTableName)); - for (int i = 0; i < numberOfBatches; i++) { - final String insertQueryTemplate = prepareInsertQueryTemplate(schemaName, i, - numberOfColumns, - numberOfDummyRecords); - ctx.fetch(String.format(insertQueryTemplate, currentTableName)); - } - - c.info("Finished processing for stream " + currentSteamNumber); - } - return null; - }); - } - - /** - * This is a data provider for fill DB script,, Each argument's group would be ran as a separate - * test. Set the "testArgs" in test class of your DB in @BeforeTest method. - * - * 1st arg - a name of DB that will be used in jdbc connection string. 2nd arg - a schemaName that - * will be ised as a NameSpace in Configured Airbyte Catalog. 3rd arg - a number of expected records - * retrieved in each stream. 4th arg - a number of columns in each stream\table that will be use for - * Airbyte Cataloq configuration 5th arg - a number of streams to read in configured airbyte - * Catalog. Each stream\table in DB should be names like "test_0", "test_1",..., test_n. - * - * Stream.of( Arguments.of("your_db_name", "your_schema_name", 100, 2, 240, 1000) ); - */ - protected abstract Stream provideParameters(); - - protected String prepareCreateTableQuery(final String dbSchemaName, - final int numberOfColumns, - final String currentTableName) { - - final StringJoiner sj = new StringJoiner(","); - for (int i = 0; i < numberOfColumns; i++) { - sj.add(String.format(" %s%s %s", getTestColumnName(), i, TEST_DB_FIELD_TYPE)); - } - - return String.format(CREATE_DB_TABLE_TEMPLATE, dbSchemaName, currentTableName, sj.toString()); - } - - protected String prepareInsertQueryTemplate(final String dbSchemaName, - final int batchNumber, - final int numberOfColumns, - final int recordsNumber) { - - final StringJoiner fieldsNames = new StringJoiner(","); - fieldsNames.add("id"); - - final StringJoiner baseInsertQuery = new StringJoiner(","); - baseInsertQuery.add("id_placeholder"); - - for (int i = 0; i < numberOfColumns; i++) { - fieldsNames.add(getTestColumnName() + i); - baseInsertQuery.add(TEST_VALUE_TEMPLATE_POSTGRES); - } - - final StringJoiner insertGroupValuesJoiner = new StringJoiner(","); - - final int batchMessages = batchNumber * 100; - - for (int currentRecordNumber = batchMessages; - currentRecordNumber < recordsNumber + batchMessages; - currentRecordNumber++) { - insertGroupValuesJoiner - .add("(" + baseInsertQuery.toString() - .replaceAll("id_placeholder", String.valueOf(currentRecordNumber)) + ")"); - } - - return String - .format(INSERT_INTO_DB_TABLE_QUERY_TEMPLATE, dbSchemaName, "%s", fieldsNames.toString(), - insertGroupValuesJoiner.toString()); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.java b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.java deleted file mode 100644 index c4279364c5ad..000000000000 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/java/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.java +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.integrations.standardtest.source.performancetest; - -import static org.junit.jupiter.api.Assertions.fail; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.Lists; -import io.airbyte.cdk.integrations.standardtest.source.TestDestinationEnv; -import io.airbyte.protocol.models.Field; -import io.airbyte.protocol.models.JsonSchemaType; -import io.airbyte.protocol.models.v0.AirbyteStream; -import io.airbyte.protocol.models.v0.CatalogHelpers; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; -import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream; -import io.airbyte.protocol.models.v0.DestinationSyncMode; -import io.airbyte.protocol.models.v0.SyncMode; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This abstract class contains common methods for Performance tests. - */ -public abstract class AbstractSourcePerformanceTest extends AbstractSourceBasePerformanceTest { - - protected static final Logger c = LoggerFactory.getLogger(AbstractSourcePerformanceTest.class); - private static final String ID_COLUMN_NAME = "id"; - protected JsonNode config; - - /** - * Setup the test database. All tables and data described in the registered tests will be put there. - * - * @throws Exception - might throw any exception during initialization. - */ - protected abstract void setupDatabase(String dbName) throws Exception; - - @Override - protected JsonNode getConfig() { - return config; - } - - @Override - protected void tearDown(final TestDestinationEnv testEnv) {} - - /** - * This is a data provider for performance tests, Each argument's group would be ran as a separate - * test. Set the "testArgs" in test class of your DB in @BeforeTest method. - * - * 1st arg - a name of DB that will be used in jdbc connection string. 2nd arg - a schemaName that - * will be ised as a NameSpace in Configured Airbyte Catalog. 3rd arg - a number of expected records - * retrieved in each stream. 4th arg - a number of columns in each stream\table that will be used - * for Airbyte Cataloq configuration 5th arg - a number of streams to read in configured airbyte - * Catalog. Each stream\table in DB should be names like "test_0", "test_1",..., test_n. - * - * Example: Stream.of( Arguments.of("test1000tables240columns200recordsDb", "dbo", 200, 240, 1000), - * Arguments.of("test5000tables240columns200recordsDb", "dbo", 200, 240, 1000), - * Arguments.of("newregular25tables50000records", "dbo", 50052, 8, 25), - * Arguments.of("newsmall1000tableswith10000rows", "dbo", 10011, 8, 1000) ); - */ - protected abstract Stream provideParameters(); - - @ParameterizedTest - @MethodSource("provideParameters") - public void testPerformance(final String dbName, - final String schemaName, - final int numberOfDummyRecords, - final int numberOfColumns, - final int numberOfStreams) - throws Exception { - - setupDatabase(dbName); - - final ConfiguredAirbyteCatalog catalog = getConfiguredCatalog(schemaName, numberOfStreams, - numberOfColumns); - final Map mapOfExpectedRecordsCount = prepareMapWithExpectedRecords( - numberOfStreams, numberOfDummyRecords); - final Map checkStatusMap = runReadVerifyNumberOfReceivedMsgs(catalog, null, - mapOfExpectedRecordsCount); - validateNumberOfReceivedMsgs(checkStatusMap); - - } - - /** - * The column name will be used for a PK column in the test tables. Override it if default name is - * not valid for your source. - * - * @return Id column name - */ - protected String getIdColumnName() { - return ID_COLUMN_NAME; - } - - protected void validateNumberOfReceivedMsgs(final Map checkStatusMap) { - // Iterate through all streams map and check for streams where - final Map failedStreamsMap = checkStatusMap.entrySet().stream() - .filter(el -> el.getValue() != 0).collect(Collectors.toMap(Entry::getKey, Entry::getValue)); - - if (!failedStreamsMap.isEmpty()) { - fail("Non all messages were delivered. " + failedStreamsMap.toString()); - } - c.info("Finished all checks, no issues found for {} of streams", checkStatusMap.size()); - } - - protected Map prepareMapWithExpectedRecords(final int streamNumber, - final int expectedRecordsNumberInEachStream) { - final Map resultMap = new HashMap<>(); // streamName&expected records in stream - - for (int currentStream = 0; currentStream < streamNumber; currentStream++) { - final String streamName = String.format(getTestStreamNameTemplate(), currentStream); - resultMap.put(streamName, expectedRecordsNumberInEachStream); - } - return resultMap; - } - - /** - * Configures streams for all registered data type tests. - * - * @return configured catalog - */ - protected ConfiguredAirbyteCatalog getConfiguredCatalog(final String nameSpace, - final int numberOfStreams, - final int numberOfColumns) { - final List streams = new ArrayList<>(); - - for (int currentStream = 0; currentStream < numberOfStreams; currentStream++) { - - // CREATE TABLE test.test_1_int(id INTEGER PRIMARY KEY, test_column int) - final List fields = new ArrayList<>(); - - fields.add(Field.of(getIdColumnName(), JsonSchemaType.NUMBER)); - for (int currentColumnNumber = 0; - currentColumnNumber < numberOfColumns; - currentColumnNumber++) { - fields.add(Field.of(getTestColumnName() + currentColumnNumber, JsonSchemaType.STRING)); - } - - final AirbyteStream airbyteStream = CatalogHelpers - .createAirbyteStream(String.format(getTestStreamNameTemplate(), currentStream), - nameSpace, fields) - .withSourceDefinedCursor(true) - .withSourceDefinedPrimaryKey(List.of(List.of(getIdColumnName()))) - .withSupportedSyncModes( - Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)); - - final ConfiguredAirbyteStream configuredAirbyteStream = new ConfiguredAirbyteStream() - .withSyncMode(SyncMode.INCREMENTAL) - .withCursorField(Lists.newArrayList(getIdColumnName())) - .withDestinationSyncMode(DestinationSyncMode.APPEND) - .withStream(airbyteStream); - - streams.add(configuredAirbyteStream); - - } - - return new ConfiguredAirbyteCatalog().withStreams(streams); - } - -} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debezium/CdcSourceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debezium/CdcSourceTest.kt new file mode 100644 index 000000000000..9afef4bfcd9f --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debezium/CdcSourceTest.kt @@ -0,0 +1,1100 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debezium + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.ObjectNode +import com.google.common.collect.* +import io.airbyte.cdk.integrations.base.Source +import io.airbyte.cdk.testutils.TestDatabase +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.util.AutoCloseableIterators +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.function.Consumer +import java.util.stream.Collectors +import java.util.stream.Stream +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +abstract class CdcSourceTest> { + protected lateinit var testdb: T + + protected fun createTableSqlFmt(): String { + return "CREATE TABLE %s.%s(%s);" + } + + protected fun createSchemaSqlFmt(): String { + return "CREATE SCHEMA %s;" + } + + protected fun modelsSchema(): String { + return "models_schema" + } + + /** The schema of a random table which is used as a new table in snapshot test */ + protected fun randomSchema(): String { + return "models_schema_random" + } + + protected val catalog: AirbyteCatalog + get() = + AirbyteCatalog() + .withStreams( + java.util.List.of( + CatalogHelpers.createAirbyteStream( + MODELS_STREAM_NAME, + modelsSchema(), + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), + Field.of(COL_MODEL, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey( + java.util.List.of(java.util.List.of(COL_ID)) + ) + ) + ) + + protected val configuredCatalog: ConfiguredAirbyteCatalog + get() { + val configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog(catalog) + configuredCatalog.streams.forEach( + Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL } + ) + return configuredCatalog + } + + protected abstract fun createTestDatabase(): T + + protected abstract fun source(): S + + protected abstract fun config(): JsonNode? + + protected abstract fun cdcLatestTargetPosition(): CdcTargetPosition<*> + + protected abstract fun extractPosition(record: JsonNode?): CdcTargetPosition<*>? + + protected abstract fun assertNullCdcMetaData(data: JsonNode?) + + protected abstract fun assertCdcMetaData(data: JsonNode?, deletedAtNull: Boolean) + + protected abstract fun removeCDCColumns(data: ObjectNode?) + + protected abstract fun addCdcMetadataColumns(stream: AirbyteStream?) + + protected abstract fun addCdcDefaultCursorField(stream: AirbyteStream?) + + protected abstract fun assertExpectedStateMessages(stateMessages: List?) + + // TODO: this assertion should be added into test cases in this class, we will need to implement + // corresponding iterator for other connectors before + // doing so. + protected fun assertExpectedStateMessageCountMatches( + stateMessages: List?, + totalCount: Long + ) { + // Do nothing. + } + + @BeforeEach + protected fun setup() { + testdb = createTestDatabase() + createTables() + populateTables() + } + + protected fun createTables() { + // create and populate actual table + val actualColumns = + ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)") + testdb + .with(createSchemaSqlFmt(), modelsSchema()) + .with( + createTableSqlFmt(), + modelsSchema(), + MODELS_STREAM_NAME, + columnClause(actualColumns, Optional.of(COL_ID)) + ) + + // Create random table. + // This table is not part of Airbyte sync. It is being created just to make sure the schemas + // not + // being synced by Airbyte are not causing issues with our debezium logic. + val randomColumns = + ImmutableMap.of( + COL_ID + "_random", + "INTEGER", + COL_MAKE_ID + "_random", + "INTEGER", + COL_MODEL + "_random", + "VARCHAR(200)" + ) + if (randomSchema() != modelsSchema()) { + testdb!!.with(createSchemaSqlFmt(), randomSchema()) + } + testdb!!.with( + createTableSqlFmt(), + randomSchema(), + RANDOM_TABLE_NAME, + columnClause(randomColumns, Optional.of(COL_ID + "_random")) + ) + } + + protected fun populateTables() { + for (recordJson in MODEL_RECORDS) { + writeModelRecord(recordJson) + } + + for (recordJson in MODEL_RECORDS_RANDOM) { + writeRecords( + recordJson, + randomSchema(), + RANDOM_TABLE_NAME, + COL_ID + "_random", + COL_MAKE_ID + "_random", + COL_MODEL + "_random" + ) + } + } + + @AfterEach + protected fun tearDown() { + try { + testdb!!.close() + } catch (e: Throwable) { + LOGGER.error("exception during teardown", e) + } + } + + protected fun columnClause( + columnsWithDataType: Map, + primaryKey: Optional + ): String { + val columnClause = StringBuilder() + var i = 0 + for ((key, value) in columnsWithDataType) { + columnClause.append(key) + columnClause.append(" ") + columnClause.append(value) + if (i < (columnsWithDataType.size - 1)) { + columnClause.append(",") + columnClause.append(" ") + } + i++ + } + primaryKey.ifPresent { s: String? -> + columnClause.append(", PRIMARY KEY (").append(s).append(")") + } + + return columnClause.toString() + } + + protected fun writeModelRecord(recordJson: JsonNode) { + writeRecords(recordJson, modelsSchema(), MODELS_STREAM_NAME, COL_ID, COL_MAKE_ID, COL_MODEL) + } + + protected fun writeRecords( + recordJson: JsonNode, + dbName: String?, + streamName: String?, + idCol: String?, + makeIdCol: String?, + modelCol: String? + ) { + testdb!!.with( + "INSERT INTO %s.%s (%s, %s, %s) VALUES (%s, %s, '%s');", + dbName, + streamName, + idCol, + makeIdCol, + modelCol, + recordJson[idCol].asInt(), + recordJson[makeIdCol].asInt(), + recordJson[modelCol].asText() + ) + } + + protected fun deleteMessageOnIdCol(streamName: String?, idCol: String?, idValue: Int) { + testdb!!.with("DELETE FROM %s.%s WHERE %s = %s", modelsSchema(), streamName, idCol, idValue) + } + + protected fun deleteCommand(streamName: String?) { + testdb!!.with("DELETE FROM %s.%s", modelsSchema(), streamName) + } + + protected fun updateCommand( + streamName: String?, + modelCol: String?, + modelVal: String?, + idCol: String?, + idValue: Int + ) { + testdb!!.with( + "UPDATE %s.%s SET %s = '%s' WHERE %s = %s", + modelsSchema(), + streamName, + modelCol, + modelVal, + COL_ID, + 11 + ) + } + + protected fun extractRecordMessages(messages: List): Set { + val recordsPerStream = extractRecordMessagesStreamWise(messages) + val consolidatedRecords: MutableSet = HashSet() + recordsPerStream.values.forEach( + Consumer { c: Set? -> consolidatedRecords.addAll(c!!) } + ) + return consolidatedRecords + } + + protected fun extractRecordMessagesStreamWise( + messages: List + ): Map> { + val recordsPerStream: MutableMap> = HashMap() + for (message in messages) { + if (message.type == AirbyteMessage.Type.RECORD) { + val recordMessage = message.record + recordsPerStream + .computeIfAbsent(recordMessage.stream) { c: String? -> ArrayList() } + .add(recordMessage) + } + } + + val recordsPerStreamWithNoDuplicates: MutableMap> = + HashMap() + for ((streamName, records) in recordsPerStream) { + val recordMessageSet: Set = HashSet(records) + Assertions.assertEquals( + records.size, + recordMessageSet.size, + "Expected no duplicates in airbyte record message output for a single sync." + ) + recordsPerStreamWithNoDuplicates[streamName] = recordMessageSet + } + + return recordsPerStreamWithNoDuplicates + } + + protected fun extractStateMessages(messages: List): List { + return messages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .map { obj: AirbyteMessage -> obj.state } + .collect(Collectors.toList()) + } + + protected fun assertExpectedRecords( + expectedRecords: Set, + actualRecords: Set + ) { + // assume all streams are cdc. + assertExpectedRecords( + expectedRecords, + actualRecords, + actualRecords + .stream() + .map { obj: AirbyteRecordMessage -> obj.stream } + .collect(Collectors.toSet()) + ) + } + + private fun assertExpectedRecords( + expectedRecords: Set, + actualRecords: Set, + cdcStreams: Set + ) { + assertExpectedRecords( + expectedRecords, + actualRecords, + cdcStreams, + STREAM_NAMES, + modelsSchema() + ) + } + + protected fun assertExpectedRecords( + expectedRecords: Set?, + actualRecords: Set, + cdcStreams: Set, + streamNames: Set, + namespace: String? + ) { + val actualData = + actualRecords + .stream() + .map { recordMessage: AirbyteRecordMessage -> + Assertions.assertTrue(streamNames.contains(recordMessage.stream)) + Assertions.assertNotNull(recordMessage.emittedAt) + + Assertions.assertEquals(namespace, recordMessage.namespace) + + val data = recordMessage.data + + if (cdcStreams.contains(recordMessage.stream)) { + assertCdcMetaData(data, true) + } else { + assertNullCdcMetaData(data) + } + + removeCDCColumns(data as ObjectNode) + data + } + .collect(Collectors.toSet()) + + Assertions.assertEquals(expectedRecords, actualData) + } + + @Test + @Throws(Exception::class) + fun testExistingData() { + val targetPosition = cdcLatestTargetPosition() + val read = source()!!.read(config()!!, configuredCatalog, null) + val actualRecords = AutoCloseableIterators.toListAndClose(read) + + val recordMessages = extractRecordMessages(actualRecords) + val stateMessages = extractStateMessages(actualRecords) + + Assertions.assertNotNull(targetPosition) + recordMessages.forEach( + Consumer { record: AirbyteRecordMessage -> + compareTargetPositionFromTheRecordsWithTargetPostionGeneratedBeforeSync( + targetPosition, + record + ) + } + ) + + assertExpectedRecords(HashSet(MODEL_RECORDS), recordMessages) + assertExpectedStateMessages(stateMessages) + assertExpectedStateMessageCountMatches(stateMessages, MODEL_RECORDS.size.toLong()) + } + + protected fun compareTargetPositionFromTheRecordsWithTargetPostionGeneratedBeforeSync( + targetPosition: CdcTargetPosition<*>?, + record: AirbyteRecordMessage + ) { + Assertions.assertEquals(extractPosition(record.data), targetPosition) + } + + @Test // When a record is deleted, produces a deletion record. + @Throws(Exception::class) + fun testDelete() { + val read1 = source().read(config()!!, configuredCatalog, null) + val actualRecords1 = AutoCloseableIterators.toListAndClose(read1) + val stateMessages1 = extractStateMessages(actualRecords1) + assertExpectedStateMessages(stateMessages1) + + deleteMessageOnIdCol(MODELS_STREAM_NAME, COL_ID, 11) + waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1) + + val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1])) + val read2 = source().read(config()!!, configuredCatalog, state) + val actualRecords2 = AutoCloseableIterators.toListAndClose(read2) + val recordMessages2: List = + ArrayList(extractRecordMessages(actualRecords2)) + val stateMessages2 = extractStateMessages(actualRecords2) + assertExpectedStateMessagesFromIncrementalSync(stateMessages2) + assertExpectedStateMessageCountMatches(stateMessages2, 1) + Assertions.assertEquals(1, recordMessages2.size) + Assertions.assertEquals(11, recordMessages2[0].data[COL_ID].asInt()) + assertCdcMetaData(recordMessages2[0].data, false) + } + + protected fun assertExpectedStateMessagesFromIncrementalSync( + stateMessages: List? + ) { + assertExpectedStateMessages(stateMessages) + } + + @Test // When a record is updated, produces an update record. + @Throws(Exception::class) + fun testUpdate() { + val updatedModel = "Explorer" + val read1 = source().read(config()!!, configuredCatalog, null) + val actualRecords1 = AutoCloseableIterators.toListAndClose(read1) + val stateMessages1 = extractStateMessages(actualRecords1) + assertExpectedStateMessages(stateMessages1) + + updateCommand(MODELS_STREAM_NAME, COL_MODEL, updatedModel, COL_ID, 11) + waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1) + + val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1])) + val read2 = source().read(config()!!, configuredCatalog, state) + val actualRecords2 = AutoCloseableIterators.toListAndClose(read2) + val recordMessages2: List = + ArrayList(extractRecordMessages(actualRecords2)) + val stateMessages2 = extractStateMessages(actualRecords2) + assertExpectedStateMessagesFromIncrementalSync(stateMessages2) + Assertions.assertEquals(1, recordMessages2.size) + Assertions.assertEquals(11, recordMessages2[0].data[COL_ID].asInt()) + Assertions.assertEquals(updatedModel, recordMessages2[0].data[COL_MODEL].asText()) + assertCdcMetaData(recordMessages2[0].data, true) + assertExpectedStateMessageCountMatches(stateMessages2, 1) + } + + @Test // Verify that when data is inserted into the database while a sync is happening and after + // the first + // sync, it all gets replicated. + @Throws(Exception::class) + protected fun testRecordsProducedDuringAndAfterSync() { + val recordsCreatedBeforeTestCount = MODEL_RECORDS.size + var expectedRecords = recordsCreatedBeforeTestCount + var expectedRecordsInCdc = 0 + val recordsToCreate = 20 + // first batch of records. 20 created here and 6 created in setup method. + for (recordsCreated in 0 until recordsToCreate) { + val record = + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + 100 + recordsCreated, + COL_MAKE_ID, + 1, + COL_MODEL, + "F-$recordsCreated" + ) + ) + writeModelRecord(record) + expectedRecords++ + expectedRecordsInCdc++ + } + waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, expectedRecordsInCdc) + + val firstBatchIterator = source().read(config()!!, configuredCatalog, null) + val dataFromFirstBatch = AutoCloseableIterators.toListAndClose(firstBatchIterator) + val stateAfterFirstBatch = extractStateMessages(dataFromFirstBatch) + assertExpectedStateMessagesForRecordsProducedDuringAndAfterSync(stateAfterFirstBatch) + val recordsFromFirstBatch = extractRecordMessages(dataFromFirstBatch) + Assertions.assertEquals(expectedRecords, recordsFromFirstBatch.size) + + // second batch of records again 20 being created + for (recordsCreated in 0 until recordsToCreate) { + val record = + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + 200 + recordsCreated, + COL_MAKE_ID, + 1, + COL_MODEL, + "F-$recordsCreated" + ) + ) + writeModelRecord(record) + expectedRecords++ + expectedRecordsInCdc++ + } + waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, expectedRecordsInCdc) + + val state = Jsons.jsonNode(listOf(stateAfterFirstBatch[stateAfterFirstBatch.size - 1])) + val secondBatchIterator = source().read(config()!!, configuredCatalog, state) + val dataFromSecondBatch = AutoCloseableIterators.toListAndClose(secondBatchIterator) + + val stateAfterSecondBatch = extractStateMessages(dataFromSecondBatch) + assertExpectedStateMessagesFromIncrementalSync(stateAfterSecondBatch) + + val recordsFromSecondBatch = extractRecordMessages(dataFromSecondBatch) + Assertions.assertEquals( + recordsToCreate, + recordsFromSecondBatch.size, + "Expected 20 records to be replicated in the second sync." + ) + + // sometimes there can be more than one of these at the end of the snapshot and just before + // the + // first incremental. + val recordsFromFirstBatchWithoutDuplicates = removeDuplicates(recordsFromFirstBatch) + val recordsFromSecondBatchWithoutDuplicates = removeDuplicates(recordsFromSecondBatch) + + Assertions.assertTrue( + recordsCreatedBeforeTestCount < recordsFromFirstBatchWithoutDuplicates.size, + "Expected first sync to include records created while the test was running." + ) + Assertions.assertEquals( + expectedRecords, + recordsFromFirstBatchWithoutDuplicates.size + + recordsFromSecondBatchWithoutDuplicates.size + ) + } + + protected fun assertExpectedStateMessagesForRecordsProducedDuringAndAfterSync( + stateAfterFirstBatch: List? + ) { + assertExpectedStateMessages(stateAfterFirstBatch) + } + + @Test // When both incremental CDC and full refresh are configured for different streams in a + // sync, the + // data is replicated as expected. + @Throws(Exception::class) + fun testCdcAndFullRefreshInSameSync() { + val configuredCatalog = Jsons.clone(configuredCatalog) + + val MODEL_RECORDS_2: List = + ImmutableList.of( + Jsons.jsonNode(ImmutableMap.of(COL_ID, 110, COL_MAKE_ID, 1, COL_MODEL, "Fiesta-2")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 120, COL_MAKE_ID, 1, COL_MODEL, "Focus-2")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 130, COL_MAKE_ID, 1, COL_MODEL, "Ranger-2")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 140, COL_MAKE_ID, 2, COL_MODEL, "GLA-2")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 150, COL_MAKE_ID, 2, COL_MODEL, "A 220-2")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 160, COL_MAKE_ID, 2, COL_MODEL, "E 350-2")) + ) + + val columns = + ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)") + testdb!!.with( + createTableSqlFmt(), + modelsSchema(), + MODELS_STREAM_NAME + "_2", + columnClause(columns, Optional.of(COL_ID)) + ) + + for (recordJson in MODEL_RECORDS_2) { + writeRecords( + recordJson, + modelsSchema(), + MODELS_STREAM_NAME + "_2", + COL_ID, + COL_MAKE_ID, + COL_MODEL + ) + } + + val airbyteStream = + ConfiguredAirbyteStream() + .withStream( + CatalogHelpers.createAirbyteStream( + MODELS_STREAM_NAME + "_2", + modelsSchema(), + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), + Field.of(COL_MODEL, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(COL_ID))) + ) + airbyteStream.syncMode = SyncMode.FULL_REFRESH + + val streams = configuredCatalog.streams + streams.add(airbyteStream) + configuredCatalog.withStreams(streams) + + val read1 = source().read(config()!!, configuredCatalog, null) + val actualRecords1 = AutoCloseableIterators.toListAndClose(read1) + + val recordMessages1 = extractRecordMessages(actualRecords1) + val stateMessages1 = extractStateMessages(actualRecords1) + val names = HashSet(STREAM_NAMES) + names.add(MODELS_STREAM_NAME + "_2") + assertExpectedStateMessages(stateMessages1) + // Full refresh does not get any state messages. + assertExpectedStateMessageCountMatches(stateMessages1, MODEL_RECORDS_2.size.toLong()) + assertExpectedRecords( + Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream()) + .collect(Collectors.toSet()), + recordMessages1, + setOf(MODELS_STREAM_NAME), + names, + modelsSchema() + ) + + val puntoRecord = + Jsons.jsonNode(ImmutableMap.of(COL_ID, 100, COL_MAKE_ID, 3, COL_MODEL, "Punto")) + writeModelRecord(puntoRecord) + waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1) + + val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1])) + val read2 = source().read(config()!!, configuredCatalog, state) + val actualRecords2 = AutoCloseableIterators.toListAndClose(read2) + + val recordMessages2 = extractRecordMessages(actualRecords2) + val stateMessages2 = extractStateMessages(actualRecords2) + assertExpectedStateMessagesFromIncrementalSync(stateMessages2) + assertExpectedStateMessageCountMatches(stateMessages2, 1) + assertExpectedRecords( + Streams.concat(MODEL_RECORDS_2.stream(), Stream.of(puntoRecord)) + .collect(Collectors.toSet()), + recordMessages2, + setOf(MODELS_STREAM_NAME), + names, + modelsSchema() + ) + } + + @Test // When no records exist, no records are returned. + @Throws(Exception::class) + fun testNoData() { + deleteCommand(MODELS_STREAM_NAME) + waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, MODEL_RECORDS.size) + val read = source()!!.read(config()!!, configuredCatalog, null) + val actualRecords = AutoCloseableIterators.toListAndClose(read) + + val recordMessages = extractRecordMessages(actualRecords) + val stateMessages = extractStateMessages(actualRecords) + assertExpectedRecords(emptySet(), recordMessages) + assertExpectedStateMessagesForNoData(stateMessages) + assertExpectedStateMessageCountMatches(stateMessages, 0) + } + + protected fun assertExpectedStateMessagesForNoData(stateMessages: List?) { + assertExpectedStateMessages(stateMessages) + } + + @Test // When no changes have been made to the database since the previous sync, no records are + // returned. + @Throws(Exception::class) + fun testNoDataOnSecondSync() { + val read1 = source().read(config()!!, configuredCatalog, null) + val actualRecords1 = AutoCloseableIterators.toListAndClose(read1) + val stateMessagesFromFirstSync = extractStateMessages(actualRecords1) + val state = + Jsons.jsonNode(listOf(stateMessagesFromFirstSync[stateMessagesFromFirstSync.size - 1])) + + val read2 = source().read(config()!!, configuredCatalog, state) + val actualRecords2 = AutoCloseableIterators.toListAndClose(read2) + + val recordMessages2 = extractRecordMessages(actualRecords2) + val stateMessages2 = extractStateMessages(actualRecords2) + + assertExpectedRecords(emptySet(), recordMessages2) + assertExpectedStateMessagesFromIncrementalSync(stateMessages2) + assertExpectedStateMessageCountMatches(stateMessages2, 0) + } + + @Test + @Throws(Exception::class) + fun testCheck() { + val status = source()!!.check(config()!!) + Assertions.assertEquals(status!!.status, AirbyteConnectionStatus.Status.SUCCEEDED) + } + + @Test + @Throws(Exception::class) + fun testDiscover() { + val expectedCatalog = expectedCatalogForDiscover() + val actualCatalog = source()!!.discover(config()!!) + + Assertions.assertEquals( + expectedCatalog.streams + .stream() + .sorted(Comparator.comparing { obj: AirbyteStream -> obj.name }) + .collect(Collectors.toList()), + actualCatalog!! + .streams + .stream() + .sorted(Comparator.comparing { obj: AirbyteStream -> obj.name }) + .collect(Collectors.toList()) + ) + } + + @Test + @Throws(Exception::class) + fun newTableSnapshotTest() { + val firstBatchIterator = source().read(config()!!, configuredCatalog, null) + val dataFromFirstBatch = AutoCloseableIterators.toListAndClose(firstBatchIterator) + val recordsFromFirstBatch = extractRecordMessages(dataFromFirstBatch) + val stateAfterFirstBatch = extractStateMessages(dataFromFirstBatch) + assertExpectedStateMessages(stateAfterFirstBatch) + assertExpectedStateMessageCountMatches(stateAfterFirstBatch, MODEL_RECORDS.size.toLong()) + + val stateMessageEmittedAfterFirstSyncCompletion = + stateAfterFirstBatch[stateAfterFirstBatch.size - 1] + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + stateMessageEmittedAfterFirstSyncCompletion.type + ) + Assertions.assertNotNull(stateMessageEmittedAfterFirstSyncCompletion.global.sharedState) + val streamsInStateAfterFirstSyncCompletion = + stateMessageEmittedAfterFirstSyncCompletion.global.streamStates + .stream() + .map { obj: AirbyteStreamState -> obj.streamDescriptor } + .collect(Collectors.toSet()) + Assertions.assertEquals(1, streamsInStateAfterFirstSyncCompletion.size) + Assertions.assertTrue( + streamsInStateAfterFirstSyncCompletion.contains( + StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()) + ) + ) + Assertions.assertNotNull(stateMessageEmittedAfterFirstSyncCompletion.data) + + Assertions.assertEquals((MODEL_RECORDS.size), recordsFromFirstBatch.size) + assertExpectedRecords(HashSet(MODEL_RECORDS), recordsFromFirstBatch) + + val state = stateAfterFirstBatch[stateAfterFirstBatch.size - 1].data + + val newTables = + CatalogHelpers.toDefaultConfiguredCatalog( + AirbyteCatalog() + .withStreams( + java.util.List.of( + CatalogHelpers.createAirbyteStream( + RANDOM_TABLE_NAME, + randomSchema(), + Field.of(COL_ID + "_random", JsonSchemaType.NUMBER), + Field.of(COL_MAKE_ID + "_random", JsonSchemaType.NUMBER), + Field.of(COL_MODEL + "_random", JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey( + java.util.List.of(java.util.List.of(COL_ID + "_random")) + ) + ) + ) + ) + + newTables.streams.forEach( + Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL } + ) + val combinedStreams: MutableList = ArrayList() + combinedStreams.addAll(configuredCatalog.streams) + combinedStreams.addAll(newTables.streams) + + val updatedCatalog = ConfiguredAirbyteCatalog().withStreams(combinedStreams) + + /* + * Write 20 records to the existing table + */ + val recordsWritten: MutableSet = HashSet() + for (recordsCreated in 0..19) { + val record = + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + 100 + recordsCreated, + COL_MAKE_ID, + 1, + COL_MODEL, + "F-$recordsCreated" + ) + ) + recordsWritten.add(record) + writeModelRecord(record) + } + + val secondBatchIterator = source().read(config()!!, updatedCatalog, state) + val dataFromSecondBatch = AutoCloseableIterators.toListAndClose(secondBatchIterator) + + val stateAfterSecondBatch = extractStateMessages(dataFromSecondBatch) + assertStateMessagesForNewTableSnapshotTest( + stateAfterSecondBatch, + stateMessageEmittedAfterFirstSyncCompletion + ) + + val recordsStreamWise = extractRecordMessagesStreamWise(dataFromSecondBatch) + Assertions.assertTrue(recordsStreamWise.containsKey(MODELS_STREAM_NAME)) + Assertions.assertTrue(recordsStreamWise.containsKey(RANDOM_TABLE_NAME)) + + val recordsForModelsStreamFromSecondBatch = recordsStreamWise[MODELS_STREAM_NAME]!! + val recordsForModelsRandomStreamFromSecondBatch = recordsStreamWise[RANDOM_TABLE_NAME]!! + + Assertions.assertEquals( + (MODEL_RECORDS_RANDOM.size), + recordsForModelsRandomStreamFromSecondBatch.size + ) + Assertions.assertEquals(20, recordsForModelsStreamFromSecondBatch.size) + assertExpectedRecords( + HashSet(MODEL_RECORDS_RANDOM), + recordsForModelsRandomStreamFromSecondBatch, + recordsForModelsRandomStreamFromSecondBatch + .stream() + .map { obj: AirbyteRecordMessage -> obj.stream } + .collect(Collectors.toSet()), + Sets.newHashSet(RANDOM_TABLE_NAME), + randomSchema() + ) + assertExpectedRecords(recordsWritten, recordsForModelsStreamFromSecondBatch) + + /* + * Write 20 records to both the tables + */ + val recordsWrittenInRandomTable: MutableSet = HashSet() + recordsWritten.clear() + for (recordsCreated in 30..49) { + val record = + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + 100 + recordsCreated, + COL_MAKE_ID, + 1, + COL_MODEL, + "F-$recordsCreated" + ) + ) + writeModelRecord(record) + recordsWritten.add(record) + + val record2 = + Jsons.jsonNode( + ImmutableMap.of( + COL_ID + "_random", + 11000 + recordsCreated, + COL_MAKE_ID + "_random", + 1 + recordsCreated, + COL_MODEL + "_random", + "Fiesta-random$recordsCreated" + ) + ) + writeRecords( + record2, + randomSchema(), + RANDOM_TABLE_NAME, + COL_ID + "_random", + COL_MAKE_ID + "_random", + COL_MODEL + "_random" + ) + recordsWrittenInRandomTable.add(record2) + } + + val state2 = stateAfterSecondBatch[stateAfterSecondBatch.size - 1].data + val thirdBatchIterator = source().read(config()!!, updatedCatalog, state2) + val dataFromThirdBatch = AutoCloseableIterators.toListAndClose(thirdBatchIterator) + + val stateAfterThirdBatch = extractStateMessages(dataFromThirdBatch) + Assertions.assertTrue(stateAfterThirdBatch.size >= 1) + + val stateMessageEmittedAfterThirdSyncCompletion = + stateAfterThirdBatch[stateAfterThirdBatch.size - 1] + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + stateMessageEmittedAfterThirdSyncCompletion.type + ) + Assertions.assertNotEquals( + stateMessageEmittedAfterThirdSyncCompletion.global.sharedState, + stateAfterSecondBatch[stateAfterSecondBatch.size - 1].global.sharedState + ) + val streamsInSyncCompletionStateAfterThirdSync = + stateMessageEmittedAfterThirdSyncCompletion.global.streamStates + .stream() + .map { obj: AirbyteStreamState -> obj.streamDescriptor } + .collect(Collectors.toSet()) + Assertions.assertTrue( + streamsInSyncCompletionStateAfterThirdSync.contains( + StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()) + ) + ) + Assertions.assertTrue( + streamsInSyncCompletionStateAfterThirdSync.contains( + StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()) + ) + ) + Assertions.assertNotNull(stateMessageEmittedAfterThirdSyncCompletion.data) + + val recordsStreamWiseFromThirdBatch = extractRecordMessagesStreamWise(dataFromThirdBatch) + Assertions.assertTrue(recordsStreamWiseFromThirdBatch.containsKey(MODELS_STREAM_NAME)) + Assertions.assertTrue(recordsStreamWiseFromThirdBatch.containsKey(RANDOM_TABLE_NAME)) + + val recordsForModelsStreamFromThirdBatch = + recordsStreamWiseFromThirdBatch[MODELS_STREAM_NAME]!! + val recordsForModelsRandomStreamFromThirdBatch = + recordsStreamWiseFromThirdBatch[RANDOM_TABLE_NAME]!! + + Assertions.assertEquals(20, recordsForModelsStreamFromThirdBatch.size) + Assertions.assertEquals(20, recordsForModelsRandomStreamFromThirdBatch.size) + assertExpectedRecords(recordsWritten, recordsForModelsStreamFromThirdBatch) + assertExpectedRecords( + recordsWrittenInRandomTable, + recordsForModelsRandomStreamFromThirdBatch, + recordsForModelsRandomStreamFromThirdBatch + .stream() + .map { obj: AirbyteRecordMessage -> obj.stream } + .collect(Collectors.toSet()), + Sets.newHashSet(RANDOM_TABLE_NAME), + randomSchema() + ) + } + + protected fun assertStateMessagesForNewTableSnapshotTest( + stateMessages: List, + stateMessageEmittedAfterFirstSyncCompletion: AirbyteStateMessage + ) { + Assertions.assertEquals(2, stateMessages.size) + val stateMessageEmittedAfterSnapshotCompletionInSecondSync = stateMessages[0] + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + stateMessageEmittedAfterSnapshotCompletionInSecondSync.type + ) + Assertions.assertEquals( + stateMessageEmittedAfterFirstSyncCompletion.global.sharedState, + stateMessageEmittedAfterSnapshotCompletionInSecondSync.global.sharedState + ) + val streamsInSnapshotState = + stateMessageEmittedAfterSnapshotCompletionInSecondSync.global.streamStates + .stream() + .map { obj: AirbyteStreamState -> obj.streamDescriptor } + .collect(Collectors.toSet()) + Assertions.assertEquals(2, streamsInSnapshotState.size) + Assertions.assertTrue( + streamsInSnapshotState.contains( + StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()) + ) + ) + Assertions.assertTrue( + streamsInSnapshotState.contains( + StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()) + ) + ) + Assertions.assertNotNull(stateMessageEmittedAfterSnapshotCompletionInSecondSync.data) + + val stateMessageEmittedAfterSecondSyncCompletion = stateMessages[1] + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + stateMessageEmittedAfterSecondSyncCompletion.type + ) + Assertions.assertNotEquals( + stateMessageEmittedAfterFirstSyncCompletion.global.sharedState, + stateMessageEmittedAfterSecondSyncCompletion.global.sharedState + ) + val streamsInSyncCompletionState = + stateMessageEmittedAfterSecondSyncCompletion.global.streamStates + .stream() + .map { obj: AirbyteStreamState -> obj.streamDescriptor } + .collect(Collectors.toSet()) + Assertions.assertEquals(2, streamsInSnapshotState.size) + Assertions.assertTrue( + streamsInSyncCompletionState.contains( + StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()) + ) + ) + Assertions.assertTrue( + streamsInSyncCompletionState.contains( + StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()) + ) + ) + Assertions.assertNotNull(stateMessageEmittedAfterSecondSyncCompletion.data) + } + + protected fun expectedCatalogForDiscover(): AirbyteCatalog { + val expectedCatalog = Jsons.clone(catalog) + + val columns = + ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)") + testdb!!.with( + createTableSqlFmt(), + modelsSchema(), + MODELS_STREAM_NAME + "_2", + columnClause(columns, Optional.empty()) + ) + + val streams = expectedCatalog.streams + // stream with PK + streams[0].sourceDefinedCursor = true + addCdcMetadataColumns(streams[0]) + addCdcDefaultCursorField(streams[0]) + + val streamWithoutPK = + CatalogHelpers.createAirbyteStream( + MODELS_STREAM_NAME + "_2", + modelsSchema(), + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), + Field.of(COL_MODEL, JsonSchemaType.STRING) + ) + streamWithoutPK.sourceDefinedPrimaryKey = emptyList() + streamWithoutPK.supportedSyncModes = java.util.List.of(SyncMode.FULL_REFRESH) + addCdcDefaultCursorField(streamWithoutPK) + addCdcMetadataColumns(streamWithoutPK) + + val randomStream = + CatalogHelpers.createAirbyteStream( + RANDOM_TABLE_NAME, + randomSchema(), + Field.of(COL_ID + "_random", JsonSchemaType.INTEGER), + Field.of(COL_MAKE_ID + "_random", JsonSchemaType.INTEGER), + Field.of(COL_MODEL + "_random", JsonSchemaType.STRING) + ) + .withSourceDefinedCursor(true) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey( + java.util.List.of(java.util.List.of(COL_ID + "_random")) + ) + + addCdcDefaultCursorField(randomStream) + addCdcMetadataColumns(randomStream) + + streams.add(streamWithoutPK) + streams.add(randomStream) + expectedCatalog.withStreams(streams) + return expectedCatalog + } + + @Throws(Exception::class) + protected fun waitForCdcRecords(schemaName: String?, tableName: String?, recordCount: Int) {} + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(CdcSourceTest::class.java) + + protected const val MODELS_STREAM_NAME: String = "models" + protected val STREAM_NAMES: Set = java.util.Set.of(MODELS_STREAM_NAME) + protected const val COL_ID: String = "id" + protected const val COL_MAKE_ID: String = "make_id" + protected const val COL_MODEL: String = "model" + + protected val MODEL_RECORDS: List = + ImmutableList.of( + Jsons.jsonNode(ImmutableMap.of(COL_ID, 11, COL_MAKE_ID, 1, COL_MODEL, "Fiesta")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 12, COL_MAKE_ID, 1, COL_MODEL, "Focus")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 13, COL_MAKE_ID, 1, COL_MODEL, "Ranger")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 14, COL_MAKE_ID, 2, COL_MODEL, "GLA")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 15, COL_MAKE_ID, 2, COL_MODEL, "A 220")), + Jsons.jsonNode(ImmutableMap.of(COL_ID, 16, COL_MAKE_ID, 2, COL_MODEL, "E 350")) + ) + + protected const val RANDOM_TABLE_NAME: String = MODELS_STREAM_NAME + "_random" + + protected val MODEL_RECORDS_RANDOM: List = + MODEL_RECORDS.stream() + .map { r: JsonNode -> + Jsons.jsonNode( + ImmutableMap.of( + COL_ID + "_random", + r[COL_ID].asInt() * 1000, + COL_MAKE_ID + "_random", + r[COL_MAKE_ID], + COL_MODEL + "_random", + r[COL_MODEL].asText() + "-random" + ) + ) + } + .toList() + + protected fun removeDuplicates( + messages: Set + ): Set { + val existingDataRecordsWithoutUpdated: MutableSet = HashSet() + val output: MutableSet = HashSet() + + for (message in messages) { + val node = message.data.deepCopy() + node.remove("_ab_cdc_updated_at") + + if (existingDataRecordsWithoutUpdated.contains(node)) { + LOGGER.info("Removing duplicate node: $node") + } else { + output.add(message) + existingDataRecordsWithoutUpdated.add(node) + } + } + + return output + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debug/DebugUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debug/DebugUtil.kt new file mode 100644 index 000000000000..d04a8ea5e014 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debug/DebugUtil.kt @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.debug + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode +import io.airbyte.cdk.integrations.base.Source +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.resources.MoreResources +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog + +/** + * Utility class defined to debug a source. Copy over any relevant configurations, catalogs & state + * in the resources/debug_resources directory. + */ +object DebugUtil { + @Suppress("deprecation") + @Throws(Exception::class) + fun debug(debugSource: Source) { + val debugConfig = config + val configuredAirbyteCatalog = catalog + var state = + try { + state + } catch (e: Exception) { + null + } + + debugSource.check(debugConfig) + debugSource.discover(debugConfig) + + val messageIterator = debugSource.read(debugConfig, configuredAirbyteCatalog, state) + messageIterator.forEachRemaining { message: AirbyteMessage? -> } + } + + @get:Throws(Exception::class) + private val config: JsonNode + get() { + val originalConfig = + ObjectMapper().readTree(MoreResources.readResource("debug_resources/config.json")) + val debugConfig: JsonNode = + (originalConfig.deepCopy() as ObjectNode).put("debug_mode", true) + return debugConfig + } + + @get:Throws(Exception::class) + private val catalog: ConfiguredAirbyteCatalog + get() { + val catalog = MoreResources.readResource("debug_resources/configured_catalog.json") + return Jsons.deserialize(catalog, ConfiguredAirbyteCatalog::class.java) + } + + @get:Throws(Exception::class) + private val state: JsonNode + get() { + val message = + Jsons.deserialize( + MoreResources.readResource("debug_resources/state.json"), + AirbyteStateMessage::class.java + ) + return Jsons.jsonNode(listOf(message)) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.kt new file mode 100644 index 000000000000..d8444667cfe2 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.kt @@ -0,0 +1,1661 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc.test + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.ObjectNode +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import io.airbyte.cdk.db.factory.DatabaseDriver +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.integrations.base.Source +import io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils +import io.airbyte.cdk.integrations.source.relationaldb.models.DbState +import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState +import io.airbyte.cdk.testutils.TestDatabase +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.resources.MoreResources +import io.airbyte.commons.util.MoreIterators +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.* +import java.math.BigDecimal +import java.sql.SQLException +import java.util.* +import java.util.function.Consumer +import java.util.stream.Collectors +import org.hamcrest.MatcherAssert +import org.hamcrest.Matchers +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.Mockito + +/** Tests that should be run on all Sources that extend the AbstractJdbcSource. */ +@SuppressFBWarnings( + value = ["MS_SHOULD_BE_FINAL"], + justification = + "The static variables are updated in subclasses for convenience, and cannot be final." +) +abstract class JdbcSourceAcceptanceTest> { + protected lateinit var testdb: T + + protected fun streamName(): String { + return TABLE_NAME + } + + /** + * A valid configuration to connect to a test database. + * + * @return config + */ + protected abstract fun config(): JsonNode + + /** + * An instance of the source that should be tests. + * + * @return abstract jdbc source + */ + protected abstract fun source(): S + + /** + * Creates a TestDatabase instance to be used in [.setup]. + * + * @return TestDatabase instance to use for test case. + */ + protected abstract fun createTestDatabase(): T + + /** + * These tests write records without specifying a namespace (schema name). They will be written + * into whatever the default schema is for the database. When they are discovered they will be + * namespaced by the schema name (e.g. .). Thus the source + * needs to tell the tests what that default schema name is. If the database does not support + * schemas, then database name should used instead. + * + * @return name that will be used to namespace the record. + */ + protected abstract fun supportsSchemas(): Boolean + + protected fun createTableQuery( + tableName: String?, + columnClause: String?, + primaryKeyClause: String + ): String { + return String.format( + "CREATE TABLE %s(%s %s %s)", + tableName, + columnClause, + if (primaryKeyClause == "") "" else ",", + primaryKeyClause + ) + } + + protected fun primaryKeyClause(columns: List): String { + if (columns.isEmpty()) { + return "" + } + + val clause = StringBuilder() + clause.append("PRIMARY KEY (") + for (i in columns.indices) { + clause.append(columns[i]) + if (i != (columns.size - 1)) { + clause.append(",") + } + } + clause.append(")") + return clause.toString() + } + + @BeforeEach + @Throws(Exception::class) + fun setup() { + testdb = createTestDatabase() + if (supportsSchemas()) { + createSchemas() + } + if (testdb!!.databaseDriver == DatabaseDriver.ORACLE) { + testdb!!.with("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'") + } + testdb + .with( + createTableQuery( + getFullyQualifiedTableName(TABLE_NAME), + COLUMN_CLAUSE_WITH_PK, + primaryKeyClause(listOf("id")) + ) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (1, 'picard', '2004-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (2, 'crusher', '2005-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (3, 'vash', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + .with( + createTableQuery( + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK), + COLUMN_CLAUSE_WITHOUT_PK, + "" + ) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (1, 'picard', '2004-10-19')", + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (2, 'crusher', '2005-10-19')", + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (3, 'vash', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK) + ) + .with( + createTableQuery( + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK), + COLUMN_CLAUSE_WITH_COMPOSITE_PK, + primaryKeyClause(listOf("first_name", "last_name")) + ) + ) + .with( + "INSERT INTO %s(first_name, last_name, updated_at) VALUES ('first', 'picard', '2004-10-19')", + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK) + ) + .with( + "INSERT INTO %s(first_name, last_name, updated_at) VALUES ('second', 'crusher', '2005-10-19')", + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK) + ) + .with( + "INSERT INTO %s(first_name, last_name, updated_at) VALUES ('third', 'vash', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK) + ) + } + + protected fun maybeSetShorterConnectionTimeout(config: JsonNode?) { + // Optionally implement this to speed up test cases which will result in a connection + // timeout. + } + + @AfterEach + fun tearDown() { + testdb!!.close() + } + + @Test + @Throws(Exception::class) + fun testSpec() { + val actual = source()!!.spec() + val resourceString = MoreResources.readResource("spec.json") + val expected = Jsons.deserialize(resourceString, ConnectorSpecification::class.java) + + Assertions.assertEquals(expected, actual) + } + + @Test + @Throws(Exception::class) + fun testCheckSuccess() { + val actual = source()!!.check(config()) + val expected = + AirbyteConnectionStatus().withStatus(AirbyteConnectionStatus.Status.SUCCEEDED) + Assertions.assertEquals(expected, actual) + } + + @Test + @Throws(Exception::class) + protected fun testCheckFailure() { + val config = config() + maybeSetShorterConnectionTimeout(config) + (config as ObjectNode).put(JdbcUtils.PASSWORD_KEY, "fake") + val actual = source()!!.check(config) + Assertions.assertEquals(AirbyteConnectionStatus.Status.FAILED, actual!!.status) + } + + @Test + @Throws(Exception::class) + fun testDiscover() { + val actual = filterOutOtherSchemas(source()!!.discover(config())) + val expected = getCatalog(defaultNamespace) + Assertions.assertEquals(expected.streams.size, actual!!.streams.size) + actual.streams.forEach( + Consumer { actualStream: AirbyteStream -> + val expectedStream = + expected.streams + .stream() + .filter { stream: AirbyteStream -> + stream.namespace == actualStream.namespace && + stream.name == actualStream.name + } + .findAny() + Assertions.assertTrue( + expectedStream.isPresent, + String.format("Unexpected stream %s", actualStream.name) + ) + Assertions.assertEquals(expectedStream.get(), actualStream) + } + ) + } + + @Test + @Throws(Exception::class) + protected fun testDiscoverWithNonCursorFields() { + testdb!! + .with( + CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY, + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_CURSOR_TYPE), + COL_CURSOR + ) + .with( + INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY, + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_CURSOR_TYPE) + ) + val actual = filterOutOtherSchemas(source()!!.discover(config())) + val stream = + actual!! + .streams + .stream() + .filter { s: AirbyteStream -> + s.name.equals(TABLE_NAME_WITHOUT_CURSOR_TYPE, ignoreCase = true) + } + .findFirst() + .orElse(null) + Assertions.assertNotNull(stream) + Assertions.assertEquals( + TABLE_NAME_WITHOUT_CURSOR_TYPE.lowercase(Locale.getDefault()), + stream.name.lowercase(Locale.getDefault()) + ) + Assertions.assertEquals(1, stream.supportedSyncModes.size) + Assertions.assertEquals(SyncMode.FULL_REFRESH, stream.supportedSyncModes[0]) + } + + @Test + @Throws(Exception::class) + protected fun testDiscoverWithNullableCursorFields() { + testdb!! + .with( + CREATE_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY, + getFullyQualifiedTableName(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE), + COL_CURSOR + ) + .with( + INSERT_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY, + getFullyQualifiedTableName(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE) + ) + val actual = filterOutOtherSchemas(source()!!.discover(config())) + val stream = + actual!! + .streams + .stream() + .filter { s: AirbyteStream -> + s.name.equals(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE, ignoreCase = true) + } + .findFirst() + .orElse(null) + Assertions.assertNotNull(stream) + Assertions.assertEquals( + TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE.lowercase(Locale.getDefault()), + stream.name.lowercase(Locale.getDefault()) + ) + Assertions.assertEquals(2, stream.supportedSyncModes.size) + Assertions.assertTrue(stream.supportedSyncModes.contains(SyncMode.FULL_REFRESH)) + Assertions.assertTrue(stream.supportedSyncModes.contains(SyncMode.INCREMENTAL)) + } + + protected fun filterOutOtherSchemas(catalog: AirbyteCatalog?): AirbyteCatalog? { + if (supportsSchemas()) { + val filteredCatalog = Jsons.clone(catalog) + filteredCatalog!!.streams = + filteredCatalog.streams + .stream() + .filter { stream: AirbyteStream -> + TEST_SCHEMAS.stream().anyMatch { schemaName: String? -> + stream.namespace.startsWith(schemaName!!) + } + } + .collect(Collectors.toList()) + return filteredCatalog + } else { + return catalog + } + } + + @Test + @Throws(Exception::class) + protected fun testDiscoverWithMultipleSchemas() { + // clickhouse and mysql do not have a concept of schemas, so this test does not make sense + // for them. + when (testdb!!.databaseDriver) { + DatabaseDriver.MYSQL, + DatabaseDriver.CLICKHOUSE, + DatabaseDriver.TERADATA -> return + else -> {} + } + // add table and data to a separate schema. + testdb!! + .with( + "CREATE TABLE %s(id VARCHAR(200) NOT NULL, name VARCHAR(200) NOT NULL)", + RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name) VALUES ('1','picard')", + RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name) VALUES ('2', 'crusher')", + RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name) VALUES ('3', 'vash')", + RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME) + ) + + val actual = source()!!.discover(config()) + + val expected = getCatalog(defaultNamespace) + val catalogStreams: MutableList = ArrayList() + catalogStreams.addAll(expected.streams) + catalogStreams.add( + CatalogHelpers.createAirbyteStream( + TABLE_NAME, + SCHEMA_NAME2, + Field.of(COL_ID, JsonSchemaType.STRING), + Field.of(COL_NAME, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + ) + expected.streams = catalogStreams + // sort streams by name so that we are comparing lists with the same order. + val schemaTableCompare = + Comparator.comparing { stream: AirbyteStream -> stream.namespace + "." + stream.name } + expected.streams.sortWith(schemaTableCompare) + actual!!.streams.sortWith(schemaTableCompare) + Assertions.assertEquals(expected, filterOutOtherSchemas(actual)) + } + + @Test + @Throws(Exception::class) + fun testReadSuccess() { + val actualMessages = + MoreIterators.toList( + source()!!.read(config(), getConfiguredCatalogWithOneStream(defaultNamespace), null) + ) + + setEmittedAtToNull(actualMessages) + val expectedMessages = testMessages + MatcherAssert.assertThat( + expectedMessages, + Matchers.containsInAnyOrder(*actualMessages.toTypedArray()) + ) + MatcherAssert.assertThat( + actualMessages, + Matchers.containsInAnyOrder(*expectedMessages.toTypedArray()) + ) + } + + @Test + @Throws(Exception::class) + protected fun testReadOneColumn() { + val catalog = + CatalogHelpers.createConfiguredAirbyteCatalog( + streamName(), + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.NUMBER) + ) + val actualMessages = MoreIterators.toList(source()!!.read(config(), catalog, null)) + + setEmittedAtToNull(actualMessages) + + val expectedMessages = airbyteMessagesReadOneColumn + Assertions.assertEquals(expectedMessages.size, actualMessages.size) + Assertions.assertTrue(expectedMessages.containsAll(actualMessages)) + Assertions.assertTrue(actualMessages.containsAll(expectedMessages)) + } + + protected val airbyteMessagesReadOneColumn: List + get() { + val expectedMessages = + testMessages + .stream() + .map { `object`: AirbyteMessage -> Jsons.clone(`object`) } + .peek { m: AirbyteMessage -> + (m.record.data as ObjectNode).remove(COL_NAME) + (m.record.data as ObjectNode).remove(COL_UPDATED_AT) + (m.record.data as ObjectNode).replace( + COL_ID, + convertIdBasedOnDatabase(m.record.data[COL_ID].asInt()) + ) + } + .collect(Collectors.toList()) + return expectedMessages + } + + @Test + @Throws(Exception::class) + protected fun testReadMultipleTables() { + val catalog = getConfiguredCatalogWithOneStream(defaultNamespace) + val expectedMessages: MutableList = ArrayList(testMessages) + + for (i in 2..9) { + val streamName2 = streamName() + i + val tableName = getFullyQualifiedTableName(TABLE_NAME + i) + testdb!! + .with(createTableQuery(tableName, "id INTEGER, name VARCHAR(200)", "")) + .with("INSERT INTO %s(id, name) VALUES (1,'picard')", tableName) + .with("INSERT INTO %s(id, name) VALUES (2, 'crusher')", tableName) + .with("INSERT INTO %s(id, name) VALUES (3, 'vash')", tableName) + catalog.streams.add( + CatalogHelpers.createConfiguredAirbyteStream( + streamName2, + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.NUMBER), + Field.of(COL_NAME, JsonSchemaType.STRING) + ) + ) + + expectedMessages.addAll(getAirbyteMessagesSecondSync(streamName2)) + } + + val actualMessages = MoreIterators.toList(source()!!.read(config(), catalog, null)) + + setEmittedAtToNull(actualMessages) + + Assertions.assertEquals(expectedMessages.size, actualMessages.size) + Assertions.assertTrue(expectedMessages.containsAll(actualMessages)) + Assertions.assertTrue(actualMessages.containsAll(expectedMessages)) + } + + protected fun getAirbyteMessagesSecondSync(streamName: String?): List { + return testMessages + .stream() + .map { `object`: AirbyteMessage -> Jsons.clone(`object`) } + .peek { m: AirbyteMessage -> + m.record.stream = streamName + m.record.namespace = defaultNamespace + (m.record.data as ObjectNode).remove(COL_UPDATED_AT) + (m.record.data as ObjectNode).replace( + COL_ID, + convertIdBasedOnDatabase(m.record.data[COL_ID].asInt()) + ) + } + .collect(Collectors.toList()) + } + + @Test + @Throws(Exception::class) + protected fun testTablesWithQuoting() { + val streamForTableWithSpaces = createTableWithSpaces() + + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( + getConfiguredCatalogWithOneStream(defaultNamespace).streams[0], + streamForTableWithSpaces + ) + ) + val actualMessages = MoreIterators.toList(source()!!.read(config(), catalog, null)) + + setEmittedAtToNull(actualMessages) + + val expectedMessages: MutableList = ArrayList(testMessages) + expectedMessages.addAll(getAirbyteMessagesForTablesWithQuoting(streamForTableWithSpaces)) + + Assertions.assertEquals(expectedMessages.size, actualMessages.size) + Assertions.assertTrue(expectedMessages.containsAll(actualMessages)) + Assertions.assertTrue(actualMessages.containsAll(expectedMessages)) + } + + protected fun getAirbyteMessagesForTablesWithQuoting( + streamForTableWithSpaces: ConfiguredAirbyteStream + ): List { + return testMessages + .stream() + .map { `object`: AirbyteMessage -> Jsons.clone(`object`) } + .peek { m: AirbyteMessage -> + m.record.stream = streamForTableWithSpaces.stream.name + (m.record.data as ObjectNode).set( + COL_LAST_NAME_WITH_SPACE, + (m.record.data as ObjectNode).remove(COL_NAME) + ) + (m.record.data as ObjectNode).remove(COL_UPDATED_AT) + (m.record.data as ObjectNode).replace( + COL_ID, + convertIdBasedOnDatabase(m.record.data[COL_ID].asInt()) + ) + } + .collect(Collectors.toList()) + } + + @Test + fun testReadFailure() { + val spiedAbStream = + Mockito.spy(getConfiguredCatalogWithOneStream(defaultNamespace).streams[0]) + val catalog = ConfiguredAirbyteCatalog().withStreams(java.util.List.of(spiedAbStream)) + Mockito.doCallRealMethod().doThrow(RuntimeException()).`when`(spiedAbStream).stream + + Assertions.assertThrows(RuntimeException::class.java) { + source()!!.read(config(), catalog, null) + } + } + + @Test + @Throws(Exception::class) + fun testIncrementalNoPreviousState() { + incrementalCursorCheck(COL_ID, null, "3", testMessages) + } + + @Test + @Throws(Exception::class) + fun testIncrementalIntCheckCursor() { + incrementalCursorCheck(COL_ID, "2", "3", java.util.List.of(testMessages[2])) + } + + @Test + @Throws(Exception::class) + fun testIncrementalStringCheckCursor() { + incrementalCursorCheck( + COL_NAME, + "patent", + "vash", + java.util.List.of(testMessages[0], testMessages[2]) + ) + } + + @Test + @Throws(Exception::class) + fun testIncrementalStringCheckCursorSpaceInColumnName() { + val streamWithSpaces = createTableWithSpaces() + + val expectedRecordMessages = + getAirbyteMessagesCheckCursorSpaceInColumnName(streamWithSpaces) + incrementalCursorCheck( + COL_LAST_NAME_WITH_SPACE, + COL_LAST_NAME_WITH_SPACE, + "patent", + "vash", + expectedRecordMessages, + streamWithSpaces + ) + } + + protected fun getAirbyteMessagesCheckCursorSpaceInColumnName( + streamWithSpaces: ConfiguredAirbyteStream + ): List { + val firstMessage = testMessages[0] + firstMessage.record.stream = streamWithSpaces.stream.name + (firstMessage.record.data as ObjectNode).remove(COL_UPDATED_AT) + (firstMessage.record.data as ObjectNode).set( + COL_LAST_NAME_WITH_SPACE, + (firstMessage.record.data as ObjectNode).remove(COL_NAME) + ) + + val secondMessage = testMessages[2] + secondMessage.record.stream = streamWithSpaces.stream.name + (secondMessage.record.data as ObjectNode).remove(COL_UPDATED_AT) + (secondMessage.record.data as ObjectNode).set( + COL_LAST_NAME_WITH_SPACE, + (secondMessage.record.data as ObjectNode).remove(COL_NAME) + ) + + return java.util.List.of(firstMessage, secondMessage) + } + + @Test + @Throws(Exception::class) + fun testIncrementalDateCheckCursor() { + incrementalDateCheck() + } + + @Throws(Exception::class) + protected fun incrementalDateCheck() { + incrementalCursorCheck( + COL_UPDATED_AT, + "2005-10-18", + "2006-10-19", + java.util.List.of(testMessages[1], testMessages[2]) + ) + } + + @Test + @Throws(Exception::class) + fun testIncrementalCursorChanges() { + incrementalCursorCheck( + COL_ID, + COL_NAME, // cheesing this value a little bit. in the correct implementation this + // initial cursor value should + // be ignored because the cursor field changed. setting it to a value that if used, will + // cause + // records to (incorrectly) be filtered out. + "data", + "vash", + testMessages + ) + } + + @Test + @Throws(Exception::class) + protected fun testReadOneTableIncrementallyTwice() { + val config = config() + val namespace = defaultNamespace + val configuredCatalog = getConfiguredCatalogWithOneStream(namespace) + configuredCatalog.streams.forEach( + Consumer { airbyteStream: ConfiguredAirbyteStream -> + airbyteStream.syncMode = SyncMode.INCREMENTAL + airbyteStream.cursorField = java.util.List.of(COL_ID) + airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND + } + ) + + val actualMessagesFirstSync = + MoreIterators.toList( + source()!!.read( + config, + configuredCatalog, + createEmptyState(streamName(), namespace) + ) + ) + + val stateAfterFirstSyncOptional = + actualMessagesFirstSync + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .findFirst() + Assertions.assertTrue(stateAfterFirstSyncOptional.isPresent) + + executeStatementReadIncrementallyTwice() + + val actualMessagesSecondSync = + MoreIterators.toList( + source()!!.read( + config, + configuredCatalog, + extractState(stateAfterFirstSyncOptional.get()) + ) + ) + + Assertions.assertEquals( + 2, + actualMessagesSecondSync + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } + .count() + .toInt() + ) + val expectedMessages = getExpectedAirbyteMessagesSecondSync(namespace) + + setEmittedAtToNull(actualMessagesSecondSync) + + Assertions.assertEquals(expectedMessages.size, actualMessagesSecondSync.size) + Assertions.assertTrue(expectedMessages.containsAll(actualMessagesSecondSync)) + Assertions.assertTrue(actualMessagesSecondSync.containsAll(expectedMessages)) + } + + protected fun executeStatementReadIncrementallyTwice() { + testdb + .with( + "INSERT INTO %s (id, name, updated_at) VALUES (4, 'riker', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + .with( + "INSERT INTO %s (id, name, updated_at) VALUES (5, 'data', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + } + + protected fun getExpectedAirbyteMessagesSecondSync(namespace: String?): List { + val expectedMessages: MutableList = ArrayList() + expectedMessages.add( + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(namespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_4, + COL_NAME, + "riker", + COL_UPDATED_AT, + "2006-10-19" + ) + ) + ) + ) + ) + expectedMessages.add( + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(namespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_5, + COL_NAME, + "data", + COL_UPDATED_AT, + "2006-10-19" + ) + ) + ) + ) + ) + val state = + DbStreamState() + .withStreamName(streamName()) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + .withCursor("5") + .withCursorRecordCount(1L) + expectedMessages.addAll(createExpectedTestMessages(java.util.List.of(state), 2L)) + return expectedMessages + } + + @Test + @Throws(Exception::class) + protected fun testReadMultipleTablesIncrementally() { + val tableName2 = TABLE_NAME + 2 + val streamName2 = streamName() + 2 + val fqTableName2 = getFullyQualifiedTableName(tableName2) + testdb!! + .with(createTableQuery(fqTableName2, "id INTEGER, name VARCHAR(200)", "")) + .with("INSERT INTO %s(id, name) VALUES (1,'picard')", fqTableName2) + .with("INSERT INTO %s(id, name) VALUES (2, 'crusher')", fqTableName2) + .with("INSERT INTO %s(id, name) VALUES (3, 'vash')", fqTableName2) + + val namespace = defaultNamespace + val configuredCatalog = getConfiguredCatalogWithOneStream(namespace) + configuredCatalog.streams.add( + CatalogHelpers.createConfiguredAirbyteStream( + streamName2, + namespace, + Field.of(COL_ID, JsonSchemaType.NUMBER), + Field.of(COL_NAME, JsonSchemaType.STRING) + ) + ) + configuredCatalog.streams.forEach( + Consumer { airbyteStream: ConfiguredAirbyteStream -> + airbyteStream.syncMode = SyncMode.INCREMENTAL + airbyteStream.cursorField = java.util.List.of(COL_ID) + airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND + } + ) + + val actualMessagesFirstSync = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + createEmptyState(streamName(), namespace) + ) + ) + + // get last state message. + val stateAfterFirstSyncOptional = + actualMessagesFirstSync + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .reduce { first: AirbyteMessage?, second: AirbyteMessage -> second } + Assertions.assertTrue(stateAfterFirstSyncOptional.isPresent) + + // we know the second streams messages are the same as the first minus the updated at + // column. so we + // cheat and generate the expected messages off of the first expected messages. + val secondStreamExpectedMessages = getAirbyteMessagesSecondStreamWithNamespace(streamName2) + + // Represents the state after the first stream has been updated + val expectedStateStreams1 = + java.util.List.of( + DbStreamState() + .withStreamName(streamName()) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + .withCursor("3") + .withCursorRecordCount(1L), + DbStreamState() + .withStreamName(streamName2) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + ) + + // Represents the state after both streams have been updated + val expectedStateStreams2 = + java.util.List.of( + DbStreamState() + .withStreamName(streamName()) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + .withCursor("3") + .withCursorRecordCount(1L), + DbStreamState() + .withStreamName(streamName2) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + .withCursor("3") + .withCursorRecordCount(1L) + ) + + val expectedMessagesFirstSync: MutableList = ArrayList(testMessages) + expectedMessagesFirstSync.add( + createStateMessage(expectedStateStreams1[0], expectedStateStreams1, 3L) + ) + expectedMessagesFirstSync.addAll(secondStreamExpectedMessages) + expectedMessagesFirstSync.add( + createStateMessage(expectedStateStreams2[1], expectedStateStreams2, 3L) + ) + + setEmittedAtToNull(actualMessagesFirstSync) + + Assertions.assertEquals(expectedMessagesFirstSync.size, actualMessagesFirstSync.size) + Assertions.assertTrue(expectedMessagesFirstSync.containsAll(actualMessagesFirstSync)) + Assertions.assertTrue(actualMessagesFirstSync.containsAll(expectedMessagesFirstSync)) + } + + protected fun getAirbyteMessagesSecondStreamWithNamespace( + streamName2: String? + ): List { + return testMessages + .stream() + .map { `object`: AirbyteMessage -> Jsons.clone(`object`) } + .peek { m: AirbyteMessage -> + m.record.stream = streamName2 + (m.record.data as ObjectNode).remove(COL_UPDATED_AT) + (m.record.data as ObjectNode).replace( + COL_ID, + convertIdBasedOnDatabase(m.record.data[COL_ID].asInt()) + ) + } + .collect(Collectors.toList()) + } + + // when initial and final cursor fields are the same. + @Throws(Exception::class) + protected fun incrementalCursorCheck( + cursorField: String, + initialCursorValue: String?, + endCursorValue: String, + expectedRecordMessages: List + ) { + incrementalCursorCheck( + cursorField, + cursorField, + initialCursorValue, + endCursorValue, + expectedRecordMessages + ) + } + + // See https://github.com/airbytehq/airbyte/issues/14732 for rationale and details. + @Test + @Throws(Exception::class) + fun testIncrementalWithConcurrentInsertion() { + val namespace = defaultNamespace + val fullyQualifiedTableName = getFullyQualifiedTableName(TABLE_NAME_AND_TIMESTAMP) + val columnDefinition = + String.format( + "name VARCHAR(200) NOT NULL, %s %s NOT NULL", + COL_TIMESTAMP, + COL_TIMESTAMP_TYPE + ) + + // 1st sync + testdb!! + .with(createTableQuery(fullyQualifiedTableName, columnDefinition, "")) + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "a", + "2021-01-01 00:00:00" + ) + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "b", + "2021-01-01 00:00:00" + ) + + val configuredCatalog = + CatalogHelpers.toDefaultConfiguredCatalog( + AirbyteCatalog() + .withStreams( + java.util.List.of( + CatalogHelpers.createAirbyteStream( + TABLE_NAME_AND_TIMESTAMP, + namespace, + Field.of(COL_NAME, JsonSchemaType.STRING), + Field.of( + COL_TIMESTAMP, + JsonSchemaType.STRING_TIMESTAMP_WITHOUT_TIMEZONE + ) + ) + ) + ) + ) + + configuredCatalog.streams.forEach( + Consumer { airbyteStream: ConfiguredAirbyteStream -> + airbyteStream.syncMode = SyncMode.INCREMENTAL + airbyteStream.cursorField = java.util.List.of(COL_TIMESTAMP) + airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND + } + ) + + val firstSyncActualMessages = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + createEmptyState(TABLE_NAME_AND_TIMESTAMP, namespace) + ) + ) + + // cursor after 1st sync: 2021-01-01 00:00:00, count 2 + val firstSyncStateOptional = + firstSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .findFirst() + Assertions.assertTrue(firstSyncStateOptional.isPresent) + val firstSyncState = getStateData(firstSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP) + Assertions.assertEquals( + firstSyncState["cursor_field"].elements().next().asText(), + COL_TIMESTAMP + ) + Assertions.assertTrue(firstSyncState["cursor"].asText().contains("2021-01-01")) + Assertions.assertTrue(firstSyncState["cursor"].asText().contains("00:00:00")) + Assertions.assertEquals(2L, firstSyncState["cursor_record_count"].asLong()) + + val firstSyncNames = + firstSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } + .map { r: AirbyteMessage -> r.record.data[COL_NAME].asText() } + .toList() + // some databases don't make insertion order guarantee when equal ordering value + if ( + testdb!!.databaseDriver == DatabaseDriver.TERADATA || + testdb!!.databaseDriver == DatabaseDriver.ORACLE + ) { + MatcherAssert.assertThat( + listOf("a", "b"), + Matchers.containsInAnyOrder(*firstSyncNames.toTypedArray()) + ) + } else { + Assertions.assertEquals(listOf("a", "b"), firstSyncNames) + } + + // 2nd sync + testdb!!.with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "c", + "2021-01-02 00:00:00" + ) + + val secondSyncActualMessages = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + createState(TABLE_NAME_AND_TIMESTAMP, namespace, firstSyncState) + ) + ) + + // cursor after 2nd sync: 2021-01-02 00:00:00, count 1 + val secondSyncStateOptional = + secondSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .findFirst() + Assertions.assertTrue(secondSyncStateOptional.isPresent) + val secondSyncState = getStateData(secondSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP) + Assertions.assertEquals( + secondSyncState["cursor_field"].elements().next().asText(), + COL_TIMESTAMP + ) + Assertions.assertTrue(secondSyncState["cursor"].asText().contains("2021-01-02")) + Assertions.assertTrue(secondSyncState["cursor"].asText().contains("00:00:00")) + Assertions.assertEquals(1L, secondSyncState["cursor_record_count"].asLong()) + + val secondSyncNames = + secondSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } + .map { r: AirbyteMessage -> r.record.data[COL_NAME].asText() } + .toList() + Assertions.assertEquals(listOf("c"), secondSyncNames) + + // 3rd sync has records with duplicated cursors + testdb!! + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "d", + "2021-01-02 00:00:00" + ) + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "e", + "2021-01-02 00:00:00" + ) + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "f", + "2021-01-03 00:00:00" + ) + + val thirdSyncActualMessages = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + createState(TABLE_NAME_AND_TIMESTAMP, namespace, secondSyncState) + ) + ) + + // Cursor after 3rd sync is: 2021-01-03 00:00:00, count 1. + val thirdSyncStateOptional = + thirdSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .findFirst() + Assertions.assertTrue(thirdSyncStateOptional.isPresent) + val thirdSyncState = getStateData(thirdSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP) + Assertions.assertEquals( + thirdSyncState["cursor_field"].elements().next().asText(), + COL_TIMESTAMP + ) + Assertions.assertTrue(thirdSyncState["cursor"].asText().contains("2021-01-03")) + Assertions.assertTrue(thirdSyncState["cursor"].asText().contains("00:00:00")) + Assertions.assertEquals(1L, thirdSyncState["cursor_record_count"].asLong()) + + // The c, d, e, f are duplicated records from this sync, because the cursor + // record count in the database is different from that in the state. + val thirdSyncExpectedNames = + thirdSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } + .map { r: AirbyteMessage -> r.record.data[COL_NAME].asText() } + .toList() + + // teradata doesn't make insertion order guarantee when equal ordering value + if (testdb!!.databaseDriver == DatabaseDriver.TERADATA) { + MatcherAssert.assertThat( + listOf("c", "d", "e", "f"), + Matchers.containsInAnyOrder(*thirdSyncExpectedNames.toTypedArray()) + ) + } else { + Assertions.assertEquals(listOf("c", "d", "e", "f"), thirdSyncExpectedNames) + } + } + + protected fun getStateData(airbyteMessage: AirbyteMessage, streamName: String): JsonNode { + for (stream in airbyteMessage.state.data["streams"]) { + if (stream["stream_name"].asText() == streamName) { + return stream + } + } + throw IllegalArgumentException("Stream not found in state message: $streamName") + } + + @Throws(Exception::class) + private fun incrementalCursorCheck( + initialCursorField: String, + cursorField: String, + initialCursorValue: String?, + endCursorValue: String, + expectedRecordMessages: List + ) { + incrementalCursorCheck( + initialCursorField, + cursorField, + initialCursorValue, + endCursorValue, + expectedRecordMessages, + getConfiguredCatalogWithOneStream(defaultNamespace).streams[0] + ) + } + + @Throws(Exception::class) + protected fun incrementalCursorCheck( + initialCursorField: String?, + cursorField: String, + initialCursorValue: String?, + endCursorValue: String?, + expectedRecordMessages: List, + airbyteStream: ConfiguredAirbyteStream + ) { + airbyteStream.syncMode = SyncMode.INCREMENTAL + airbyteStream.cursorField = java.util.List.of(cursorField) + airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND + + val configuredCatalog = + ConfiguredAirbyteCatalog().withStreams(java.util.List.of(airbyteStream)) + + val dbStreamState = buildStreamState(airbyteStream, initialCursorField, initialCursorValue) + + val actualMessages = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + Jsons.jsonNode(createState(java.util.List.of(dbStreamState))) + ) + ) + + setEmittedAtToNull(actualMessages) + + val expectedStreams = + java.util.List.of(buildStreamState(airbyteStream, cursorField, endCursorValue)) + + val expectedMessages: MutableList = ArrayList(expectedRecordMessages) + expectedMessages.addAll( + createExpectedTestMessages(expectedStreams, expectedRecordMessages.size.toLong()) + ) + + Assertions.assertEquals(expectedMessages.size, actualMessages.size) + Assertions.assertTrue(expectedMessages.containsAll(actualMessages)) + Assertions.assertTrue(actualMessages.containsAll(expectedMessages)) + } + + protected fun buildStreamState( + configuredAirbyteStream: ConfiguredAirbyteStream, + cursorField: String?, + cursorValue: String? + ): DbStreamState { + return DbStreamState() + .withStreamName(configuredAirbyteStream.stream.name) + .withStreamNamespace(configuredAirbyteStream.stream.namespace) + .withCursorField(java.util.List.of(cursorField)) + .withCursor(cursorValue) + .withCursorRecordCount(1L) + } + + // get catalog and perform a defensive copy. + protected fun getConfiguredCatalogWithOneStream( + defaultNamespace: String? + ): ConfiguredAirbyteCatalog { + val catalog = CatalogHelpers.toDefaultConfiguredCatalog(getCatalog(defaultNamespace)) + // Filter to only keep the main stream name as configured stream + catalog.withStreams( + catalog.streams + .stream() + .filter { s: ConfiguredAirbyteStream -> s.stream.name == streamName() } + .collect(Collectors.toList()) + ) + return catalog + } + + protected fun getCatalog(defaultNamespace: String?): AirbyteCatalog { + return AirbyteCatalog() + .withStreams( + java.util.List.of( + CatalogHelpers.createAirbyteStream( + TABLE_NAME, + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_NAME, JsonSchemaType.STRING), + Field.of(COL_UPDATED_AT, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(COL_ID))), + CatalogHelpers.createAirbyteStream( + TABLE_NAME_WITHOUT_PK, + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_NAME, JsonSchemaType.STRING), + Field.of(COL_UPDATED_AT, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(emptyList()), + CatalogHelpers.createAirbyteStream( + TABLE_NAME_COMPOSITE_PK, + defaultNamespace, + Field.of(COL_FIRST_NAME, JsonSchemaType.STRING), + Field.of(COL_LAST_NAME, JsonSchemaType.STRING), + Field.of(COL_UPDATED_AT, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey( + java.util.List.of( + java.util.List.of(COL_FIRST_NAME), + java.util.List.of(COL_LAST_NAME) + ) + ) + ) + ) + } + + protected val testMessages: List + get() = + java.util.List.of( + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(defaultNamespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_1, + COL_NAME, + "picard", + COL_UPDATED_AT, + "2004-10-19" + ) + ) + ) + ), + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(defaultNamespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_2, + COL_NAME, + "crusher", + COL_UPDATED_AT, + "2005-10-19" + ) + ) + ) + ), + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(defaultNamespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_3, + COL_NAME, + "vash", + COL_UPDATED_AT, + "2006-10-19" + ) + ) + ) + ) + ) + + protected fun createExpectedTestMessages( + states: List, + numRecords: Long + ): List { + return states + .stream() + .map { s: DbStreamState -> + AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(s.streamNamespace) + .withName(s.streamName) + ) + .withStreamState(Jsons.jsonNode(s)) + ) + .withData(Jsons.jsonNode(DbState().withCdc(false).withStreams(states))) + .withSourceStats( + AirbyteStateStats().withRecordCount(numRecords.toDouble()) + ) + ) + } + .collect(Collectors.toList()) + } + + protected fun createState(states: List): List { + return states + .stream() + .map { s: DbStreamState -> + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(s.streamNamespace) + .withName(s.streamName) + ) + .withStreamState(Jsons.jsonNode(s)) + ) + } + .collect(Collectors.toList()) + } + + @Throws(SQLException::class) + protected fun createTableWithSpaces(): ConfiguredAirbyteStream { + val tableNameWithSpaces = TABLE_NAME_WITH_SPACES + "2" + val streamName2 = tableNameWithSpaces + + testdb!!.getDataSource()!!.connection.use { connection -> + val identifierQuoteString = connection.metaData.identifierQuoteString + connection + .createStatement() + .execute( + createTableQuery( + getFullyQualifiedTableName( + RelationalDbQueryUtils.enquoteIdentifier( + tableNameWithSpaces, + identifierQuoteString + ) + ), + "id INTEGER, " + + RelationalDbQueryUtils.enquoteIdentifier( + COL_LAST_NAME_WITH_SPACE, + identifierQuoteString + ) + + " VARCHAR(200)", + "" + ) + ) + connection + .createStatement() + .execute( + String.format( + "INSERT INTO %s(id, %s) VALUES (1,'picard')", + getFullyQualifiedTableName( + RelationalDbQueryUtils.enquoteIdentifier( + tableNameWithSpaces, + identifierQuoteString + ) + ), + RelationalDbQueryUtils.enquoteIdentifier( + COL_LAST_NAME_WITH_SPACE, + identifierQuoteString + ) + ) + ) + connection + .createStatement() + .execute( + String.format( + "INSERT INTO %s(id, %s) VALUES (2, 'crusher')", + getFullyQualifiedTableName( + RelationalDbQueryUtils.enquoteIdentifier( + tableNameWithSpaces, + identifierQuoteString + ) + ), + RelationalDbQueryUtils.enquoteIdentifier( + COL_LAST_NAME_WITH_SPACE, + identifierQuoteString + ) + ) + ) + connection + .createStatement() + .execute( + String.format( + "INSERT INTO %s(id, %s) VALUES (3, 'vash')", + getFullyQualifiedTableName( + RelationalDbQueryUtils.enquoteIdentifier( + tableNameWithSpaces, + identifierQuoteString + ) + ), + RelationalDbQueryUtils.enquoteIdentifier( + COL_LAST_NAME_WITH_SPACE, + identifierQuoteString + ) + ) + ) + } + return CatalogHelpers.createConfiguredAirbyteStream( + streamName2, + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.NUMBER), + Field.of(COL_LAST_NAME_WITH_SPACE, JsonSchemaType.STRING) + ) + } + + fun getFullyQualifiedTableName(tableName: String): String { + return RelationalDbQueryUtils.getFullyQualifiedTableName(defaultSchemaName, tableName) + } + + protected fun createSchemas() { + if (supportsSchemas()) { + for (schemaName in TEST_SCHEMAS) { + testdb!!.with("CREATE SCHEMA %s;", schemaName) + } + } + } + + private fun convertIdBasedOnDatabase(idValue: Int): JsonNode { + return when (testdb!!.databaseDriver) { + DatabaseDriver.ORACLE, + DatabaseDriver.SNOWFLAKE -> Jsons.jsonNode(BigDecimal.valueOf(idValue.toLong())) + else -> Jsons.jsonNode(idValue) + } + } + + private val defaultSchemaName: String? + get() = if (supportsSchemas()) SCHEMA_NAME else null + + protected val defaultNamespace: String + get() = + when (testdb!!.databaseDriver) { + DatabaseDriver.MYSQL, + DatabaseDriver.CLICKHOUSE, + DatabaseDriver.TERADATA -> testdb!!.databaseName!! + else -> SCHEMA_NAME + } + + /** + * Creates empty state with the provided stream name and namespace. + * + * @param streamName The stream name. + * @param streamNamespace The stream namespace. + * @return [JsonNode] representation of the generated empty state. + */ + protected fun createEmptyState(streamName: String?, streamNamespace: String?): JsonNode { + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(streamName).withNamespace(streamNamespace) + ) + ) + return Jsons.jsonNode(java.util.List.of(airbyteStateMessage)) + } + + protected fun createState( + streamName: String?, + streamNamespace: String?, + stateData: JsonNode? + ): JsonNode { + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(streamName).withNamespace(streamNamespace) + ) + .withStreamState(stateData) + ) + return Jsons.jsonNode(java.util.List.of(airbyteStateMessage)) + } + + protected fun extractState(airbyteMessage: AirbyteMessage): JsonNode { + return Jsons.jsonNode(java.util.List.of(airbyteMessage.state)) + } + + protected fun createStateMessage( + dbStreamState: DbStreamState, + legacyStates: List?, + recordCount: Long + ): AirbyteMessage { + return AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(dbStreamState.streamNamespace) + .withName(dbStreamState.streamName) + ) + .withStreamState(Jsons.jsonNode(dbStreamState)) + ) + .withData(Jsons.jsonNode(DbState().withCdc(false).withStreams(legacyStates))) + .withSourceStats(AirbyteStateStats().withRecordCount(recordCount.toDouble())) + ) + } + + protected fun extractSpecificFieldFromCombinedMessages( + messages: List, + streamName: String, + field: String? + ): List { + return extractStateMessage(messages) + .stream() + .filter { s: AirbyteStateMessage -> s.stream.streamDescriptor.name == streamName } + .map { s: AirbyteStateMessage -> + if (s.stream.streamState[field] != null) s.stream.streamState[field].asText() + else "" + } + .toList() + } + + protected fun filterRecords(messages: List): List { + return messages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } + .collect(Collectors.toList()) + } + + protected fun extractStateMessage(messages: List): List { + return messages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .map { obj: AirbyteMessage -> obj.state } + .collect(Collectors.toList()) + } + + protected fun extractStateMessage( + messages: List, + streamName: String + ): List { + return messages + .stream() + .filter { r: AirbyteMessage -> + r.type == AirbyteMessage.Type.STATE && + r.state.stream.streamDescriptor.name == streamName + } + .map { obj: AirbyteMessage -> obj.state } + .collect(Collectors.toList()) + } + + protected fun createRecord( + stream: String?, + namespace: String?, + data: Map + ): AirbyteMessage { + return AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withData(Jsons.jsonNode(data)) + .withStream(stream) + .withNamespace(namespace) + ) + } + + companion object { + @JvmStatic protected var SCHEMA_NAME: String = "jdbc_integration_test1" + protected var SCHEMA_NAME2: String = "jdbc_integration_test2" + protected var TEST_SCHEMAS: Set = java.util.Set.of(SCHEMA_NAME, SCHEMA_NAME2) + + protected var TABLE_NAME: String = "id_and_name" + protected var TABLE_NAME_WITH_SPACES: String = "id and name" + protected var TABLE_NAME_WITHOUT_PK: String = "id_and_name_without_pk" + protected var TABLE_NAME_COMPOSITE_PK: String = "full_name_composite_pk" + protected var TABLE_NAME_WITHOUT_CURSOR_TYPE: String = "table_without_cursor_type" + protected var TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE: String = "table_with_null_cursor_type" + + // this table is used in testing incremental sync with concurrent insertions + protected var TABLE_NAME_AND_TIMESTAMP: String = "name_and_timestamp" + + protected var COL_ID: String = "id" + protected var COL_NAME: String = "name" + protected var COL_UPDATED_AT: String = "updated_at" + protected var COL_FIRST_NAME: String = "first_name" + protected var COL_LAST_NAME: String = "last_name" + protected var COL_LAST_NAME_WITH_SPACE: String = "last name" + protected var COL_CURSOR: String = "cursor_field" + protected var COL_TIMESTAMP: String = "timestamp" + protected var COL_TIMESTAMP_TYPE: String = "TIMESTAMP" + protected var ID_VALUE_1: Number = 1 + protected var ID_VALUE_2: Number = 2 + protected var ID_VALUE_3: Number = 3 + protected var ID_VALUE_4: Number = 4 + protected var ID_VALUE_5: Number = 5 + + protected var DROP_SCHEMA_QUERY: String = "DROP SCHEMA IF EXISTS %s CASCADE" + protected var COLUMN_CLAUSE_WITH_PK: String = + "id INTEGER, name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" + protected var COLUMN_CLAUSE_WITHOUT_PK: String = + "id INTEGER, name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" + protected var COLUMN_CLAUSE_WITH_COMPOSITE_PK: String = + "first_name VARCHAR(200) NOT NULL, last_name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" + + @JvmField + var CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY: String = "CREATE TABLE %s (%s bit NOT NULL);" + @JvmField var INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY: String = "INSERT INTO %s VALUES(0);" + protected var CREATE_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY: String = + "CREATE TABLE %s (%s VARCHAR(20));" + protected var INSERT_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY: String = + "INSERT INTO %s VALUES('Hello world :)');" + protected var INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY: String = + "INSERT INTO %s (name, timestamp) VALUES ('%s', '%s')" + + protected fun setEmittedAtToNull(messages: Iterable) { + for (actualMessage in messages) { + if (actualMessage.record != null) { + actualMessage.record.emittedAt = null + } + } + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.kt new file mode 100644 index 000000000000..c6b3e45735f2 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.kt @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.source.jdbc.test + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.collect.ImmutableMap +import com.google.common.collect.Lists +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import io.airbyte.cdk.db.factory.DataSourceFactory.create +import io.airbyte.cdk.db.jdbc.DefaultJdbcDatabase +import io.airbyte.cdk.db.jdbc.JdbcDatabase +import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.cdk.integrations.source.jdbc.AbstractJdbcSource +import io.airbyte.commons.functional.CheckedConsumer +import io.airbyte.commons.json.Jsons +import io.airbyte.commons.stream.MoreStreams +import io.airbyte.commons.string.Strings +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.* +import java.math.BigDecimal +import java.nio.ByteBuffer +import java.sql.Connection +import java.util.* +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Runs a "large" amount of data through a JdbcSource to ensure that it streams / chunks records. + */ +// todo (cgardens) - this needs more love and thought. we should be able to test this without having +// to rewrite so much data. it is enough for now to sanity check that our JdbcSources can actually +// handle more data than fits in memory. +@SuppressFBWarnings( + value = ["MS_SHOULD_BE_FINAL"], + justification = + "The static variables are updated in sub classes for convenience, and cannot be final." +) +abstract class JdbcStressTest { + private var bitSet: BitSet? = null + private lateinit var config: JsonNode + private var source: AbstractJdbcSource<*>? = null + + /** + * These tests write records without specifying a namespace (schema name). They will be written + * into whatever the default schema is for the database. When they are discovered they will be + * namespaced by the schema name (e.g. .). Thus the source + * needs to tell the tests what that default schema name is. If the database does not support + * schemas, then database name should used instead. + * + * @return name that will be used to namespace the record. + */ + abstract val defaultSchemaName: Optional + + /** + * A valid configuration to connect to a test database. + * + * @return config + */ + abstract fun getConfig(): JsonNode + + /** + * Full qualified class name of the JDBC driver for the database. + * + * @return driver + */ + abstract val driverClass: String + + /** + * An instance of the source that should be tests. + * + * @return source + */ + abstract fun getSource(): AbstractJdbcSource<*>? + + protected fun createTableQuery(tableName: String?, columnClause: String?): String { + return String.format("CREATE TABLE %s(%s)", tableName, columnClause) + } + + @Throws(Exception::class) + open fun setup() { + LOGGER.info("running for driver:" + driverClass) + bitSet = BitSet(TOTAL_RECORDS.toInt()) + + source = getSource() + streamName = + defaultSchemaName.map { `val`: String -> `val` + "." + TABLE_NAME }.orElse(TABLE_NAME) + config = getConfig() + + val jdbcConfig = source!!.toDatabaseConfig(config) + val database: JdbcDatabase = + DefaultJdbcDatabase( + create( + jdbcConfig[JdbcUtils.USERNAME_KEY].asText(), + if (jdbcConfig.has(JdbcUtils.PASSWORD_KEY)) + jdbcConfig[JdbcUtils.PASSWORD_KEY].asText() + else null, + driverClass, + jdbcConfig[JdbcUtils.JDBC_URL_KEY].asText() + ) + ) + + database.execute( + CheckedConsumer { connection: Connection -> + connection + .createStatement() + .execute( + createTableQuery( + "id_and_name", + String.format("id %s, name VARCHAR(200)", COL_ID_TYPE) + ) + ) + } + ) + val batchCount = TOTAL_RECORDS / BATCH_SIZE + LOGGER.info("writing {} batches of {}", batchCount, BATCH_SIZE) + for (i in 0 until batchCount) { + if (i % 1000 == 0L) LOGGER.info("writing batch: $i") + val insert: MutableList = ArrayList() + for (j in 0 until BATCH_SIZE) { + val recordNumber = (i * BATCH_SIZE + j).toInt() + insert.add(String.format(INSERT_STATEMENT, recordNumber, recordNumber)) + } + + val sql = prepareInsertStatement(insert) + database.execute( + CheckedConsumer { connection: Connection -> + connection.createStatement().execute(sql) + } + ) + } + } + + // todo (cgardens) - restructure these tests so that testFullRefresh() and testIncremental() can + // be + // separate tests. current constrained by only wanting to setup the fixture in the database + // once, + // but it is not trivial to move them to @BeforeAll because it is static and we are doing + // inheritance. Not impossible, just needs to be done thoughtfully and for all JdbcSources. + @Test + @Throws(Exception::class) + fun stressTest() { + testFullRefresh() + testIncremental() + } + + @Throws(Exception::class) + private fun testFullRefresh() { + runTest(configuredCatalogFullRefresh, "full_refresh") + } + + @Throws(Exception::class) + private fun testIncremental() { + runTest(configuredCatalogIncremental, "incremental") + } + + @Throws(Exception::class) + private fun runTest(configuredCatalog: ConfiguredAirbyteCatalog, testName: String) { + LOGGER.info("running stress test for: $testName") + val read: Iterator = + source!!.read(config!!, configuredCatalog, Jsons.jsonNode(emptyMap())) + val actualCount = + MoreStreams.toStream(read) + .filter { m: AirbyteMessage -> m.type == AirbyteMessage.Type.RECORD } + .peek { m: AirbyteMessage -> + if (m.record.data[COL_ID].asLong() % 100000 == 0L) { + LOGGER.info("reading batch: " + m.record.data[COL_ID].asLong() / 1000) + } + } + .peek { m: AirbyteMessage -> assertExpectedMessage(m) } + .count() + var a: ByteBuffer + val expectedRoundedRecordsCount = TOTAL_RECORDS - TOTAL_RECORDS % 1000 + LOGGER.info("expected records count: " + TOTAL_RECORDS) + LOGGER.info("actual records count: $actualCount") + Assertions.assertEquals(expectedRoundedRecordsCount, actualCount, "testing: $testName") + Assertions.assertEquals( + expectedRoundedRecordsCount, + bitSet!!.cardinality().toLong(), + "testing: $testName" + ) + } + + // each is roughly 106 bytes. + private fun assertExpectedMessage(actualMessage: AirbyteMessage) { + val recordNumber = actualMessage.record.data[COL_ID].asLong() + bitSet!!.set(recordNumber.toInt()) + actualMessage.record.emittedAt = null + + val expectedRecordNumber: Number = + if (driverClass.lowercase(Locale.getDefault()).contains("oracle")) + BigDecimal(recordNumber) + else recordNumber + + val expectedMessage = + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName) + .withData( + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + expectedRecordNumber, + COL_NAME, + "picard-$recordNumber" + ) + ) + ) + ) + Assertions.assertEquals(expectedMessage, actualMessage) + } + + private fun prepareInsertStatement(inserts: List): String { + if (driverClass.lowercase(Locale.getDefault()).contains("oracle")) { + return String.format("INSERT ALL %s SELECT * FROM dual", Strings.join(inserts, " ")) + } + return String.format( + "INSERT INTO id_and_name (id, name) VALUES %s", + Strings.join(inserts, ", ") + ) + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(JdbcStressTest::class.java) + + // this will get rounded down to the nearest 1000th. + private const val TOTAL_RECORDS = 10000000L + private const val BATCH_SIZE = 1000 + var TABLE_NAME: String = "id_and_name" + var COL_ID: String = "id" + var COL_NAME: String = "name" + var COL_ID_TYPE: String = "BIGINT" + var INSERT_STATEMENT: String = "(%s,'picard-%s')" + + private var streamName: String? = null + + private val configuredCatalogFullRefresh: ConfiguredAirbyteCatalog + get() = CatalogHelpers.toDefaultConfiguredCatalog(catalog) + + private val configuredCatalogIncremental: ConfiguredAirbyteCatalog + get() = + ConfiguredAirbyteCatalog() + .withStreams( + listOf( + ConfiguredAirbyteStream() + .withStream(catalog.streams[0]) + .withCursorField(listOf(COL_ID)) + .withSyncMode(SyncMode.INCREMENTAL) + .withDestinationSyncMode(DestinationSyncMode.APPEND) + ) + ) + + private val catalog: AirbyteCatalog + get() = + AirbyteCatalog() + .withStreams( + Lists.newArrayList( + CatalogHelpers.createAirbyteStream( + streamName, + Field.of(COL_ID, JsonSchemaType.NUMBER), + Field.of(COL_NAME, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + ) + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.kt new file mode 100644 index 000000000000..136596b3dd43 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.kt @@ -0,0 +1,377 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.api.client.AirbyteApiClient +import io.airbyte.api.client.generated.SourceApi +import io.airbyte.api.client.model.generated.DiscoverCatalogResult +import io.airbyte.api.client.model.generated.SourceDiscoverSchemaWriteRequestBody +import io.airbyte.commons.features.EnvVariableFeatureFlags +import io.airbyte.commons.features.FeatureFlags +import io.airbyte.commons.json.Jsons +import io.airbyte.configoss.* +import io.airbyte.protocol.models.v0.AirbyteCatalog +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConnectorSpecification +import io.airbyte.workers.exception.TestHarnessException +import io.airbyte.workers.general.DefaultCheckConnectionTestHarness +import io.airbyte.workers.general.DefaultDiscoverCatalogTestHarness +import io.airbyte.workers.general.DefaultGetSpecTestHarness +import io.airbyte.workers.helper.CatalogClientConverters +import io.airbyte.workers.helper.ConnectorConfigUpdater +import io.airbyte.workers.helper.EntrypointEnvChecker +import io.airbyte.workers.internal.AirbyteSource +import io.airbyte.workers.internal.DefaultAirbyteSource +import io.airbyte.workers.process.AirbyteIntegrationLauncher +import io.airbyte.workers.process.DockerProcessFactory +import io.airbyte.workers.process.ProcessFactory +import java.nio.file.Files +import java.nio.file.Path +import java.util.* +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers +import org.mockito.Mockito +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * This abstract class contains helpful functionality and boilerplate for testing a source + * connector. + */ +abstract class AbstractSourceConnectorTest { + private var environment: TestDestinationEnv? = null + private var jobRoot: Path? = null + protected var localRoot: Path? = null + private var processFactory: ProcessFactory? = null + + /** Name of the docker image that the tests will run against. */ + protected abstract val imageName: String + + @get:Throws(Exception::class) + protected abstract val config: JsonNode? + /** + * Configuration specific to the integration. Will be passed to integration where + * appropriate in each test. Should be valid. + * + * @return integration-specific configuration + */ + get + + /** + * Function that performs any setup of external resources required for the test. e.g. + * instantiate a postgres database. This function will be called before EACH test. + * + * @param environment + * - information about the test environment. + * @throws Exception + * - can throw any exception, test framework will handle. + */ + @Throws(Exception::class) + protected abstract fun setupEnvironment(environment: TestDestinationEnv?) + + /** + * Function that performs any clean up of external resources required for the test. e.g. delete + * a postgres database. This function will be called after EACH test. It MUST remove all data in + * the destination so that there is no contamination across tests. + * + * @param testEnv + * - information about the test environment. + * @throws Exception + * - can throw any exception, test framework will handle. + */ + @Throws(Exception::class) protected abstract fun tearDown(testEnv: TestDestinationEnv?) + + private lateinit var mAirbyteApiClient: AirbyteApiClient + + private lateinit var mSourceApi: SourceApi + + private var mConnectorConfigUpdater: ConnectorConfigUpdater? = null + + protected val lastPersistedCatalog: AirbyteCatalog + get() = + convertProtocolObject( + CatalogClientConverters.toAirbyteProtocol(discoverWriteRequest.value.catalog), + AirbyteCatalog::class.java + ) + + private val discoverWriteRequest: ArgumentCaptor = + ArgumentCaptor.forClass(SourceDiscoverSchemaWriteRequestBody::class.java) + + @BeforeEach + @Throws(Exception::class) + fun setUpInternal() { + val testDir = Path.of("/tmp/airbyte_tests/") + Files.createDirectories(testDir) + val workspaceRoot = Files.createTempDirectory(testDir, "test") + jobRoot = Files.createDirectories(Path.of(workspaceRoot.toString(), "job")) + localRoot = Files.createTempDirectory(testDir, "output") + environment = TestDestinationEnv(localRoot) + setupEnvironment(environment) + mAirbyteApiClient = Mockito.mock(AirbyteApiClient::class.java) + mSourceApi = Mockito.mock(SourceApi::class.java) + Mockito.`when`(mAirbyteApiClient.getSourceApi()).thenReturn(mSourceApi) + Mockito.`when`(mSourceApi.writeDiscoverCatalogResult(ArgumentMatchers.any())) + .thenReturn(DiscoverCatalogResult().catalogId(CATALOG_ID)) + mConnectorConfigUpdater = Mockito.mock(ConnectorConfigUpdater::class.java) + val envMap = HashMap(TestEnvConfigs().jobDefaultEnvMap) + envMap[EnvVariableFeatureFlags.DEPLOYMENT_MODE] = featureFlags().deploymentMode() + processFactory = + DockerProcessFactory( + workspaceRoot, + workspaceRoot.toString(), + localRoot.toString(), + "host", + envMap + ) + + postSetup() + } + + /** + * Override this method if you want to do any per-test setup that depends on being able to e.g. + * [.runRead]. + */ + @Throws(Exception::class) protected fun postSetup() {} + + @AfterEach + @Throws(Exception::class) + fun tearDownInternal() { + tearDown(environment) + } + + protected fun featureFlags(): FeatureFlags { + return EnvVariableFeatureFlags() + } + + @Throws(TestHarnessException::class) + protected fun runSpec(): ConnectorSpecification { + val spec = + DefaultGetSpecTestHarness( + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ) + ) + .run(JobGetSpecConfig().withDockerImage(imageName), jobRoot) + .spec + return convertProtocolObject(spec, ConnectorSpecification::class.java) + } + + @Throws(Exception::class) + protected fun runCheck(): StandardCheckConnectionOutput { + return DefaultCheckConnectionTestHarness( + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ), + mConnectorConfigUpdater + ) + .run(StandardCheckConnectionInput().withConnectionConfiguration(config), jobRoot) + .checkConnection + } + + @Throws(Exception::class) + protected fun runCheckAndGetStatusAsString(config: JsonNode?): String { + return DefaultCheckConnectionTestHarness( + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ), + mConnectorConfigUpdater + ) + .run(StandardCheckConnectionInput().withConnectionConfiguration(config), jobRoot) + .checkConnection + .status + .toString() + } + + @Throws(Exception::class) + protected fun runDiscover(): UUID { + val toReturn = + DefaultDiscoverCatalogTestHarness( + mAirbyteApiClient, + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ), + mConnectorConfigUpdater + ) + .run( + StandardDiscoverCatalogInput() + .withSourceId(SOURCE_ID.toString()) + .withConnectionConfiguration(config), + jobRoot + ) + .discoverCatalogId + Mockito.verify(mSourceApi).writeDiscoverCatalogResult(discoverWriteRequest.capture()) + return toReturn + } + + @Throws(Exception::class) + protected fun checkEntrypointEnvVariable() { + val entrypoint = + EntrypointEnvChecker.getEntrypointEnvVariable( + processFactory, + JOB_ID, + JOB_ATTEMPT, + jobRoot, + imageName + ) + + Assertions.assertNotNull(entrypoint) + Assertions.assertFalse(entrypoint.isBlank()) + } + + @Throws(Exception::class) + protected fun runRead(configuredCatalog: ConfiguredAirbyteCatalog?): List { + return runRead(configuredCatalog, null) + } + + // todo (cgardens) - assume no state since we are all full refresh right now. + @Throws(Exception::class) + protected fun runRead( + catalog: ConfiguredAirbyteCatalog?, + state: JsonNode? + ): List { + val sourceConfig = + WorkerSourceConfig() + .withSourceConnectionConfiguration(config) + .withState(if (state == null) null else State().withState(state)) + .withCatalog( + convertProtocolObject( + catalog, + io.airbyte.protocol.models.ConfiguredAirbyteCatalog::class.java + ) + ) + + val source: AirbyteSource = + DefaultAirbyteSource( + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ), + featureFlags() + ) + val messages: MutableList = ArrayList() + source.start(sourceConfig, jobRoot) + while (!source.isFinished) { + source.attemptRead().ifPresent { m: io.airbyte.protocol.models.AirbyteMessage -> + messages.add(convertProtocolObject(m, AirbyteMessage::class.java)) + } + } + source.close() + + return messages + } + + @Throws(Exception::class) + protected fun runReadVerifyNumberOfReceivedMsgs( + catalog: ConfiguredAirbyteCatalog, + state: JsonNode?, + mapOfExpectedRecordsCount: MutableMap + ): Map { + val sourceConfig = + WorkerSourceConfig() + .withSourceConnectionConfiguration(config) + .withState(if (state == null) null else State().withState(state)) + .withCatalog( + convertProtocolObject( + catalog, + io.airbyte.protocol.models.ConfiguredAirbyteCatalog::class.java + ) + ) + + val source = prepareAirbyteSource() + source.start(sourceConfig, jobRoot) + + while (!source.isFinished) { + val airbyteMessageOptional = + source.attemptRead().map { m: io.airbyte.protocol.models.AirbyteMessage -> + convertProtocolObject(m, AirbyteMessage::class.java) + } + if ( + airbyteMessageOptional.isPresent && + airbyteMessageOptional.get().type == AirbyteMessage.Type.RECORD + ) { + val airbyteMessage = airbyteMessageOptional.get() + val record = airbyteMessage.record + + val streamName = record.stream + mapOfExpectedRecordsCount[streamName] = mapOfExpectedRecordsCount[streamName]!! - 1 + } + } + source.close() + return mapOfExpectedRecordsCount + } + + private fun prepareAirbyteSource(): AirbyteSource { + val integrationLauncher = + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ) + return DefaultAirbyteSource(integrationLauncher, featureFlags()) + } + + companion object { + protected val LOGGER: Logger = + LoggerFactory.getLogger(AbstractSourceConnectorTest::class.java) + private const val JOB_ID = 0L.toString() + private const val JOB_ATTEMPT = 0 + + private val CATALOG_ID: UUID = UUID.randomUUID() + + private val SOURCE_ID: UUID = UUID.randomUUID() + + private const val CPU_REQUEST_FIELD_NAME = "cpuRequest" + private const val CPU_LIMIT_FIELD_NAME = "cpuLimit" + private const val MEMORY_REQUEST_FIELD_NAME = "memoryRequest" + private const val MEMORY_LIMIT_FIELD_NAME = "memoryLimit" + + private fun convertProtocolObject(v1: V1, klass: Class): V0 { + return Jsons.`object`(Jsons.jsonNode(v1), klass) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.kt new file mode 100644 index 000000000000..573b8852d4fa --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.kt @@ -0,0 +1,426 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.collect.Lists +import io.airbyte.cdk.db.Database +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.* +import java.io.IOException +import java.sql.SQLException +import java.util.function.Consumer +import java.util.function.Function +import java.util.stream.Collectors +import org.apache.commons.lang3.StringUtils +import org.jooq.DSLContext +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * This abstract class contains common helpers and boilerplate for comprehensively testing that all + * data types in a source can be read and handled correctly by the connector and within Airbyte's + * type system. + */ +abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { + protected val testDataHolders: MutableList = ArrayList() + protected var database: Database? = null + + protected val idColumnName: String + /** + * The column name will be used for a PK column in the test tables. Override it if default + * name is not valid for your source. + * + * @return Id column name + */ + get() = "id" + + protected val testColumnName: String + /** + * The column name will be used for a test column in the test tables. Override it if default + * name is not valid for your source. + * + * @return Test column name + */ + get() = "test_column" + + /** + * Setup the test database. All tables and data described in the registered tests will be put + * there. + * + * @return configured test database + * @throws Exception + * - might throw any exception during initialization. + */ + @Throws(Exception::class) protected abstract fun setupDatabase(): Database? + + /** Put all required tests here using method [.addDataTypeTestData] */ + protected abstract fun initTests() + + @Throws(Exception::class) + override fun setupEnvironment(environment: TestDestinationEnv?) { + database = setupDatabase() + initTests() + createTables() + populateTables() + } + + protected abstract val nameSpace: String + /** + * Provide a source namespace. It's allocated place for table creation. It also known ask + * "Database Schema" or "Dataset" + * + * @return source name space + */ + get + + /** + * Test the 'discover' command. TODO (liren): Some existing databases may fail testDataTypes(), + * so it is turned off by default. It should be enabled for all databases eventually. + */ + protected fun testCatalog(): Boolean { + return false + } + + /** + * The test checks that the types from the catalog matches the ones discovered from the source. + * This test is disabled by default. To enable it you need to overwrite testCatalog() function. + */ + @Test + @Throws(Exception::class) + fun testDataTypes() { + if (testCatalog()) { + runDiscover() + val streams = + lastPersistedCatalog.streams + .stream() + .collect( + Collectors.toMap( + Function { obj: AirbyteStream -> obj.name }, + Function { s: AirbyteStream? -> s } + ) + ) + + // testDataHolders should be initialized using the `addDataTypeTestData` function + testDataHolders.forEach( + Consumer { testDataHolder: TestDataHolder -> + val airbyteStream = streams[testDataHolder.nameWithTestPrefix] + val jsonSchemaTypeMap = + Jsons.deserialize( + airbyteStream!!.jsonSchema["properties"][testColumnName].toString(), + MutableMap::class.java + ) as Map + Assertions.assertEquals( + testDataHolder.airbyteType.jsonSchemaTypeMap, + jsonSchemaTypeMap, + "Expected column type for " + testDataHolder.nameWithTestPrefix + ) + } + ) + } + } + + /** + * The test checks that connector can fetch prepared data without failure. It uses a prepared + * catalog and read the source using that catalog. Then makes sure that the expected values are + * the ones inserted in the source. + */ + @Test + @Throws(Exception::class) + fun testDataContent() { + // Class used to make easier the error reporting + class MissedRecords( // Stream that is missing any value + var streamName: + String?, // Which are the values that has not being gathered from the source + var missedValues: List? + ) + + class UnexpectedRecord(val streamName: String, val unexpectedValue: String?) + + val catalog = configuredCatalog + val allMessages = runRead(catalog) + + val recordMessages = + allMessages!! + .stream() + .filter { m: AirbyteMessage? -> m!!.type == AirbyteMessage.Type.RECORD } + .toList() + val expectedValues: MutableMap?> = HashMap() + val missedValuesByStream: MutableMap> = HashMap() + val unexpectedValuesByStream: MutableMap> = HashMap() + val testByName: MutableMap = HashMap() + + // If there is no expected value in the test set we don't include it in the list to be + // asserted + // (even if the table contains records) + testDataHolders.forEach( + Consumer { testDataHolder: TestDataHolder -> + if (!testDataHolder.expectedValues.isEmpty()) { + expectedValues[testDataHolder.nameWithTestPrefix] = + testDataHolder.expectedValues + testByName[testDataHolder.nameWithTestPrefix] = testDataHolder + } else { + LOGGER.warn("Missing expected values for type: " + testDataHolder.sourceType) + } + } + ) + + for (message in recordMessages) { + val streamName = message!!.record.stream + val expectedValuesForStream = expectedValues[streamName] + if (expectedValuesForStream != null) { + val value = getValueFromJsonNode(message.record.data[testColumnName]) + if (!expectedValuesForStream.contains(value)) { + unexpectedValuesByStream.putIfAbsent(streamName, ArrayList()) + unexpectedValuesByStream[streamName]!!.add(UnexpectedRecord(streamName, value)) + } else { + expectedValuesForStream.remove(value) + } + } + } + + // Gather all the missing values, so we don't stop the test in the first missed one + expectedValues.forEach { (streamName: String?, values: List?) -> + if (!values!!.isEmpty()) { + missedValuesByStream.putIfAbsent(streamName, ArrayList()) + missedValuesByStream[streamName]!!.add(MissedRecords(streamName, values)) + } + } + + val errorsByStream: MutableMap> = HashMap() + for (streamName in unexpectedValuesByStream.keys) { + errorsByStream.putIfAbsent(streamName, ArrayList()) + val test = testByName.getValue(streamName) + val unexpectedValues: List = unexpectedValuesByStream[streamName]!! + for (unexpectedValue in unexpectedValues) { + errorsByStream[streamName]!!.add( + "The stream '%s' checking type '%s' initialized at %s got unexpected values: %s".formatted( + streamName, + test.sourceType, + test!!.declarationLocation, + unexpectedValue + ) + ) + } + } + + for (streamName in missedValuesByStream.keys) { + errorsByStream.putIfAbsent(streamName, ArrayList()) + val test = testByName.getValue(streamName) + val missedValues: List = missedValuesByStream[streamName]!! + for (missedValue in missedValues) { + errorsByStream[streamName]!!.add( + "The stream '%s' checking type '%s' initialized at %s is missing values: %s".formatted( + streamName, + test.sourceType, + test!!.declarationLocation, + missedValue + ) + ) + } + } + + val errorStrings: MutableList = ArrayList() + for (errors in errorsByStream.values) { + errorStrings.add(StringUtils.join(errors, "\n")) + } + + Assertions.assertTrue(errorsByStream.isEmpty(), StringUtils.join(errorStrings, "\n")) + } + + @Throws(IOException::class) + protected fun getValueFromJsonNode(jsonNode: JsonNode?): String? { + if (jsonNode != null) { + if (jsonNode.isArray) { + return jsonNode.toString() + } + + var value = + (if (jsonNode.isBinary) jsonNode.binaryValue().contentToString() + else jsonNode.asText()) + value = (if (value != null && value == "null") null else value) + return value + } + return null + } + + /** + * Creates all tables and insert data described in the registered data type tests. + * + * @throws Exception might raise exception if configuration goes wrong or tables creation/insert + * scripts failed. + */ + @Throws(Exception::class) + protected fun createTables() { + for (test in testDataHolders) { + database!!.query { ctx: DSLContext? -> + ctx!!.fetch(test.createSqlQuery) + LOGGER.info("Table {} is created.", test.nameWithTestPrefix) + null + } + } + } + + @Throws(Exception::class) + protected fun populateTables() { + for (test in testDataHolders) { + database!!.query { ctx: DSLContext? -> + test.insertSqlQueries.forEach(Consumer { sql: String? -> ctx!!.fetch(sql) }) + LOGGER.info( + "Inserted {} rows in Ttable {}", + test.insertSqlQueries.size, + test.nameWithTestPrefix + ) + null + } + } + } + + protected val configuredCatalog: ConfiguredAirbyteCatalog + /** + * Configures streams for all registered data type tests. + * + * @return configured catalog + */ + get() = + ConfiguredAirbyteCatalog() + .withStreams( + testDataHolders + .stream() + .map { test: TestDataHolder -> + ConfiguredAirbyteStream() + .withSyncMode(SyncMode.INCREMENTAL) + .withCursorField(Lists.newArrayList(idColumnName)) + .withDestinationSyncMode(DestinationSyncMode.APPEND) + .withStream( + CatalogHelpers.createAirbyteStream( + String.format("%s", test.nameWithTestPrefix), + String.format("%s", nameSpace), + Field.of(idColumnName, JsonSchemaType.INTEGER), + Field.of(testColumnName, test.airbyteType) + ) + .withSourceDefinedCursor(true) + .withSourceDefinedPrimaryKey( + java.util.List.of(java.util.List.of(idColumnName)) + ) + .withSupportedSyncModes( + Lists.newArrayList( + SyncMode.FULL_REFRESH, + SyncMode.INCREMENTAL + ) + ) + ) + } + .collect(Collectors.toList()) + ) + + /** + * Register your test in the run scope. For each test will be created a table with one column of + * specified type. Note! If you register more than one test with the same type name, they will + * be run as independent tests with own streams. + * + * @param test comprehensive data type test + */ + fun addDataTypeTestData(test: TestDataHolder) { + testDataHolders.add(test) + test.setTestNumber( + testDataHolders + .stream() + .filter { t: TestDataHolder -> t.sourceType == test.sourceType } + .count() + ) + test.nameSpace = nameSpace + test.setIdColumnName(idColumnName) + test.setTestColumnName(testColumnName) + test.setDeclarationLocation(Thread.currentThread().stackTrace) + } + + private fun formatCollection(collection: Collection?): String { + return collection!!.stream().map { s: String? -> "`$s`" }.collect(Collectors.joining(", ")) + } + + val markdownTestTable: String + /** + * Builds a table with all registered test cases with values using Markdown syntax (can be + * used in the github). + * + * @return formatted list of test cases + */ + get() { + val table = + StringBuilder() + .append( + "|**Data Type**|**Insert values**|**Expected values**|**Comment**|**Common test result**|\n" + ) + .append("|----|----|----|----|----|\n") + + testDataHolders.forEach( + Consumer { test: TestDataHolder -> + table.append( + String.format( + "| %s | %s | %s | %s | %s |\n", + test.sourceType, + formatCollection(test.values), + formatCollection(test.expectedValues), + "", + "Ok" + ) + ) + } + ) + return table.toString() + } + + protected fun printMarkdownTestTable() { + LOGGER.info(markdownTestTable) + } + + @Throws(SQLException::class) + protected fun createDummyTableWithData(database: Database): ConfiguredAirbyteStream { + database.query { ctx: DSLContext? -> + ctx!!.fetch( + "CREATE TABLE " + + nameSpace + + ".random_dummy_table(id INTEGER PRIMARY KEY, test_column VARCHAR(63));" + ) + ctx.fetch("INSERT INTO " + nameSpace + ".random_dummy_table VALUES (2, 'Random Data');") + null + } + + return ConfiguredAirbyteStream() + .withSyncMode(SyncMode.INCREMENTAL) + .withCursorField(Lists.newArrayList("id")) + .withDestinationSyncMode(DestinationSyncMode.APPEND) + .withStream( + CatalogHelpers.createAirbyteStream( + "random_dummy_table", + nameSpace, + Field.of("id", JsonSchemaType.INTEGER), + Field.of("test_column", JsonSchemaType.STRING) + ) + .withSourceDefinedCursor(true) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(java.util.List.of(listOf("id"))) + ) + } + + protected fun extractStateMessages(messages: List): List { + return messages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .map { obj: AirbyteMessage -> obj.state } + .collect(Collectors.toList()) + } + + companion object { + private val LOGGER: Logger = + LoggerFactory.getLogger(AbstractSourceDatabaseTypeTest::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.kt new file mode 100644 index 000000000000..0ebdc2addeab --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.kt @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.collect.Lists +import com.google.common.collect.Streams +import io.airbyte.commons.io.IOs +import io.airbyte.commons.io.LineGobbler +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConnectorSpecification +import io.airbyte.workers.TestHarnessUtils +import java.io.IOException +import java.nio.file.Files +import java.nio.file.Path +import java.util.* +import java.util.concurrent.TimeUnit +import java.util.function.Consumer +import org.junit.jupiter.api.Assertions +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Extends TestSource such that it can be called using resources pulled from the file system. Will + * also add the ability to execute arbitrary scripts in the next version. + */ +class PythonSourceAcceptanceTest : SourceAcceptanceTest() { + private var testRoot: Path? = null + + @get:Throws(IOException::class) + override val spec: ConnectorSpecification + get() = runExecutable(Command.GET_SPEC, ConnectorSpecification::class.java) + + @get:Throws(IOException::class) + override val config: JsonNode? + get() = runExecutable(Command.GET_CONFIG) + + @get:Throws(IOException::class) + override val configuredCatalog: ConfiguredAirbyteCatalog + get() = runExecutable(Command.GET_CONFIGURED_CATALOG, ConfiguredAirbyteCatalog::class.java) + + @get:Throws(IOException::class) + override val state: JsonNode? + get() = runExecutable(Command.GET_STATE) + + @Throws(IOException::class) + override fun assertFullRefreshMessages(allMessages: List?) { + val regexTests = + Streams.stream( + runExecutable(Command.GET_REGEX_TESTS).withArray("tests").elements() + ) + .map { obj: JsonNode -> obj.textValue() } + .toList() + val stringMessages = + allMessages!! + .stream() + .map { `object`: AirbyteMessage? -> Jsons.serialize(`object`) } + .toList() + LOGGER.info("Running " + regexTests.size + " regex tests...") + regexTests.forEach( + Consumer { regex: String -> + LOGGER.info("Looking for [$regex]") + Assertions.assertTrue( + stringMessages.stream().anyMatch { line: String -> + line.matches(regex.toRegex()) + }, + "Failed to find regex: $regex" + ) + } + ) + } + + override val imageName: String + get() = IMAGE_NAME + + @Throws(Exception::class) + override fun setupEnvironment(environment: TestDestinationEnv?) { + testRoot = + Files.createTempDirectory( + Files.createDirectories(Path.of("/tmp/standard_test")), + "pytest" + ) + runExecutableVoid(Command.SETUP) + } + + @Throws(Exception::class) + override fun tearDown(testEnv: TestDestinationEnv?) { + runExecutableVoid(Command.TEARDOWN) + } + + private enum class Command { + GET_SPEC, + GET_CONFIG, + GET_CONFIGURED_CATALOG, + GET_STATE, + GET_REGEX_TESTS, + SETUP, + TEARDOWN + } + + @Throws(IOException::class) + private fun runExecutable(cmd: Command, klass: Class): T { + return Jsons.`object`(runExecutable(cmd), klass) + } + + @Throws(IOException::class) + private fun runExecutable(cmd: Command): JsonNode { + return Jsons.deserialize(IOs.readFile(runExecutableInternal(cmd), OUTPUT_FILENAME)) + } + + @Throws(IOException::class) + private fun runExecutableVoid(cmd: Command) { + runExecutableInternal(cmd) + } + + @Throws(IOException::class) + private fun runExecutableInternal(cmd: Command): Path? { + LOGGER.info("testRoot = $testRoot") + val dockerCmd: List = + Lists.newArrayList( + "docker", + "run", + "--rm", + "-i", + "-v", + String.format("%s:%s", testRoot, "/test_root"), + "-w", + testRoot.toString(), + "--network", + "host", + PYTHON_CONTAINER_NAME, + cmd.toString().lowercase(Locale.getDefault()), + "--out", + "/test_root" + ) + + val process = ProcessBuilder(dockerCmd).start() + LineGobbler.gobble(process.errorStream) { msg: String? -> LOGGER.error(msg) } + LineGobbler.gobble(process.inputStream) { msg: String? -> LOGGER.info(msg) } + + TestHarnessUtils.gentleClose(process, 1, TimeUnit.MINUTES) + + val exitCode = process.exitValue() + if (exitCode != 0) { + throw RuntimeException("python execution failed") + } + + return testRoot + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(PythonSourceAcceptanceTest::class.java) + private const val OUTPUT_FILENAME = "output.json" + + lateinit var IMAGE_NAME: String + var PYTHON_CONTAINER_NAME: String? = null + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.kt new file mode 100644 index 000000000000..8045d5377a09 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.kt @@ -0,0 +1,457 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.ObjectNode +import com.google.common.collect.Iterables +import com.google.common.collect.Sets +import io.airbyte.commons.json.Jsons +import io.airbyte.configoss.StandardCheckConnectionOutput +import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.stream.Collectors +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { + /** + * TODO hack: Various Singer integrations use cursor fields inclusively i.e: they output records + * whose cursor field >= the provided cursor value. This leads to the last record in a sync to + * always be the first record in the next sync. This is a fine assumption from a product POV + * since we offer at-least-once delivery. But for simplicity, the incremental test suite + * currently assumes that the second incremental read should output no records when provided the + * state from the first sync. This works for many integrations but not some Singer ones, so we + * hardcode the list of integrations to skip over when performing those tests. + */ + private val IMAGES_TO_SKIP_SECOND_INCREMENTAL_READ: Set = + Sets.newHashSet( + "airbyte/source-intercom-singer", + "airbyte/source-exchangeratesapi-singer", + "airbyte/source-hubspot", + "airbyte/source-iterable", + "airbyte/source-marketo-singer", + "airbyte/source-twilio-singer", + "airbyte/source-mixpanel-singer", + "airbyte/source-twilio-singer", + "airbyte/source-braintree-singer", + "airbyte/source-stripe-singer", + "airbyte/source-exchange-rates", + "airbyte/source-stripe", + "airbyte/source-github-singer", + "airbyte/source-gitlab-singer", + "airbyte/source-google-workspace-admin-reports", + "airbyte/source-zendesk-talk", + "airbyte/source-zendesk-support-singer", + "airbyte/source-quickbooks-singer", + "airbyte/source-jira" + ) + + /** + * FIXME: Some sources can't guarantee that there will be no events between two sequential sync + */ + private val IMAGES_TO_SKIP_IDENTICAL_FULL_REFRESHES: Set = + Sets.newHashSet("airbyte/source-google-workspace-admin-reports", "airbyte/source-kafka") + + @get:Throws(Exception::class) + protected abstract val spec: ConnectorSpecification + /** + * Specification for integration. Will be passed to integration where appropriate in each + * test. Should be valid. + * + * @return integration-specific configuration + */ + get + + @get:Throws(Exception::class) + protected abstract val configuredCatalog: ConfiguredAirbyteCatalog + /** + * The catalog to use to validate the output of read operations. This will be used as + * follows: + * + * Full Refresh syncs will be tested on all the input streams which support it Incremental + * syncs: - if the stream declares a source-defined cursor, it will be tested with an + * incremental sync using the default cursor. - if the stream requires a user-defined + * cursor, it will be tested with the input cursor in both cases, the input [.getState] will + * be used as the input state. + * + * @return + * @throws Exception + */ + get + + @get:Throws(Exception::class) + protected abstract val state: JsonNode? + /** @return a JSON file representing the state file to use when testing incremental syncs */ + get + + /** Verify that a spec operation issued to the connector returns a valid spec. */ + @Test + @Throws(Exception::class) + fun testGetSpec() { + Assertions.assertEquals( + spec, + runSpec(), + "Expected spec output by integration to be equal to spec provided by test runner" + ) + } + + /** + * Verify that a check operation issued to the connector with the input config file returns a + * success response. + */ + @Test + @Throws(Exception::class) + fun testCheckConnection() { + Assertions.assertEquals( + StandardCheckConnectionOutput.Status.SUCCEEDED, + runCheck().status, + "Expected check connection operation to succeed" + ) + } + + // /** + // * Verify that when given invalid credentials, that check connection returns a failed + // response. + // * Assume that the {@link TestSource#getFailCheckConfig()} is invalid. + // */ + // @Test + // public void testCheckConnectionInvalidCredentials() throws Exception { + // final OutputAndStatus output = runCheck(); + // assertTrue(output.getOutput().isPresent()); + // assertEquals(Status.FAILED, output.getOutput().get().getStatus()); + // } + /** + * Verifies when a discover operation is run on the connector using the given config file, a + * valid catalog is output by the connector. + */ + @Test + @Throws(Exception::class) + fun testDiscover() { + val discoverOutput = runDiscover() + val discoveredCatalog = lastPersistedCatalog + Assertions.assertNotNull(discoveredCatalog, "Expected discover to produce a catalog") + verifyCatalog(discoveredCatalog) + } + + /** Override this method to check the actual catalog. */ + @Throws(Exception::class) + protected fun verifyCatalog(catalog: AirbyteCatalog?) { + // do nothing by default + } + + /** + * Configuring all streams in the input catalog to full refresh mode, verifies that a read + * operation produces some RECORD messages. + */ + @Test + @Throws(Exception::class) + fun testFullRefreshRead() { + if (!sourceSupportsFullRefresh()) { + LOGGER.info("Test skipped. Source does not support full refresh.") + return + } + + val catalog = withFullRefreshSyncModes(configuredCatalog) + val allMessages = runRead(catalog) + + Assertions.assertFalse( + filterRecords(allMessages).isEmpty(), + "Expected a full refresh sync to produce records" + ) + assertFullRefreshMessages(allMessages) + } + + /** Override this method to perform more specific assertion on the messages. */ + @Throws(Exception::class) + protected open fun assertFullRefreshMessages(allMessages: List?) { + // do nothing by default + } + + /** + * Configuring all streams in the input catalog to full refresh mode, performs two read + * operations on all streams which support full refresh syncs. It then verifies that the RECORD + * messages output from both were identical. + */ + @Test + @Throws(Exception::class) + fun testIdenticalFullRefreshes() { + if (!sourceSupportsFullRefresh()) { + LOGGER.info("Test skipped. Source does not support full refresh.") + return + } + + if ( + IMAGES_TO_SKIP_IDENTICAL_FULL_REFRESHES.contains( + imageName.split(":".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()[0] + ) + ) { + return + } + + val configuredCatalog = withFullRefreshSyncModes(configuredCatalog) + val recordMessagesFirstRun = filterRecords(runRead(configuredCatalog)) + val recordMessagesSecondRun = filterRecords(runRead(configuredCatalog)) + // the worker validates the messages, so we just validate the message, so we do not need to + // validate + // again (as long as we use the worker, which we will not want to do long term). + Assertions.assertFalse( + recordMessagesFirstRun.isEmpty(), + "Expected first full refresh to produce records" + ) + Assertions.assertFalse( + recordMessagesSecondRun.isEmpty(), + "Expected second full refresh to produce records" + ) + + assertSameRecords( + recordMessagesFirstRun, + recordMessagesSecondRun, + "Expected two full refresh syncs to produce the same records" + ) + } + + /** + * This test verifies that all streams in the input catalog which support incremental sync can + * do so correctly. It does this by running two read operations on the connector's Docker image: + * the first takes the configured catalog and config provided to this test as input. It then + * verifies that the sync produced a non-zero number of RECORD and STATE messages. + * + * The second read takes the same catalog and config used in the first test, plus the last STATE + * message output by the first read operation as the input state file. It verifies that no + * records are produced (since we read all records in the first sync). + * + * This test is performed only for streams which support incremental. Streams which do not + * support incremental sync are ignored. If no streams in the input catalog support incremental + * sync, this test is skipped. + */ + @Test + @Throws(Exception::class) + fun testIncrementalSyncWithState() { + if (!sourceSupportsIncremental()) { + return + } + + val configuredCatalog = withSourceDefinedCursors(configuredCatalog) + // only sync incremental streams + configuredCatalog.streams = + configuredCatalog.streams + .stream() + .filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL } + .collect(Collectors.toList()) + + val airbyteMessages = runRead(configuredCatalog, state) + val recordMessages = filterRecords(airbyteMessages) + val stateMessages = + airbyteMessages + .stream() + .filter { m: AirbyteMessage? -> m!!.type == AirbyteMessage.Type.STATE } + .map { obj: AirbyteMessage? -> obj!!.state } + .collect(Collectors.toList()) + Assertions.assertFalse( + recordMessages.isEmpty(), + "Expected the first incremental sync to produce records" + ) + Assertions.assertFalse( + stateMessages.isEmpty(), + "Expected incremental sync to produce STATE messages" + ) + + // TODO validate exact records + if ( + IMAGES_TO_SKIP_SECOND_INCREMENTAL_READ.contains( + imageName.split(":".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()[0] + ) + ) { + return + } + + // when we run incremental sync again there should be no new records. Run a sync with the + // latest + // state message and assert no records were emitted. + var latestState: JsonNode? = null + for (stateMessage in stateMessages) { + if (stateMessage.type == AirbyteStateMessage.AirbyteStateType.STREAM) { + latestState = Jsons.jsonNode(stateMessages) + break + } else if (stateMessage.type == AirbyteStateMessage.AirbyteStateType.GLOBAL) { + latestState = Jsons.jsonNode(java.util.List.of(Iterables.getLast(stateMessages))) + break + } else { + throw RuntimeException("Unknown state type " + stateMessage.type) + } + } + + assert(Objects.nonNull(latestState)) + val secondSyncRecords = filterRecords(runRead(configuredCatalog, latestState)) + Assertions.assertTrue( + secondSyncRecords.isEmpty(), + "Expected the second incremental sync to produce no records when given the first sync's output state." + ) + } + + /** + * If the source does not support incremental sync, this test is skipped. + * + * Otherwise, this test runs two syncs: one where all streams provided in the input catalog sync + * in full refresh mode, and another where all the streams which in the input catalog which + * support incremental, sync in incremental mode (streams which don't support incremental sync + * in full refresh mode). Then, the test asserts that the two syncs produced the same RECORD + * messages. Any other type of message is disregarded. + */ + @Test + @Throws(Exception::class) + fun testEmptyStateIncrementalIdenticalToFullRefresh() { + if (!sourceSupportsIncremental()) { + return + } + + if (!sourceSupportsFullRefresh()) { + LOGGER.info("Test skipped. Source does not support full refresh.") + return + } + + val configuredCatalog = configuredCatalog + val fullRefreshCatalog = withFullRefreshSyncModes(configuredCatalog) + + val fullRefreshRecords = filterRecords(runRead(fullRefreshCatalog)) + val emptyStateRecords = + filterRecords(runRead(configuredCatalog, Jsons.jsonNode(HashMap()))) + Assertions.assertFalse( + fullRefreshRecords.isEmpty(), + "Expected a full refresh sync to produce records" + ) + Assertions.assertFalse( + emptyStateRecords.isEmpty(), + "Expected state records to not be empty" + ) + assertSameRecords( + fullRefreshRecords, + emptyStateRecords, + "Expected a full refresh sync and incremental sync with no input state to produce identical records" + ) + } + + /** + * In order to launch a source on Kubernetes in a pod, we need to be able to wrap the + * entrypoint. The source connector must specify its entrypoint in the AIRBYTE_ENTRYPOINT + * variable. This test ensures that the entrypoint environment variable is set. + */ + @Test + @Throws(Exception::class) + fun testEntrypointEnvVar() { + checkEntrypointEnvVariable() + } + + protected fun withSourceDefinedCursors( + catalog: ConfiguredAirbyteCatalog + ): ConfiguredAirbyteCatalog { + val clone = Jsons.clone(catalog) + for (configuredStream in clone.streams) { + if ( + configuredStream.syncMode == SyncMode.INCREMENTAL && + configuredStream.stream.sourceDefinedCursor != null && + configuredStream.stream.sourceDefinedCursor + ) { + configuredStream.cursorField = configuredStream.stream.defaultCursorField + } + } + return clone + } + + protected fun withFullRefreshSyncModes( + catalog: ConfiguredAirbyteCatalog + ): ConfiguredAirbyteCatalog { + val clone = Jsons.clone(catalog) + for (configuredStream in clone.streams) { + if (configuredStream.stream.supportedSyncModes.contains(SyncMode.FULL_REFRESH)) { + configuredStream.syncMode = SyncMode.FULL_REFRESH + configuredStream.destinationSyncMode = DestinationSyncMode.OVERWRITE + } + } + return clone + } + + @Throws(Exception::class) + private fun sourceSupportsIncremental(): Boolean { + return sourceSupports(SyncMode.INCREMENTAL) + } + + @Throws(Exception::class) + private fun sourceSupportsFullRefresh(): Boolean { + return sourceSupports(SyncMode.FULL_REFRESH) + } + + @Throws(Exception::class) + private fun sourceSupports(syncMode: SyncMode): Boolean { + val catalog = configuredCatalog + for (stream in catalog.streams) { + if (stream.stream.supportedSyncModes.contains(syncMode)) { + return true + } + } + return false + } + + private fun assertSameRecords( + expected: List, + actual: List, + message: String + ) { + val prunedExpected = + expected + .stream() + .map { m: AirbyteRecordMessage -> this.pruneEmittedAt(m) } + .collect(Collectors.toList()) + val prunedActual = + actual + .stream() + .map { m: AirbyteRecordMessage -> this.pruneEmittedAt(m) } + .map { m: AirbyteRecordMessage -> this.pruneCdcMetadata(m) } + .collect(Collectors.toList()) + Assertions.assertEquals(prunedExpected.size, prunedActual.size, message) + Assertions.assertTrue(prunedExpected.containsAll(prunedActual), message) + Assertions.assertTrue(prunedActual.containsAll(prunedExpected), message) + } + + private fun pruneEmittedAt(m: AirbyteRecordMessage): AirbyteRecordMessage { + return Jsons.clone(m).withEmittedAt(null) + } + + private fun pruneCdcMetadata(m: AirbyteRecordMessage): AirbyteRecordMessage { + val clone = Jsons.clone(m) + (clone.data as ObjectNode).remove(CDC_LSN) + (clone.data as ObjectNode).remove(CDC_LOG_FILE) + (clone.data as ObjectNode).remove(CDC_LOG_POS) + (clone.data as ObjectNode).remove(CDC_UPDATED_AT) + (clone.data as ObjectNode).remove(CDC_DELETED_AT) + (clone.data as ObjectNode).remove(CDC_EVENT_SERIAL_NO) + (clone.data as ObjectNode).remove(CDC_DEFAULT_CURSOR) + return clone + } + + companion object { + const val CDC_LSN: String = "_ab_cdc_lsn" + const val CDC_UPDATED_AT: String = "_ab_cdc_updated_at" + const val CDC_DELETED_AT: String = "_ab_cdc_deleted_at" + const val CDC_LOG_FILE: String = "_ab_cdc_log_file" + const val CDC_LOG_POS: String = "_ab_cdc_log_pos" + const val CDC_DEFAULT_CURSOR: String = "_ab_cdc_cursor" + const val CDC_EVENT_SERIAL_NO: String = "_ab_cdc_event_serial_no" + + private val LOGGER: Logger = LoggerFactory.getLogger(SourceAcceptanceTest::class.java) + + protected fun filterRecords( + messages: Collection? + ): List { + return messages!! + .stream() + .filter { m: AirbyteMessage? -> m!!.type == AirbyteMessage.Type.RECORD } + .map { obj: AirbyteMessage? -> obj!!.record } + .collect(Collectors.toList()) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.kt new file mode 100644 index 000000000000..c14e9f7e33a4 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.kt @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import io.airbyte.protocol.models.JsonSchemaType +import java.util.* + +class TestDataHolder +internal constructor( + val sourceType: String?, + val airbyteType: JsonSchemaType, + val values: List, + val expectedValues: MutableList, + private val createTablePatternSql: String, + private val insertPatternSql: String, + private val fullSourceDataType: String? +) { + var nameSpace: String? = null + private var testNumber: Long = 0 + private var idColumnName: String? = null + private var testColumnName: String? = null + + var declarationLocation: String = "" + private set + + class TestDataHolderBuilder internal constructor() { + private var sourceType: String? = null + private lateinit var airbyteType: JsonSchemaType + private val values: MutableList = ArrayList() + private val expectedValues: MutableList = ArrayList() + private var createTablePatternSql: String + private var insertPatternSql: String + private var fullSourceDataType: String? = null + + init { + this.createTablePatternSql = DEFAULT_CREATE_TABLE_SQL + this.insertPatternSql = DEFAULT_INSERT_SQL + } + + /** + * The name of the source data type. Duplicates by name will be tested independently from + * each others. Note that this name will be used for connector setup and table creation. If + * source syntax requires more details (E.g. "varchar" type requires length "varchar(50)"), + * you can additionally set custom data type syntax by + * [TestDataHolderBuilder.fullSourceDataType] method. + * + * @param sourceType source data type name + * @return builder + */ + fun sourceType(sourceType: String?): TestDataHolderBuilder { + this.sourceType = sourceType + if (fullSourceDataType == null) fullSourceDataType = sourceType + return this + } + + /** + * corresponding Airbyte data type. It requires for proper configuration + * [ConfiguredAirbyteStream] + * + * @param airbyteType Airbyte data type + * @return builder + */ + fun airbyteType(airbyteType: JsonSchemaType): TestDataHolderBuilder { + this.airbyteType = airbyteType + return this + } + + /** + * Set custom the create table script pattern. Use it if you source uses untypical table + * creation sql. Default patter described [.DEFAULT_CREATE_TABLE_SQL] Note! The patter + * should contain four String place holders for the: - namespace.table name (as one + * placeholder together) - id column name - test column name - test column data type + * + * @param createTablePatternSql creation table sql pattern + * @return builder + */ + fun createTablePatternSql(createTablePatternSql: String): TestDataHolderBuilder { + this.createTablePatternSql = createTablePatternSql + return this + } + + /** + * Set custom the insert record script pattern. Use it if you source uses untypical insert + * record sql. Default patter described [.DEFAULT_INSERT_SQL] Note! The patter should + * contains two String place holders for the table name and value. + * + * @param insertPatternSql creation table sql pattern + * @return builder + */ + fun insertPatternSql(insertPatternSql: String): TestDataHolderBuilder { + this.insertPatternSql = insertPatternSql + return this + } + + /** + * Allows to set extended data type for the table creation. E.g. The "varchar" type requires + * in MySQL requires length. In this case fullSourceDataType will be "varchar(50)". + * + * @param fullSourceDataType actual string for the column data type description + * @return builder + */ + fun fullSourceDataType(fullSourceDataType: String?): TestDataHolderBuilder { + this.fullSourceDataType = fullSourceDataType + return this + } + + /** + * Adds value(s) to the scope of a corresponding test. The values will be inserted into the + * created table. Note! The value will be inserted into the insert script without any + * transformations. Make sure that the value is in line with the source syntax. + * + * @param insertValue test value + * @return builder + */ + fun addInsertValues(vararg insertValue: String): TestDataHolderBuilder { + values.addAll(Arrays.asList(*insertValue)) + return this + } + + /** + * Adds expected value(s) to the test scope. If you add at least one value, it will check + * that all values are provided by corresponding streamer. + * + * @param expectedValue value which should be provided by a streamer + * @return builder + */ + fun addExpectedValues(vararg expectedValue: String?): TestDataHolderBuilder { + expectedValues.addAll(Arrays.asList(*expectedValue)) + return this + } + + /** + * Add NULL value to the expected value list. If you need to add only one value and it's + * NULL, you have to use this method instead of [.addExpectedValues] + * + * @return builder + */ + fun addNullExpectedValue(): TestDataHolderBuilder { + expectedValues.add(null) + return this + } + + fun build(): TestDataHolder { + return TestDataHolder( + sourceType, + airbyteType, + values, + expectedValues, + createTablePatternSql, + insertPatternSql, + fullSourceDataType + ) + } + } + + fun setTestNumber(testNumber: Long) { + this.testNumber = testNumber + } + + fun setIdColumnName(idColumnName: String?) { + this.idColumnName = idColumnName + } + + fun setTestColumnName(testColumnName: String?) { + this.testColumnName = testColumnName + } + + val nameWithTestPrefix: String + get() = // source type may include space (e.g. "character varying") + nameSpace + "_" + testNumber + "_" + sourceType!!.replace("\\s".toRegex(), "_") + + val createSqlQuery: String + get() = + String.format( + createTablePatternSql, + (if (nameSpace != null) "$nameSpace." else "") + this.nameWithTestPrefix, + idColumnName, + testColumnName, + fullSourceDataType + ) + + fun setDeclarationLocation(declarationLocation: Array) { + this.declarationLocation = Arrays.asList(*declarationLocation).subList(2, 3).toString() + } + + val insertSqlQueries: List + get() { + val insertSqls: MutableList = ArrayList() + var rowId = 1 + for (value in values) { + insertSqls.add( + String.format( + insertPatternSql, + (if (nameSpace != null) "$nameSpace." else "") + this.nameWithTestPrefix, + rowId++, + value + ) + ) + } + return insertSqls + } + + companion object { + private const val DEFAULT_CREATE_TABLE_SQL = + "CREATE TABLE %1\$s(%2\$s INTEGER PRIMARY KEY, %3\$s %4\$s)" + private const val DEFAULT_INSERT_SQL = "INSERT INTO %1\$s VALUES (%2\$s, %3\$s)" + + /** + * The builder allows to setup any comprehensive data type test. + * + * @return builder for setup comprehensive test + */ + fun builder(): TestDataHolderBuilder { + return TestDataHolderBuilder() + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDestinationEnv.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDestinationEnv.kt new file mode 100644 index 000000000000..c58b6150f12a --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDestinationEnv.kt @@ -0,0 +1,8 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import java.nio.file.Path + +class TestDestinationEnv(val localRoot: Path?) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.kt new file mode 100644 index 000000000000..73c05019e02a --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.kt @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import com.google.common.base.Preconditions +import io.airbyte.commons.lang.Exceptions +import io.airbyte.commons.map.MoreMaps +import io.airbyte.commons.version.AirbyteVersion +import java.util.* +import java.util.function.Function +import java.util.function.Supplier +import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * This class passes environment variable to the DockerProcessFactory that runs the source in the + * SourceAcceptanceTest. + */ +// todo (cgardens) - this cloud_deployment implicit interface is going to bite us. +class TestEnvConfigs private constructor(envMap: Map) { + enum class DeploymentMode { + OSS, + CLOUD + } + + enum class WorkerEnvironment { + DOCKER, + KUBERNETES + } + + private val getEnv = Function { key: String -> envMap[key] } + private val getAllEnvKeys = Supplier { envMap.keys } + + constructor() : this(System.getenv()) + + val airbyteRole: String? + // CORE + get() = getEnv(AIRBYTE_ROLE) + + val airbyteVersion: AirbyteVersion + get() = AirbyteVersion(getEnsureEnv(AIRBYTE_VERSION)) + + val deploymentMode: DeploymentMode + get() = + getEnvOrDefault(DEPLOYMENT_MODE, DeploymentMode.OSS) { s: String -> + try { + return@getEnvOrDefault DeploymentMode.valueOf(s) + } catch (e: IllegalArgumentException) { + LOGGER.info(s + " not recognized, defaulting to " + DeploymentMode.OSS) + return@getEnvOrDefault DeploymentMode.OSS + } + } + + val workerEnvironment: WorkerEnvironment + get() = + getEnvOrDefault(WORKER_ENVIRONMENT, WorkerEnvironment.DOCKER) { s: String -> + WorkerEnvironment.valueOf(s.uppercase(Locale.getDefault())) + } + + val jobDefaultEnvMap: Map + /** + * There are two types of environment variables available to the job container: + * + * * Exclusive variables prefixed with JOB_DEFAULT_ENV_PREFIX + * * Shared variables defined in JOB_SHARED_ENVS + */ + get() { + val jobPrefixedEnvMap = + getAllEnvKeys + .get() + .stream() + .filter { key: String -> key.startsWith(JOB_DEFAULT_ENV_PREFIX) } + .collect( + Collectors.toMap( + Function { key: String -> key.replace(JOB_DEFAULT_ENV_PREFIX, "") }, + getEnv + ) + ) + // This method assumes that these shared env variables are not critical to the execution + // of the jobs, and only serve as metadata. So any exception is swallowed and default to + // an empty string. Change this logic if this assumption no longer holds. + val jobSharedEnvMap = + JOB_SHARED_ENVS.entries + .stream() + .collect( + Collectors.toMap( + Function { obj: Map.Entry> -> + obj.key + }, + Function { entry: Map.Entry> + -> + Exceptions.swallowWithDefault( + { Objects.requireNonNullElse(entry.value.apply(this), "") }, + "" + ) + } + ) + ) + return MoreMaps.merge(jobPrefixedEnvMap, jobSharedEnvMap) + } + + fun getEnvOrDefault(key: String, defaultValue: T, parser: Function): T { + return getEnvOrDefault(key, defaultValue, parser, false) + } + + fun getEnvOrDefault( + key: String, + defaultValue: T, + parser: Function, + isSecret: Boolean + ): T { + val value = getEnv.apply(key) + if (value != null && !value.isEmpty()) { + return parser.apply(value) + } else { + LOGGER.info( + "Using default value for environment variable {}: '{}'", + key, + if (isSecret) "*****" else defaultValue + ) + return defaultValue + } + } + + fun getEnv(name: String): String? { + return getEnv.apply(name) + } + + fun getEnsureEnv(name: String): String? { + val value = getEnv(name) + Preconditions.checkArgument(value != null, "'%s' environment variable cannot be null", name) + + return value + } + + companion object { + private val LOGGER: Logger = LoggerFactory.getLogger(TestEnvConfigs::class.java) + + // env variable names + const val AIRBYTE_ROLE: String = "AIRBYTE_ROLE" + const val AIRBYTE_VERSION: String = "AIRBYTE_VERSION" + const val WORKER_ENVIRONMENT: String = "WORKER_ENVIRONMENT" + const val DEPLOYMENT_MODE: String = "DEPLOYMENT_MODE" + const val JOB_DEFAULT_ENV_PREFIX: String = "JOB_DEFAULT_ENV_" + + val JOB_SHARED_ENVS: Map> = + java.util.Map.of( + AIRBYTE_VERSION, + Function { instance: TestEnvConfigs -> instance.airbyteVersion.serialize() }, + AIRBYTE_ROLE, + Function { obj: TestEnvConfigs -> obj.airbyteRole }, + DEPLOYMENT_MODE, + Function { instance: TestEnvConfigs -> instance.deploymentMode.name }, + WORKER_ENVIRONMENT, + Function { instance: TestEnvConfigs -> instance.workerEnvironment.name } + ) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.kt new file mode 100644 index 000000000000..9a9c0b90ee52 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.kt @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import net.sourceforge.argparse4j.ArgumentParsers +import net.sourceforge.argparse4j.inf.ArgumentParserException +import net.sourceforge.argparse4j.inf.Namespace + +/** + * Parse command line arguments and inject them into the test class before running the test. Then + * runs the tests. + */ +object TestPythonSourceMain { + @JvmStatic + fun main(args: Array) { + val parser = + ArgumentParsers.newFor(TestPythonSourceMain::class.java.name) + .build() + .defaultHelp(true) + .description("Run standard source tests") + + parser.addArgument("--imageName").help("Name of the integration image") + + parser.addArgument("--pythonContainerName").help("Name of the python integration image") + + var ns: Namespace? = null + try { + ns = parser.parseArgs(args) + } catch (e: ArgumentParserException) { + parser.handleError(e) + System.exit(1) + } + + val imageName = ns!!.getString("imageName") + val pythonContainerName = ns.getString("pythonContainerName") + + PythonSourceAcceptanceTest.Companion.IMAGE_NAME = imageName + PythonSourceAcceptanceTest.Companion.PYTHON_CONTAINER_NAME = pythonContainerName + + TestRunner.runTestClass(PythonSourceAcceptanceTest::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestRunner.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestRunner.kt new file mode 100644 index 000000000000..c28e8c07b99e --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestRunner.kt @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source + +import java.io.PrintWriter +import java.nio.charset.StandardCharsets +import org.junit.platform.engine.discovery.DiscoverySelectors +import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder +import org.junit.platform.launcher.core.LauncherFactory +import org.junit.platform.launcher.listeners.SummaryGeneratingListener + +object TestRunner { + fun runTestClass(testClass: Class<*>?) { + val request = + LauncherDiscoveryRequestBuilder.request() + .selectors(DiscoverySelectors.selectClass(testClass)) + .build() + + val plan = LauncherFactory.create().discover(request) + val launcher = LauncherFactory.create() + + // Register a listener of your choice + val listener = SummaryGeneratingListener() + + launcher.execute(plan, listener) + + listener.summary.printFailuresTo(PrintWriter(System.out, false, StandardCharsets.UTF_8)) + listener.summary.printTo(PrintWriter(System.out, false, StandardCharsets.UTF_8)) + + if (listener.summary.testsFailedCount > 0) { + println( + "There are failing tests. See https://docs.airbyte.io/contributing-to-airbyte/building-new-connector/standard-source-tests " + + "for more information about the standard source test suite." + ) + System.exit(1) + } + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.kt new file mode 100644 index 000000000000..6b79f0863bc5 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.kt @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source.fs + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.integrations.standardtest.source.SourceAcceptanceTest +import io.airbyte.cdk.integrations.standardtest.source.TestDestinationEnv +import io.airbyte.commons.io.IOs +import io.airbyte.commons.json.Jsons +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConnectorSpecification +import java.nio.file.Path + +/** + * Extends TestSource such that it can be called using resources pulled from the file system. Will + * also add the ability to execute arbitrary scripts in the next version. + */ +class ExecutableTestSource : SourceAcceptanceTest() { + class TestConfig( + val imageName: String, + val specPath: Path, + val configPath: Path, + val catalogPath: Path, + val statePath: Path? + ) + + override val spec: ConnectorSpecification + get() = + Jsons.deserialize( + IOs.readFile(TEST_CONFIG!!.specPath), + ConnectorSpecification::class.java + ) + + override val imageName: String + get() = TEST_CONFIG!!.imageName + + override val config: JsonNode? + get() = Jsons.deserialize(IOs.readFile(TEST_CONFIG!!.configPath)) + + override val configuredCatalog: ConfiguredAirbyteCatalog + get() = + Jsons.deserialize( + IOs.readFile(TEST_CONFIG!!.catalogPath), + ConfiguredAirbyteCatalog::class.java + ) + + override val state: JsonNode? + get() = + if (TEST_CONFIG!!.statePath != null) { + Jsons.deserialize(IOs.readFile(TEST_CONFIG!!.statePath)) + } else { + Jsons.deserialize("{}") + } + + @Throws(Exception::class) + override fun setupEnvironment(environment: TestDestinationEnv?) { + // no-op, for now + } + + @Throws(Exception::class) + override fun tearDown(testEnv: TestDestinationEnv?) { + // no-op, for now + } + + companion object { + var TEST_CONFIG: TestConfig? = null + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.kt new file mode 100644 index 000000000000..b1552e38d7c2 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.kt @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source.fs + +import io.airbyte.cdk.integrations.standardtest.source.TestRunner +import java.nio.file.Path +import net.sourceforge.argparse4j.ArgumentParsers +import net.sourceforge.argparse4j.inf.ArgumentParserException +import net.sourceforge.argparse4j.inf.Namespace +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * Parse command line arguments and inject them into the test class before running the test. Then + * runs the tests. + */ +object TestSourceMain { + private val LOGGER: Logger = LoggerFactory.getLogger(TestSourceMain::class.java) + + @JvmStatic + fun main(args: Array) { + val parser = + ArgumentParsers.newFor(TestSourceMain::class.java.name) + .build() + .defaultHelp(true) + .description("Run standard source tests") + + parser + .addArgument("--imageName") + .required(true) + .help("Name of the source connector image e.g: airbyte/source-mailchimp") + + parser.addArgument("--spec").required(true).help("Path to file that contains spec json") + + parser.addArgument("--config").required(true).help("Path to file that contains config json") + + parser + .addArgument("--catalog") + .required(true) + .help("Path to file that contains catalog json") + + parser.addArgument("--state").required(false).help("Path to the file containing state") + + var ns: Namespace? = null + try { + ns = parser.parseArgs(args) + } catch (e: ArgumentParserException) { + parser.handleError(e) + System.exit(1) + } + + val imageName = ns!!.getString("imageName") + val specFile = ns.getString("spec") + val configFile = ns.getString("config") + val catalogFile = ns.getString("catalog") + val stateFile = ns.getString("state") + + ExecutableTestSource.Companion.TEST_CONFIG = + ExecutableTestSource.TestConfig( + imageName, + Path.of(specFile), + Path.of(configFile), + Path.of(catalogFile), + if (stateFile != null) Path.of(stateFile) else null + ) + + TestRunner.runTestClass(ExecutableTestSource::class.java) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.kt new file mode 100644 index 000000000000..6f94ccff21e8 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.kt @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source.performancetest + +import io.airbyte.cdk.integrations.standardtest.source.AbstractSourceConnectorTest +import io.airbyte.cdk.integrations.standardtest.source.TestDestinationEnv + +/** + * This abstract class contains common methods for both steams - Fill Db scripts and Performance + * tests. + */ +abstract class AbstractSourceBasePerformanceTest : AbstractSourceConnectorTest() { + /** + * The column name will be used for a test column in the test tables. Override it if default + * name is not valid for your source. + */ + protected val testColumnName + get() = TEST_COLUMN_NAME + /** + * The stream name template will be used for a test tables. Override it if default name is not + * valid for your source. + */ + protected val testStreamNameTemplate + get() = TEST_STREAM_NAME_TEMPLATE + @Throws(Exception::class) + override fun setupEnvironment(environment: TestDestinationEnv?) { + // DO NOTHING. Mandatory to override. DB will be setup as part of each test + } + + companion object { + protected const val TEST_COLUMN_NAME: String = "test_column" + protected const val TEST_STREAM_NAME_TEMPLATE: String = "test_%S" + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.kt new file mode 100644 index 000000000000..37e88b9e0269 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.kt @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source.performancetest + +import io.airbyte.cdk.db.Database +import java.util.* +import java.util.stream.Stream +import org.jooq.DSLContext +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** This abstract class contains common methods for Fill Db scripts. */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +abstract class AbstractSourceFillDbWithTestData : AbstractSourceBasePerformanceTest() { + /** + * Setup the test database. All tables and data described in the registered tests will be put + * there. + * + * @return configured test database + * @throws Exception + * - might throw any exception during initialization. + */ + @Throws(Exception::class) protected abstract fun setupDatabase(dbName: String?): Database + + /** + * The test added test data to a new DB. 1. Set DB creds in static variables above 2. Set + * desired number for streams, coolumns and records 3. Run the test + */ + @Disabled + @ParameterizedTest + @MethodSource("provideParameters") + @Throws(Exception::class) + fun addTestData( + dbName: String?, + schemaName: String?, + numberOfDummyRecords: Int, + numberOfBatches: Int, + numberOfColumns: Int, + numberOfStreams: Int + ) { + val database = setupDatabase(dbName) + + database.query { ctx: DSLContext? -> + for (currentSteamNumber in 0 until numberOfStreams) { + val currentTableName = String.format(testStreamNameTemplate, currentSteamNumber) + + ctx!!.fetch(prepareCreateTableQuery(schemaName, numberOfColumns, currentTableName)) + for (i in 0 until numberOfBatches) { + val insertQueryTemplate = + prepareInsertQueryTemplate( + schemaName, + i, + numberOfColumns, + numberOfDummyRecords + ) + ctx.fetch(String.format(insertQueryTemplate, currentTableName)) + } + + c.info("Finished processing for stream $currentSteamNumber") + } + null + } + } + + /** + * This is a data provider for fill DB script,, Each argument's group would be ran as a separate + * test. Set the "testArgs" in test class of your DB in @BeforeTest method. + * + * 1st arg - a name of DB that will be used in jdbc connection string. 2nd arg - a schemaName + * that will be ised as a NameSpace in Configured Airbyte Catalog. 3rd arg - a number of + * expected records retrieved in each stream. 4th arg - a number of columns in each stream\table + * that will be use for Airbyte Cataloq configuration 5th arg - a number of streams to read in + * configured airbyte Catalog. Each stream\table in DB should be names like "test_0", + * "test_1",..., test_n. + * + * Stream.of( Arguments.of("your_db_name", "your_schema_name", 100, 2, 240, 1000) ); + */ + protected abstract fun provideParameters(): Stream? + + protected fun prepareCreateTableQuery( + dbSchemaName: String?, + numberOfColumns: Int, + currentTableName: String? + ): String { + val sj = StringJoiner(",") + for (i in 0 until numberOfColumns) { + sj.add(String.format(" %s%s %s", testColumnName, i, TEST_DB_FIELD_TYPE)) + } + + return String.format( + CREATE_DB_TABLE_TEMPLATE, + dbSchemaName, + currentTableName, + sj.toString() + ) + } + + protected fun prepareInsertQueryTemplate( + dbSchemaName: String?, + batchNumber: Int, + numberOfColumns: Int, + recordsNumber: Int + ): String { + val fieldsNames = StringJoiner(",") + fieldsNames.add("id") + + val baseInsertQuery = StringJoiner(",") + baseInsertQuery.add("id_placeholder") + + for (i in 0 until numberOfColumns) { + fieldsNames.add(testColumnName + i) + baseInsertQuery.add(TEST_VALUE_TEMPLATE_POSTGRES) + } + + val insertGroupValuesJoiner = StringJoiner(",") + + val batchMessages = batchNumber * 100 + + for (currentRecordNumber in batchMessages until recordsNumber + batchMessages) { + insertGroupValuesJoiner.add( + "(" + + baseInsertQuery + .toString() + .replace("id_placeholder".toRegex(), currentRecordNumber.toString()) + + ")" + ) + } + + return String.format( + INSERT_INTO_DB_TABLE_QUERY_TEMPLATE, + dbSchemaName, + "%s", + fieldsNames.toString(), + insertGroupValuesJoiner.toString() + ) + } + + companion object { + private const val CREATE_DB_TABLE_TEMPLATE = + "CREATE TABLE %s.%s(id INTEGER PRIMARY KEY, %s)" + private const val INSERT_INTO_DB_TABLE_QUERY_TEMPLATE = "INSERT INTO %s.%s (%s) VALUES %s" + private const val TEST_DB_FIELD_TYPE = "varchar(10)" + + protected val c: Logger = + LoggerFactory.getLogger(AbstractSourceFillDbWithTestData::class.java) + private const val TEST_VALUE_TEMPLATE_POSTGRES = "\'Value id_placeholder\'" + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.kt new file mode 100644 index 000000000000..4980d3ced7cb --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.kt @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2023 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.cdk.integrations.standardtest.source.performancetest + +import com.fasterxml.jackson.databind.JsonNode +import com.google.common.collect.Lists +import io.airbyte.cdk.integrations.standardtest.source.TestDestinationEnv +import io.airbyte.protocol.models.Field +import io.airbyte.protocol.models.JsonSchemaType +import io.airbyte.protocol.models.v0.* +import java.util.function.Function +import java.util.stream.Collectors +import java.util.stream.Stream +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** This abstract class contains common methods for Performance tests. */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +abstract class AbstractSourcePerformanceTest : AbstractSourceBasePerformanceTest() { + override var config: JsonNode? = null + /** + * The column name will be used for a PK column in the test tables. Override it if default name + * is not valid for your source. + */ + protected val idColumnName: String = "id" + + /** + * Setup the test database. All tables and data described in the registered tests will be put + * there. + * + * @throws Exception + * - might throw any exception during initialization. + */ + @Throws(Exception::class) protected abstract fun setupDatabase(dbName: String?) + + override fun tearDown(testEnv: TestDestinationEnv?) {} + + /** + * This is a data provider for performance tests, Each argument's group would be ran as a + * separate test. Set the "testArgs" in test class of your DB in @BeforeTest method. + * + * 1st arg - a name of DB that will be used in jdbc connection string. 2nd arg - a schemaName + * that will be ised as a NameSpace in Configured Airbyte Catalog. 3rd arg - a number of + * expected records retrieved in each stream. 4th arg - a number of columns in each stream\table + * that will be used for Airbyte Cataloq configuration 5th arg - a number of streams to read in + * configured airbyte Catalog. Each stream\table in DB should be names like "test_0", + * "test_1",..., test_n. + * + * Example: Stream.of( Arguments.of("test1000tables240columns200recordsDb", "dbo", 200, 240, + * 1000), Arguments.of("test5000tables240columns200recordsDb", "dbo", 200, 240, 1000), + * Arguments.of("newregular25tables50000records", "dbo", 50052, 8, 25), + * Arguments.of("newsmall1000tableswith10000rows", "dbo", 10011, 8, 1000) ); + */ + protected abstract fun provideParameters(): Stream? + + @ParameterizedTest + @MethodSource("provideParameters") + @Throws(Exception::class) + fun testPerformance( + dbName: String?, + schemaName: String?, + numberOfDummyRecords: Int, + numberOfColumns: Int, + numberOfStreams: Int + ) { + setupDatabase(dbName) + + val catalog = getConfiguredCatalog(schemaName, numberOfStreams, numberOfColumns) + val mapOfExpectedRecordsCount = + prepareMapWithExpectedRecords(numberOfStreams, numberOfDummyRecords) + val checkStatusMap = + runReadVerifyNumberOfReceivedMsgs(catalog, null, mapOfExpectedRecordsCount) + validateNumberOfReceivedMsgs(checkStatusMap) + } + + protected fun validateNumberOfReceivedMsgs(checkStatusMap: Map?) { + // Iterate through all streams map and check for streams where + val failedStreamsMap = + checkStatusMap!! + .entries + .stream() + .filter { el: Map.Entry -> el.value != 0 } + .collect( + Collectors.toMap( + Function { obj: Map.Entry -> obj.key }, + Function { obj: Map.Entry -> obj.value } + ) + ) + + if (!failedStreamsMap.isEmpty()) { + Assertions.fail("Non all messages were delivered. $failedStreamsMap") + } + c.info("Finished all checks, no issues found for {} of streams", checkStatusMap.size) + } + + protected fun prepareMapWithExpectedRecords( + streamNumber: Int, + expectedRecordsNumberInEachStream: Int + ): MutableMap { + val resultMap: MutableMap = HashMap() // streamName&expected records in stream + + for (currentStream in 0 until streamNumber) { + val streamName = String.format(testStreamNameTemplate, currentStream) + resultMap[streamName] = expectedRecordsNumberInEachStream + } + return resultMap + } + + /** + * Configures streams for all registered data type tests. + * + * @return configured catalog + */ + protected fun getConfiguredCatalog( + nameSpace: String?, + numberOfStreams: Int, + numberOfColumns: Int + ): ConfiguredAirbyteCatalog { + val streams: MutableList = ArrayList() + + for (currentStream in 0 until numberOfStreams) { + // CREATE TABLE test.test_1_int(id INTEGER PRIMARY KEY, test_column int) + + val fields: MutableList = ArrayList() + + fields.add(Field.of(this.idColumnName, JsonSchemaType.NUMBER)) + for (currentColumnNumber in 0 until numberOfColumns) { + fields.add(Field.of(testColumnName + currentColumnNumber, JsonSchemaType.STRING)) + } + + val airbyteStream = + CatalogHelpers.createAirbyteStream( + String.format(testStreamNameTemplate, currentStream), + nameSpace, + fields + ) + .withSourceDefinedCursor(true) + .withSourceDefinedPrimaryKey( + java.util.List.of>( + java.util.List.of(this.idColumnName) + ) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + + val configuredAirbyteStream = + ConfiguredAirbyteStream() + .withSyncMode(SyncMode.INCREMENTAL) + .withCursorField(Lists.newArrayList(this.idColumnName)) + .withDestinationSyncMode(DestinationSyncMode.APPEND) + .withStream(airbyteStream) + + streams.add(configuredAirbyteStream) + } + + return ConfiguredAirbyteCatalog().withStreams(streams) + } + + companion object { + protected val c: Logger = LoggerFactory.getLogger(AbstractSourcePerformanceTest::class.java) + } +}