From a1b912e9ed54bf257246203eccf2789d5f6ad269 Mon Sep 17 00:00:00 2001 From: Francesco Casula Date: Wed, 11 Oct 2023 17:05:01 +0200 Subject: [PATCH] feat: use append vs merge option from backend config --- go.mod | 2 +- go.sum | 4 +- testhelper/clone.go | 20 + warehouse/integrations/bigquery/bigquery.go | 113 +-- .../integrations/bigquery/bigquery_test.go | 525 +++++++------- .../bigquery/testdata/template.json | 8 +- warehouse/integrations/deltalake/deltalake.go | 30 +- .../integrations/deltalake/deltalake_test.go | 650 +++++++++--------- .../deltalake/testdata/template.json | 3 +- warehouse/integrations/mssql/mssql.go | 10 +- warehouse/integrations/postgres/load.go | 133 ++-- warehouse/integrations/postgres/load_test.go | 1 + warehouse/integrations/postgres/postgres.go | 2 + .../integrations/postgres/postgres_test.go | 133 ++-- warehouse/integrations/redshift/redshift.go | 313 +++------ .../integrations/redshift/redshift_test.go | 239 ++++--- warehouse/integrations/snowflake/snowflake.go | 66 +- .../integrations/snowflake/snowflake_test.go | 197 +++--- .../snowflake/testdata/template.json | 12 +- warehouse/internal/model/warehouse.go | 17 +- warehouse/logfield/logfield.go | 1 + warehouse/router/router.go | 2 +- warehouse/router/upload.go | 4 + warehouse/utils/querytype.go | 1 + warehouse/utils/querytype_test.go | 1 + warehouse/utils/utils_test.go | 8 +- 26 files changed, 1220 insertions(+), 1275 deletions(-) create mode 100644 testhelper/clone.go diff --git a/go.mod b/go.mod index 46d7c1babc..be1182facb 100644 --- a/go.mod +++ b/go.mod @@ -81,7 +81,7 @@ require ( github.com/rs/cors v1.10.1 github.com/rudderlabs/analytics-go v3.3.3+incompatible github.com/rudderlabs/compose-test v0.1.3 - github.com/rudderlabs/rudder-go-kit v0.16.2 + github.com/rudderlabs/rudder-go-kit v0.16.3 github.com/rudderlabs/sql-tunnels v0.1.5 github.com/samber/lo v1.38.1 github.com/segmentio/kafka-go v0.4.42 diff --git a/go.sum b/go.sum index d3a13de991..f1fcffcc64 100644 --- a/go.sum +++ b/go.sum @@ -949,8 +949,8 @@ github.com/rudderlabs/compose-test v0.1.3 h1:uyep6jDCIF737sfv4zIaMsKRQKX95IDz5Xb github.com/rudderlabs/compose-test v0.1.3/go.mod h1:tuvS1eQdSfwOYv1qwyVAcpdJxPLQXJgy5xGDd/9XmMg= github.com/rudderlabs/parquet-go v0.0.2 h1:ZXRdZdimB0PdJtmxeSSxfI0fDQ3kZjwzBxRi6Ut1J8k= github.com/rudderlabs/parquet-go v0.0.2/go.mod h1:g6guum7o8uhj/uNhunnt7bw5Vabu/goI5i21/3fnxWQ= -github.com/rudderlabs/rudder-go-kit v0.16.2 h1:1zR0ivPT3Rp9bHmfq5k8VVfOy3bJrag2gHbjnqbUmtM= -github.com/rudderlabs/rudder-go-kit v0.16.2/go.mod h1:vRRTcYmAtYg87R4liGy24wO3452WlGHkFwtEopgme3k= +github.com/rudderlabs/rudder-go-kit v0.16.3 h1:IZIg7RjwbQN0GAHpiZgNLW388AwBmgVnh3bYPXP7SKQ= +github.com/rudderlabs/rudder-go-kit v0.16.3/go.mod h1:vRRTcYmAtYg87R4liGy24wO3452WlGHkFwtEopgme3k= github.com/rudderlabs/sql-tunnels v0.1.5 h1:L/e9GQtqJlTVMauAE+ym/XUqhg+Va6RZQiOvBgbhspY= github.com/rudderlabs/sql-tunnels v0.1.5/go.mod h1:ZwQkCLb/5hHm5U90juAj9idkkFGv2R2dzDHJoPbKIto= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= diff --git a/testhelper/clone.go b/testhelper/clone.go new file mode 100644 index 0000000000..8c37b67529 --- /dev/null +++ b/testhelper/clone.go @@ -0,0 +1,20 @@ +package testhelper + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Clone[T any](t testing.TB, v T) T { + t.Helper() + + buf, err := json.Marshal(v) + require.NoError(t, err) + + var clone T + require.NoError(t, json.Unmarshal(buf, &clone)) + + return clone +} diff --git a/warehouse/integrations/bigquery/bigquery.go b/warehouse/integrations/bigquery/bigquery.go index 018ef7cf8b..d252f17bec 100644 --- a/warehouse/integrations/bigquery/bigquery.go +++ b/warehouse/integrations/bigquery/bigquery.go @@ -10,11 +10,8 @@ import ( "strings" "time" - "github.com/rudderlabs/rudder-server/warehouse/integrations/types" - - "github.com/samber/lo" - "cloud.google.com/go/bigquery" + "github.com/samber/lo" bqService "google.golang.org/api/bigquery/v2" "google.golang.org/api/googleapi" "google.golang.org/api/iterator" @@ -26,6 +23,7 @@ import ( "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/client" "github.com/rudderlabs/rudder-server/warehouse/integrations/bigquery/middleware" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/logfield" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -43,8 +41,7 @@ type BigQuery struct { config struct { setUsersLoadPartitionFirstEventFilter bool customPartitionsEnabled bool - isUsersTableDedupEnabled bool - isDedupEnabled bool + allowMerge bool enableDeleteByJobs bool customPartitionsEnabledWorkspaceIDs []string slowQueryThreshold time.Duration @@ -140,8 +137,7 @@ func New(conf *config.Config, log logger.Logger) *BigQuery { bq.config.setUsersLoadPartitionFirstEventFilter = conf.GetBool("Warehouse.bigquery.setUsersLoadPartitionFirstEventFilter", true) bq.config.customPartitionsEnabled = conf.GetBool("Warehouse.bigquery.customPartitionsEnabled", false) - bq.config.isUsersTableDedupEnabled = conf.GetBool("Warehouse.bigquery.isUsersTableDedupEnabled", false) - bq.config.isDedupEnabled = conf.GetBool("Warehouse.bigquery.isDedupEnabled", false) + bq.config.allowMerge = conf.GetBool("Warehouse.bigquery.allowMerge", true) bq.config.enableDeleteByJobs = conf.GetBool("Warehouse.bigquery.enableDeleteByJobs", false) bq.config.customPartitionsEnabledWorkspaceIDs = conf.GetStringSlice("Warehouse.bigquery.customPartitionsEnabledWorkspaceIDs", nil) bq.config.slowQueryThreshold = conf.GetDuration("Warehouse.bigquery.slowQueryThreshold", 5, time.Minute) @@ -163,6 +159,7 @@ func (bq *BigQuery) getMiddleware() *middleware.Client { logfield.DestinationType, bq.warehouse.Destination.DestinationDefinition.Name, logfield.WorkspaceID, bq.warehouse.WorkspaceID, logfield.Schema, bq.namespace, + logfield.ShouldMerge, bq.shouldMerge(), ), middleware.WithSlowQueryThreshold(bq.config.slowQueryThreshold), ) @@ -193,23 +190,21 @@ func (bq *BigQuery) CreateTable(ctx context.Context, tableName string, columnMap return fmt.Errorf("create table: %w", err) } - if !bq.dedupEnabled() { - if err = bq.createTableView(ctx, tableName, columnMap); err != nil { - return fmt.Errorf("create view: %w", err) - } + if err = bq.createTableView(ctx, tableName, columnMap); err != nil { + return fmt.Errorf("create view: %w", err) } + return nil } -func (bq *BigQuery) DropTable(ctx context.Context, tableName string) (err error) { - err = bq.DeleteTable(ctx, tableName) - if err != nil { - return +func (bq *BigQuery) DropTable(ctx context.Context, tableName string) error { + if err := bq.DeleteTable(ctx, tableName); err != nil { + return fmt.Errorf("cannot delete table %q: %w", tableName, err) } - if !bq.dedupEnabled() { - err = bq.DeleteTable(ctx, tableName+"_view") + if err := bq.DeleteTable(ctx, tableName+"_view"); err != nil { + return fmt.Errorf("cannot delete table %q: %w", tableName+"_view", err) } - return + return nil } func (bq *BigQuery) createTableView(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { @@ -225,9 +220,16 @@ func (bq *BigQuery) createTableView(ctx context.Context, tableName string, colum // assuming it has field named id upon which dedup is done in view viewQuery := `SELECT * EXCEPT (__row_number) FROM ( - SELECT *, ROW_NUMBER() OVER (PARTITION BY ` + partitionKey + viewOrderByStmt + `) AS __row_number FROM ` + "`" + bq.projectID + "." + bq.namespace + "." + tableName + "`" + ` WHERE _PARTITIONTIME BETWEEN TIMESTAMP_TRUNC(TIMESTAMP_MICROS(UNIX_MICROS(CURRENT_TIMESTAMP()) - 60 * 60 * 60 * 24 * 1000000), DAY, 'UTC') - AND TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, 'UTC') - ) + SELECT *, ROW_NUMBER() OVER (PARTITION BY ` + partitionKey + viewOrderByStmt + `) AS __row_number + FROM ` + "`" + bq.projectID + "." + bq.namespace + "." + tableName + "`" + ` + WHERE + _PARTITIONTIME BETWEEN TIMESTAMP_TRUNC( + TIMESTAMP_MICROS(UNIX_MICROS(CURRENT_TIMESTAMP()) - 60 * 60 * 60 * 24 * 1000000), + DAY, + 'UTC' + ) + AND TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, 'UTC') + ) WHERE __row_number = 1` metaData := &bigquery.TableMetadata{ ViewQuery: viewQuery, @@ -241,7 +243,8 @@ func (bq *BigQuery) schemaExists(ctx context.Context, _, _ string) (exists bool, ds := bq.db.Dataset(bq.namespace) _, err = ds.Metadata(ctx) if err != nil { - if e, ok := err.(*googleapi.Error); ok && e.Code == 404 { + var e *googleapi.Error + if errors.As(err, &e) && e.Code == 404 { bq.logger.Debugf("BQ: Dataset %s not found", bq.namespace) return false, nil } @@ -275,7 +278,8 @@ func (bq *BigQuery) CreateSchema(ctx context.Context) (err error) { bq.logger.Infof("BQ: Creating schema: %s ...", bq.namespace) err = ds.Create(ctx, meta) if err != nil { - if e, ok := err.(*googleapi.Error); ok && e.Code == 409 { + var e *googleapi.Error + if errors.As(err, &e) && e.Code == 409 { bq.logger.Infof("BQ: Create schema %s failed as schema already exists", bq.namespace) return nil } @@ -285,7 +289,8 @@ func (bq *BigQuery) CreateSchema(ctx context.Context) (err error) { func checkAndIgnoreAlreadyExistError(err error) bool { if err != nil { - if e, ok := err.(*googleapi.Error); ok { + var e *googleapi.Error + if errors.As(err, &e) { // 409 is returned when we try to create a table that already exists // 400 is returned for all kinds of invalid input - so we need to check the error message too if e.Code == 409 || (e.Code == 400 && strings.Contains(e.Message, "already exists in schema")) { @@ -384,14 +389,14 @@ func (bq *BigQuery) loadTable( gcsRef.MaxBadRecords = 0 gcsRef.IgnoreUnknownValues = false - if bq.dedupEnabled() { + if bq.shouldMerge() { return bq.loadTableByMerge(ctx, tableName, gcsRef, log, skipTempTableDelete) } return bq.loadTableByAppend(ctx, tableName, gcsRef, log) } func (bq *BigQuery) loadTableStrategy() string { - if bq.dedupEnabled() { + if bq.shouldMerge() { return "MERGE" } return "APPEND" @@ -599,8 +604,7 @@ func (bq *BigQuery) loadTableByMerge( SET %[6]s WHEN NOT MATCHED THEN INSERT (%[4]s) VALUES - (%[5]s); -`, + (%[5]s);`, bqTable(tableName), bqTable(stagingTableName), primaryJoinClause, @@ -646,7 +650,7 @@ func (bq *BigQuery) loadTableByMerge( func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]error) { errorMap = map[string]error{warehouseutils.IdentifiesTable: nil} - bq.logger.Infof("BQ: Starting load for identifies and users tables\n") + bq.logger.Infof("BQ: Starting load for identifies and users tables") _, identifyLoadTable, err := bq.loadTable(ctx, warehouseutils.IdentifiesTable, true) if err != nil { errorMap[warehouseutils.IdentifiesTable] = err @@ -704,10 +708,12 @@ func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]err bqIdentifiesTable := bqTable(warehouseutils.IdentifiesTable) partition := fmt.Sprintf("TIMESTAMP('%s')", identifyLoadTable.partitionDate) var identifiesFrom string - if bq.dedupEnabled() { - identifiesFrom = fmt.Sprintf(`%s WHERE user_id IS NOT NULL %s`, bqTable(identifyLoadTable.stagingTableName), loadedAtFilter()) + if bq.shouldMerge() { + identifiesFrom = fmt.Sprintf(`%s WHERE user_id IS NOT NULL %s`, + bqTable(identifyLoadTable.stagingTableName), loadedAtFilter()) } else { - identifiesFrom = fmt.Sprintf(`%s WHERE _PARTITIONTIME = %s AND user_id IS NOT NULL %s`, bqIdentifiesTable, partition, loadedAtFilter()) + identifiesFrom = fmt.Sprintf(`%s WHERE _PARTITIONTIME = %s AND user_id IS NOT NULL %s`, + bqIdentifiesTable, partition, loadedAtFilter()) } sqlStatement := fmt.Sprintf(`SELECT DISTINCT * FROM ( SELECT id, %[1]s FROM ( @@ -824,7 +830,7 @@ func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]err } } - if !bq.dedupEnabled() { + if !bq.shouldMerge() { loadUserTableByAppend() return } @@ -847,28 +853,24 @@ func Connect(context context.Context, cred *BQCredentials) (*bigquery.Client, er } opts = append(opts, option.WithCredentialsJSON(credBytes)) } - client, err := bigquery.NewClient(context, cred.ProjectID, opts...) - return client, err + c, err := bigquery.NewClient(context, cred.ProjectID, opts...) + return c, err } func (bq *BigQuery) connect(ctx context.Context, cred BQCredentials) (*bigquery.Client, error) { bq.logger.Infof("BQ: Connecting to BigQuery in project: %s", cred.ProjectID) - client, err := Connect(ctx, &cred) - return client, err + c, err := Connect(ctx, &cred) + return c, err } -func (bq *BigQuery) dedupEnabled() bool { - return bq.config.isDedupEnabled || bq.config.isUsersTableDedupEnabled +// shouldMerge returns true if: +// * the server config says we allow merging +// * the user opted in to merging +func (bq *BigQuery) shouldMerge() bool { + return bq.config.allowMerge && bq.warehouse.GetBoolDestinationConfig(model.EnableMergeSetting) } func (bq *BigQuery) CrashRecover(ctx context.Context) { - if !bq.dedupEnabled() { - return - } - bq.dropDanglingStagingTables(ctx) -} - -func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { sqlStatement := fmt.Sprintf(` SELECT table_name @@ -876,8 +878,7 @@ func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { %[1]s.INFORMATION_SCHEMA.TABLES WHERE table_schema = '%[1]s' - AND table_name LIKE '%[2]s'; - `, + AND table_name LIKE '%[2]s';`, bq.namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), ) @@ -885,7 +886,7 @@ func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { it, err := bq.getMiddleware().Read(ctx, query) if err != nil { bq.logger.Errorf("WH: BQ: Error dropping dangling staging tables in BQ: %v\nQuery: %s\n", err, sqlStatement) - return false + return } var stagingTableNames []string @@ -897,22 +898,20 @@ func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { break } bq.logger.Errorf("BQ: Error in processing fetched staging tables from information schema in dataset %v : %v", bq.namespace, err) - return false + return } if _, ok := values[0].(string); ok { stagingTableNames = append(stagingTableNames, values[0].(string)) } } bq.logger.Infof("WH: PG: Dropping dangling staging tables: %+v %+v\n", len(stagingTableNames), stagingTableNames) - delSuccess := true for _, stagingTableName := range stagingTableNames { err := bq.DeleteTable(ctx, stagingTableName) if err != nil { bq.logger.Errorf("WH: BQ: Error dropping dangling staging table: %s in BQ: %v", stagingTableName, err) - delSuccess = false + return } } - return delSuccess } func (bq *BigQuery) IsEmpty( @@ -1040,7 +1039,8 @@ func (bq *BigQuery) FetchSchema(ctx context.Context) (model.Schema, model.Schema it, err := bq.getMiddleware().Read(ctx, query) if err != nil { - if e, ok := err.(*googleapi.Error); ok && e.Code == 404 { + var e *googleapi.Error + if errors.As(err, &e) && e.Code == 404 { // if dataset resource is not found, return empty schema return schema, unrecognizedSchema, nil } @@ -1109,7 +1109,8 @@ func (bq *BigQuery) tableExists(ctx context.Context, tableName string) (exists b if err == nil { return true, nil } - if e, ok := err.(*googleapi.Error); ok { + var e *googleapi.Error + if errors.As(err, &e) { if e.Code == 404 { return false, nil } diff --git a/warehouse/integrations/bigquery/bigquery_test.go b/warehouse/integrations/bigquery/bigquery_test.go index c484c16a85..9508076d63 100644 --- a/warehouse/integrations/bigquery/bigquery_test.go +++ b/warehouse/integrations/bigquery/bigquery_test.go @@ -11,32 +11,30 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" - - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/filemanager" - "github.com/rudderlabs/rudder-go-kit/logger" - mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "cloud.google.com/go/bigquery" - "google.golang.org/api/option" - + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/api/option" "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/compose-test/testcompose" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" + th "github.com/rudderlabs/rudder-server/testhelper" "github.com/rudderlabs/rudder-server/testhelper/health" "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/client" whbigquery "github.com/rudderlabs/rudder-server/warehouse/integrations/bigquery" bqHelper "github.com/rudderlabs/rudder-server/warehouse/integrations/bigquery/testhelper" - "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + whth "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/rudderlabs/rudder-server/warehouse/validations" ) @@ -68,11 +66,9 @@ func TestIntegration(t *testing.T) { sourcesSourceID := warehouseutils.RandHex() sourcesDestinationID := warehouseutils.RandHex() sourcesWriteKey := warehouseutils.RandHex() - destType := warehouseutils.BQ - - namespace := testhelper.RandSchema(destType) - sourcesNamespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) + sourcesNamespace := whth.RandSchema(destType) bqTestCredentials, err := bqHelper.GetBQTestCredentials() require.NoError(t, err) @@ -82,71 +78,63 @@ func TestIntegration(t *testing.T) { escapedCredentialsTrimmedStr := strings.Trim(string(escapedCredentials), `"`) - templateConfigurations := map[string]any{ - "workspaceID": workspaceID, - "sourceID": sourceID, - "destinationID": destinationID, - "writeKey": writeKey, - "sourcesSourceID": sourcesSourceID, - "sourcesDestinationID": sourcesDestinationID, - "sourcesWriteKey": sourcesWriteKey, - "namespace": namespace, - "project": bqTestCredentials.ProjectID, - "location": bqTestCredentials.Location, - "bucketName": bqTestCredentials.BucketName, - "credentials": escapedCredentialsTrimmedStr, - "sourcesNamespace": sourcesNamespace, - } - workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - - testhelper.EnhanceWithDefaultEnvs(t) - t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_MAX_PARALLEL_LOADS", "8") - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_ENABLE_DELETE_BY_JOBS", "true") - t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) - t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_SLOW_QUERY_THRESHOLD", "0s") - - svcDone := make(chan struct{}) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - r := runner.New(runner.ReleaseInfo{}) - _ = r.Run(ctx, []string{"bigquery-integration-test"}) - - close(svcDone) - }() - t.Cleanup(func() { <-svcDone }) + bootstrapSvc := func(t *testing.T, enableMerge bool) *bigquery.Client { + templateConfigurations := map[string]any{ + "workspaceID": workspaceID, + "sourceID": sourceID, + "destinationID": destinationID, + "writeKey": writeKey, + "sourcesSourceID": sourcesSourceID, + "sourcesDestinationID": sourcesDestinationID, + "sourcesWriteKey": sourcesWriteKey, + "namespace": namespace, + "project": bqTestCredentials.ProjectID, + "location": bqTestCredentials.Location, + "bucketName": bqTestCredentials.BucketName, + "credentials": escapedCredentialsTrimmedStr, + "sourcesNamespace": sourcesNamespace, + "enableMerge": enableMerge, + } + workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) + + whth.EnhanceWithDefaultEnvs(t) + t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) + t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) + t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_ENABLE_DELETE_BY_JOBS", "true") + t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) + t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) + t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_SLOW_QUERY_THRESHOLD", "0s") + + svcDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + r := runner.New(runner.ReleaseInfo{}) + _ = r.Run(ctx, []string{"bigquery-integration-test"}) + close(svcDone) + }() + + t.Cleanup(func() { <-svcDone }) + t.Cleanup(cancel) + + serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) + health.WaitUntilReady(ctx, t, + serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint", + ) - serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) - health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint") + db, err := bigquery.NewClient(ctx, + bqTestCredentials.ProjectID, + option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) - db, err := bigquery.NewClient( - ctx, - bqTestCredentials.ProjectID, option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), - ) - require.NoError(t, err) + return db + } t.Run("Event flow", func(t *testing.T) { - jobsDB := testhelper.JobsDB(t, jobsDBPort) - - t.Cleanup(func() { - for _, dataset := range []string{namespace, sourcesNamespace} { - require.Eventually(t, func() bool { - if err := db.Dataset(dataset).DeleteWithContents(ctx); err != nil { - t.Logf("error deleting dataset: %v", err) - return false - } - return true - }, - time.Minute, - time.Second, - ) - } - }) + jobsDB := whth.JobsDB(t, jobsDBPort) testcase := []struct { name string @@ -155,15 +143,15 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string tables []string - stagingFilesEventsMap testhelper.EventsCountMap - stagingFilesModifiedEventsMap testhelper.EventsCountMap - loadFilesEventsMap testhelper.EventsCountMap - tableUploadsEventsMap testhelper.EventsCountMap - warehouseEventsMap testhelper.EventsCountMap + stagingFilesEventsMap whth.EventsCountMap + stagingFilesModifiedEventsMap whth.EventsCountMap + loadFilesEventsMap whth.EventsCountMap + tableUploadsEventsMap whth.EventsCountMap + warehouseEventsMap whth.EventsCountMap asyncJob bool skipModifiedEvents bool - prerequisite func(t testing.TB) - isDedupEnabled bool + prerequisite func(context.Context, testing.TB, *bigquery.Client) + enableMerge bool customPartitionsEnabledWorkspaceIDs string stagingFilePrefix string }{ @@ -181,10 +169,9 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap: loadFilesEventsMap(), tableUploadsEventsMap: tableUploadsEventsMap(), warehouseEventsMap: mergeEventsMap(), - isDedupEnabled: true, - prerequisite: func(t testing.TB) { + enableMerge: true, + prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/upload-job-merge-mode", @@ -196,20 +183,19 @@ func TestIntegration(t *testing.T) { destinationID: sourcesDestinationID, schema: sourcesNamespace, tables: []string{"tracks", "google_sheet"}, - stagingFilesEventsMap: testhelper.EventsCountMap{ + stagingFilesEventsMap: whth.EventsCountMap{ "wh_staging_files": 9, // 8 + 1 (merge events because of ID resolution) }, - stagingFilesModifiedEventsMap: testhelper.EventsCountMap{ + stagingFilesModifiedEventsMap: whth.EventsCountMap{ "wh_staging_files": 8, // 8 (de-duped by encounteredMergeRuleMap) }, - loadFilesEventsMap: testhelper.SourcesLoadFilesEventsMap(), - tableUploadsEventsMap: testhelper.SourcesTableUploadsEventsMap(), - warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), + loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), + tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), + warehouseEventsMap: whth.SourcesWarehouseEventsMap(), asyncJob: true, - isDedupEnabled: false, - prerequisite: func(t testing.TB) { + enableMerge: false, + prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/sources-job", @@ -229,10 +215,9 @@ func TestIntegration(t *testing.T) { tableUploadsEventsMap: tableUploadsEventsMap(), warehouseEventsMap: appendEventsMap(), skipModifiedEvents: true, - isDedupEnabled: false, - prerequisite: func(t testing.TB) { + enableMerge: false, + prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/upload-job-append-mode", @@ -252,9 +237,9 @@ func TestIntegration(t *testing.T) { tableUploadsEventsMap: tableUploadsEventsMap(), warehouseEventsMap: appendEventsMap(), skipModifiedEvents: true, - isDedupEnabled: false, + enableMerge: false, customPartitionsEnabledWorkspaceIDs: workspaceID, - prerequisite: func(t testing.TB) { + prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() _ = db.Dataset(namespace).DeleteWithContents(ctx) @@ -284,13 +269,33 @@ func TestIntegration(t *testing.T) { for _, tc := range testcase { tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_IS_DEDUP_ENABLED", strconv.FormatBool(tc.isDedupEnabled)) - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_CUSTOM_PARTITIONS_ENABLED_WORKSPACE_IDS", tc.customPartitionsEnabledWorkspaceIDs) + t.Setenv( + "RSERVER_WAREHOUSE_BIGQUERY_CUSTOM_PARTITIONS_ENABLED_WORKSPACE_IDS", + tc.customPartitionsEnabledWorkspaceIDs, + ) + db := bootstrapSvc(t, tc.enableMerge) + + t.Cleanup(func() { + for _, dataset := range []string{tc.schema} { + t.Logf("Cleaning up dataset %s.%s", tc.schema, dataset) + require.Eventually(t, + func() bool { + err := db.Dataset(dataset).DeleteWithContents(context.Background()) + if err != nil { + t.Logf("Error deleting dataset %s.%s: %v", tc.schema, dataset, err) + return false + } + return true + }, + time.Minute, + time.Second, + ) + } + }) if tc.prerequisite != nil { - tc.prerequisite(t) + tc.prerequisite(context.Background(), t, db) } sqlClient := &client.Client{ @@ -304,7 +309,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -323,7 +328,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) @@ -332,7 +337,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -352,7 +357,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } if tc.asyncJob { ts2.UserID = ts1.UserID @@ -363,14 +368,22 @@ func TestIntegration(t *testing.T) { }) t.Run("Validations", func(t *testing.T) { + ctx := context.Background() + db, err := bigquery.NewClient(ctx, + bqTestCredentials.ProjectID, + option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) t.Cleanup(func() { - require.Eventually(t, func() bool { - if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { - t.Logf("error deleting dataset: %v", err) - return false - } - return true - }, + require.Eventually(t, + func() bool { + if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { + t.Logf("error deleting dataset: %v", err) + return false + } + return true + }, time.Minute, time.Second, ) @@ -396,7 +409,7 @@ func TestIntegration(t *testing.T) { Enabled: true, RevisionID: destinationID, } - testhelper.VerifyConfigurationTest(t, dest) + whth.VerifyConfigurationTest(t, dest) }) t.Run("Load Table", func(t *testing.T) { @@ -406,8 +419,15 @@ func TestIntegration(t *testing.T) { workspaceID = "test_workspace_id" ) - namespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) + ctx := context.Background() + db, err := bigquery.NewClient(ctx, + bqTestCredentials.ProjectID, + option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) t.Cleanup(func() { require.Eventually(t, func() bool { if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { @@ -480,10 +500,10 @@ func TestIntegration(t *testing.T) { }) require.NoError(t, err) - t.Run("schema does not exists", func(t *testing.T) { + t.Run("schema does not exist", func(t *testing.T) { tableName := "schema_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -496,10 +516,10 @@ func TestIntegration(t *testing.T) { require.Error(t, err) require.Nil(t, loadTableStat) }) - t.Run("table does not exists", func(t *testing.T) { + t.Run("table does not exist", func(t *testing.T) { tableName := "table_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -515,40 +535,39 @@ func TestIntegration(t *testing.T) { require.Error(t, err) require.Nil(t, loadTableStat) }) - t.Run("merge", func(t *testing.T) { - tableName := "merge_test_table" - - c := config.New() - c.Set("Warehouse.bigquery.isDedupEnabled", true) + t.Run("merge with dedup", func(t *testing.T) { + tableName := "merge_with_dedup_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) - t.Run("without dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} - mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + dedupWarehouse := th.Clone(t, warehouse) + dedupWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true - bq := whbigquery.New(c, logger.NOP) - err := bq.Setup(ctx, warehouse, mockUploader) - require.NoError(t, err) + c := config.New() + bq := whbigquery.New(c, logger.NOP) + err := bq.Setup(ctx, dedupWarehouse, mockUploader) + require.NoError(t, err) - err = bq.CreateSchema(ctx) - require.NoError(t, err) + err = bq.CreateSchema(ctx) + require.NoError(t, err) - err = bq.CreateTable(ctx, tableName, schemaInWarehouse) - require.NoError(t, err) + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) - loadTableStat, err := bq.LoadTable(ctx, tableName) - require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(14)) - require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - loadTableStat, err = bq.LoadTable(ctx, tableName) - require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + loadTableStat, err = bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := bqHelper.RetrieveRecordsFromWarehouse(t, db, - fmt.Sprintf(` + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` SELECT id, received_at, @@ -557,62 +576,64 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s - ORDER BY - id; - `, - fmt.Sprintf("`%s`.`%s`", namespace, tableName), - ), - ) - require.Equal(t, records, testhelper.SampleTestRecords()) - }) - t.Run("with dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.json.gz", tableName) + FROM %s + ORDER BY id;`, + fmt.Sprintf("`%s`.`%s`", namespace, tableName), + ), + ) + require.Equal(t, records, whth.SampleTestRecords()) + }) + t.Run("merge without dedup", func(t *testing.T) { + tableName := "merge_without_dedup_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.json.gz", tableName) - loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} - mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - bq := whbigquery.New(c, logger.NOP) - err := bq.Setup(ctx, warehouse, mockUploader) - require.NoError(t, err) + c := config.New() + bq := whbigquery.New(c, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) - err = bq.CreateSchema(ctx) - require.NoError(t, err) + err = bq.CreateSchema(ctx) + require.NoError(t, err) - err = bq.CreateTable(ctx, tableName, schemaInWarehouse) - require.NoError(t, err) + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) - loadTableStat, err := bq.LoadTable(ctx, tableName) - require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := bqHelper.RetrieveRecordsFromWarehouse(t, db, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s - ORDER BY - id; - `, - fmt.Sprintf("`%s`.`%s`", namespace, tableName), - ), - ) - require.Equal(t, records, testhelper.DedupTestRecords()) - }) + retrieveRecordsSQL := fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s + ORDER BY id;`, + fmt.Sprintf("`%s`.`%s`", namespace, tableName), + ) + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, retrieveRecordsSQL) + require.Equal(t, records, whth.DedupTestRecords()) + + loadTableStat, err = bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records = bqHelper.RetrieveRecordsFromWarehouse(t, db, retrieveRecordsSQL) + require.Len(t, records, 28) }) t.Run("append", func(t *testing.T) { tableName := "append_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -647,20 +668,16 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - WHERE - _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') - ORDER BY - id; - `, + FROM %s.%s + WHERE _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') + ORDER BY id;`, namespace, tableName, time.Now().Add(-24*time.Hour).Format("2006-01-02"), time.Now().Add(+24*time.Hour).Format("2006-01-02"), ), ) - require.Equal(t, records, testhelper.AppendTestRecords()) + require.Equal(t, records, whth.AppendTestRecords()) }) t.Run("load file does not exists", func(t *testing.T) { tableName := "load_file_not_exists_test_table" @@ -687,7 +704,7 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in number of columns", func(t *testing.T) { tableName := "mismatch_columns_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-columns.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -709,7 +726,7 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in schema", func(t *testing.T) { tableName := "mismatch_schema_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-schema.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -731,7 +748,7 @@ func TestIntegration(t *testing.T) { t.Run("discards", func(t *testing.T) { tableName := warehouseutils.DiscardsTable - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/discards.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) @@ -752,27 +769,25 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsUpdated, int64(0)) records := bqHelper.RetrieveRecordsFromWarehouse(t, db, - fmt.Sprintf(` - SELECT - column_name, - column_value, - received_at, - row_id, - table_name, - uuid_ts - FROM - %s - ORDER BY row_id ASC; - `, + fmt.Sprintf( + `SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM %s + ORDER BY row_id ASC;`, fmt.Sprintf("`%s`.`%s`", namespace, tableName), ), ) - require.Equal(t, records, testhelper.DiscardTestRecords()) + require.Equal(t, records, whth.DiscardTestRecords()) }) t.Run("custom partition", func(t *testing.T) { tableName := "partition_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader( @@ -800,50 +815,52 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsUpdated, int64(0)) records := bqHelper.RetrieveRecordsFromWarehouse(t, db, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - WHERE - _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') - ORDER BY - id; - `, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + WHERE _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') + ORDER BY id;`, namespace, tableName, time.Now().Add(-24*time.Hour).Format("2006-01-02"), time.Now().Add(+24*time.Hour).Format("2006-01-02"), ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) }) t.Run("IsEmpty", func(t *testing.T) { - namespace := testhelper.RandSchema(warehouseutils.BQ) + ctx := context.Background() + db, err := bigquery.NewClient(ctx, + bqTestCredentials.ProjectID, + option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + namespace := whth.RandSchema(warehouseutils.BQ) t.Cleanup(func() { - require.Eventually(t, func() bool { - if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { - t.Logf("error deleting dataset: %v", err) - return false - } - return true - }, + require.Eventually(t, + func() bool { + if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { + t.Logf("error deleting dataset: %v", err) + return false + } + return true + }, time.Minute, time.Second, ) }) - ctx := context.Background() - credentials, err := bqHelper.GetBQTestCredentials() require.NoError(t, err) @@ -969,8 +986,8 @@ func newMockUploader( return mockUploader } -func loadFilesEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func loadFilesEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 4, "users": 4, "tracks": 4, @@ -983,8 +1000,8 @@ func loadFilesEventsMap() testhelper.EventsCountMap { } } -func tableUploadsEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func tableUploadsEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 4, "users": 4, "tracks": 4, @@ -997,14 +1014,14 @@ func tableUploadsEventsMap() testhelper.EventsCountMap { } } -func stagingFilesEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func stagingFilesEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "wh_staging_files": 34, // Since extra 2 merge events because of ID resolution } } -func mergeEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func mergeEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 1, "users": 1, "tracks": 1, @@ -1017,8 +1034,8 @@ func mergeEventsMap() testhelper.EventsCountMap { } } -func appendEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func appendEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 4, "users": 1, "tracks": 4, diff --git a/warehouse/integrations/bigquery/testdata/template.json b/warehouse/integrations/bigquery/testdata/template.json index d865510ba7..c6eac2a526 100644 --- a/warehouse/integrations/bigquery/testdata/template.json +++ b/warehouse/integrations/bigquery/testdata/template.json @@ -31,7 +31,8 @@ "credentials": "{{.credentials}}", "prefix": "", "namespace": "{{.namespace}}", - "syncFrequency": "30" + "syncFrequency": "30", + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, @@ -98,7 +99,7 @@ "id": "1dCzCUAtpWDzNxgGUYzq9sZdZZB", "name": "HTTP", "displayName": "HTTP", - "category": "", + "category": "singer-protocol", "createdAt": "2020-06-12T06:35:35.962Z", "updatedAt": "2020-06-12T06:35:35.962Z" }, @@ -159,7 +160,8 @@ "credentials": "{{.credentials}}", "prefix": "", "namespace": "{{.sourcesNamespace}}", - "syncFrequency": "30" + "syncFrequency": "30", + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, diff --git a/warehouse/integrations/deltalake/deltalake.go b/warehouse/integrations/deltalake/deltalake.go index ccee4c253d..eb6b02e122 100644 --- a/warehouse/integrations/deltalake/deltalake.go +++ b/warehouse/integrations/deltalake/deltalake.go @@ -11,8 +11,6 @@ import ( "strings" "time" - "github.com/rudderlabs/rudder-server/warehouse/integrations/types" - dbsql "github.com/databricks/databricks-sql-go" dbsqllog "github.com/databricks/databricks-sql-go/logger" @@ -22,6 +20,7 @@ import ( "github.com/rudderlabs/rudder-server/utils/misc" warehouseclient "github.com/rudderlabs/rudder-server/warehouse/client" sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/logfield" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -46,9 +45,6 @@ const ( partitionNotFound = "SHOW PARTITIONS is not allowed on a table that is not partitioned" columnsAlreadyExists = "already exists in root" - mergeMode = "MERGE" - appendMode = "APPEND" - rudderStagingTableRegex = "^rudder_staging_.*$" // matches rudder_staging_* tables nonRudderStagingTableRegex = "^(?!rudder_staging_.*$).*" // matches tables that do not start with rudder_staging_ ) @@ -125,7 +121,7 @@ type Deltalake struct { stats stats.Stats config struct { - loadTableStrategy string + allowMerge bool enablePartitionPruning bool slowQueryThreshold time.Duration maxRetries int @@ -141,7 +137,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Deltalake { dl.logger = log.Child("integration").Child("deltalake") dl.stats = stat - dl.config.loadTableStrategy = conf.GetString("Warehouse.deltalake.loadTableStrategy", mergeMode) + dl.config.allowMerge = conf.GetBool("Warehouse.deltalake.allowMerge", true) dl.config.enablePartitionPruning = conf.GetBool("Warehouse.deltalake.enablePartitionPruning", true) dl.config.slowQueryThreshold = conf.GetDuration("Warehouse.deltalake.slowQueryThreshold", 5, time.Minute) dl.config.maxRetries = conf.GetInt("Warehouse.deltalake.maxRetries", 10) @@ -244,6 +240,7 @@ func (d *Deltalake) dropDanglingStagingTables(ctx context.Context) { logfield.DestinationType, d.Warehouse.Destination.DestinationDefinition.Name, logfield.WorkspaceID, d.Warehouse.WorkspaceID, logfield.Namespace, d.Namespace, + logfield.ShouldMerge, d.ShouldMerge(), logfield.Error, err.Error(), ) return @@ -586,7 +583,7 @@ func (d *Deltalake) loadTable( logfield.WorkspaceID, d.Warehouse.WorkspaceID, logfield.Namespace, d.Namespace, logfield.TableName, tableName, - logfield.LoadTableStrategy, d.config.loadTableStrategy, + logfield.ShouldMerge, d.ShouldMerge(), ) log.Infow("started loading") @@ -617,7 +614,7 @@ func (d *Deltalake) loadTable( } var loadTableStat *types.LoadTableStats - if d.ShouldAppend() { + if !d.ShouldMerge() { log.Infow("inserting data from staging table to main table") loadTableStat, err = d.insertIntoLoadTable( ctx, tableName, stagingTableName, @@ -1110,7 +1107,7 @@ func (d *Deltalake) LoadUserTables(ctx context.Context) map[string]error { columnKeys := append([]string{`id`}, userColNames...) - if d.ShouldAppend() { + if !d.ShouldMerge() { query = fmt.Sprintf(` INSERT INTO %[1]s.%[2]s (%[4]s) SELECT @@ -1172,7 +1169,7 @@ func (d *Deltalake) LoadUserTables(ctx context.Context) map[string]error { inserted int64 ) - if d.ShouldAppend() { + if !d.ShouldMerge() { err = row.Scan(&affected, &inserted) } else { err = row.Scan(&affected, &updated, &deleted, &inserted) @@ -1385,9 +1382,10 @@ func (*Deltalake) DeleteBy(context.Context, []string, warehouseutils.DeleteByPar return fmt.Errorf(warehouseutils.NotImplementedErrorCode) } -// ShouldAppend returns true if: -// * the load table strategy is "append" mode -// * the uploader says we can append -func (d *Deltalake) ShouldAppend() bool { - return d.config.loadTableStrategy == appendMode && d.Uploader.CanAppend() +// ShouldMerge returns true if: +// * the uploader says we cannot append +// * the user opted in to merging and we allow merging +func (d *Deltalake) ShouldMerge() bool { + return !d.Uploader.CanAppend() || + (d.config.allowMerge && d.Warehouse.GetBoolDestinationConfig(model.EnableMergeSetting)) } diff --git a/warehouse/integrations/deltalake/deltalake_test.go b/warehouse/integrations/deltalake/deltalake_test.go index 09d3eb0510..2fc4ea20ab 100644 --- a/warehouse/integrations/deltalake/deltalake_test.go +++ b/warehouse/integrations/deltalake/deltalake_test.go @@ -13,11 +13,6 @@ import ( "testing" "time" - "github.com/rudderlabs/rudder-go-kit/stats/memstats" - - "github.com/rudderlabs/rudder-go-kit/filemanager" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - dbsql "github.com/databricks/databricks-sql-go" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -25,17 +20,21 @@ import ( "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/compose-test/testcompose" "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" + th "github.com/rudderlabs/rudder-server/testhelper" "github.com/rudderlabs/rudder-server/testhelper/health" "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" "github.com/rudderlabs/rudder-server/utils/misc" warehouseclient "github.com/rudderlabs/rudder-server/warehouse/client" "github.com/rudderlabs/rudder-server/warehouse/integrations/deltalake" - "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + whth "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/rudderlabs/rudder-server/warehouse/validations" ) @@ -95,54 +94,12 @@ func TestIntegration(t *testing.T) { sourceID := warehouseutils.RandHex() destinationID := warehouseutils.RandHex() writeKey := warehouseutils.RandHex() - destType := warehouseutils.DELTALAKE - - namespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) deltaLakeCredentials, err := deltaLakeTestCredentials() require.NoError(t, err) - templateConfigurations := map[string]any{ - "workspaceID": workspaceID, - "sourceID": sourceID, - "destinationID": destinationID, - "writeKey": writeKey, - "host": deltaLakeCredentials.Host, - "port": deltaLakeCredentials.Port, - "path": deltaLakeCredentials.Path, - "token": deltaLakeCredentials.Token, - "namespace": namespace, - "containerName": deltaLakeCredentials.ContainerName, - "accountName": deltaLakeCredentials.AccountName, - "accountKey": deltaLakeCredentials.AccountKey, - } - workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - - testhelper.EnhanceWithDefaultEnvs(t) - t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_MAX_PARALLEL_LOADS", "8") - t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) - t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) - t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_SLOW_QUERY_THRESHOLD", "0s") - - svcDone := make(chan struct{}) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - r := runner.New(runner.ReleaseInfo{}) - _ = r.Run(ctx, []string{"deltalake-integration-test"}) - - close(svcDone) - }() - t.Cleanup(func() { <-svcDone }) - - serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) - health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint") - port, err := strconv.Atoi(deltaLakeCredentials.Port) require.NoError(t, err) @@ -160,17 +117,62 @@ func TestIntegration(t *testing.T) { db := sql.OpenDB(connector) require.NoError(t, db.Ping()) + bootstrapSvc := func(t *testing.T, enableMerge bool) { + templateConfigurations := map[string]any{ + "workspaceID": workspaceID, + "sourceID": sourceID, + "destinationID": destinationID, + "writeKey": writeKey, + "host": deltaLakeCredentials.Host, + "port": deltaLakeCredentials.Port, + "path": deltaLakeCredentials.Path, + "token": deltaLakeCredentials.Token, + "namespace": namespace, + "containerName": deltaLakeCredentials.ContainerName, + "accountName": deltaLakeCredentials.AccountName, + "accountKey": deltaLakeCredentials.AccountKey, + "enableMerge": enableMerge, + } + workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) + + whth.EnhanceWithDefaultEnvs(t) + t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) + t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) + t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) + t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) + t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_SLOW_QUERY_THRESHOLD", "0s") + + svcDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + r := runner.New(runner.ReleaseInfo{}) + _ = r.Run(ctx, []string{"deltalake-integration-test"}) + close(svcDone) + }() + + t.Cleanup(func() { <-svcDone }) + t.Cleanup(cancel) + + serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) + health.WaitUntilReady(ctx, t, + serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint", + ) + } + t.Run("Event flow", func(t *testing.T) { - jobsDB := testhelper.JobsDB(t, jobsDBPort) + jobsDB := whth.JobsDB(t, jobsDBPort) t.Cleanup(func() { - require.Eventually(t, func() bool { - if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %[1]s CASCADE;`, namespace)); err != nil { - t.Logf("error deleting schema: %v", err) - return false - } - return true - }, + require.Eventually(t, + func() bool { + if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %s CASCADE;`, namespace)); err != nil { + t.Logf("error deleting schema %q: %v", namespace, err) + return false + } + return true + }, time.Minute, time.Second, ) @@ -183,8 +185,8 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string messageID string - warehouseEventsMap testhelper.EventsCountMap - loadTableStrategy string + warehouseEventsMap whth.EventsCountMap + enableMerge bool useParquetLoadFiles bool stagingFilePrefix string jobRunID string @@ -196,7 +198,7 @@ func TestIntegration(t *testing.T) { sourceID: sourceID, destinationID: destinationID, warehouseEventsMap: mergeEventsMap(), - loadTableStrategy: "MERGE", + enableMerge: true, useParquetLoadFiles: false, stagingFilePrefix: "testdata/upload-job-merge-mode", jobRunID: misc.FastUUID().String(), @@ -208,7 +210,7 @@ func TestIntegration(t *testing.T) { sourceID: sourceID, destinationID: destinationID, warehouseEventsMap: appendEventsMap(), - loadTableStrategy: "APPEND", + enableMerge: false, useParquetLoadFiles: false, stagingFilePrefix: "testdata/upload-job-append-mode", // an empty jobRunID means that the source is not an ETL one @@ -222,7 +224,7 @@ func TestIntegration(t *testing.T) { sourceID: sourceID, destinationID: destinationID, warehouseEventsMap: mergeEventsMap(), - loadTableStrategy: "MERGE", + enableMerge: true, useParquetLoadFiles: true, stagingFilePrefix: "testdata/upload-job-parquet", jobRunID: misc.FastUUID().String(), @@ -231,9 +233,8 @@ func TestIntegration(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_LOAD_TABLE_STRATEGY", tc.loadTableStrategy) + bootstrapSvc(t, tc.enableMerge) t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_USE_PARQUET_LOAD_FILES", strconv.FormatBool(tc.useParquetLoadFiles)) sqlClient := &warehouseclient.Client{ @@ -253,14 +254,14 @@ func TestIntegration(t *testing.T) { tables := []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"} t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: writeKey, Schema: tc.schema, Tables: tables, SourceID: tc.sourceID, DestinationID: tc.destinationID, JobRunID: tc.jobRunID, - WarehouseEventsMap: testhelper.EventsCountMap{ + WarehouseEventsMap: whth.EventsCountMap{ "identifies": 1, "users": 1, "tracks": 1, @@ -277,12 +278,12 @@ func TestIntegration(t *testing.T) { HTTPPort: httpPort, Client: sqlClient, StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: writeKey, Schema: tc.schema, Tables: tables, @@ -306,13 +307,14 @@ func TestIntegration(t *testing.T) { t.Run("Validation", func(t *testing.T) { t.Cleanup(func() { - require.Eventually(t, func() bool { - if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %[1]s CASCADE;`, namespace)); err != nil { - t.Logf("error deleting schema: %v", err) - return false - } - return true - }, + require.Eventually(t, + func() bool { + if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %s CASCADE;`, namespace)); err != nil { + t.Logf("error deleting schema %q: %v", namespace, err) + return false + } + return true + }, time.Minute, time.Second, ) @@ -372,15 +374,17 @@ func TestIntegration(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_USE_PARQUET_LOAD_FILES", strconv.FormatBool(tc.useParquetLoadFiles)) + t.Setenv( + "RSERVER_WAREHOUSE_DELTALAKE_USE_PARQUET_LOAD_FILES", + strconv.FormatBool(tc.useParquetLoadFiles), + ) for k, v := range tc.conf { dest.Config[k] = v } - testhelper.VerifyConfigurationTest(t, dest) + whth.VerifyConfigurationTest(t, dest) }) } }) @@ -392,20 +396,22 @@ func TestIntegration(t *testing.T) { workspaceID = "test_workspace_id" ) - namespace := testhelper.RandSchema(destType) - - t.Cleanup(func() { - require.Eventually(t, func() bool { - if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %[1]s CASCADE;`, namespace)); err != nil { - t.Logf("error deleting schema: %v", err) - return false - } - return true - }, + ctx := context.Background() + namespace := whth.RandSchema(destType) + cleanupSchema := func() { + require.Eventually(t, + func() bool { + _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %s CASCADE;`, namespace)) + if err != nil { + t.Logf("error deleting schema %q: %v", namespace, err) + return false + } + return true + }, time.Minute, time.Second, ) - }) + } schemaInUpload := model.TableSchema{ "test_bool": "boolean", @@ -470,7 +476,7 @@ func TestIntegration(t *testing.T) { t.Run("schema does not exists", func(t *testing.T) { tableName := "schema_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") @@ -486,7 +492,7 @@ func TestIntegration(t *testing.T) { t.Run("table does not exists", func(t *testing.T) { tableName := "table_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") @@ -497,6 +503,7 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) loadTableStat, err := d.LoadTable(ctx, tableName) require.Error(t, err) @@ -506,7 +513,7 @@ func TestIntegration(t *testing.T) { tableName := "merge_test_table" t.Run("without dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") @@ -517,6 +524,7 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -531,7 +539,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(0)) require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, fmt.Sprintf(` SELECT id, @@ -541,40 +549,40 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("with dedup use new record", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} - mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, true, "2022-12-15T06:53:49.640Z") + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, true, true, "2022-12-15T06:53:49.640Z") + + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true d := deltalake.New(config.New(), logger.NOP, memstats.New()) - err := d.Setup(ctx, warehouse, mockUploader) + err := d.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) loadTableStat, err := d.LoadTable(ctx, tableName) require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` + retrieveRecordsSQL := fmt.Sprintf(` SELECT id, received_at, @@ -583,29 +591,38 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, - namespace, - tableName, - ), + FROM %s.%s + ORDER BY id;`, + namespace, + tableName, ) - require.Equal(t, records, testhelper.DedupTestRecords()) + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, retrieveRecordsSQL) + require.Equal(t, records, whth.DedupTestRecords()) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records = whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, retrieveRecordsSQL) + require.Equal(t, records, whth.DedupTestRecords()) }) t.Run("with no overlapping partition", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-11-15T06:53:49.640Z") + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + d := deltalake.New(config.New(), logger.NOP, memstats.New()) - err := d.Setup(ctx, warehouse, mockUploader) + err := d.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -615,45 +632,45 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.DedupTwiceTestRecords()) + require.Equal(t, records, whth.DedupTwiceTestRecords()) }) }) t.Run("append", func(t *testing.T) { tableName := "append_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, true, false, "2022-12-15T06:53:49.640Z") - c := config.New() - c.Set("Warehouse.deltalake.loadTableStrategy", "APPEND") - - d := deltalake.New(c, logger.NOP, memstats.New()) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -668,26 +685,23 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.AppendTestRecords()) + require.Equal(t, records, whth.AppendTestRecords()) }) t.Run("load file does not exists", func(t *testing.T) { tableName := "load_file_not_exists_test_table" @@ -703,6 +717,7 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -714,7 +729,7 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in number of columns", func(t *testing.T) { tableName := "mismatch_columns_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") @@ -725,6 +740,7 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -734,31 +750,28 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("mismatch in schema", func(t *testing.T) { tableName := "mismatch_schema_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") @@ -769,6 +782,7 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -778,31 +792,28 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.MismatchSchemaTestRecords()) + require.Equal(t, records, whth.MismatchSchemaTestRecords()) }) t.Run("discards", func(t *testing.T) { tableName := warehouseutils.DiscardsTable - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") @@ -813,6 +824,7 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, warehouseutils.DiscardsSchema) require.NoError(t, err) @@ -822,29 +834,27 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(6)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - column_name, - column_value, - received_at, - row_id, - table_name, - uuid_ts - FROM - %s.%s - ORDER BY row_id ASC; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM %s.%s + ORDER BY row_id ASC;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.DiscardTestRecords()) + require.Equal(t, records, whth.DiscardTestRecords()) }) t.Run("parquet", func(t *testing.T) { tableName := "parquet_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeParquet, false, false, "2022-12-15T06:53:49.640Z") @@ -855,6 +865,7 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -864,32 +875,29 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("partition pruning", func(t *testing.T) { t.Run("not partitioned", func(t *testing.T) { tableName := "not_partitioned_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") @@ -900,26 +908,26 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) - - _, err = d.DB.QueryContext(ctx, ` - CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( - extra_test_bool BOOLEAN, - extra_test_datetime TIMESTAMP, - extra_test_float DOUBLE, - extra_test_int BIGINT, - extra_test_string STRING, - id STRING, - received_at TIMESTAMP, - event_date DATE GENERATED ALWAYS AS ( - CAST(received_at AS DATE) - ), - test_bool BOOLEAN, - test_datetime TIMESTAMP, - test_float DOUBLE, - test_int BIGINT, - test_string STRING - ) USING DELTA; - `) + t.Cleanup(cleanupSchema) + + _, err = d.DB.QueryContext(ctx, + `CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( + extra_test_bool BOOLEAN, + extra_test_datetime TIMESTAMP, + extra_test_float DOUBLE, + extra_test_int BIGINT, + extra_test_string STRING, + id STRING, + received_at TIMESTAMP, + event_date DATE GENERATED ALWAYS AS ( + CAST(received_at AS DATE) + ), + test_bool BOOLEAN, + test_datetime TIMESTAMP, + test_float DOUBLE, + test_int BIGINT, + test_string STRING + ) USING DELTA;`) require.NoError(t, err) loadTableStat, err := d.LoadTable(ctx, tableName) @@ -927,31 +935,28 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("event_date is not in partition", func(t *testing.T) { tableName := "not_event_date_partition_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") @@ -962,26 +967,26 @@ func TestIntegration(t *testing.T) { err = d.CreateSchema(ctx) require.NoError(t, err) - - _, err = d.DB.QueryContext(ctx, ` - CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( - extra_test_bool BOOLEAN, - extra_test_datetime TIMESTAMP, - extra_test_float DOUBLE, - extra_test_int BIGINT, - extra_test_string STRING, - id STRING, - received_at TIMESTAMP, - event_date DATE GENERATED ALWAYS AS ( - CAST(received_at AS DATE) - ), - test_bool BOOLEAN, - test_datetime TIMESTAMP, - test_float DOUBLE, - test_int BIGINT, - test_string STRING - ) USING DELTA PARTITIONED BY(id); - `) + t.Cleanup(cleanupSchema) + + _, err = d.DB.QueryContext(ctx, + `CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( + extra_test_bool BOOLEAN, + extra_test_datetime TIMESTAMP, + extra_test_float DOUBLE, + extra_test_int BIGINT, + extra_test_string STRING, + id STRING, + received_at TIMESTAMP, + event_date DATE GENERATED ALWAYS AS ( + CAST(received_at AS DATE) + ), + test_bool BOOLEAN, + test_datetime TIMESTAMP, + test_float DOUBLE, + test_int BIGINT, + test_string STRING + ) USING DELTA PARTITIONED BY(id);`) require.NoError(t, err) loadTableStat, err := d.LoadTable(ctx, tableName) @@ -989,26 +994,23 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) }) }) @@ -1047,62 +1049,66 @@ func TestDeltalake_TrimErrorMessage(t *testing.T) { c.Set("Warehouse.deltalake.maxErrorLength", len(tempError.Error())*25) d := deltalake.New(c, logger.NOP, memstats.New()) - require.Equal(t, d.TrimErrorMessage(tc.inputError), tc.expectedError) + require.Equal(t, tc.expectedError, d.TrimErrorMessage(tc.inputError)) }) } } -func TestDeltalake_ShouldAppend(t *testing.T) { +func TestDeltalake_ShouldMerge(t *testing.T) { testCases := []struct { name string - loadTableStrategy string + enableMerge bool uploaderCanAppend bool uploaderExpectedCalls int expected bool }{ { - name: "uploader says we can append and we are in append mode", - loadTableStrategy: "APPEND", + name: "uploader says we can append and merge is not enabled", + enableMerge: false, uploaderCanAppend: true, uploaderExpectedCalls: 1, - expected: true, + expected: false, }, { - name: "uploader says we cannot append and we are in append mode", - loadTableStrategy: "APPEND", - uploaderCanAppend: false, + name: "uploader says we can append and merge is enabled", + enableMerge: true, + uploaderCanAppend: true, uploaderExpectedCalls: 1, - expected: false, + expected: true, }, { - name: "uploader says we can append and we are in merge mode", - loadTableStrategy: "MERGE", - uploaderCanAppend: true, - uploaderExpectedCalls: 0, - expected: false, + name: "uploader says we cannot append so enableMerge false is ignored", + enableMerge: false, + uploaderCanAppend: false, + uploaderExpectedCalls: 1, + expected: true, }, { - name: "uploader says we cannot append and we are in merge mode", - loadTableStrategy: "MERGE", + name: "uploader says we cannot append so enableMerge true is ignored", + enableMerge: true, uploaderCanAppend: false, - uploaderExpectedCalls: 0, - expected: false, + uploaderExpectedCalls: 1, + expected: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := config.New() - c.Set("Warehouse.deltalake.loadTableStrategy", tc.loadTableStrategy) - - d := deltalake.New(c, logger.NOP, memstats.New()) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) + d.Warehouse = model.Warehouse{ + Destination: backendconfig.DestinationT{ + Config: map[string]any{ + string(model.EnableMergeSetting): tc.enableMerge, + }, + }, + } mockCtrl := gomock.NewController(t) uploader := mockuploader.NewMockUploader(mockCtrl) uploader.EXPECT().CanAppend().Times(tc.uploaderExpectedCalls).Return(tc.uploaderCanAppend) d.Uploader = uploader - require.Equal(t, d.ShouldAppend(), tc.expected) + require.Equal(t, d.ShouldMerge(), tc.expected) }) } } @@ -1142,8 +1148,8 @@ func newMockUploader( return mockUploader } -func mergeEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func mergeEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 1, "users": 1, "tracks": 1, @@ -1155,8 +1161,8 @@ func mergeEventsMap() testhelper.EventsCountMap { } } -func appendEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func appendEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 2, "users": 2, "tracks": 2, diff --git a/warehouse/integrations/deltalake/testdata/template.json b/warehouse/integrations/deltalake/testdata/template.json index da90ff8365..2b796aa448 100644 --- a/warehouse/integrations/deltalake/testdata/template.json +++ b/warehouse/integrations/deltalake/testdata/template.json @@ -38,7 +38,8 @@ "accountKey": "{{.accountKey}}", "syncFrequency": "30", "eventDelivery": false, - "eventDeliveryTS": 1648195480174 + "eventDeliveryTS": 1648195480174, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": { "eventDelivery": false, diff --git a/warehouse/integrations/mssql/mssql.go b/warehouse/integrations/mssql/mssql.go index 8071c9d6a9..7538b8646c 100644 --- a/warehouse/integrations/mssql/mssql.go +++ b/warehouse/integrations/mssql/mssql.go @@ -294,11 +294,9 @@ func (ms *MSSQL) loadTable( // - Refer to Microsoft's documentation on temporary tables at // https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175528(v=sql.105)?redirectedfrom=MSDN. log.Debugw("creating staging table") - createStagingTableStmt := fmt.Sprintf(` - SELECT - TOP 0 * INTO %[1]s.%[2]s - FROM - %[1]s.%[3]s;`, + createStagingTableStmt := fmt.Sprintf( + `SELECT TOP 0 * INTO %[1]s.%[2]s + FROM %[1]s.%[3]s;`, ms.Namespace, stagingTableName, tableName, @@ -471,7 +469,7 @@ func (ms *MSSQL) loadDataIntoStagingTable( return nil } -func (as *MSSQL) ProcessColumnValue( +func (ms *MSSQL) ProcessColumnValue( value string, valueType string, ) (interface{}, error) { diff --git a/warehouse/integrations/postgres/load.go b/warehouse/integrations/postgres/load.go index 89b9faa74c..da76cfedff 100644 --- a/warehouse/integrations/postgres/load.go +++ b/warehouse/integrations/postgres/load.go @@ -31,27 +31,15 @@ type loadUsersTableResponse struct { } func (pg *Postgres) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { - log := pg.logger.With( - logfield.SourceID, pg.Warehouse.Source.ID, - logfield.SourceType, pg.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, pg.Warehouse.Destination.ID, - logfield.DestinationType, pg.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, pg.Warehouse.WorkspaceID, - logfield.Namespace, pg.Namespace, - logfield.TableName, tableName, - logfield.LoadTableStrategy, pg.loadTableStrategy(), - ) - log.Infow("started loading") - var loadTableStats *types.LoadTableStats - var err error - - err = pg.DB.WithTx(ctx, func(tx *sqlmiddleware.Tx) error { + err := pg.DB.WithTx(ctx, func(tx *sqlmiddleware.Tx) error { + var err error loadTableStats, _, err = pg.loadTable( ctx, tx, tableName, pg.Uploader.GetTableSchemaInUpload(tableName), + false, ) return err }) @@ -59,8 +47,6 @@ func (pg *Postgres) LoadTable(ctx context.Context, tableName string) (*types.Loa return nil, fmt.Errorf("loading table: %w", err) } - log.Infow("completed loading") - return loadTableStats, err } @@ -69,6 +55,7 @@ func (pg *Postgres) loadTable( txn *sqlmiddleware.Tx, tableName string, tableSchemaInUpload model.TableSchema, + forceMerge bool, ) (*types.LoadTableStats, string, error) { log := pg.logger.With( logfield.SourceID, pg.Warehouse.Source.ID, @@ -78,9 +65,10 @@ func (pg *Postgres) loadTable( logfield.WorkspaceID, pg.Warehouse.WorkspaceID, logfield.Namespace, pg.Namespace, logfield.TableName, tableName, - logfield.LoadTableStrategy, pg.loadTableStrategy(), + logfield.ShouldMerge, pg.shouldMerge(), ) log.Infow("started loading") + defer log.Infow("completed loading") log.Debugw("setting search path") searchPathStmt := fmt.Sprintf(`SET search_path TO %q;`, @@ -105,10 +93,9 @@ func (pg *Postgres) loadTable( ) log.Debugw("creating staging table") - createStagingTableStmt := fmt.Sprintf(` - CREATE TEMPORARY TABLE %[2]s (LIKE %[1]q.%[3]q) - ON COMMIT PRESERVE ROWS; -`, + createStagingTableStmt := fmt.Sprintf( + `CREATE TEMPORARY TABLE %[2]s (LIKE %[1]q.%[3]q) + ON COMMIT PRESERVE ROWS;`, pg.Namespace, stagingTableName, tableName, @@ -143,7 +130,7 @@ func (pg *Postgres) loadTable( } var rowsDeleted int64 - if !slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID) { + if forceMerge || pg.shouldMerge() { log.Infow("deleting from load table") rowsDeleted, err = pg.deleteFromLoadTable( ctx, txn, tableName, @@ -225,13 +212,6 @@ func (pg *Postgres) loadDataIntoStagingTable( return nil } -func (pg *Postgres) loadTableStrategy() string { - if slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID) { - return "APPEND" - } - return "MERGE" -} - func (pg *Postgres) deleteFromLoadTable( ctx context.Context, txn *sqlmiddleware.Tx, @@ -387,7 +367,7 @@ func (pg *Postgres) loadUsersTable( usersSchemaInUpload, usersSchemaInWarehouse model.TableSchema, ) loadUsersTableResponse { - _, identifyStagingTable, err := pg.loadTable(ctx, tx, warehouseutils.IdentifiesTable, identifiesSchemaInUpload) + _, identifyStagingTable, err := pg.loadTable(ctx, tx, warehouseutils.IdentifiesTable, identifiesSchemaInUpload, false) if err != nil { return loadUsersTableResponse{ identifiesError: fmt.Errorf("loading identifies table: %w", err), @@ -398,9 +378,10 @@ func (pg *Postgres) loadUsersTable( return loadUsersTableResponse{} } - canSkipComputingLatestUserTraits := pg.config.skipComputingUserLatestTraits || slices.Contains(pg.config.skipComputingUserLatestTraitsWorkspaceIDs, pg.Warehouse.WorkspaceID) + canSkipComputingLatestUserTraits := pg.config.skipComputingUserLatestTraits || + slices.Contains(pg.config.skipComputingUserLatestTraitsWorkspaceIDs, pg.Warehouse.WorkspaceID) if canSkipComputingLatestUserTraits { - if _, _, err = pg.loadTable(ctx, tx, warehouseutils.UsersTable, usersSchemaInUpload); err != nil { + if _, _, err = pg.loadTable(ctx, tx, warehouseutils.UsersTable, usersSchemaInUpload, true); err != nil { return loadUsersTableResponse{ usersError: fmt.Errorf("loading users table: %w", err), } @@ -417,60 +398,40 @@ func (pg *Postgres) loadUsersTable( continue } userColNames = append(userColNames, fmt.Sprintf(`%q`, colName)) - caseSubQuery := fmt.Sprintf(` - CASE WHEN ( - SELECT - true + caseSubQuery := fmt.Sprintf( + `CASE WHEN ( + SELECT true ) THEN ( - SELECT - %[1]q - FROM - %[2]q AS staging_table - WHERE - x.id = staging_table.id AND - %[1]q IS NOT NULL - ORDER BY - received_at DESC - LIMIT - 1 - ) END AS %[1]q -`, + SELECT %[1]q + FROM %[2]q AS staging_table + WHERE x.id = staging_table.id AND %[1]q IS NOT NULL + ORDER BY received_at DESC + LIMIT 1 + ) END AS %[1]q`, colName, unionStagingTableName, ) firstValProps = append(firstValProps, caseSubQuery) } - query := fmt.Sprintf(` - CREATE TEMPORARY TABLE %[5]s AS ( - ( - SELECT - id, - %[4]s - FROM - %[1]q.%[2]q - WHERE - id IN ( - SELECT - user_id - FROM - %[3]q - WHERE - user_id IS NOT NULL - ) - ) - UNION + query := fmt.Sprintf( + `CREATE TEMPORARY TABLE %[5]s AS ( ( - SELECT - user_id, - %[4]s - FROM - %[3]q - WHERE - user_id IS NOT NULL + SELECT id, %[4]s + FROM %[1]q.%[2]q + WHERE id IN ( + SELECT user_id + FROM %[3]q + WHERE user_id IS NOT NULL + ) ) - ); -`, + UNION + ( + SELECT user_id, %[4]s + FROM %[3]q + WHERE user_id IS NOT NULL + ) + );`, pg.Namespace, warehouseutils.UsersTable, identifyStagingTable, @@ -534,13 +495,8 @@ func (pg *Postgres) loadUsersTable( // Delete from users table if the id is present in the staging table primaryKey := "id" query = fmt.Sprintf(` - DELETE FROM - %[1]q.%[2]q using %[3]q _source - WHERE - ( - _source.%[4]s = %[1]s.%[2]s.%[4]s - ); -`, + DELETE FROM %[1]q.%[2]q using %[3]q _source + WHERE _source.%[4]s = %[1]s.%[2]s.%[4]s;`, pg.Namespace, warehouseutils.UsersTable, usersStagingTableName, @@ -596,3 +552,10 @@ func (pg *Postgres) loadUsersTable( return loadUsersTableResponse{} } + +func (pg *Postgres) shouldMerge() bool { + return !pg.Uploader.CanAppend() || + (pg.config.allowMerge && + pg.Warehouse.GetBoolDestinationConfig(model.EnableMergeSetting) && + !slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID)) +} diff --git a/warehouse/integrations/postgres/load_test.go b/warehouse/integrations/postgres/load_test.go index 71e22197e4..969c650b5b 100644 --- a/warehouse/integrations/postgres/load_test.go +++ b/warehouse/integrations/postgres/load_test.go @@ -247,6 +247,7 @@ func TestLoadUsersTable(t *testing.T) { mockUploader := mockuploader.NewMockUploader(ctrl) mockUploader.EXPECT().GetTableSchemaInUpload(gomock.Any()).AnyTimes().DoAndReturn(f) mockUploader.EXPECT().GetTableSchemaInWarehouse(gomock.Any()).AnyTimes().DoAndReturn(f) + mockUploader.EXPECT().CanAppend().Return(true).AnyTimes() pg.DB = db pg.Namespace = namespace diff --git a/warehouse/integrations/postgres/postgres.go b/warehouse/integrations/postgres/postgres.go index bebc1a6b18..efe3a51814 100644 --- a/warehouse/integrations/postgres/postgres.go +++ b/warehouse/integrations/postgres/postgres.go @@ -123,6 +123,7 @@ type Postgres struct { LoadFileDownloader downloader.Downloader config struct { + allowMerge bool enableDeleteByJobs bool numWorkersDownloadLoadFiles int slowQueryThreshold time.Duration @@ -163,6 +164,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Postgres { pg.logger = log.Child("integrations").Child("postgres") pg.stats = stat + pg.config.allowMerge = conf.GetBool("Warehouse.postgres.allowMerge", true) pg.config.enableDeleteByJobs = conf.GetBool("Warehouse.postgres.enableDeleteByJobs", false) pg.config.numWorkersDownloadLoadFiles = conf.GetInt("Warehouse.postgres.numWorkersDownloadLoadFiles", 1) pg.config.slowQueryThreshold = conf.GetDuration("Warehouse.postgres.slowQueryThreshold", 5, time.Minute) diff --git a/warehouse/integrations/postgres/postgres_test.go b/warehouse/integrations/postgres/postgres_test.go index 35080e7242..6a741bd8f2 100644 --- a/warehouse/integrations/postgres/postgres_test.go +++ b/warehouse/integrations/postgres/postgres_test.go @@ -10,38 +10,30 @@ import ( "testing" "time" - "github.com/rudderlabs/rudder-go-kit/stats/memstats" - - "github.com/rudderlabs/rudder-server/warehouse/integrations/tunnelling" - "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "github.com/rudderlabs/compose-test/compose" + "github.com/rudderlabs/compose-test/testcompose" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-server/warehouse/integrations/postgres" - mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - - "github.com/rudderlabs/compose-test/compose" - - "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" - - "github.com/rudderlabs/compose-test/testcompose" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" + th "github.com/rudderlabs/rudder-server/testhelper" "github.com/rudderlabs/rudder-server/testhelper/health" - "github.com/rudderlabs/rudder-server/warehouse/client" - - "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" - - "github.com/stretchr/testify/require" - + "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" "github.com/rudderlabs/rudder-server/utils/misc" - "github.com/rudderlabs/rudder-server/warehouse/validations" - + "github.com/rudderlabs/rudder-server/warehouse/client" + "github.com/rudderlabs/rudder-server/warehouse/integrations/postgres" + whth "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + "github.com/rudderlabs/rudder-server/warehouse/integrations/tunnelling" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" + "github.com/rudderlabs/rudder-server/warehouse/validations" ) func TestIntegration(t *testing.T) { @@ -82,9 +74,9 @@ func TestIntegration(t *testing.T) { destType := warehouseutils.POSTGRES - namespace := testhelper.RandSchema(destType) - sourcesNamespace := testhelper.RandSchema(destType) - tunnelledNamespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) + sourcesNamespace := whth.RandSchema(destType) + tunnelledNamespace := whth.RandSchema(destType) host := "localhost" database := "rudderdb" @@ -143,7 +135,7 @@ func TestIntegration(t *testing.T) { } workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - testhelper.EnhanceWithDefaultEnvs(t) + whth.EnhanceWithDefaultEnvs(t) t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("MINIO_ACCESS_KEY_ID", accessKeyID) @@ -184,7 +176,7 @@ func TestIntegration(t *testing.T) { require.NoError(t, err) require.NoError(t, db.Ping()) - jobsDB := testhelper.JobsDB(t, jobsDBPort) + jobsDB := whth.JobsDB(t, jobsDBPort) testCases := []struct { name string @@ -193,10 +185,10 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string tables []string - stagingFilesEventsMap testhelper.EventsCountMap - loadFilesEventsMap testhelper.EventsCountMap - tableUploadsEventsMap testhelper.EventsCountMap - warehouseEventsMap testhelper.EventsCountMap + stagingFilesEventsMap whth.EventsCountMap + loadFilesEventsMap whth.EventsCountMap + tableUploadsEventsMap whth.EventsCountMap + warehouseEventsMap whth.EventsCountMap asyncJob bool stagingFilePrefix string }{ @@ -218,10 +210,10 @@ func TestIntegration(t *testing.T) { tables: []string{"tracks", "google_sheet"}, sourceID: sourcesSourceID, destinationID: sourcesDestinationID, - stagingFilesEventsMap: testhelper.SourcesStagingFilesEventsMap(), - loadFilesEventsMap: testhelper.SourcesLoadFilesEventsMap(), - tableUploadsEventsMap: testhelper.SourcesTableUploadsEventsMap(), - warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), + stagingFilesEventsMap: whth.SourcesStagingFilesEventsMap(), + loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), + tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), + warehouseEventsMap: whth.SourcesWarehouseEventsMap(), asyncJob: true, stagingFilePrefix: "testdata/sources-job", }, @@ -249,7 +241,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -268,12 +260,12 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -293,7 +285,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } if tc.asyncJob { ts2.UserID = ts1.UserID @@ -330,7 +322,7 @@ func TestIntegration(t *testing.T) { require.NoError(t, err) require.NoError(t, db.Ping()) - jobsDB := testhelper.JobsDB(t, jobsDBPort) + jobsDB := whth.JobsDB(t, jobsDBPort) testcases := []struct { name string @@ -339,10 +331,10 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string tables []string - stagingFilesEventsMap testhelper.EventsCountMap - loadFilesEventsMap testhelper.EventsCountMap - tableUploadsEventsMap testhelper.EventsCountMap - warehouseEventsMap testhelper.EventsCountMap + stagingFilesEventsMap whth.EventsCountMap + loadFilesEventsMap whth.EventsCountMap + tableUploadsEventsMap whth.EventsCountMap + warehouseEventsMap whth.EventsCountMap stagingFilePrefix string }{ { @@ -380,7 +372,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, SourceID: tc.sourceID, @@ -399,12 +391,12 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, SourceID: tc.sourceID, @@ -423,7 +415,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts2.VerifyEvents(t) }) @@ -459,7 +451,7 @@ func TestIntegration(t *testing.T) { Enabled: true, RevisionID: "29eeuu9kywWsRAybaXcxcnTVEl8", } - testhelper.VerifyConfigurationTest(t, dest) + whth.VerifyConfigurationTest(t, dest) }) t.Run("Load Table", func(t *testing.T) { @@ -469,7 +461,7 @@ func TestIntegration(t *testing.T) { workspaceID = "test_workspace_id" ) - namespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) schemaInUpload := model.TableSchema{ "test_bool": "boolean", @@ -546,7 +538,7 @@ func TestIntegration(t *testing.T) { t.Run("schema does not exists", func(t *testing.T) { tableName := "schema_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -562,7 +554,7 @@ func TestIntegration(t *testing.T) { t.Run("table does not exists", func(t *testing.T) { tableName := "table_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -582,7 +574,7 @@ func TestIntegration(t *testing.T) { tableName := "merge_test_table" t.Run("without dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -590,8 +582,11 @@ func TestIntegration(t *testing.T) { c := config.New() c.Set("Warehouse.postgres.EnableSQLStatementExecutionPlanWorkspaceIDs", workspaceID) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + pg := postgres.New(c, logger.NOP, memstats.New()) - err := pg.Setup(ctx, warehouse, mockUploader) + err := pg.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = pg.CreateSchema(ctx) @@ -610,7 +605,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(0)) require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, pg.DB.DB, fmt.Sprintf(` SELECT id, @@ -629,10 +624,10 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("with dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -640,8 +635,11 @@ func TestIntegration(t *testing.T) { c := config.New() c.Set("Warehouse.postgres.EnableSQLStatementExecutionPlanWorkspaceIDs", workspaceID) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + pg := postgres.New(config.New(), logger.NOP, memstats.New()) - err := pg.Setup(ctx, warehouse, mockUploader) + err := pg.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = pg.CreateSchema(ctx) @@ -655,7 +653,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(0)) require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, pg.DB.DB, fmt.Sprintf(` SELECT id, @@ -674,13 +672,13 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.DedupTestRecords()) + require.Equal(t, records, whth.DedupTestRecords()) }) }) t.Run("append", func(t *testing.T) { tableName := "append_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -708,7 +706,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, pg.DB.DB, fmt.Sprintf(` SELECT id, @@ -727,7 +725,7 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.AppendTestRecords()) + require.Equal(t, records, whth.AppendTestRecords()) }) t.Run("load file does not exists", func(t *testing.T) { tableName := "load_file_not_exists_test_table" @@ -754,7 +752,7 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in number of columns", func(t *testing.T) { tableName := "mismatch_columns_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -776,7 +774,7 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in schema", func(t *testing.T) { tableName := "mismatch_schema_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -798,7 +796,7 @@ func TestIntegration(t *testing.T) { t.Run("discards", func(t *testing.T) { tableName := warehouseutils.DiscardsTable - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) @@ -818,7 +816,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(6)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, pg.DB.DB, fmt.Sprintf(` SELECT column_name, @@ -835,7 +833,7 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.DiscardTestRecords()) + require.Equal(t, records, whth.DiscardTestRecords()) }) }) } @@ -855,6 +853,7 @@ func mockUploader( mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).Return(loadFiles, nil).AnyTimes() // Try removing this mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + mockUploader.EXPECT().CanAppend().Return(true).AnyTimes() return mockUploader } diff --git a/warehouse/integrations/redshift/redshift.go b/warehouse/integrations/redshift/redshift.go index 39ea0ecf1c..fedf6a1529 100644 --- a/warehouse/integrations/redshift/redshift.go +++ b/warehouse/integrations/redshift/redshift.go @@ -157,6 +157,7 @@ type Redshift struct { stats stats.Stats config struct { + allowMerge bool slowQueryThreshold time.Duration dedupWindow bool dedupWindowInHours time.Duration @@ -180,7 +181,7 @@ type s3Manifest struct { Entries []s3ManifestEntry `json:"entries"` } -type RedshiftCredentials struct { +type connectionCredentials struct { Host string Port string DbName string @@ -196,6 +197,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Redshift { rs.logger = log.Child("integrations").Child("redshift") rs.stats = stat + rs.config.allowMerge = conf.GetBool("Warehouse.redshift.allowMerge", true) rs.config.dedupWindow = conf.GetBool("Warehouse.redshift.dedupWindow", false) rs.config.dedupWindowInHours = conf.GetDuration("Warehouse.redshift.dedupWindowInHours", 720, time.Hour) rs.config.skipDedupDestinationIDs = conf.GetStringSlice("Warehouse.redshift.skipDedupDestinationIDs", nil) @@ -300,7 +302,8 @@ func (rs *Redshift) AddColumns(ctx context.Context, tableName string, columnsInf func CheckAndIgnoreColumnAlreadyExistError(err error) bool { if err != nil { - if e, ok := err.(*pq.Error); ok { + var e *pq.Error + if errors.As(err, &e) { if e.Code == "42701" { return true } @@ -315,10 +318,10 @@ func (rs *Redshift) DeleteBy(ctx context.Context, tableNames []string, params wa rs.logger.Infof("RS: Flag for enableDeleteByJobs is %t", rs.config.enableDeleteByJobs) for _, tb := range tableNames { sqlStatement := fmt.Sprintf(`DELETE FROM "%[1]s"."%[2]s" WHERE - context_sources_job_run_id <> $1 AND - context_sources_task_run_id <> $2 AND - context_source_id = $3 AND - received_at < $4`, + context_sources_job_run_id <> $1 AND + context_sources_task_run_id <> $2 AND + context_source_id = $3 AND + received_at < $4`, rs.Namespace, tb, ) @@ -461,7 +464,7 @@ func (rs *Redshift) loadTable( logfield.WorkspaceID, rs.Warehouse.WorkspaceID, logfield.Namespace, rs.Namespace, logfield.TableName, tableName, - logfield.LoadTableStrategy, rs.loadTableStrategy(), + logfield.ShouldMerge, rs.shouldMerge(), ) log.Infow("started loading") @@ -516,7 +519,7 @@ func (rs *Redshift) loadTable( } var rowsDeleted int64 - if !slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID) { + if rs.shouldMerge() { log.Infow("deleting from load table") rowsDeleted, err = rs.deleteFromLoadTable( ctx, txn, tableName, @@ -549,13 +552,6 @@ func (rs *Redshift) loadTable( }, stagingTableName, nil } -func (rs *Redshift) loadTableStrategy() string { - if slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID) { - return "APPEND" - } - return "MERGE" -} - func (rs *Redshift) copyIntoLoadTable( ctx context.Context, txn *sqlmiddleware.Tx, @@ -579,15 +575,13 @@ func (rs *Redshift) copyIntoLoadTable( var copyStmt string if rs.Uploader.GetLoadFileType() == warehouseutils.LoadFileTypeParquet { - copyStmt = fmt.Sprintf(` - COPY %s - FROM - '%s' + copyStmt = fmt.Sprintf( + `COPY %s + FROM '%s' ACCESS_KEY_ID '%s' SECRET_ACCESS_KEY '%s' SESSION_TOKEN '%s' - MANIFEST FORMAT PARQUET; - `, + MANIFEST FORMAT PARQUET;`, fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName), manifestS3Location, tempAccessKeyId, @@ -595,10 +589,9 @@ func (rs *Redshift) copyIntoLoadTable( token, ) } else { - copyStmt = fmt.Sprintf(` - COPY %s(%s) - FROM - '%s' + copyStmt = fmt.Sprintf( + `COPY %s(%s) + FROM '%s' CSV GZIP ACCESS_KEY_ID '%s' SECRET_ACCESS_KEY '%s' @@ -608,8 +601,7 @@ func (rs *Redshift) copyIntoLoadTable( TIMEFORMAT 'auto' MANIFEST TRUNCATECOLUMNS EMPTYASNULL BLANKSASNULL FILLRECORD ACCEPTANYDATE TRIMBLANKS ACCEPTINVCHARS COMPUPDATE OFF - STATUPDATE OFF; - `, + STATUPDATE OFF;`, fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName), sortedColumnNames, manifestS3Location, @@ -638,14 +630,10 @@ func (rs *Redshift) deleteFromLoadTable( primaryKey = column } - deleteStmt := fmt.Sprintf(` - DELETE FROM - %[1]s.%[2]q - USING - %[1]s.%[3]q _source - WHERE - _source.%[4]s = %[1]s.%[2]q.%[4]s -`, + deleteStmt := fmt.Sprintf( + `DELETE FROM %[1]s.%[2]q + USING %[1]s.%[3]q _source + WHERE _source.%[4]s = %[1]s.%[2]q.%[4]s`, rs.Namespace, tableName, stagingTableName, @@ -653,9 +641,8 @@ func (rs *Redshift) deleteFromLoadTable( ) if rs.config.dedupWindow { if _, ok := tableSchemaAfterUpload["received_at"]; ok { - deleteStmt += fmt.Sprintf(` - AND %[1]s.%[2]q.received_at > GETDATE() - INTERVAL '%[3]d HOUR' -`, + deleteStmt += fmt.Sprintf( + ` AND %[1]s.%[2]q.received_at > GETDATE() - INTERVAL '%[3]d HOUR'`, rs.Namespace, tableName, rs.config.dedupWindowInHours/time.Hour, @@ -663,10 +650,8 @@ func (rs *Redshift) deleteFromLoadTable( } } if tableName == warehouseutils.DiscardsTable { - deleteStmt += fmt.Sprintf(` - AND _source.%[3]s = %[1]s.%[2]q.%[3]s - AND _source.%[4]s = %[1]s.%[2]q.%[4]s -`, + deleteStmt += fmt.Sprintf( + ` AND _source.%[3]s = %[1]s.%[2]q.%[3]s AND _source.%[4]s = %[1]s.%[2]q.%[4]s`, rs.Namespace, tableName, "table_name", @@ -697,10 +682,9 @@ func (rs *Redshift) insertIntoLoadTable( sortedColumnKeys, ) - insertStmt := fmt.Sprintf(` - INSERT INTO %[1]q.%[2]q (%[3]s) - SELECT - %[3]s + insertStmt := fmt.Sprintf( + `INSERT INTO %[1]q.%[2]q (%[3]s) + SELECT %[3]s FROM ( SELECT @@ -710,12 +694,9 @@ func (rs *Redshift) insertIntoLoadTable( ORDER BY received_at DESC ) AS _rudder_staging_row_number - FROM - %[1]q.%[4]q + FROM %[1]q.%[4]q ) AS _ - WHERE - _rudder_staging_row_number = 1; -`, + WHERE _rudder_staging_row_number = 1;`, rs.Namespace, tableName, quotedColumnNames, @@ -740,14 +721,17 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { firstValProps []string ) - rs.logger.Infow("started loading for identifies and users tables", + logFields := []any{ logfield.SourceID, rs.Warehouse.Source.ID, logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, logfield.DestinationID, rs.Warehouse.Destination.ID, logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, logfield.WorkspaceID, rs.Warehouse.WorkspaceID, logfield.Namespace, rs.Namespace, - ) + logfield.ShouldMerge, rs.shouldMerge(), + logfield.TableName, warehouseutils.UsersTable, + } + rs.logger.Infow("started loading for identifies and users tables", logFields...) _, identifyStagingTable, err = rs.loadTable(ctx, warehouseutils.IdentifiesTable, rs.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), rs.Uploader.GetTableSchemaInWarehouse(warehouseutils.IdentifiesTable), true) if err != nil { @@ -791,15 +775,12 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { stagingTableName := warehouseutils.StagingTableName(provider, warehouseutils.UsersTable, tableNameLimit) - query = fmt.Sprintf(` - CREATE TABLE %[1]q.%[2]q AS ( - SELECT - DISTINCT * + query = fmt.Sprintf( + `CREATE TABLE %[1]q.%[2]q AS ( + SELECT DISTINCT * FROM ( - SELECT - id, - %[3]s + SELECT id, %[3]s FROM ( ( @@ -820,18 +801,13 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { ) UNION ( - SELECT - user_id, - %[6]s - FROM - %[1]q.%[5]q - WHERE - user_id IS NOT NULL + SELECT user_id, %[6]s + FROM %[1]q.%[5]q + WHERE user_id IS NOT NULL ) ) ) - ); -`, + );`, rs.Namespace, stagingTableName, strings.Join(firstValProps, ","), @@ -849,16 +825,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { if _, err = txn.ExecContext(ctx, query); err != nil { _ = txn.Rollback() - rs.logger.Warnw("creating staging table for users", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, warehouseutils.UsersTable, - logfield.Error, err.Error(), - ) + rs.logger.Warnw("creating staging table for users", append(logFields, logfield.Error, err.Error())...) return map[string]error{ warehouseutils.IdentifiesTable: nil, warehouseutils.UsersTable: fmt.Errorf("creating staging table for users: %w", err), @@ -866,77 +833,45 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { } defer rs.dropStagingTables(ctx, []string{stagingTableName}) - primaryKey := "id" - query = fmt.Sprintf(` - DELETE FROM - %[1]s.%[2]q USING %[1]s.%[3]q _source - WHERE - ( - _source.%[4]s = %[1]s.%[2]s.%[4]s - ); -`, - rs.Namespace, - warehouseutils.UsersTable, - stagingTableName, - primaryKey, - ) + if rs.shouldMerge() { + primaryKey := "id" + query = fmt.Sprintf(`DELETE FROM %[1]s.%[2]q USING %[1]s.%[3]q _source + WHERE _source.%[4]s = %[1]s.%[2]s.%[4]s;`, + rs.Namespace, + warehouseutils.UsersTable, + stagingTableName, + primaryKey, + ) - if _, err = txn.ExecContext(ctx, query); err != nil { - _ = txn.Rollback() + if _, err = txn.ExecContext(ctx, query); err != nil { + _ = txn.Rollback() - rs.logger.Warnw("deleting from users table for dedup", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.Query, query, - logfield.TableName, warehouseutils.UsersTable, - logfield.Error, err.Error(), - ) - return map[string]error{ - warehouseutils.UsersTable: fmt.Errorf("deleting from main table for dedup: %w", normalizeError(err)), + rs.logger.Warnw("deleting from users table for dedup", append(logFields, + logfield.Query, query, + logfield.Error, err.Error(), + )...) + return map[string]error{ + warehouseutils.UsersTable: fmt.Errorf("deleting from main table for dedup: %w", normalizeError(err)), + } } } - query = fmt.Sprintf(` - INSERT INTO %[1]q.%[2]q (%[4]s) - SELECT - %[4]s - FROM - %[1]q.%[3]q; -`, + query = fmt.Sprintf( + `INSERT INTO %[1]q.%[2]q (%[4]s) + SELECT %[4]s + FROM %[1]q.%[3]q;`, rs.Namespace, warehouseutils.UsersTable, stagingTableName, warehouseutils.DoubleQuoteAndJoinByComma(append([]string{"id"}, userColNames...)), ) - rs.logger.Infow("inserting into users table", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, warehouseutils.UsersTable, - logfield.Query, query, - ) + rs.logger.Infow("inserting into users table", append(logFields, logfield.Query, query)...) if _, err = txn.ExecContext(ctx, query); err != nil { _ = txn.Rollback() - rs.logger.Warnw("failed inserting into users table", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, warehouseutils.UsersTable, - logfield.Error, err.Error(), - ) + rs.logger.Warnw("failed inserting into users table", append(logFields, logfield.Error, err.Error())...) return map[string]error{ warehouseutils.IdentifiesTable: nil, @@ -947,16 +882,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { if err = txn.Commit(); err != nil { _ = txn.Rollback() - rs.logger.Warnw("committing transaction for user table", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, warehouseutils.UsersTable, - logfield.Error, err.Error(), - ) + rs.logger.Warnw("committing transaction for user table", append(logFields, logfield.Error, err.Error())...) return map[string]error{ warehouseutils.IdentifiesTable: nil, @@ -964,14 +890,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { } } - rs.logger.Infow("completed loading for users and identifies tables", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - ) + rs.logger.Infow("completed loading for users and identifies tables", logFields...) return map[string]error{ warehouseutils.IdentifiesTable: nil, @@ -1012,8 +931,7 @@ func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) { } } - stmt := `SET query_group to 'RudderStack'` - _, err = db.ExecContext(ctx, stmt) + _, err = db.ExecContext(ctx, `SET query_group to 'RudderStack'`) if err != nil { return nil, fmt.Errorf("redshift set query_group error : %v", err) } @@ -1040,15 +958,9 @@ func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) { } func (rs *Redshift) dropDanglingStagingTables(ctx context.Context) { - sqlStatement := ` - SELECT - table_name - FROM - information_schema.tables - WHERE - table_schema = $1 AND - table_name like $2; - ` + sqlStatement := `SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 AND table_name like $2;` rows, err := rs.DB.QueryContext(ctx, sqlStatement, rs.Namespace, @@ -1121,12 +1033,7 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu // creating staging column stagingColumnType = getRSDataType(columnType) stagingColumnName = fmt.Sprintf(`%s-staging-%s`, columnName, misc.FastUUID().String()) - query = fmt.Sprintf(` - ALTER TABLE - %q.%q - ADD - COLUMN %q %s; - `, + query = fmt.Sprintf(`ALTER TABLE %q.%q ADD COLUMN %q %s;`, rs.Namespace, tableName, stagingColumnName, @@ -1137,14 +1044,10 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu } // populating staging column - query = fmt.Sprintf(` - UPDATE - %[1]q.%[2]q - SET - %[3]q = CAST (%[4]q AS %[5]s) - WHERE - %[4]q IS NOT NULL; - `, + query = fmt.Sprintf( + `UPDATE %[1]q.%[2]q + SET %[3]q = CAST (%[4]q AS %[5]s) + WHERE %[4]q IS NOT NULL;`, rs.Namespace, tableName, stagingColumnName, @@ -1157,12 +1060,8 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu // renaming original column to deprecated column deprecatedColumnName = fmt.Sprintf(`%s-deprecated-%s`, columnName, misc.FastUUID().String()) - query = fmt.Sprintf(` - ALTER TABLE - %[1]q.%[2]q - RENAME COLUMN - %[3]q TO %[4]q; - `, + query = fmt.Sprintf( + `ALTER TABLE %[1]q.%[2]q RENAME COLUMN %[3]q TO %[4]q;`, rs.Namespace, tableName, columnName, @@ -1173,12 +1072,8 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu } // renaming staging column to original column - query = fmt.Sprintf(` - ALTER TABLE - %[1]q.%[2]q - RENAME COLUMN - %[3]q TO %[4]q; - `, + query = fmt.Sprintf( + `ALTER TABLE %[1]q.%[2]q RENAME COLUMN %[3]q TO %[4]q;`, rs.Namespace, tableName, stagingColumnName, @@ -1195,20 +1090,17 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu // dropping deprecated column // Since dropping the column can fail, we need to do it outside the transaction - // Because if will fail during the commit of the transaction + // Because if it will fail during the commit of the transaction // https://github.com/lib/pq/blob/d5affd5073b06f745459768de35356df2e5fd91d/conn.go#L600 - query = fmt.Sprintf(` - ALTER TABLE - %[1]q.%[2]q - DROP COLUMN - %[3]q; - `, + query = fmt.Sprintf( + `ALTER TABLE %[1]q.%[2]q DROP COLUMN %[3]q;`, rs.Namespace, tableName, deprecatedColumnName, ) if _, err = rs.DB.ExecContext(ctx, query); err != nil { - if pqError, ok := err.(*pq.Error); !ok || pqError.Code != "2BP01" { + var pqError *pq.Error + if !errors.As(err, &pqError) || pqError.Code != "2BP01" { return model.AlterTableResponse{}, fmt.Errorf("drop deprecated column: %w", err) } @@ -1228,8 +1120,8 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu return res, nil } -func (rs *Redshift) getConnectionCredentials() RedshiftCredentials { - creds := RedshiftCredentials{ +func (rs *Redshift) getConnectionCredentials() connectionCredentials { + creds := connectionCredentials{ Host: warehouseutils.GetConfigValue(RSHost, rs.Warehouse), Port: warehouseutils.GetConfigValue(RSPort, rs.Warehouse), DbName: warehouseutils.GetConfigValue(RSDbName, rs.Warehouse), @@ -1247,18 +1139,13 @@ func (rs *Redshift) FetchSchema(ctx context.Context) (model.Schema, model.Schema schema := make(model.Schema) unrecognizedSchema := make(model.Schema) - sqlStatement := ` - SELECT + sqlStatement := `SELECT table_name, column_name, data_type, character_maximum_length - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - table_schema = $1 - and table_name not like $2; - ` + FROM INFORMATION_SCHEMA.COLUMNS + WHERE table_schema = $1 and table_name not like $2;` rows, err := rs.DB.QueryContext( ctx, @@ -1439,12 +1326,20 @@ func (rs *Redshift) SetConnectionTimeout(timeout time.Duration) { rs.connectTimeout = timeout } +func (rs *Redshift) shouldMerge() bool { + return !rs.Uploader.CanAppend() || + (rs.config.allowMerge && + rs.Warehouse.GetBoolDestinationConfig(model.EnableMergeSetting) && + !slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID)) +} + func (*Redshift) ErrorMappings() []model.JobError { return errorsMappings } func normalizeError(err error) error { - if pqErr, ok := err.(*pq.Error); ok { + var pqErr *pq.Error + if errors.As(err, &pqErr) { return fmt.Errorf("pq: message: %s, detail: %s", pqErr.Message, pqErr.Detail, diff --git a/warehouse/integrations/redshift/redshift_test.go b/warehouse/integrations/redshift/redshift_test.go index 87ce47ce6f..abdcf72fe2 100644 --- a/warehouse/integrations/redshift/redshift_test.go +++ b/warehouse/integrations/redshift/redshift_test.go @@ -13,45 +13,33 @@ import ( "testing" "time" - "github.com/rudderlabs/rudder-go-kit/stats/memstats" - "github.com/golang/mock/gomock" - - "github.com/rudderlabs/rudder-go-kit/filemanager" - mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - - "github.com/rudderlabs/compose-test/compose" - "github.com/rudderlabs/rudder-go-kit/logger" - - "github.com/rudderlabs/rudder-server/warehouse/integrations/redshift" - - "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" - - sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - "github.com/lib/pq" "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/compose-test/testcompose" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" + th "github.com/rudderlabs/rudder-server/testhelper" "github.com/rudderlabs/rudder-server/testhelper/health" - - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" - - "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" - + "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" "github.com/rudderlabs/rudder-server/utils/misc" - "github.com/rudderlabs/rudder-server/warehouse/validations" - - backendconfig "github.com/rudderlabs/rudder-server/backend-config" - - "github.com/stretchr/testify/require" - "github.com/rudderlabs/rudder-server/warehouse/client" + sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/integrations/redshift" + whth "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" + "github.com/rudderlabs/rudder-server/warehouse/validations" ) type testCredentials struct { @@ -116,8 +104,8 @@ func TestIntegration(t *testing.T) { destType := warehouseutils.RS - namespace := testhelper.RandSchema(destType) - sourcesNamespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) + sourcesNamespace := whth.RandSchema(destType) rsTestCredentials, err := rsTestCredentials() require.NoError(t, err) @@ -143,7 +131,7 @@ func TestIntegration(t *testing.T) { } workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - testhelper.EnhanceWithDefaultEnvs(t) + whth.EnhanceWithDefaultEnvs(t) t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) @@ -184,7 +172,7 @@ func TestIntegration(t *testing.T) { t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_DEDUP_WINDOW", "true") t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_DEDUP_WINDOW_IN_HOURS", "5") - jobsDB := testhelper.JobsDB(t, jobsDBPort) + jobsDB := whth.JobsDB(t, jobsDBPort) testcase := []struct { name string @@ -193,10 +181,10 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string tables []string - stagingFilesEventsMap testhelper.EventsCountMap - loadFilesEventsMap testhelper.EventsCountMap - tableUploadsEventsMap testhelper.EventsCountMap - warehouseEventsMap testhelper.EventsCountMap + stagingFilesEventsMap whth.EventsCountMap + loadFilesEventsMap whth.EventsCountMap + tableUploadsEventsMap whth.EventsCountMap + warehouseEventsMap whth.EventsCountMap asyncJob bool stagingFilePrefix string }{ @@ -216,10 +204,10 @@ func TestIntegration(t *testing.T) { tables: []string{"tracks", "google_sheet"}, sourceID: sourcesSourceID, destinationID: sourcesDestinationID, - stagingFilesEventsMap: testhelper.SourcesStagingFilesEventsMap(), - loadFilesEventsMap: testhelper.SourcesLoadFilesEventsMap(), - tableUploadsEventsMap: testhelper.SourcesTableUploadsEventsMap(), - warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), + stagingFilesEventsMap: whth.SourcesStagingFilesEventsMap(), + loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), + tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), + warehouseEventsMap: whth.SourcesWarehouseEventsMap(), asyncJob: true, stagingFilePrefix: "testdata/sources-job", }, @@ -258,7 +246,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -277,12 +265,12 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -302,7 +290,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } if tc.asyncJob { ts2.UserID = ts1.UserID @@ -351,7 +339,7 @@ func TestIntegration(t *testing.T) { Enabled: true, RevisionID: "29HgOWobrn0RYZLpaSwPIbN2987", } - testhelper.VerifyConfigurationTest(t, dest) + whth.VerifyConfigurationTest(t, dest) }) t.Run("Load Table", func(t *testing.T) { @@ -361,7 +349,7 @@ func TestIntegration(t *testing.T) { workspaceID = "test_workspace_id" ) - namespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) t.Cleanup(func() { require.Eventually(t, func() bool { @@ -442,7 +430,7 @@ func TestIntegration(t *testing.T) { t.Run("schema does not exists", func(t *testing.T) { tableName := "schema_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) @@ -458,7 +446,7 @@ func TestIntegration(t *testing.T) { t.Run("table does not exists", func(t *testing.T) { tableName := "table_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) @@ -475,10 +463,9 @@ func TestIntegration(t *testing.T) { require.Nil(t, loadTableStat) }) t.Run("merge", func(t *testing.T) { - tableName := "merge_test_table" - t.Run("without dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + tableName := "merge_without_dedup_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) @@ -500,12 +487,12 @@ func TestIntegration(t *testing.T) { loadTableStat, err = d.LoadTable(ctx, tableName) require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT id, received_at, test_bool, @@ -513,25 +500,26 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, whth.DedupTwiceTestRecords(), records) }) t.Run("with dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + tableName := "merge_with_dedup_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + d := redshift.New(config.New(), logger.NOP, memstats.New()) - err := d.Setup(ctx, warehouse, mockUploader) + err := d.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) @@ -542,12 +530,17 @@ func TestIntegration(t *testing.T) { loadTableStat, err := d.LoadTable(ctx, tableName) require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) require.Equal(t, loadTableStat.RowsInserted, int64(0)) require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT id, received_at, test_bool, @@ -555,29 +548,82 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.DedupTestRecords()) + require.Equal(t, whth.DedupTestRecords(), records) }) t.Run("with dedup window", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + tableName := "merge_with_dedup_window_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + + c := config.New() + c.Set("Warehouse.redshift.dedupWindow", true) + c.Set("Warehouse.redshift.dedupWindowInHours", 999999) + + d := redshift.New(c, logger.NOP, memstats.New()) + err := d.Setup(ctx, mergeWarehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, + namespace, + tableName, + ), + ) + require.Equal(t, whth.DedupTestRecords(), records) + }) + t.Run("with short dedup window", func(t *testing.T) { + tableName := "merge_with_short_dedup_window_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + c := config.New() c.Set("Warehouse.redshift.dedupWindow", true) c.Set("Warehouse.redshift.dedupWindowInHours", 0) d := redshift.New(c, logger.NOP, memstats.New()) - err := d.Setup(ctx, warehouse, mockUploader) + err := d.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) @@ -591,9 +637,14 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT id, received_at, test_bool, @@ -601,22 +652,19 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.DedupTwiceTestRecords()) + require.Equal(t, whth.DedupTwiceTestRecords(), records) }) }) t.Run("append", func(t *testing.T) { tableName := "append_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) @@ -644,7 +692,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, rs.DB.DB, fmt.Sprintf(` SELECT id, @@ -663,7 +711,7 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.AppendTestRecords()) + require.Equal(t, records, whth.AppendTestRecords()) }) t.Run("load file does not exists", func(t *testing.T) { tableName := "load_file_not_exists_test_table" @@ -690,7 +738,7 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in number of columns", func(t *testing.T) { tableName := "mismatch_columns_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) @@ -712,7 +760,7 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in schema", func(t *testing.T) { tableName := "mismatch_schema_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) @@ -734,7 +782,7 @@ func TestIntegration(t *testing.T) { t.Run("discards", func(t *testing.T) { tableName := warehouseutils.DiscardsTable - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema, warehouseutils.LoadFileTypeCsv) @@ -754,7 +802,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(6)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, rs.DB.DB, fmt.Sprintf(` SELECT column_name, @@ -771,12 +819,12 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.DiscardTestRecords()) + require.Equal(t, records, whth.DiscardTestRecords()) }) t.Run("parquet", func(t *testing.T) { tableName := "parquet_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) fileStat, err := os.Stat("../testdata/load.parquet") require.NoError(t, err) @@ -802,7 +850,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, rs.DB.DB, fmt.Sprintf(` SELECT id, @@ -821,7 +869,7 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) }) } @@ -990,6 +1038,7 @@ func newMockUploader( mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() mockUploader.EXPECT().GetLoadFileType().Return(loadFileType).AnyTimes() + mockUploader.EXPECT().CanAppend().Return(true).AnyTimes() return mockUploader } diff --git a/warehouse/integrations/snowflake/snowflake.go b/warehouse/integrations/snowflake/snowflake.go index eed8ab2aba..d1c707a002 100644 --- a/warehouse/integrations/snowflake/snowflake.go +++ b/warehouse/integrations/snowflake/snowflake.go @@ -13,8 +13,6 @@ import ( "strings" "time" - "github.com/rudderlabs/rudder-server/warehouse/integrations/types" - "github.com/samber/lo" snowflake "github.com/snowflakedb/gosnowflake" @@ -24,6 +22,7 @@ import ( "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/client" sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" "github.com/rudderlabs/rudder-server/warehouse/internal/model" lf "github.com/rudderlabs/rudder-server/warehouse/logfield" whutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -44,9 +43,6 @@ const ( role = "role" password = "password" application = "Rudderstack_Warehouse" - - loadTableStrategyMergeMode = "MERGE" - loadTableStrategyAppendMode = "APPEND" ) var primaryKeyMap = map[string]string{ @@ -169,7 +165,7 @@ type Snowflake struct { stats stats.Stats config struct { - loadTableStrategy string + allowMerge bool slowQueryThreshold time.Duration enableDeleteByJobs bool @@ -186,16 +182,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) (*Snowflake, sf.logger = log.Child("integrations").Child("snowflake") sf.stats = stat - loadTableStrategy := conf.GetString("Warehouse.snowflake.loadTableStrategy", loadTableStrategyMergeMode) - switch loadTableStrategy { - case loadTableStrategyMergeMode, loadTableStrategyAppendMode: - sf.config.loadTableStrategy = loadTableStrategy - default: - return nil, fmt.Errorf("loadTableStrategy out of the known domain [%+v]: %v", - []string{loadTableStrategyMergeMode, loadTableStrategyAppendMode}, loadTableStrategy, - ) - } - + sf.config.allowMerge = conf.GetBool("Warehouse.snowflake.allowMerge", true) sf.config.enableDeleteByJobs = conf.GetBool("Warehouse.snowflake.enableDeleteByJobs", false) sf.config.slowQueryThreshold = conf.GetDuration("Warehouse.snowflake.slowQueryThreshold", 5, time.Minute) sf.config.debugDuplicateWorkspaceIDs = conf.GetStringSlice("Warehouse.snowflake.debugDuplicateWorkspaceIDs", nil) @@ -285,7 +272,8 @@ func (sf *Snowflake) createSchema(ctx context.Context) (err error) { func checkAndIgnoreAlreadyExistError(err error) bool { if err != nil { // TODO: throw error if column already exists but of different type - if e, ok := err.(*snowflake.SnowflakeError); ok && e.SQLState == "42601" { + var e *snowflake.SnowflakeError + if errors.As(err, &e) && e.SQLState == "42601" { return true } return false @@ -312,9 +300,8 @@ func (sf *Snowflake) DeleteBy(ctx context.Context, tableNames []string, params w ) log.Infow("Cleaning up the following tables in snowflake") - sqlStatement := fmt.Sprintf(` - DELETE FROM - %[1]q.%[2]q + sqlStatement := fmt.Sprintf( + `DELETE FROM %[1]q.%[2]q WHERE context_sources_job_run_id <> '%[3]s' AND context_sources_task_run_id <> '%[4]s' AND @@ -360,7 +347,7 @@ func (sf *Snowflake) loadTable( lf.WorkspaceID, sf.Warehouse.WorkspaceID, lf.Namespace, sf.Namespace, lf.TableName, tableName, - lf.LoadTableStrategy, sf.config.loadTableStrategy, + lf.ShouldMerge, sf.ShouldMerge(), ) log.Infow("started loading") @@ -384,7 +371,7 @@ func (sf *Snowflake) loadTable( // Truncating the columns by default to avoid size limitation errors // https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions - if sf.ShouldAppend() { + if !sf.ShouldMerge() { log.Infow("copying data into main table") loadTableStats, err := sf.copyInto(ctx, db, schemaIdentifier, tableName, sortedColumnNames, tableName) if err != nil { @@ -551,24 +538,17 @@ func (sf *Snowflake) sampleDuplicateMessages( mainTable := fmt.Sprintf("%s.%q", identifier, mainTableName) stagingTable := fmt.Sprintf("%s.%q", identifier, stagingTableName) - rows, err := db.QueryContext(ctx, ` - SELECT - ID, - RECEIVED_AT - FROM - `+mainTable+` + rows, err := db.QueryContext(ctx, + `SELECT ID, RECEIVED_AT + FROM `+mainTable+` WHERE RECEIVED_AT > (SELECT DATEADD(day,-?,current_date())) AND ID IN ( - SELECT - ID - FROM - `+stagingTable+` + SELECT ID + FROM `+stagingTable+` ) ORDER BY RECEIVED_AT ASC - LIMIT - ?; -`, + LIMIT ?;`, sf.config.debugDuplicateIntervalInDays, sf.config.debugDuplicateLimit, ) @@ -843,11 +823,13 @@ func (sf *Snowflake) LoadIdentityMappingsTable(ctx context.Context) error { return nil } -// ShouldAppend returns true if: -// * the load table strategy is "append" mode -// * the uploader says we can append -func (sf *Snowflake) ShouldAppend() bool { - return sf.config.loadTableStrategy == loadTableStrategyAppendMode && sf.Uploader.CanAppend() +// ShouldMerge returns true if: +// * the uploader says we cannot append +// * the server configuration says we can merge +// * the user opted-in +func (sf *Snowflake) ShouldMerge() bool { + return !sf.Uploader.CanAppend() || + (sf.config.allowMerge && sf.Warehouse.GetBoolDestinationConfig(model.EnableMergeSetting)) } func (sf *Snowflake) LoadUserTables(ctx context.Context) map[string]error { @@ -872,7 +854,7 @@ func (sf *Snowflake) LoadUserTables(ctx context.Context) map[string]error { lf.WorkspaceID, sf.Warehouse.WorkspaceID, lf.Namespace, sf.Namespace, lf.TableName, whutils.UsersTable, - lf.LoadTableStrategy, sf.config.loadTableStrategy, + lf.ShouldMerge, sf.ShouldMerge(), ) log.Infow("started loading for identifies and users tables") @@ -889,7 +871,7 @@ func (sf *Snowflake) LoadUserTables(ctx context.Context) map[string]error { } schemaIdentifier := sf.schemaIdentifier() - if sf.ShouldAppend() { + if !sf.ShouldMerge() { tmpIdentifiesStagingTable := whutils.StagingTableName(provider, identifiesTable, tableNameLimit) sqlStatement := fmt.Sprintf( `CREATE TEMPORARY TABLE %[1]s.%[2]q LIKE %[1]s.%[3]q;`, diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index b4d8447f1a..6782aac3e1 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -13,21 +13,17 @@ import ( "testing" "time" - "github.com/rudderlabs/rudder-go-kit/stats/memstats" - - "github.com/samber/lo" - - "github.com/rudderlabs/rudder-go-kit/filemanager" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "github.com/golang/mock/gomock" + "github.com/samber/lo" sfdb "github.com/snowflakedb/gosnowflake" "github.com/stretchr/testify/require" "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/compose-test/testcompose" "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" @@ -38,6 +34,7 @@ import ( "github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake" "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" whutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/rudderlabs/rudder-server/warehouse/validations" ) @@ -133,57 +130,53 @@ func TestIntegration(t *testing.T) { rbacCredentials, err := getSnowflakeTestCredentials(testRBACKey) require.NoError(t, err) - templateConfigurations := map[string]any{ - "workspaceID": workspaceID, - "sourceID": sourceID, - "destinationID": destinationID, - "writeKey": writeKey, - "caseSensitiveSourceID": caseSensitiveSourceID, - "caseSensitiveDestinationID": caseSensitiveDestinationID, - "caseSensitiveWriteKey": caseSensitiveWriteKey, - "rbacSourceID": rbacSourceID, - "rbacDestinationID": rbacDestinationID, - "rbacWriteKey": rbacWriteKey, - "sourcesSourceID": sourcesSourceID, - "sourcesDestinationID": sourcesDestinationID, - "sourcesWriteKey": sourcesWriteKey, - "account": credentials.Account, - "user": credentials.User, - "password": credentials.Password, - "role": credentials.Role, - "database": credentials.Database, - "caseSensitiveDatabase": strings.ToLower(credentials.Database), - "warehouse": credentials.Warehouse, - "bucketName": credentials.BucketName, - "accessKeyID": credentials.AccessKeyID, - "accessKey": credentials.AccessKey, - "namespace": namespace, - "sourcesNamespace": sourcesNamespace, - "caseSensitiveNamespace": caseSensitiveNamespace, - "rbacNamespace": rbacNamespace, - "rbacAccount": rbacCredentials.Account, - "rbacUser": rbacCredentials.User, - "rbacPassword": rbacCredentials.Password, - "rbacRole": rbacCredentials.Role, - "rbacDatabase": rbacCredentials.Database, - "rbacWarehouse": rbacCredentials.Warehouse, - "rbacBucketName": rbacCredentials.BucketName, - "rbacAccessKeyID": rbacCredentials.AccessKeyID, - "rbacAccessKey": rbacCredentials.AccessKey, - } - workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - - bootstrap := func(t testing.TB, appendMode bool) func() { - loadTableStrategy := "MERGE" - if appendMode { - loadTableStrategy = "APPEND" + bootstrapSvc := func(t testing.TB, enableMerge bool) { + templateConfigurations := map[string]any{ + "workspaceID": workspaceID, + "sourceID": sourceID, + "destinationID": destinationID, + "writeKey": writeKey, + "caseSensitiveSourceID": caseSensitiveSourceID, + "caseSensitiveDestinationID": caseSensitiveDestinationID, + "caseSensitiveWriteKey": caseSensitiveWriteKey, + "rbacSourceID": rbacSourceID, + "rbacDestinationID": rbacDestinationID, + "rbacWriteKey": rbacWriteKey, + "sourcesSourceID": sourcesSourceID, + "sourcesDestinationID": sourcesDestinationID, + "sourcesWriteKey": sourcesWriteKey, + "account": credentials.Account, + "user": credentials.User, + "password": credentials.Password, + "role": credentials.Role, + "database": credentials.Database, + "caseSensitiveDatabase": strings.ToLower(credentials.Database), + "warehouse": credentials.Warehouse, + "bucketName": credentials.BucketName, + "accessKeyID": credentials.AccessKeyID, + "accessKey": credentials.AccessKey, + "namespace": namespace, + "sourcesNamespace": sourcesNamespace, + "caseSensitiveNamespace": caseSensitiveNamespace, + "rbacNamespace": rbacNamespace, + "rbacAccount": rbacCredentials.Account, + "rbacUser": rbacCredentials.User, + "rbacPassword": rbacCredentials.Password, + "rbacRole": rbacCredentials.Role, + "rbacDatabase": rbacCredentials.Database, + "rbacWarehouse": rbacCredentials.Warehouse, + "rbacBucketName": rbacCredentials.BucketName, + "rbacAccessKeyID": rbacCredentials.AccessKeyID, + "rbacAccessKey": rbacCredentials.AccessKey, + "enableMerge": enableMerge, } + workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) + testhelper.EnhanceWithDefaultEnvs(t) t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("RSERVER_WAREHOUSE_SNOWFLAKE_MAX_PARALLEL_LOADS", "8") t.Setenv("RSERVER_WAREHOUSE_SNOWFLAKE_ENABLE_DELETE_BY_JOBS", "true") - t.Setenv("RSERVER_WAREHOUSE_SNOWFLAKE_LOAD_TABLE_STRATEGY", loadTableStrategy) t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) t.Setenv("RSERVER_WAREHOUSE_SNOWFLAKE_SLOW_QUERY_THRESHOLD", "0s") @@ -196,20 +189,19 @@ func TestIntegration(t *testing.T) { )) ctx, cancel := context.WithCancel(context.Background()) - svcDone := make(chan struct{}) + go func() { r := runner.New(runner.ReleaseInfo{}) _ = r.Run(ctx, []string{"snowflake-integration-test"}) - close(svcDone) }() + t.Cleanup(func() { <-svcDone }) + t.Cleanup(cancel) serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, 100*time.Millisecond, "serviceHealthEndpoint") - - return cancel } t.Run("Event flow", func(t *testing.T) { @@ -235,7 +227,7 @@ func TestIntegration(t *testing.T) { asyncJob bool stagingFilePrefix string emptyJobRunID bool - appendMode bool + enableMerge bool customUserID string }{ { @@ -256,6 +248,7 @@ func TestIntegration(t *testing.T) { "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) }, stagingFilePrefix: "testdata/upload-job", + enableMerge: true, }, { name: "Upload Job with Role", @@ -275,6 +268,7 @@ func TestIntegration(t *testing.T) { "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) }, stagingFilePrefix: "testdata/upload-job-with-role", + enableMerge: true, }, { name: "Upload Job with Case Sensitive Database", @@ -294,6 +288,7 @@ func TestIntegration(t *testing.T) { "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) }, stagingFilePrefix: "testdata/upload-job-case-sensitive", + enableMerge: true, }, { name: "Async Job with Sources", @@ -315,6 +310,7 @@ func TestIntegration(t *testing.T) { warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), asyncJob: true, stagingFilePrefix: "testdata/sources-job", + enableMerge: true, }, { name: "Upload Job in append mode", @@ -335,7 +331,7 @@ func TestIntegration(t *testing.T) { // an empty jobRunID means that the source is not an ETL one // see Uploader.CanAppend() emptyJobRunID: true, - appendMode: true, + enableMerge: false, customUserID: testhelper.GetUserId("append_test"), }, } @@ -343,8 +339,7 @@ func TestIntegration(t *testing.T) { for _, tc := range testcase { tc := tc t.Run(tc.name, func(t *testing.T) { - cancel := bootstrap(t, tc.appendMode) - defer cancel() + bootstrapSvc(t, tc.enableMerge) urlConfig := sfdb.Config{ Account: tc.cred.Account, @@ -506,6 +501,7 @@ func TestIntegration(t *testing.T) { "syncFrequency": "30", "enableSSE": false, "useRudderStorage": false, + "enableMerge": true, }, DestinationDefinition: backendconfig.DestinationDefinitionT{ ID: "1XjvXnzw34UMAz1YOuKqL1kwzh6", @@ -664,17 +660,9 @@ func TestIntegration(t *testing.T) { uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} - mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, false, false) - - c := config.New() - c.Set("Warehouse.snowflake.debugDuplicateWorkspaceIDs", []string{workspaceID}) - c.Set("Warehouse.snowflake.debugDuplicateIntervalInDays", 1000) - c.Set("Warehouse.snowflake.debugDuplicateTables", []string{whutils.ToProviderCase( - whutils.SNOWFLAKE, - tableName, - )}) + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, true, false) - sf, err := snowflake.New(c, logger.NOP, memstats.New()) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -692,12 +680,14 @@ func TestIntegration(t *testing.T) { loadTableStat, err = sf.LoadTable(ctx, tableName) require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + require.Equal(t, loadTableStat.RowsInserted, int64(0), + "2nd copy on the same table with the same data should not have any 'rows_loaded'") + require.Equal(t, loadTableStat.RowsUpdated, int64(0), + "2nd copy on the same table with the same data should not have any 'rows_loaded'") records := testhelper.RetrieveRecordsFromWarehouse(t, sf.DB.DB, - fmt.Sprintf(` - SELECT + fmt.Sprintf( + `SELECT id, received_at, test_bool, @@ -705,16 +695,13 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %q.%q - ORDER BY - id; - `, + FROM %q.%q + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, testhelper.SampleTestRecords(), records) }) t.Run("with dedup use new record", func(t *testing.T) { uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) @@ -769,10 +756,7 @@ func TestIntegration(t *testing.T) { loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, true, false) - c := config.New() - c.Set("Warehouse.snowflake.loadTableStrategy", "APPEND") - - sf, err := snowflake.New(c, logger.NOP, memstats.New()) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -962,58 +946,63 @@ func TestIntegration(t *testing.T) { }) } -func TestSnowflake_ShouldAppend(t *testing.T) { +func TestSnowflake_ShouldMerge(t *testing.T) { testCases := []struct { name string - loadTableStrategy string + enableMerge bool uploaderCanAppend bool uploaderExpectedCalls int expected bool }{ { - name: "uploader says we can append and we are in append mode", - loadTableStrategy: "APPEND", + name: "uploader says we can append and merge is not enabled", + enableMerge: false, uploaderCanAppend: true, uploaderExpectedCalls: 1, - expected: true, + expected: false, }, { - name: "uploader says we cannot append and we are in append mode", - loadTableStrategy: "APPEND", + name: "uploader says we cannot append and merge is not enabled", + enableMerge: false, uploaderCanAppend: false, uploaderExpectedCalls: 1, - expected: false, + expected: true, }, { - name: "uploader says we can append and we are in merge mode", - loadTableStrategy: "MERGE", + name: "uploader says we can append and merge is enabled", + enableMerge: true, uploaderCanAppend: true, - uploaderExpectedCalls: 0, - expected: false, + uploaderExpectedCalls: 1, + expected: true, }, { name: "uploader says we cannot append and we are in merge mode", - loadTableStrategy: "MERGE", + enableMerge: true, uploaderCanAppend: false, - uploaderExpectedCalls: 0, - expected: false, + uploaderExpectedCalls: 1, + expected: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := config.New() - c.Set("Warehouse.snowflake.loadTableStrategy", tc.loadTableStrategy) - - sf, err := snowflake.New(c, logger.NOP, memstats.New()) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) + sf.Warehouse = model.Warehouse{ + Destination: backendconfig.DestinationT{ + Config: map[string]any{ + string(model.EnableMergeSetting): tc.enableMerge, + }, + }, + } + mockCtrl := gomock.NewController(t) uploader := mockuploader.NewMockUploader(mockCtrl) uploader.EXPECT().CanAppend().Times(tc.uploaderExpectedCalls).Return(tc.uploaderCanAppend) sf.Uploader = uploader - require.Equal(t, sf.ShouldAppend(), tc.expected) + require.Equal(t, sf.ShouldMerge(), tc.expected) }) } } diff --git a/warehouse/integrations/snowflake/testdata/template.json b/warehouse/integrations/snowflake/testdata/template.json index 2b7aab4b96..73d113e764 100644 --- a/warehouse/integrations/snowflake/testdata/template.json +++ b/warehouse/integrations/snowflake/testdata/template.json @@ -39,7 +39,8 @@ "prefix": "snowflake-prefix", "syncFrequency": "30", "enableSSE": false, - "useRudderStorage": false + "useRudderStorage": false, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, @@ -163,7 +164,8 @@ "prefix": "snowflake-prefix", "syncFrequency": "30", "enableSSE": false, - "useRudderStorage": false + "useRudderStorage": false, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, @@ -288,7 +290,8 @@ "prefix": "snowflake-prefix", "syncFrequency": "30", "enableSSE": false, - "useRudderStorage": false + "useRudderStorage": false, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, @@ -438,7 +441,8 @@ "prefix": "snowflake-prefix", "syncFrequency": "30", "enableSSE": false, - "useRudderStorage": false + "useRudderStorage": false, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, diff --git a/warehouse/internal/model/warehouse.go b/warehouse/internal/model/warehouse.go index 4608ffb241..f398960e5e 100644 --- a/warehouse/internal/model/warehouse.go +++ b/warehouse/internal/model/warehouse.go @@ -2,6 +2,17 @@ package model import backendconfig "github.com/rudderlabs/rudder-server/backend-config" +type DestinationConfigSetting interface{ string() string } + +type destConfSetting string + +func (s destConfSetting) string() string { return string(s) } + +const ( + EnableMergeSetting destConfSetting = "enableMerge" + UseRudderStorageSetting destConfSetting = "useRudderStorage" +) + type Warehouse struct { WorkspaceID string Source backendconfig.SourceT @@ -11,10 +22,10 @@ type Warehouse struct { Identifier string } -func (w *Warehouse) GetBoolDestinationConfig(key string) bool { +func (w *Warehouse) GetBoolDestinationConfig(key DestinationConfigSetting) bool { destConfig := w.Destination.Config - if destConfig[key] != nil { - if val, ok := destConfig[key].(bool); ok { + if destConfig[key.string()] != nil { + if val, ok := destConfig[key.string()].(bool); ok { return val } } diff --git a/warehouse/logfield/logfield.go b/warehouse/logfield/logfield.go index 5c084dc4f0..a50cc33b44 100644 --- a/warehouse/logfield/logfield.go +++ b/warehouse/logfield/logfield.go @@ -36,4 +36,5 @@ const ( IntervalInHours = "intervalInHours" StartTime = "startTime" EndTime = "endTime" + ShouldMerge = "shouldMerge" ) diff --git a/warehouse/router/router.go b/warehouse/router/router.go index 96ce9734bd..1e209538c4 100644 --- a/warehouse/router/router.go +++ b/warehouse/router/router.go @@ -480,7 +480,7 @@ func (r *Router) uploadsToProcess(ctx context.Context, availableWorkers int, ski }) r.configSubscriberLock.RUnlock() - upload.UseRudderStorage = warehouse.GetBoolDestinationConfig("useRudderStorage") + upload.UseRudderStorage = warehouse.GetBoolDestinationConfig(model.UseRudderStorageSetting) if !found { uploadJob := r.uploadJobFactory.NewUploadJob(ctx, &model.UploadJob{ diff --git a/warehouse/router/upload.go b/warehouse/router/upload.go index 324bdac5a1..522cb47e0d 100644 --- a/warehouse/router/upload.go +++ b/warehouse/router/upload.go @@ -721,6 +721,7 @@ func (job *UploadJob) exportRegularTables(specialTables []string, loadFilesTable // * the source is not an ETL source // * the source is not a replay source // * the source category is not in "mergeSourceCategoryMap" +// * the job is not a retry func (job *UploadJob) CanAppend() bool { if isSourceETL := job.upload.SourceJobRunID != ""; isSourceETL { return false @@ -731,6 +732,9 @@ func (job *UploadJob) CanAppend() bool { if _, isMergeCategory := mergeSourceCategoryMap[job.warehouse.Source.SourceDefinition.Category]; isMergeCategory { return false } + if job.upload.Retried { + return false + } return true } diff --git a/warehouse/utils/querytype.go b/warehouse/utils/querytype.go index a439f3fee9..9fb994ef14 100644 --- a/warehouse/utils/querytype.go +++ b/warehouse/utils/querytype.go @@ -28,6 +28,7 @@ func init() { "(?P(?:IF.*)*DROP.*TABLE)", "(?PSHOW.*TABLES)", "(?PSHOW.*PARTITIONS)", + "(?PSHOW.*SCHEMAS)", "(?PDESCRIBE.*(?:QUERY.*)*TABLE)", "(?PSET.*TO)", } diff --git a/warehouse/utils/querytype_test.go b/warehouse/utils/querytype_test.go index 55ce137bb0..96f2470234 100644 --- a/warehouse/utils/querytype_test.go +++ b/warehouse/utils/querytype_test.go @@ -38,6 +38,7 @@ func TestGetQueryType(t *testing.T) { {"drop table 2", "\t\n\n \t\n\n IF OBJECT_ID ('foo.qux','X') IS NOT NULL DROP TABLE foo.bar", "DROP_TABLE", true}, {"show tables", "\t\n\n \t\n\n sHoW tAbLes FROM some_table", "SHOW_TABLES", true}, {"show partitions", "\t\n\n \t\n\n sHoW pArtItiOns billing.tracks_t1", "SHOW_PARTITIONS", true}, + {"show schemas", "\t\n\n \t\n\n sHoW schemaS like foobar", "SHOW_SCHEMAS", true}, {"describe table 1", "\t\n\n \t\n\n dEscrIbe tABLE t1", "DESCRIBE_TABLE", true}, {"describe table 2", "\t\n\n \t\n\n dEscrIbe qUeRy tABLE t1", "DESCRIBE_TABLE", true}, {"set", "\t\n\n \t\n\n sEt something TO something_else", "SET_TO", true}, diff --git a/warehouse/utils/utils_test.go b/warehouse/utils/utils_test.go index bcf81dcecb..06b5ad6be1 100644 --- a/warehouse/utils/utils_test.go +++ b/warehouse/utils/utils_test.go @@ -979,7 +979,7 @@ func TestWarehouseT_GetBoolDestinationConfig(t *testing.T) { warehouse: model.Warehouse{ Destination: backendconfig.DestinationT{ Config: map[string]interface{}{ - "k1": "true", + "useRudderStorage": "true", }, }, }, @@ -989,7 +989,7 @@ func TestWarehouseT_GetBoolDestinationConfig(t *testing.T) { warehouse: model.Warehouse{ Destination: backendconfig.DestinationT{ Config: map[string]interface{}{ - "k1": false, + "useRudderStorage": false, }, }, }, @@ -999,7 +999,7 @@ func TestWarehouseT_GetBoolDestinationConfig(t *testing.T) { warehouse: model.Warehouse{ Destination: backendconfig.DestinationT{ Config: map[string]interface{}{ - "k1": true, + "useRudderStorage": true, }, }, }, @@ -1007,7 +1007,7 @@ func TestWarehouseT_GetBoolDestinationConfig(t *testing.T) { }, } for idx, input := range inputs { - got := input.warehouse.GetBoolDestinationConfig("k1") + got := input.warehouse.GetBoolDestinationConfig(model.UseRudderStorageSetting) want := input.expected if got != want { t.Errorf("got %t expected %t input %d", got, want, idx)