diff --git a/src/main/java/io/debezium/connector/jdbc/JdbcChangeEventSink.java b/src/main/java/io/debezium/connector/jdbc/JdbcChangeEventSink.java index 8373e849..3ba3446f 100644 --- a/src/main/java/io/debezium/connector/jdbc/JdbcChangeEventSink.java +++ b/src/main/java/io/debezium/connector/jdbc/JdbcChangeEventSink.java @@ -8,6 +8,7 @@ import static io.debezium.connector.jdbc.JdbcSinkConnectorConfig.SchemaEvolutionMode.NONE; import java.sql.SQLException; +import java.sql.Statement; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -21,7 +22,6 @@ import org.hibernate.StatelessSession; import org.hibernate.Transaction; import org.hibernate.dialect.DatabaseVersion; -import org.hibernate.query.NativeQuery; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -265,11 +265,17 @@ private TableDescriptor checkAndApplyTableChangesIfNeeded(TableId tableId, SinkR } private boolean hasTable(TableId tableId) { - return session.doReturningWork((connection) -> dialect.tableExists(connection, tableId)); + return session.doReturningWork((connection) -> { + dialect.prepareConnection(connection); + return dialect.tableExists(connection, tableId); + }); } private TableDescriptor readTable(TableId tableId) { - return session.doReturningWork((connection) -> dialect.readTable(connection, tableId)); + return session.doReturningWork((connection) -> { + dialect.prepareConnection(connection); + return dialect.readTable(connection, tableId); + }); } private TableDescriptor createTable(TableId tableId, SinkRecordDescriptor record) throws SQLException { @@ -284,7 +290,7 @@ private TableDescriptor createTable(TableId tableId, SinkRecordDescriptor record try { final String createSql = dialect.getCreateTableStatement(record, tableId); LOGGER.trace("SQL: {}", createSql); - session.createNativeQuery(createSql, Object.class).executeUpdate(); + executeNativeQuery(createSql); transaction.commit(); } catch (Exception e) { @@ -333,7 +339,7 @@ private TableDescriptor alterTableIfNeeded(TableId tableId, SinkRecordDescriptor try { final String alterSql = dialect.getAlterTableStatement(table, record, missingFields); LOGGER.trace("SQL: {}", alterSql); - session.createNativeQuery(alterSql, Object.class).executeUpdate(); + executeNativeQuery(alterSql); transaction.commit(); } catch (Exception e) { @@ -371,9 +377,7 @@ private void writeTruncate(String sql) throws SQLException { final Transaction transaction = session.beginTransaction(); try { LOGGER.trace("SQL: {}", sql); - final NativeQuery query = session.createNativeQuery(sql, Object.class); - - query.executeUpdate(); + executeNativeQuery(sql); transaction.commit(); } catch (Exception e) { @@ -381,4 +385,13 @@ private void writeTruncate(String sql) throws SQLException { throw e; } } + + private void executeNativeQuery(String sql) throws SQLException { + session.doWork(connection -> { + dialect.prepareConnection(connection); + try (Statement statement = connection.createStatement()) { + statement.execute(sql); + } + }); + } } diff --git a/src/main/java/io/debezium/connector/jdbc/JdbcSinkConnectorConfig.java b/src/main/java/io/debezium/connector/jdbc/JdbcSinkConnectorConfig.java index 68bdd5b1..8f74df42 100644 --- a/src/main/java/io/debezium/connector/jdbc/JdbcSinkConnectorConfig.java +++ b/src/main/java/io/debezium/connector/jdbc/JdbcSinkConnectorConfig.java @@ -66,6 +66,7 @@ public class JdbcSinkConnectorConfig { public static final String DATABASE_TIME_ZONE = "database.time_zone"; public static final String POSTGRES_POSTGIS_SCHEMA = "dialect.postgres.postgis.schema"; public static final String SQLSERVER_IDENTITY_INSERT = "dialect.sqlserver.identity.insert"; + public static final String STARROCKS_CATALOG_NAME = "dialect.starrocks.catalog_name"; public static final String BATCH_SIZE = "batch.size"; public static final String FIELD_INCLUDE_LIST = "field.include.list"; public static final String FIELD_EXCLUDE_LIST = "field.exclude.list"; @@ -277,6 +278,14 @@ public class JdbcSinkConnectorConfig { .withDefault(false) .withDescription("Allowing to insert explicit value for identity column in table for SQLSERVER."); + public static final Field STARROCKS_CATALOG_NAME_FIELD = Field.create(STARROCKS_CATALOG_NAME) + .withDisplayName("Specifies the catalog name to use when connecting to StarRocks") + .withType(Type.STRING) + .withGroup(Field.createGroupEntry(Field.Group.CONNECTOR_ADVANCED, 4)) + .withWidth(ConfigDef.Width.SHORT) + .withImportance(ConfigDef.Importance.LOW) + .withDescription("The default catalog to use when connecting to StarRocks"); + public static final Field BATCH_SIZE_FIELD = Field.create(BATCH_SIZE) .withDisplayName("Specifies how many records to attempt to batch together into the destination table, when possible. " + "You can also configure the connector’s underlying consumer’s max.poll.records using consumer.override.max.poll.records in the connector configuration.") @@ -331,6 +340,7 @@ public class JdbcSinkConnectorConfig { DATABASE_TIME_ZONE_FIELD, POSTGRES_POSTGIS_SCHEMA_FIELD, SQLSERVER_IDENTITY_INSERT_FIELD, + STARROCKS_CATALOG_NAME_FIELD, BATCH_SIZE_FIELD, FIELD_INCLUDE_LIST_FIELD, FIELD_EXCLUDE_LIST_FIELD) @@ -505,6 +515,7 @@ public String getValue() { private final String databaseTimezone; private final String postgresPostgisSchema; private final boolean sqlServerIdentityInsert; + private final String starRocksCatalogName; private FieldNameFilter fieldsFilter; private final long batchSize; @@ -525,6 +536,7 @@ public JdbcSinkConnectorConfig(Map props) { this.databaseTimezone = config.getString(DATABASE_TIME_ZONE_FIELD); this.postgresPostgisSchema = config.getString(POSTGRES_POSTGIS_SCHEMA_FIELD); this.sqlServerIdentityInsert = config.getBoolean(SQLSERVER_IDENTITY_INSERT_FIELD); + this.starRocksCatalogName = config.getString(STARROCKS_CATALOG_NAME_FIELD); this.batchSize = config.getLong(BATCH_SIZE_FIELD); String fieldExcludeList = config.getString(FIELD_EXCLUDE_LIST); @@ -623,6 +635,10 @@ public String getPostgresPostgisSchema() { return postgresPostgisSchema; } + public String getStarRocksCatalogName() { + return starRocksCatalogName; + } + /** makes {@link org.hibernate.cfg.Configuration} from connector config * * @return {@link org.hibernate.cfg.Configuration} diff --git a/src/main/java/io/debezium/connector/jdbc/RecordWriter.java b/src/main/java/io/debezium/connector/jdbc/RecordWriter.java index 7d6c8406..c3d9c0ff 100644 --- a/src/main/java/io/debezium/connector/jdbc/RecordWriter.java +++ b/src/main/java/io/debezium/connector/jdbc/RecordWriter.java @@ -62,6 +62,8 @@ public void write(List records, String sqlStatement) { private Work processBatch(List records, String sqlStatement) { return conn -> { + // Allow doing some prep work for certain dialects/databases + dialect.prepareConnection(conn); try (PreparedStatement prepareStatement = conn.prepareStatement(sqlStatement)) { diff --git a/src/main/java/io/debezium/connector/jdbc/dialect/DatabaseDialect.java b/src/main/java/io/debezium/connector/jdbc/dialect/DatabaseDialect.java index 066956fc..824194fd 100644 --- a/src/main/java/io/debezium/connector/jdbc/dialect/DatabaseDialect.java +++ b/src/main/java/io/debezium/connector/jdbc/dialect/DatabaseDialect.java @@ -369,4 +369,12 @@ default String getTimeQueryBinding() { * @return the list of bounded values */ List bindValue(FieldDescriptor field, int startIndex, Object value); + + /** + * Prepares the connection for use + * + * @param connection the connection, should never be {@code null} + */ + default void prepareConnection(Connection connection) throws SQLException { + } } diff --git a/src/main/java/io/debezium/connector/jdbc/dialect/GeneralDatabaseDialect.java b/src/main/java/io/debezium/connector/jdbc/dialect/GeneralDatabaseDialect.java index be845f53..7b7c599a 100644 --- a/src/main/java/io/debezium/connector/jdbc/dialect/GeneralDatabaseDialect.java +++ b/src/main/java/io/debezium/connector/jdbc/dialect/GeneralDatabaseDialect.java @@ -566,6 +566,7 @@ protected String getDatabaseTimeZone(SessionFactory sessionFactory) { if (query.isPresent()) { try (StatelessSession session = sessionFactory.openStatelessSession()) { return session.doReturningWork((connection) -> { + prepareConnection(connection); try (Statement st = connection.createStatement()) { try (ResultSet rs = st.executeQuery(query.get())) { if (rs.next()) { diff --git a/src/main/java/io/debezium/connector/jdbc/dialect/mysql/MySqlDatabaseDialect.java b/src/main/java/io/debezium/connector/jdbc/dialect/mysql/MySqlDatabaseDialect.java index 2c4f5633..f6e5a00e 100644 --- a/src/main/java/io/debezium/connector/jdbc/dialect/mysql/MySqlDatabaseDialect.java +++ b/src/main/java/io/debezium/connector/jdbc/dialect/mysql/MySqlDatabaseDialect.java @@ -5,8 +5,10 @@ */ package io.debezium.connector.jdbc.dialect.mysql; +import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Statement; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; @@ -19,6 +21,8 @@ import org.hibernate.StatelessSession; import org.hibernate.dialect.Dialect; import org.hibernate.dialect.MySQLDialect; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import io.debezium.connector.jdbc.JdbcSinkConnectorConfig; import io.debezium.connector.jdbc.SinkRecordDescriptor; @@ -37,6 +41,8 @@ */ public class MySqlDatabaseDialect extends GeneralDatabaseDialect { + private static final Logger LOGGER = LoggerFactory.getLogger(MySqlDatabaseDialect.class); + private static final List NO_DEFAULT_VALUE_TYPES = Arrays.asList( "tinytext", "mediumtext", "longtext", "text", "tinyblob", "mediumblob", "longblob"); @@ -65,10 +71,16 @@ public DatabaseDialect instantiate(JdbcSinkConnectorConfig config, SessionFactor } private final boolean connectionTimeZoneSet; + private final String starRocksCatalogName; private MySqlDatabaseDialect(JdbcSinkConnectorConfig config, SessionFactory sessionFactory) { super(config, sessionFactory); + this.starRocksCatalogName = config.getStarRocksCatalogName(); + if (!Strings.isNullOrBlank(this.starRocksCatalogName)) { + LOGGER.info("Using Starrocks default catalog: {}", this.starRocksCatalogName); + } + try (StatelessSession session = sessionFactory.openStatelessSession()) { this.connectionTimeZoneSet = session.doReturningWork((connection) -> connection.getMetaData().getURL().contains("connectionTimeZone=")); } @@ -190,4 +202,14 @@ protected void addColumnDefaultValue(SinkRecordDescriptor.FieldDescriptor field, } super.addColumnDefaultValue(field, columnSpec); } + + @Override + public void prepareConnection(Connection connection) throws SQLException { + if (!Strings.isNullOrBlank(starRocksCatalogName)) { + LOGGER.debug("Setting connection database as 'USE {}'", starRocksCatalogName); + try (Statement statement = connection.createStatement()) { + statement.execute(String.format("USE %s", starRocksCatalogName)); + } + } + } }