Skip to content

Commit

Permalink
Destination snowflake: mostly done implementations for sqlgenerator+d…
Browse files Browse the repository at this point in the history
…estinationhandler (#28677)

* csv sheet generator supports 1s1t

* create+insert raw tables 1s1t

* add skeletons

* start writing tests

* progress in creating raw tables

* fix tests

* add s3 test; better csv generation

* handle case-sensitive column names

* also add gcs test

* hook T+D into the destination

* fix redshift; simplify

* Delete unused files?

* disable test; enable cleanup

* initialize config singleton in tests

* logistics

* header

* simplify

* fix unit tests

* correctly disable tests

* use default null for loaded_at

* fix test

* autoformat

* cython >.>

* more singleton init

* literally how?

* basic destinationhandler impl

* use raw string for type >.>

* add toDialectType

* basic createTable impl

* better sql query

* comment

* unused variables

* recorddiffer can be case-sensitive

* misc fixes

* add expected_records

* move constants to base-java

* use ternary

* fix tests

* resolve todo

* T+D can trigger on first commit

* fix test teardown

* implement softReset

* implement overwriteFinalTable

* better type stuff; check table schema

* fix

* derp

* implement updateTable?

* derp

* random wip stuff

* fix insertRaw

* theoretically implement stuff?

* stuff

* put suffix at the end

* different uuids

* fix expected records

* move tdtest resources into dat folder

* use resource files

* stuff

* move code around

* more stuff

* rename final table

* stuff

* cdc immediate deletion

* cdcComplexUpdate

* cleanup

* botched rebase

* more tests

* move back to old file

* Automated Commit - Format and Process Resources Changes

* add comments

* Automated Commit - Format and Process Resources Changes

* fix merge

* move expected_records into dat folder

* wip implement sqlgenerator test

* basic implementation

* tons of fixes, still tons more to go

* more stuff

* fix more things

* hacky convert temporal types to varchar

* test data fix

* fix variant parsing

* fix number

* fix time parsing; fix test data

* typo

* fix input data

* progress

* switch back to float

* add more test files

* swap int -> number

* fix PK null check

* fix overwriteTable

* better test

* Automated Commit - Format and Process Resources Changes

* type aliases, one more test

* also verify numeric precision/scale

* logistics

---------

Co-authored-by: edgao <edgao@users.noreply.github.com>
  • Loading branch information
edgao and edgao authored Aug 7, 2023
1 parent 7454935 commit 2866ed6
Show file tree
Hide file tree
Showing 42 changed files with 1,039 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class StagingDatabaseCsvSheetGenerator implements CsvSheetGenerator {

public StagingDatabaseCsvSheetGenerator() {
use1s1t = TypingAndDedupingFlag.isDestinationV2();
this.header = use1s1t ? JavaBaseConstants.V2_COLUMN_NAMES : JavaBaseConstants.LEGACY_COLUMN_NAMES;
this.header = use1s1t ? JavaBaseConstants.V2_RAW_TABLE_COLUMN_NAMES : JavaBaseConstants.LEGACY_RAW_TABLE_COLUMNS;
}

// TODO is this even used anywhere?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ private JavaBaseConstants() {}
public static final String COLUMN_NAME_AB_ID = "_airbyte_ab_id";
public static final String COLUMN_NAME_EMITTED_AT = "_airbyte_emitted_at";
public static final String COLUMN_NAME_DATA = "_airbyte_data";
public static final List<String> LEGACY_COLUMN_NAMES = List.of(
public static final List<String> LEGACY_RAW_TABLE_COLUMNS = List.of(
COLUMN_NAME_AB_ID,
COLUMN_NAME_DATA,
COLUMN_NAME_EMITTED_AT);
Expand All @@ -30,11 +30,14 @@ private JavaBaseConstants() {}
public static final String COLUMN_NAME_AB_RAW_ID = "_airbyte_raw_id";
public static final String COLUMN_NAME_AB_LOADED_AT = "_airbyte_loaded_at";
public static final String COLUMN_NAME_AB_EXTRACTED_AT = "_airbyte_extracted_at";
public static final List<String> V2_COLUMN_NAMES = List.of(
public static final List<String> V2_RAW_TABLE_COLUMN_NAMES = List.of(
COLUMN_NAME_AB_RAW_ID,
COLUMN_NAME_AB_EXTRACTED_AT,
COLUMN_NAME_AB_LOADED_AT,
COLUMN_NAME_DATA);
public static final List<String> V2_FINAL_TABLE_METADATA_COLUMNS = List.of(
COLUMN_NAME_AB_RAW_ID,
COLUMN_NAME_AB_EXTRACTED_AT);

public static final String AIRBYTE_NAMESPACE_SCHEMA = "airbyte";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public abstract class BaseSqlGeneratorIntegrationTest<DialectTableDefinition> {
* Do any setup work to create a namespace for this test run. For example, this might create a
* BigQuery dataset, or a Snowflake schema.
*/
protected abstract void createNamespace(String namespace);
protected abstract void createNamespace(String namespace) throws Exception;

/**
* Create a raw table using the StreamId's rawTableId.
Expand Down Expand Up @@ -146,7 +146,7 @@ protected abstract void insertFinalTableRecords(boolean includeCdcDeletedAt, Str
* Clean up all resources in the namespace. For example, this might delete the BigQuery dataset
* created in {@link #createNamespace(String)}.
*/
protected abstract void teardownNamespace(String namespace);
protected abstract void teardownNamespace(String namespace) throws Exception;

/**
* This test implementation is extremely destination-specific, but all destinations must implement
Expand All @@ -159,7 +159,7 @@ protected abstract void insertFinalTableRecords(boolean includeCdcDeletedAt, Str
public abstract void testCreateTableIncremental() throws Exception;

@BeforeEach
public void setup() {
public void setup() throws Exception {
generator = getSqlGenerator();
destinationHandler = getDestinationHandler();
ColumnId id1 = generator.buildColumnId("id1");
Expand Down Expand Up @@ -229,7 +229,7 @@ public void setup() {
}

@AfterEach
public void teardown() {
public void teardown() throws Exception {
teardownNamespace(namespace);
}

Expand Down Expand Up @@ -337,6 +337,10 @@ public void incrementalAppend() throws Exception {
dumpFinalTableRecords(streamId, ""));
}

/**
* Create a nonempty users_final_tmp table. Overwrite users_final from users_final_tmp. Verify that
* users_final now exists and contains nonzero records.
*/
@Test
public void overwriteFinalTable() throws Exception {
createFinalTable(false, streamId, "_tmp");
Expand All @@ -357,9 +361,7 @@ public void overwriteFinalTable() throws Exception {
final String sql = generator.overwriteFinalTable(streamId, "_tmp");
destinationHandler.execute(sql);

DIFFER.diffFinalTableRecords(
records,
dumpFinalTableRecords(streamId, ""));
assertEquals(1, dumpFinalTableRecords(streamId, "").size());
}

@Test
Expand Down Expand Up @@ -632,15 +634,17 @@ private void verifyRecordCounts(int expectedRawRecords,
assertAll(
() -> assertEquals(
expectedRawRecords,
actualRawRecords.size()),
actualRawRecords.size(),
"Raw record count was incorrect"),
() -> assertEquals(
0,
actualRawRecords.stream()
.filter(record -> !record.hasNonNull("_airbyte_loaded_at"))
.count()),
() -> assertEquals(
expectedFinalRecords,
actualFinalRecords.size()));
actualFinalRecords.size(),
"Final record count was incorrect"));
}

}
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
{"_airbyte_raw_id": "d5790c04-52df-42f3-8f77-a543268822a7", "_airbyte_extracted_at": "2022-12-31T00:00:00Z", "_airbyte_meta": {}, "id1": 1, "id2": 100, "updated_at": "2022-12-31T00:00:00Z", "string": "spooky ghost"}
{"_airbyte_raw_id": "e3b03d92-0f7c-49e5-b203-573dbb7bd1cb", "_airbyte_extracted_at": "2022-12-31T00:00:00Z", "_airbyte_meta": {}, "id1": 5, "id2": 100, "updated_at": "2022-12-31T01:00:00Z", "string": "will be deleted'"}
{"_airbyte_raw_id": "e3b03d92-0f7c-49e5-b203-573dbb7bd1cb", "_airbyte_extracted_at": "2022-12-31T00:00:00Z", "_airbyte_meta": {}, "id1": 5, "id2": 100, "updated_at": "2022-12-31T01:00:00Z", "string": "will be deleted"}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ public static void copyIntoTableFromStage(final JdbcDatabase database,
AirbyteStreamNameNamespacePair streamId = new AirbyteStreamNameNamespacePair(streamNamespace, streamName);
if (!typerDeduperValve.containsKey(streamId)) {
typerDeduperValve.addStream(streamId);
} else if (typerDeduperValve.readyToTypeAndDedupe(streamId)) {
}
if (typerDeduperValve.readyToTypeAndDedupe(streamId)) {
typerDeduper.typeAndDedupe(streamId.getNamespace(), streamId.getName());
typerDeduperValve.updateTimeAndIncreaseInterval(streamId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ from unnest([
protected void insertRawTableRecords(StreamId streamId, List<JsonNode> records) throws InterruptedException {
String recordsText = records.stream()
// For each record, convert it to a string like "(rawId, extractedAt, loadedAt, data)"
.map(record -> JavaBaseConstants.V2_COLUMN_NAMES.stream()
.map(record -> JavaBaseConstants.V2_RAW_TABLE_COLUMN_NAMES.stream()
.map(record::get)
.map(r -> {
if (r == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ RUN tar xf ${APPLICATION}.tar --strip-components=1
ENV ENABLE_SENTRY true


LABEL io.airbyte.version=1.2.8
LABEL io.airbyte.version=1.2.9
LABEL io.airbyte.name=airbyte/destination-snowflake

ENV AIRBYTE_ENTRYPOINT "/airbyte/run_with_normalization.sh"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies {
// TODO (edgao) explain how you built this jar
implementation files('lib/snowflake-jdbc.jar')
implementation 'org.apache.commons:commons-csv:1.4'
implementation 'org.apache.commons:commons-text:1.10.0'
implementation 'com.github.alexmojaki:s3-stream-upload:2.2.2'
implementation "io.aesy:datasize:1.0.0"
implementation 'com.zaxxer:HikariCP:5.0.1'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
dockerImageTag: 1.2.8
dockerImageTag: 1.2.9
dockerRepository: airbyte/destination-snowflake
githubIssueLabel: destination-snowflake
icon: snowflake.svg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.airbyte.commons.json.Jsons;
import io.airbyte.db.factory.DataSourceFactory;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.db.jdbc.JdbcUtils;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.Destination;
import io.airbyte.integrations.base.TypingAndDedupingFlag;
Expand All @@ -34,7 +35,6 @@
import io.airbyte.integrations.destination.s3.csv.CsvSerializedBuffer;
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeDestinationHandler;
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeSqlGenerator;
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeTableDefinition;
import io.airbyte.integrations.destination.staging.StagingConsumerFactory;
import io.airbyte.protocol.models.v0.AirbyteConnectionStatus;
import io.airbyte.protocol.models.v0.AirbyteMessage;
Expand Down Expand Up @@ -123,7 +123,7 @@ public static Storage getStorageClient(final GcsConfig gcsConfig) throws IOExcep
}

@Override
protected DataSource getDataSource(final JsonNode config) {
public DataSource getDataSource(final JsonNode config) {
return SnowflakeDatabase.createDataSource(config, airbyteEnvironment);
}

Expand Down Expand Up @@ -152,17 +152,20 @@ public AirbyteMessageConsumer getConsumer(final JsonNode config,
SnowflakeSqlGenerator sqlGenerator = new SnowflakeSqlGenerator();
final ParsedCatalog parsedCatalog;
TyperDeduper typerDeduper;
JdbcDatabase database = getDatabase(getDataSource(config));
if (TypingAndDedupingFlag.isDestinationV2()) {
String databaseName = config.get(JdbcUtils.DATABASE_KEY).asText();
SnowflakeDestinationHandler snowflakeDestinationHandler = new SnowflakeDestinationHandler(databaseName, database);
parsedCatalog = new CatalogParser(sqlGenerator).parseCatalog(catalog);
typerDeduper = new DefaultTyperDeduper<>(sqlGenerator, new SnowflakeDestinationHandler(getDatabase(getDataSource(config))), parsedCatalog);
typerDeduper = new DefaultTyperDeduper<>(sqlGenerator, snowflakeDestinationHandler, parsedCatalog);
} else {
parsedCatalog = null;
typerDeduper = new NoopTyperDeduper();
}

return new StagingConsumerFactory().create(
outputRecordCollector,
getDatabase(getDataSource(config)),
database,
new SnowflakeGcsStagingSqlOperations(getNamingResolver(), gcsConfig),
getNamingResolver(),
CsvSerializedBuffer.createFunction(null, () -> new FileBuffer(CsvSerializedBuffer.CSV_GZ_SUFFIX, getNumberOfFileBuffers(config))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.airbyte.commons.json.Jsons;
import io.airbyte.db.factory.DataSourceFactory;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.db.jdbc.JdbcUtils;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.Destination;
import io.airbyte.integrations.base.SerializedAirbyteMessageConsumer;
Expand Down Expand Up @@ -126,17 +127,20 @@ public AirbyteMessageConsumer getConsumer(final JsonNode config,
SnowflakeSqlGenerator sqlGenerator = new SnowflakeSqlGenerator();
final ParsedCatalog parsedCatalog;
TyperDeduper typerDeduper;
JdbcDatabase database = getDatabase(getDataSource(config));
if (TypingAndDedupingFlag.isDestinationV2()) {
String databaseName = config.get(JdbcUtils.DATABASE_KEY).asText();
SnowflakeDestinationHandler snowflakeDestinationHandler = new SnowflakeDestinationHandler(databaseName, database);
parsedCatalog = new CatalogParser(sqlGenerator).parseCatalog(catalog);
typerDeduper = new DefaultTyperDeduper<>(sqlGenerator, new SnowflakeDestinationHandler(getDatabase(getDataSource(config))), parsedCatalog);
typerDeduper = new DefaultTyperDeduper<>(sqlGenerator, snowflakeDestinationHandler, parsedCatalog);
} else {
parsedCatalog = null;
typerDeduper = new NoopTyperDeduper();
}

return new StagingConsumerFactory().create(
outputRecordCollector,
getDatabase(getDataSource(config)),
database,
new SnowflakeInternalStagingSqlOperations(getNamingResolver()),
getNamingResolver(),
CsvSerializedBuffer.createFunction(null, () -> new FileBuffer(CsvSerializedBuffer.CSV_GZ_SUFFIX, getNumberOfFileBuffers(config))),
Expand All @@ -155,17 +159,20 @@ public SerializedAirbyteMessageConsumer getSerializedMessageConsumer(final JsonN
SnowflakeSqlGenerator sqlGenerator = new SnowflakeSqlGenerator();
final ParsedCatalog parsedCatalog;
TyperDeduper typerDeduper;
JdbcDatabase database = getDatabase(getDataSource(config));
if (TypingAndDedupingFlag.isDestinationV2()) {
String databaseName = config.get(JdbcUtils.DATABASE_KEY).asText();
SnowflakeDestinationHandler snowflakeDestinationHandler = new SnowflakeDestinationHandler(databaseName, database);
parsedCatalog = new CatalogParser(sqlGenerator).parseCatalog(catalog);
typerDeduper = new DefaultTyperDeduper<>(sqlGenerator, new SnowflakeDestinationHandler(getDatabase(getDataSource(config))), parsedCatalog);
typerDeduper = new DefaultTyperDeduper<>(sqlGenerator, snowflakeDestinationHandler, parsedCatalog);
} else {
parsedCatalog = null;
typerDeduper = new NoopTyperDeduper();
}

return new StagingConsumerFactory().createAsync(
outputRecordCollector,
getDatabase(getDataSource(config)),
database,
new SnowflakeInternalStagingSqlOperations(getNamingResolver()),
getNamingResolver(),
config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.airbyte.commons.json.Jsons;
import io.airbyte.db.factory.DataSourceFactory;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.db.jdbc.JdbcUtils;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.Destination;
import io.airbyte.integrations.base.TypingAndDedupingFlag;
Expand Down Expand Up @@ -142,17 +143,20 @@ public AirbyteMessageConsumer getConsumer(final JsonNode config,
SnowflakeSqlGenerator sqlGenerator = new SnowflakeSqlGenerator();
final ParsedCatalog parsedCatalog;
TyperDeduper typerDeduper;
JdbcDatabase database = getDatabase(getDataSource(config));
if (TypingAndDedupingFlag.isDestinationV2()) {
String databaseName = config.get(JdbcUtils.DATABASE_KEY).asText();
SnowflakeDestinationHandler snowflakeDestinationHandler = new SnowflakeDestinationHandler(databaseName, database);
parsedCatalog = new CatalogParser(sqlGenerator).parseCatalog(catalog);
typerDeduper = new DefaultTyperDeduper<>(sqlGenerator, new SnowflakeDestinationHandler(getDatabase(getDataSource(config))), parsedCatalog);
typerDeduper = new DefaultTyperDeduper<>(sqlGenerator, snowflakeDestinationHandler, parsedCatalog);
} else {
parsedCatalog = null;
typerDeduper = new NoopTyperDeduper();
}

return new StagingConsumerFactory().create(
outputRecordCollector,
getDatabase(getDataSource(config)),
database,
new SnowflakeS3StagingSqlOperations(getNamingResolver(), s3Config.getS3Client(), s3Config, encryptionConfig),
getNamingResolver(),
CsvSerializedBuffer.createFunction(null, () -> new FileBuffer(CsvSerializedBuffer.CSV_GZ_SUFFIX, getNumberOfFileBuffers(config))),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.airbyte.integrations.destination.snowflake.typing_deduping;

/**
* type is notably _not_ a {@link net.snowflake.client.jdbc.SnowflakeType}. That
* enum doesn't contain all the types that snowflake supports (specifically NUMBER).
*/
public record SnowflakeColumn(String name, String type) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,80 @@
import io.airbyte.integrations.base.destination.typing_deduping.DestinationHandler;
import io.airbyte.integrations.base.destination.typing_deduping.StreamId;
import java.sql.SQLException;
import java.util.LinkedHashMap;
import java.util.Optional;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SnowflakeDestinationHandler implements DestinationHandler<SnowflakeTableDefinition> {

private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeDestinationHandler.class);

private final String databaseName;
private final JdbcDatabase database;

public SnowflakeDestinationHandler(JdbcDatabase database) {
public SnowflakeDestinationHandler(String databaseName, JdbcDatabase database) {
this.databaseName = databaseName;
this.database = database;
}

@Override
public Optional<SnowflakeTableDefinition> findExistingTable(StreamId id) throws SQLException {
// TODO only fetch metadata once
database.getMetaData();
return Optional.empty();
// The obvious database.getMetaData().getColumns() solution doesn't work, because JDBC translates VARIANT as VARCHAR
LinkedHashMap<String, String> columns = database.queryJsons(
"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_catalog = ?
AND table_schema = ?
AND table_name = ?
ORDER BY ordinal_position;
""",
databaseName,
id.finalNamespace(),
id.finalName()
).stream()
.collect(LinkedHashMap::new,
(map, row) -> map.put(row.get("COLUMN_NAME").asText(), row.get("DATA_TYPE").asText()),
LinkedHashMap::putAll);
// TODO query for indexes/partitioning/etc

if (columns.isEmpty()) {
return Optional.empty();
} else {
return Optional.of(new SnowflakeTableDefinition(columns));
}
}

@Override
public boolean isFinalTableEmpty(StreamId id) {
return false;
public boolean isFinalTableEmpty(StreamId id) throws SQLException {
int rowCount = database.queryInt(
"""
SELECT row_count
FROM information_schema.tables
WHERE table_catalog = ?
AND table_schema = ?
AND table_name = ?
""",
databaseName,
id.finalNamespace(),
id.finalName());
return rowCount == 0;
}

@Override
public void execute(String sql) throws Exception {
if ("".equals(sql)) {
return;
}
final UUID queryId = UUID.randomUUID();
LOGGER.info("Executing sql {}: {}", queryId, sql);
long startTime = System.currentTimeMillis();

database.execute(sql);

LOGGER.info("Sql {} completed in {} ms", queryId, System.currentTimeMillis() - startTime);
}

}
Loading

0 comments on commit 2866ed6

Please sign in to comment.