diff --git a/.github/tools/matrixchecker/main.go b/.github/tools/matrixchecker/main.go index 2a3b6a9de4e..fcdc0bfa003 100644 --- a/.github/tools/matrixchecker/main.go +++ b/.github/tools/matrixchecker/main.go @@ -4,8 +4,8 @@ import ( "log" "os" "path/filepath" + "slices" - "golang.org/x/exp/slices" "gopkg.in/yaml.v2" ) @@ -20,6 +20,7 @@ var IgnorePackages = []string{ "warehouse/integrations/testdata", "warehouse/integrations/config", "warehouse/integrations/types", + "warehouse/integrations/tunnelling", } func main() { diff --git a/archiver/archiver_isolation_test.go b/archiver/archiver_isolation_test.go index 782f3db8289..733ff2e019b 100644 --- a/archiver/archiver_isolation_test.go +++ b/archiver/archiver_isolation_test.go @@ -14,13 +14,16 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/bytesize" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/postgres" "github.com/rudderlabs/rudder-server/testhelper/destination" + "golang.org/x/sync/errgroup" + "github.com/google/uuid" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" @@ -127,7 +130,7 @@ func ArchivalScenario( cleanup := &testhelper.Cleanup{} defer cleanup.Run() - postgresContainer, err := resource.SetupPostgres(pool, cleanup) + postgresContainer, err := resource.SetupPostgres(pool, cleanup, postgres.WithShmSize(256*bytesize.MB)) require.NoError(t, err, "failed to setup postgres container") minioResource, err := resource.SetupMinio(pool, cleanup) diff --git a/archiver/archiver_test.go b/archiver/archiver_test.go index b85bacf6710..b520d2bad97 100644 --- a/archiver/archiver_test.go +++ b/archiver/archiver_test.go @@ -17,12 +17,12 @@ import ( "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/stats" - "github.com/rudderlabs/rudder-go-kit/bytesize" + "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/postgres" trand "github.com/rudderlabs/rudder-go-kit/testhelper/rand" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/jobsdb" @@ -47,7 +47,7 @@ func TestJobsArchival(t *testing.T) { pool, err := dockertest.NewPool("") require.NoError(t, err, "Failed to create docker pool") - postgresResource, err := resource.SetupPostgres(pool, t) + postgresResource, err := resource.SetupPostgres(pool, t, postgres.WithShmSize(256*bytesize.MB)) require.NoError(t, err, "failed to setup postgres resource") c := config.New() c.Set("DB.name", postgresResource.Database) @@ -205,7 +205,7 @@ func readGzipJobFile(filename string) ([]*jobsdb.JobT, error) { if err != nil { return []*jobsdb.JobT{}, err } - defer gz.Close() + defer func() { _ = gz.Close() }() sc := bufio.NewScanner(gz) // default scanner buffer maxCapacity is 64K @@ -255,24 +255,24 @@ type jdWrapper struct { queries *int32 } -func (jd jdWrapper) GetDistinctParameterValues(ctx context.Context, parameterName string) ([]string, error) { +func (jd jdWrapper) GetDistinctParameterValues(context.Context, string) ([]string, error) { atomic.AddInt32(jd.queries, 1) return []string{}, nil } func (jd jdWrapper) GetUnprocessed( - ctx context.Context, - params jobsdb.GetQueryParams, + context.Context, + jobsdb.GetQueryParams, ) (jobsdb.JobsResult, error) { atomic.AddInt32(jd.queries, 1) return jobsdb.JobsResult{}, nil } func (jd jdWrapper) UpdateJobStatus( - ctx context.Context, - statusList []*jobsdb.JobStatusT, - customValFilters []string, - parameterFilters []jobsdb.ParameterFilterT, + context.Context, + []*jobsdb.JobStatusT, + []string, + []jobsdb.ParameterFilterT, ) error { atomic.AddInt32(jd.queries, 1) return nil diff --git a/controlplane/controlplane.go b/controlplane/controlplane.go index 55f139fd032..c61af12086d 100644 --- a/controlplane/controlplane.go +++ b/controlplane/controlplane.go @@ -38,7 +38,7 @@ func (cm *ConnectionManager) establishConnection() (*ConnHandler, error) { return nil, err } - grpcServer := grpc.NewServer() + grpcServer := grpc.NewServer(cm.Options...) service := &authService{authInfo: cm.AuthInfo} proto.RegisterDPAuthServiceServer(grpcServer, service) cn := &ConnHandler{ diff --git a/controlplane/manager.go b/controlplane/manager.go index 0ebeec30f6f..f9b1216ad57 100644 --- a/controlplane/manager.go +++ b/controlplane/manager.go @@ -18,6 +18,7 @@ type ConnectionManager struct { active bool url string connHandler *ConnHandler + Options []grpc.ServerOption } type LoggerI interface { diff --git a/enterprise/reporting/error_index/error_index_reporting.go b/enterprise/reporting/error_index/error_index_reporting.go new file mode 100644 index 00000000000..b0188a6bbf5 --- /dev/null +++ b/enterprise/reporting/error_index/error_index_reporting.go @@ -0,0 +1,340 @@ +package error_index + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/rudderlabs/rudder-server/utils/misc" + + kitsync "github.com/rudderlabs/rudder-go-kit/sync" + + "golang.org/x/sync/errgroup" + + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-server/utils/workerpool" + + "github.com/google/uuid" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-server/jobsdb" + . "github.com/rudderlabs/rudder-server/utils/tx" //nolint:staticcheck + "github.com/rudderlabs/rudder-server/utils/types" +) + +type ErrorIndexReporter struct { + ctx context.Context + cancel context.CancelFunc + g *errgroup.Group + log logger.Logger + conf *config.Config + configSubscriber configSubscriber + now func() time.Time + dbsMu sync.RWMutex + dbs map[string]*handleWithSqlDB + + trigger func() <-chan time.Time + + limiterGroup sync.WaitGroup + limiter struct { + fetch kitsync.Limiter + upload kitsync.Limiter + update kitsync.Limiter + } + + concurrency misc.ValueLoader[int] + + statsFactory stats.Stats + stats struct { + partitionTime stats.Timer + partitions stats.Gauge + } +} + +type handleWithSqlDB struct { + *jobsdb.Handle + sqlDB *sql.DB +} + +func NewErrorIndexReporter(ctx context.Context, log logger.Logger, configSubscriber configSubscriber, conf *config.Config, statsFactory stats.Stats) *ErrorIndexReporter { + ctx, cancel := context.WithCancel(ctx) + g, ctx := errgroup.WithContext(ctx) + + eir := &ErrorIndexReporter{ + ctx: ctx, + cancel: cancel, + g: g, + log: log.Child("error-index-reporter"), + conf: conf, + statsFactory: statsFactory, + + configSubscriber: configSubscriber, + now: time.Now, + dbs: map[string]*handleWithSqlDB{}, + } + + eir.concurrency = conf.GetReloadableIntVar(10, 1, "Reporting.errorIndexReporting.concurrency") + + eir.limiterGroup = sync.WaitGroup{} + eir.limiter.fetch = kitsync.NewLimiter( + eir.ctx, &eir.limiterGroup, "erridx_fetch", + eir.concurrency.Load(), + eir.statsFactory, + ) + eir.limiter.upload = kitsync.NewLimiter( + eir.ctx, &eir.limiterGroup, "erridx_upload", + eir.concurrency.Load(), + eir.statsFactory, + ) + eir.limiter.update = kitsync.NewLimiter( + eir.ctx, &eir.limiterGroup, "erridx_update", + eir.concurrency.Load(), + eir.statsFactory, + ) + g.Go(func() error { + eir.limiterGroup.Wait() + return nil + }) + + eir.trigger = func() <-chan time.Time { + return time.After(conf.GetDuration("Reporting.errorIndexReporting.SleepDuration", 30, time.Second)) + } + + eir.stats.partitionTime = eir.statsFactory.NewStat("erridx_partition_time", stats.TimerType) + eir.stats.partitions = eir.statsFactory.NewStat("erridx_active_partitions", stats.GaugeType) + + return eir +} + +// Report reports the metrics to the errorIndex JobsDB +func (eir *ErrorIndexReporter) Report(metrics []*types.PUReportedMetric, tx *Tx) error { + failedAt := eir.now() + + var jobs []*jobsdb.JobT + for _, metric := range metrics { + if metric.StatusDetail == nil { + continue + } + + for _, failedMessage := range metric.StatusDetail.FailedMessages { + workspaceID := eir.configSubscriber.WorkspaceIDFromSource(metric.SourceID) + + payload := payload{ + MessageID: failedMessage.MessageID, + SourceID: metric.SourceID, + DestinationID: metric.DestinationID, + TransformationID: metric.TransformationID, + TrackingPlanID: metric.TrackingPlanID, + FailedStage: metric.PUDetails.PU, + EventName: metric.StatusDetail.EventName, + EventType: metric.StatusDetail.EventType, + } + payload.SetReceivedAt(failedMessage.ReceivedAt) + payload.SetFailedAt(failedAt) + + payloadJSON, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("unable to marshal payload: %v", err) + } + + params := struct { + WorkspaceID string `json:"workspaceId"` + SourceID string `json:"source_id"` + }{ + WorkspaceID: workspaceID, + SourceID: metric.SourceID, + } + paramsJSON, err := json.Marshal(params) + if err != nil { + return fmt.Errorf("unable to marshal params: %v", err) + } + + jobs = append(jobs, &jobsdb.JobT{ + UUID: uuid.New(), + Parameters: paramsJSON, + EventPayload: payloadJSON, + EventCount: 1, + WorkspaceId: workspaceID, + }) + } + } + + if len(jobs) == 0 { + return nil + } + db, err := eir.resolveJobsDB(tx) + if err != nil { + return fmt.Errorf("failed to resolve jobsdb: %w", err) + } + if err := db.WithStoreSafeTxFromTx(eir.ctx, tx, func(tx jobsdb.StoreSafeTx) error { + return db.StoreInTx(eir.ctx, tx, jobs) + }); err != nil { + return fmt.Errorf("failed to store jobs: %w", err) + } + + return nil +} + +func (eir *ErrorIndexReporter) DatabaseSyncer(c types.SyncerConfig) types.ReportingSyncer { + eir.dbsMu.Lock() + defer eir.dbsMu.Unlock() + + if _, ok := eir.dbs[c.ConnInfo]; ok { + return func() {} // returning a no-op syncer since another go routine has already started syncing + } + + dbHandle, err := sql.Open("postgres", c.ConnInfo) + if err != nil { + panic(fmt.Errorf("failed to open error index db: %w", err)) + } + errIndexDB := jobsdb.NewForReadWrite( + "err_idx", + jobsdb.WithDBHandle(dbHandle), + jobsdb.WithDSLimit(eir.conf.GetReloadableIntVar(0, 1, "Reporting.errorIndexReporting.dsLimit")), + jobsdb.WithConfig(eir.conf), + jobsdb.WithSkipMaintenanceErr(eir.conf.GetBool("Reporting.errorIndexReporting.skipMaintenanceError", false)), + jobsdb.WithJobMaxAge( + func() time.Duration { + return eir.conf.GetDurationVar(24, time.Hour, "Reporting.errorIndexReporting.jobRetention") + }, + ), + ) + if err := errIndexDB.Start(); err != nil { + panic(fmt.Errorf("failed to start error index db: %w", err)) + } + eir.dbs[c.ConnInfo] = &handleWithSqlDB{ + Handle: errIndexDB, + sqlDB: dbHandle, + } + + if !eir.conf.GetBool("Reporting.errorIndexReporting.syncer.enabled", true) { + return func() { + <-eir.ctx.Done() + errIndexDB.Stop() + } + } + + return func() { + eir.g.Go(func() error { + defer errIndexDB.Stop() + return eir.mainLoop(eir.ctx, errIndexDB) + }) + } +} + +func (eir *ErrorIndexReporter) mainLoop(ctx context.Context, errIndexDB *jobsdb.Handle) error { + eir.log.Infow("Starting main loop for error index reporting") + + var ( + bucket = eir.conf.GetStringVar("rudder-failed-messages", "ErrorIndex.storage.Bucket") + regionHint = eir.conf.GetStringVar("us-east-1", "ErrorIndex.storage.RegionHint", "AWS_S3_REGION_HINT") + endpoint = eir.conf.GetStringVar("", "ErrorIndex.storage.Endpoint") + accessKeyID = eir.conf.GetStringVar("", "ErrorIndex.storage.AccessKey", "AWS_ACCESS_KEY_ID") + secretAccessKey = eir.conf.GetStringVar("", "ErrorIndex.storage.SecretAccessKey", "AWS_SECRET_ACCESS_KEY") + s3ForcePathStyle = eir.conf.GetBoolVar(false, "ErrorIndex.storage.S3ForcePathStyle") + disableSSL = eir.conf.GetBoolVar(false, "ErrorIndex.storage.DisableSSL") + enableSSE = eir.conf.GetBoolVar(false, "ErrorIndex.storage.EnableSSE", "AWS_ENABLE_SSE") + ) + + s3Config := map[string]interface{}{ + "bucketName": bucket, + "regionHint": regionHint, + "endpoint": endpoint, + "accessKeyID": accessKeyID, + "secretAccessKey": secretAccessKey, + "s3ForcePathStyle": s3ForcePathStyle, + "disableSSL": disableSSL, + "enableSSE": enableSSE, + } + fm, err := filemanager.NewS3Manager(s3Config, eir.log, func() time.Duration { + return eir.conf.GetDuration("ErrorIndex.Uploader.Timeout", 120, time.Second) + }) + if err != nil { + return fmt.Errorf("creating file manager: %w", err) + } + + workerPool := workerpool.New( + ctx, + func(sourceID string) workerpool.Worker { + return newWorker( + sourceID, + eir.conf, + eir.log, + eir.statsFactory, + errIndexDB, + eir.configSubscriber, + fm, + eir.limiter.fetch, + eir.limiter.upload, + eir.limiter.update, + ) + }, + eir.log, + workerpool.WithIdleTimeout(2*eir.conf.GetDuration("Reporting.errorIndexReporting.uploadFrequency", 5, time.Minute)), + ) + defer workerPool.Shutdown() + + for { + start := time.Now() + sources, err := errIndexDB.GetDistinctParameterValues(ctx, "source_id") + if err != nil && ctx.Err() != nil { + return nil + } + if err != nil { + return fmt.Errorf("getting distinct parameter values: %w", err) + } + + eir.stats.partitionTime.Since(start) + eir.stats.partitions.Gauge(len(sources)) + + for _, source := range sources { + workerPool.PingWorker(source) + } + + select { + case <-ctx.Done(): + return nil + case <-eir.trigger(): + } + } +} + +func (eir *ErrorIndexReporter) Stop() { + eir.cancel() + _ = eir.g.Wait() +} + +// resolveJobsDB returns the jobsdb that matches the current transaction (using system information functions) +// https://www.postgresql.org/docs/11/functions-info.html +func (eir *ErrorIndexReporter) resolveJobsDB(tx *Tx) (jobsdb.JobsDB, error) { + eir.dbsMu.RLock() + defer eir.dbsMu.RUnlock() + + if len(eir.dbs) == 1 { // optimisation, if there is only one jobsdb, return this. If it is the wrong one, it will fail anyway + for i := range eir.dbs { + return eir.dbs[i].Handle, nil + } + } + + dbIdentityQuery := `select inet_server_addr()::text || ':' || inet_server_port()::text || ':' || current_user || ':' || current_database() || ':' || current_schema || ':' || pg_postmaster_start_time()::text || ':' || version()` + var txDatabaseIdentity string + if err := tx.QueryRow(dbIdentityQuery).Scan(&txDatabaseIdentity); err != nil { + return nil, fmt.Errorf("failed to get current tx's db identity: %w", err) + } + + for key := range eir.dbs { + var databaseIdentity string + if err := eir.dbs[key].sqlDB.QueryRow(dbIdentityQuery).Scan(&databaseIdentity); err != nil { + return nil, fmt.Errorf("failed to get db identity for %q: %w", key, err) + } + if databaseIdentity == txDatabaseIdentity { + return eir.dbs[key].Handle, nil + } + } + return nil, fmt.Errorf("no jobsdb found matching the current transaction") +} diff --git a/enterprise/reporting/error_index_reporting_test.go b/enterprise/reporting/error_index/error_index_reporting_test.go similarity index 63% rename from enterprise/reporting/error_index_reporting_test.go rename to enterprise/reporting/error_index/error_index_reporting_test.go index 61b169e566e..c28b79bc37e 100644 --- a/enterprise/reporting/error_index_reporting_test.go +++ b/enterprise/reporting/error_index/error_index_reporting_test.go @@ -1,4 +1,4 @@ -package reporting +package error_index import ( "context" @@ -6,23 +6,40 @@ import ( "testing" "time" - "github.com/ory/dockertest/v3" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + + "github.com/samber/lo" - "github.com/golang/mock/gomock" + "github.com/ory/dockertest/v3" "github.com/stretchr/testify/require" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" - backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/jobsdb" - mocksBackendConfig "github.com/rudderlabs/rudder-server/mocks/backend-config" - "github.com/rudderlabs/rudder-server/utils/pubsub" . "github.com/rudderlabs/rudder-server/utils/tx" //nolint:staticcheck "github.com/rudderlabs/rudder-server/utils/types" ) +func newMockConfigSubscriber() *mockConfigSubscriber { + return &mockConfigSubscriber{ + workspaceIDForSourceIDMap: make(map[string]string), + } +} + +type mockConfigSubscriber struct { + workspaceIDForSourceIDMap map[string]string +} + +func (m *mockConfigSubscriber) WorkspaceIDFromSource(sourceID string) string { + return m.workspaceIDForSourceIDMap[sourceID] +} + +func (m *mockConfigSubscriber) addWorkspaceIDForSourceID(sourceID, workspaceID string) { + m.workspaceIDForSourceIDMap[sourceID] = workspaceID +} + func TestErrorIndexReporter(t *testing.T) { workspaceID := "test-workspace-id" sourceID := "test-source-id" @@ -30,8 +47,6 @@ func TestErrorIndexReporter(t *testing.T) { transformationID := "test-transformation-id" trackingPlanID := "test-tracking-plan-id" reportedBy := "test-reported-by" - destinationDefinitionID := "test-destination-definition-id" - destType := "test-dest-type" eventName := "test-event-name" eventType := "test-event-type" messageID := "test-message-id" @@ -41,48 +56,8 @@ func TestErrorIndexReporter(t *testing.T) { ctx := context.Background() - ctrl := gomock.NewController(t) - mockBackendConfig := mocksBackendConfig.NewMockBackendConfig(ctrl) - mockBackendConfig.EXPECT().Subscribe(gomock.Any(), backendconfig.TopicBackendConfig).DoAndReturn(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { - ch := make(chan pubsub.DataEvent, 1) - ch <- pubsub.DataEvent{ - Data: map[string]backendconfig.ConfigT{ - workspaceID: { - WorkspaceID: workspaceID, - Sources: []backendconfig.SourceT{ - { - ID: sourceID, - Enabled: true, - Destinations: []backendconfig.DestinationT{ - { - ID: destinationID, - Enabled: true, - DestinationDefinition: backendconfig.DestinationDefinitionT{ - ID: destinationDefinitionID, - Name: destType, - }, - }, - }, - }, - }, - Settings: backendconfig.Settings{ - DataRetention: backendconfig.DataRetention{ - DisableReportingPII: true, - }, - }, - }, - }, - Topic: string(backendconfig.TopicBackendConfig), - } - close(ch) - return ch - }).AnyTimes() - - receivedAt := time.Now() - - failedAt := func() time.Time { - return receivedAt.Add(time.Hour) - } + receivedAt := time.Now().UTC() + failedAt := receivedAt.Add(time.Hour) t.Run("reports", func(t *testing.T) { testCases := []struct { @@ -188,7 +163,7 @@ func TestErrorIndexReporter(t *testing.T) { expectedPayload: []payload{ { MessageID: messageID + "1", - ReceivedAt: receivedAt.Add(1 * time.Hour), + ReceivedAt: receivedAt.Add(1 * time.Hour).UnixMicro(), SourceID: sourceID, DestinationID: destinationID, TransformationID: transformationID, @@ -196,11 +171,11 @@ func TestErrorIndexReporter(t *testing.T) { EventName: eventName, EventType: eventType, FailedStage: reportedBy, - FailedAt: failedAt(), + FailedAt: failedAt.UnixMicro(), }, { MessageID: messageID + "2", - ReceivedAt: receivedAt.Add(2 * time.Hour), + ReceivedAt: receivedAt.Add(2 * time.Hour).UnixMicro(), SourceID: sourceID, DestinationID: destinationID, TransformationID: transformationID, @@ -208,11 +183,11 @@ func TestErrorIndexReporter(t *testing.T) { EventName: eventName, EventType: eventType, FailedStage: reportedBy, - FailedAt: failedAt(), + FailedAt: failedAt.UnixMicro(), }, { MessageID: messageID + "3", - ReceivedAt: receivedAt.Add(3 * time.Hour), + ReceivedAt: receivedAt.Add(3 * time.Hour).UnixMicro(), SourceID: sourceID, DestinationID: destinationID, TransformationID: transformationID, @@ -220,11 +195,11 @@ func TestErrorIndexReporter(t *testing.T) { EventName: eventName, EventType: eventType, FailedStage: reportedBy, - FailedAt: failedAt(), + FailedAt: failedAt.UnixMicro(), }, { MessageID: messageID + "4", - ReceivedAt: receivedAt.Add(4 * time.Hour), + ReceivedAt: receivedAt.Add(4 * time.Hour).UnixMicro(), SourceID: sourceID, DestinationID: destinationID, TransformationID: transformationID, @@ -232,7 +207,7 @@ func TestErrorIndexReporter(t *testing.T) { EventName: eventName, EventType: eventType, FailedStage: reportedBy, - FailedAt: failedAt(), + FailedAt: failedAt.UnixMicro(), }, }, }, @@ -244,22 +219,29 @@ func TestErrorIndexReporter(t *testing.T) { require.NoError(t, err) c := config.New() + ctx, cancel := context.WithCancel(ctx) - cs := newConfigSubscriber(logger.NOP) - subscribeDone := make(chan struct{}) - go func() { - defer close(subscribeDone) - cs.Subscribe(ctx, mockBackendConfig) - }() + defer cancel() + + cs := newMockConfigSubscriber() + cs.addWorkspaceIDForSourceID(sourceID, workspaceID) - eir := NewErrorIndexReporter(ctx, logger.NOP, cs, c) - _ = eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: postgresContainer.DBDsn}) + eir := NewErrorIndexReporter(ctx, logger.NOP, cs, c, memstats.New()) defer eir.Stop() - eir.now = failedAt - sqltx, err := postgresContainer.DB.Begin() + syncer := eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: postgresContainer.DBDsn}) + syncerDone := make(chan struct{}) + go func() { + defer close(syncerDone) + syncer() + }() + + eir.now = func() time.Time { + return failedAt + } + sqlTx, err := postgresContainer.DB.Begin() require.NoError(t, err) - tx := &Tx{Tx: sqltx} + tx := &Tx{Tx: sqlTx} err = eir.Report(tc.reports, tx) require.NoError(t, err) require.NoError(t, tx.Commit()) @@ -283,8 +265,8 @@ func TestErrorIndexReporter(t *testing.T) { require.Equal(t, eventPayload.FailedStage, tc.expectedPayload[i].FailedStage) require.Equal(t, eventPayload.EventName, tc.expectedPayload[i].EventName) require.Equal(t, eventPayload.EventType, tc.expectedPayload[i].EventType) - require.EqualValues(t, eventPayload.FailedAt.UTC(), failedAt().UTC()) - require.EqualValues(t, eventPayload.ReceivedAt.UTC(), tc.expectedPayload[i].ReceivedAt.UTC()) + require.Equal(t, eventPayload.FailedAt, tc.expectedPayload[i].FailedAt) + require.Equal(t, eventPayload.ReceivedAt, tc.expectedPayload[i].ReceivedAt) var params map[string]interface{} err = json.Unmarshal(job.Parameters, ¶ms) @@ -292,44 +274,43 @@ func TestErrorIndexReporter(t *testing.T) { require.Equal(t, params["source_id"], sourceID) require.Equal(t, params["workspaceId"], workspaceID) + + <-syncerDone } - cancel() - <-subscribeDone }) } }) + t.Run("graceful shutdown", func(t *testing.T) { postgresContainer, err := resource.SetupPostgres(pool, t) require.NoError(t, err) c := config.New() + ctx, cancel := context.WithCancel(ctx) - cs := newConfigSubscriber(logger.NOP) - subscribeDone := make(chan struct{}) - go func() { - defer close(subscribeDone) - cs.Subscribe(ctx, mockBackendConfig) - }() + defer cancel() - eir := NewErrorIndexReporter(ctx, logger.NOP, cs, c) - defer eir.Stop() - syncer := eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: postgresContainer.DBDsn}) + cf := newMockConfigSubscriber() + cf.addWorkspaceIDForSourceID(sourceID, workspaceID) - sqltx, err := postgresContainer.DB.Begin() - require.NoError(t, err) - tx := &Tx{Tx: sqltx} - err = eir.Report([]*types.PUReportedMetric{}, tx) - require.NoError(t, err) - require.NoError(t, tx.Commit()) + eir := NewErrorIndexReporter(ctx, logger.NOP, cf, c, memstats.New()) + defer eir.Stop() + syncer := eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: postgresContainer.DBDsn}) syncerDone := make(chan struct{}) go func() { defer close(syncerDone) syncer() }() + sqlTx, err := postgresContainer.DB.Begin() + require.NoError(t, err) + tx := &Tx{Tx: sqlTx} + err = eir.Report([]*types.PUReportedMetric{}, tx) + require.NoError(t, err) + require.NoError(t, tx.Commit()) + cancel() - <-subscribeDone <-syncerDone }) @@ -341,21 +322,26 @@ func TestErrorIndexReporter(t *testing.T) { require.NoError(t, err) c := config.New() + ctx, cancel := context.WithCancel(ctx) - cs := newConfigSubscriber(logger.NOP) - subscribeDone := make(chan struct{}) - go func() { - defer close(subscribeDone) - cs.Subscribe(ctx, mockBackendConfig) - }() + defer cancel() + + cf := newMockConfigSubscriber() + cf.addWorkspaceIDForSourceID(sourceID, workspaceID) - eir := NewErrorIndexReporter(ctx, logger.NOP, cs, c) + eir := NewErrorIndexReporter(ctx, logger.NOP, cf, c, memstats.New()) defer eir.Stop() - _ = eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: pg1.DBDsn}) - sqltx, err := pg2.DB.Begin() + syncer := eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: pg1.DBDsn}) + syncerDone := make(chan struct{}) + go func() { + defer close(syncerDone) + syncer() + }() + + sqlTx, err := pg2.DB.Begin() require.NoError(t, err) - tx := &Tx{Tx: sqltx} + tx := &Tx{Tx: sqlTx} err = eir.Report([]*types.PUReportedMetric{ { ConnectionDetails: types.ConnectionDetails{ @@ -386,8 +372,7 @@ func TestErrorIndexReporter(t *testing.T) { require.Error(t, err) require.Error(t, tx.Commit()) - cancel() - <-subscribeDone + <-syncerDone }) }) @@ -400,23 +385,30 @@ func TestErrorIndexReporter(t *testing.T) { require.NoError(t, err) c := config.New() + ctx, cancel := context.WithCancel(ctx) - cs := newConfigSubscriber(logger.NOP) - subscribeDone := make(chan struct{}) - go func() { - defer close(subscribeDone) - cs.Subscribe(ctx, mockBackendConfig) - }() + defer cancel() + + cs := newMockConfigSubscriber() + cs.addWorkspaceIDForSourceID(sourceID, workspaceID) - eir := NewErrorIndexReporter(ctx, logger.NOP, cs, c) + eir := NewErrorIndexReporter(ctx, logger.NOP, cs, c, memstats.New()) defer eir.Stop() - _ = eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: pg1.DBDsn}) - _ = eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: pg2.DBDsn}) + + syncer1 := eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: pg1.DBDsn}) + syncer2 := eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: pg2.DBDsn}) + + syncersDone := make(chan struct{}) + go func() { + defer close(syncersDone) + syncer1() + syncer2() + }() t.Run("correct transaction", func(t *testing.T) { - sqltx, err := pg1.DB.Begin() + sqlTx, err := pg1.DB.Begin() require.NoError(t, err) - tx := &Tx{Tx: sqltx} + tx := &Tx{Tx: sqlTx} err = eir.Report([]*types.PUReportedMetric{ { ConnectionDetails: types.ConnectionDetails{ @@ -448,9 +440,9 @@ func TestErrorIndexReporter(t *testing.T) { require.NoError(t, tx.Commit()) }) t.Run("wrong transaction", func(t *testing.T) { - sqltx, err := pg3.DB.Begin() + sqlTx, err := pg3.DB.Begin() require.NoError(t, err) - tx := &Tx{Tx: sqltx} + tx := &Tx{Tx: sqlTx} err = eir.Report([]*types.PUReportedMetric{ { ConnectionDetails: types.ConnectionDetails{ @@ -482,7 +474,128 @@ func TestErrorIndexReporter(t *testing.T) { require.NoError(t, tx.Commit()) }) + <-syncersDone + }) + + t.Run("sync data", func(t *testing.T) { + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + reports := []*types.PUReportedMetric{ + { + ConnectionDetails: types.ConnectionDetails{ + SourceID: sourceID, + DestinationID: destinationID, + TransformationID: transformationID, + TrackingPlanID: trackingPlanID, + }, + PUDetails: types.PUDetails{ + PU: reportedBy, + }, + StatusDetail: &types.StatusDetail{ + EventName: eventName, + EventType: eventType, + FailedMessages: []*types.FailedMessage{ + { + MessageID: messageID + "1", + ReceivedAt: receivedAt.Add(1 * time.Hour), + }, + { + MessageID: messageID + "2", + ReceivedAt: receivedAt.Add(2 * time.Hour), + }, + }, + }, + }, + { + ConnectionDetails: types.ConnectionDetails{ + SourceID: sourceID, + DestinationID: destinationID, + TransformationID: transformationID, + TrackingPlanID: trackingPlanID, + }, + PUDetails: types.PUDetails{ + PU: reportedBy, + }, + StatusDetail: &types.StatusDetail{ + EventName: eventName, + EventType: eventType, + FailedMessages: []*types.FailedMessage{ + { + MessageID: messageID + "3", + ReceivedAt: receivedAt.Add(3 * time.Hour), + }, + { + MessageID: messageID + "4", + ReceivedAt: receivedAt.Add(4 * time.Hour), + }, + }, + }, + }, + } + + c := config.New() + c.Set("ErrorIndex.storage.Bucket", minioResource.BucketName) + c.Set("ErrorIndex.storage.Endpoint", minioResource.Endpoint) + c.Set("ErrorIndex.storage.AccessKey", minioResource.AccessKeyID) + c.Set("ErrorIndex.storage.SecretAccessKey", minioResource.AccessKeySecret) + c.Set("ErrorIndex.storage.S3ForcePathStyle", true) + c.Set("ErrorIndex.storage.DisableSSL", true) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + cs := newMockConfigSubscriber() + cs.addWorkspaceIDForSourceID(sourceID, workspaceID) + + eir := NewErrorIndexReporter(ctx, logger.NOP, cs, c, memstats.New()) + eir.now = func() time.Time { + return failedAt + } + eir.trigger = func() <-chan time.Time { + return time.After(time.Duration(0)) + } + defer eir.Stop() + + syncer := eir.DatabaseSyncer(types.SyncerConfig{ConnInfo: postgresContainer.DBDsn}) + + syncerDone := make(chan struct{}) + go func() { + defer close(syncerDone) + syncer() + }() + + sqlTx, err := postgresContainer.DB.Begin() + require.NoError(t, err) + + tx := &Tx{Tx: sqlTx} + err = eir.Report(reports, tx) + require.NoError(t, err) + require.NoError(t, tx.Commit()) + + db, err := eir.resolveJobsDB(tx) + require.NoError(t, err) + + failedJobs := lo.Flatten(lo.Map(reports, func(item *types.PUReportedMetric, index int) []*types.FailedMessage { + return item.StatusDetail.FailedMessages + })) + + require.Eventually(t, func() bool { + jr, err := db.GetSucceeded(ctx, jobsdb.GetQueryParams{ + JobsLimit: 100, + }) + require.NoError(t, err) + + return len(jr.Jobs) == len(failedJobs) + }, + time.Second*30, + time.Millisecond*100, + ) + cancel() - <-subscribeDone + + <-syncerDone }) } diff --git a/enterprise/reporting/error_index/types.go b/enterprise/reporting/error_index/types.go new file mode 100644 index 00000000000..30f364bb983 --- /dev/null +++ b/enterprise/reporting/error_index/types.go @@ -0,0 +1,61 @@ +package error_index + +import ( + "context" + "os" + "strconv" + "time" + + "github.com/rudderlabs/rudder-server/jobsdb" + + "github.com/rudderlabs/rudder-go-kit/filemanager" +) + +type configSubscriber interface { + WorkspaceIDFromSource(sourceID string) string +} + +type uploader interface { + Upload(context.Context, *os.File, ...string) (filemanager.UploadedFile, error) +} + +type jobWithPayload struct { + *jobsdb.JobT + + payload payload +} + +type payload struct { + MessageID string `json:"messageId" parquet:"name=message_id, type=BYTE_ARRAY, convertedtype=UTF8, encoding=RLE_DICTIONARY"` + SourceID string `json:"sourceId" parquet:"name=source_id, type=BYTE_ARRAY, convertedtype=UTF8, encoding=RLE_DICTIONARY"` + DestinationID string `json:"destinationId" parquet:"name=destination_id, type=BYTE_ARRAY, convertedtype=UTF8, encoding=RLE_DICTIONARY"` + TransformationID string `json:"transformationId" parquet:"name=transformation_id, type=BYTE_ARRAY, convertedtype=UTF8, encoding=RLE_DICTIONARY"` + TrackingPlanID string `json:"trackingPlanId" parquet:"name=tracking_plan_id, type=BYTE_ARRAY, convertedtype=UTF8, encoding=RLE_DICTIONARY"` + FailedStage string `json:"failedStage" parquet:"name=failed_stage, type=BYTE_ARRAY, convertedtype=UTF8, encoding=RLE_DICTIONARY"` + EventType string `json:"eventType" parquet:"name=event_type, type=BYTE_ARRAY, convertedtype=UTF8, encoding=RLE_DICTIONARY"` + EventName string `json:"eventName" parquet:"name=event_name, type=BYTE_ARRAY, convertedtype=UTF8, encoding=RLE_DICTIONARY"` + ReceivedAt int64 `json:"receivedAt" parquet:"name=received_at, type=INT64, convertedtype=TIMESTAMP_MICROS, encoding=DELTA_BINARY_PACKED"` // In Microseconds + FailedAt int64 `json:"failedAt" parquet:"name=failed_at, type=INT64, convertedtype=TIMESTAMP_MICROS, encoding=DELTA_BINARY_PACKED"` // In Microseconds +} + +func (p *payload) SetReceivedAt(t time.Time) { + p.ReceivedAt = t.UTC().UnixMicro() +} + +func (p *payload) SetFailedAt(t time.Time) { + p.FailedAt = t.UTC().UnixMicro() +} + +func (p *payload) FailedAtTime() time.Time { + return time.UnixMicro(p.FailedAt).UTC() +} + +func (p *payload) SortingKey() string { + const sep = "_" + return strconv.FormatInt(p.FailedAt, 10) + sep + + p.DestinationID + sep + + p.EventType + sep + + p.EventName + sep + + p.TransformationID + sep + + p.TrackingPlanID +} diff --git a/enterprise/reporting/error_index/worker.go b/enterprise/reporting/error_index/worker.go new file mode 100644 index 00000000000..a2797be901c --- /dev/null +++ b/enterprise/reporting/error_index/worker.go @@ -0,0 +1,317 @@ +package error_index + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "path" + "slices" + "sort" + "strconv" + "time" + + "github.com/rudderlabs/rudder-go-kit/filemanager" + + "github.com/samber/lo" + "github.com/xitongsys/parquet-go/parquet" + "github.com/xitongsys/parquet-go/writer" + + "github.com/rudderlabs/rudder-go-kit/bytesize" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + kitsync "github.com/rudderlabs/rudder-go-kit/sync" + "github.com/rudderlabs/rudder-server/jobsdb" + "github.com/rudderlabs/rudder-server/utils/misc" +) + +type worker struct { + sourceID string + workspaceID string + + log logger.Logger + statsFactory stats.Stats + + jobsDB jobsdb.JobsDB + configSubscriber configSubscriber + uploader uploader + + lifecycle struct { + ctx context.Context + cancel context.CancelFunc + } + + limiter struct { + fetch kitsync.Limiter + upload kitsync.Limiter + update kitsync.Limiter + } + + now func() time.Time + lastUploadTime time.Time + + config struct { + parquetParallelWriters, parquetRowGroupSize, parquetPageSize misc.ValueLoader[int64] + bucketName, instanceID string + payloadLimit, eventsLimit misc.ValueLoader[int64] + minWorkerSleep, uploadFrequency, jobsDBCommandTimeout time.Duration + jobsDBMaxRetries misc.ValueLoader[int] + } +} + +// newWorker creates a new worker for the given sourceID. +func newWorker( + sourceID string, + conf *config.Config, + log logger.Logger, + statsFactory stats.Stats, + jobsDB jobsdb.JobsDB, + configFetcher configSubscriber, + uploader uploader, + fetchLimiter, uploadLimiter, updateLimiter kitsync.Limiter, +) *worker { + workspaceID := configFetcher.WorkspaceIDFromSource(sourceID) + + w := &worker{ + sourceID: sourceID, + workspaceID: workspaceID, + log: log.Child("worker").With("workspaceID", workspaceID).With("sourceID", sourceID), + statsFactory: statsFactory, + jobsDB: jobsDB, + configSubscriber: configFetcher, + uploader: uploader, + now: time.Now, + } + w.lifecycle.ctx, w.lifecycle.cancel = context.WithCancel(context.Background()) + + w.config.parquetParallelWriters = conf.GetReloadableInt64Var(8, 1, "Reporting.errorIndexReporting.parquetParallelWriters") + w.config.parquetRowGroupSize = conf.GetReloadableInt64Var(512*bytesize.MB, 1, "Reporting.errorIndexReporting.parquetRowGroupSize") + w.config.parquetPageSize = conf.GetReloadableInt64Var(8*bytesize.KB, 1, "Reporting.errorIndexReporting.parquetPageSizeInKB") + w.config.instanceID = conf.GetString("INSTANCE_ID", "1") + w.config.bucketName = conf.GetString("ErrorIndex.Storage.Bucket", "rudder-failed-messages") + w.config.payloadLimit = conf.GetReloadableInt64Var(1*bytesize.GB, 1, "Reporting.errorIndexReporting.payloadLimit") + w.config.eventsLimit = conf.GetReloadableInt64Var(100000, 1, "Reporting.errorIndexReporting.eventsLimit") + w.config.minWorkerSleep = conf.GetDuration("Reporting.errorIndexReporting.minWorkerSleep", 1, time.Minute) + w.config.uploadFrequency = conf.GetDuration("Reporting.errorIndexReporting.uploadFrequency", 5, time.Minute) + w.config.jobsDBCommandTimeout = conf.GetDurationVar(10, time.Minute, "JobsDB.CommandRequestTimeout", "Reporting.errorIndexReporting.CommandRequestTimeout") + w.config.jobsDBMaxRetries = conf.GetReloadableIntVar(3, 1, "JobsDB.MaxRetries", "Reporting.errorIndexReporting.MaxRetries") + + w.limiter.fetch = fetchLimiter + w.limiter.upload = uploadLimiter + w.limiter.update = updateLimiter + return w +} + +// Work fetches and processes job results: +// 1. Fetches job results. +// 2. If no jobs are fetched, returns. +// 3. If job limits are not reached and upload frequency is not met, returns. +// 4. Uploads jobs to object storage. +// 5. Updates job status in the jobsDB. +func (w *worker) Work() (worked bool) { + jobResult, err := w.fetchJobs() + if err != nil && w.lifecycle.ctx.Err() != nil { + return + } + if err != nil { + panic(fmt.Errorf("failed to fetch jobs for error index: %s", err.Error())) + } + if len(jobResult.Jobs) == 0 { + return + } + if !jobResult.LimitsReached && time.Since(w.lastUploadTime) < w.config.uploadFrequency { + return + } + + statusList, err := w.uploadJobs(w.lifecycle.ctx, jobResult.Jobs) + if err != nil { + w.log.Warnw("failed to upload jobs", "error", err) + return + } + w.lastUploadTime = w.now() + + err = w.markJobsStatus(statusList) + if err != nil && w.lifecycle.ctx.Err() != nil { + return + } + if err != nil { + panic(fmt.Errorf("failed to mark jobs: %s", err.Error())) + } + worked = true + + tags := stats.Tags{ + "workspaceId": w.workspaceID, + "sourceId": w.sourceID, + } + w.statsFactory.NewTaggedStat("erridx_uploaded_jobs", stats.CountType, tags).Count(len(jobResult.Jobs)) + return +} + +func (w *worker) fetchJobs() (jobsdb.JobsResult, error) { + defer w.limiter.fetch.Begin(w.sourceID)() + + return w.jobsDB.GetUnprocessed(w.lifecycle.ctx, jobsdb.GetQueryParams{ + ParameterFilters: []jobsdb.ParameterFilterT{ + {Name: "source_id", Value: w.sourceID}, + }, + PayloadSizeLimit: w.config.payloadLimit.Load(), + EventsLimit: int(w.config.eventsLimit.Load()), + JobsLimit: int(w.config.eventsLimit.Load()), + }) +} + +// uploadJobs uploads aggregated job payloads to object storage. +// It aggregates payloads from a list of jobs, applies transformations if needed, +// uploads the payloads, and returns the concatenated locations of the uploaded files. +func (w *worker) uploadJobs(ctx context.Context, jobs []*jobsdb.JobT) ([]*jobsdb.JobStatusT, error) { + defer w.limiter.upload.Begin(w.sourceID)() + + jobWithPayloadsMap := make(map[string][]jobWithPayload) + for _, job := range jobs { + var p payload + if err := json.Unmarshal(job.EventPayload, &p); err != nil { + return nil, fmt.Errorf("unmarshalling payload: %w", err) + } + + key := p.FailedAtTime().Format("2006-01-02/15") + jobWithPayloadsMap[key] = append(jobWithPayloadsMap[key], jobWithPayload{JobT: job, payload: p}) + } + + statusList := make([]*jobsdb.JobStatusT, 0, len(jobs)) + for _, jobWithPayloads := range jobWithPayloadsMap { + uploadFile, err := w.uploadPayloads(ctx, lo.Map(jobWithPayloads, func(item jobWithPayload, index int) payload { + return item.payload + })) + if err != nil { + return nil, fmt.Errorf("uploading aggregated payloads: %w", err) + } + + statusList = append(statusList, lo.Map(jobWithPayloads, func(item jobWithPayload, index int) *jobsdb.JobStatusT { + return &jobsdb.JobStatusT{ + JobID: item.JobT.JobID, + JobState: jobsdb.Succeeded.State, + ErrorResponse: []byte(fmt.Sprintf(`{"location": "%s"}`, uploadFile.Location)), + Parameters: []byte(`{}`), + AttemptNum: item.JobT.LastJobStatus.AttemptNum + 1, + ExecTime: w.now(), + RetryTime: w.now(), + } + })...) + } + return statusList, nil +} + +func (w *worker) uploadPayloads(ctx context.Context, payloads []payload) (*filemanager.UploadedFile, error) { + slices.SortFunc(payloads, func(i, j payload) int { + return i.FailedAtTime().Compare(j.FailedAtTime()) + }) + + tmpDirPath, err := misc.CreateTMPDIR() + if err != nil { + return nil, fmt.Errorf("creating tmp directory: %w", err) + } + + dir, err := os.MkdirTemp(tmpDirPath, "*") + if err != nil { + return nil, fmt.Errorf("creating tmp directory: %w", err) + } + + minFailedAt := payloads[0].FailedAtTime() + maxFailedAt := payloads[len(payloads)-1].FailedAtTime() + + filePath := path.Join(dir, fmt.Sprintf("%d_%d_%s.parquet", minFailedAt.Unix(), maxFailedAt.Unix(), w.config.instanceID)) + + f, err := os.Create(filePath) + if err != nil { + return nil, fmt.Errorf("creating file: %w", err) + } + defer func() { + _ = os.Remove(f.Name()) + }() + + if err = w.encodeToParquet(f, payloads); err != nil { + return nil, fmt.Errorf("writing to file: %w", err) + } + if err = f.Close(); err != nil { + return nil, fmt.Errorf("closing file: %w", err) + } + + f, err = os.Open(f.Name()) + if err != nil { + return nil, fmt.Errorf("opening file: %w", err) + } + + prefixes := []string{w.sourceID, minFailedAt.Format("2006-01-02"), strconv.Itoa(minFailedAt.Hour())} + uploadOutput, err := w.uploader.Upload(ctx, f, prefixes...) + if err != nil { + return nil, fmt.Errorf("uploading file to object storage: %w", err) + } + return &uploadOutput, nil +} + +// encodeToParquet writes the payloads to the writer using parquet encoding. It sorts the payloads to achieve better encoding. +func (w *worker) encodeToParquet(wr io.Writer, payloads []payload) error { + pw, err := writer.NewParquetWriterFromWriter(wr, new(payload), w.config.parquetParallelWriters.Load()) + if err != nil { + return fmt.Errorf("creating parquet writer: %v", err) + } + + pw.RowGroupSize = w.config.parquetRowGroupSize.Load() + pw.PageSize = w.config.parquetPageSize.Load() + pw.CompressionType = parquet.CompressionCodec_SNAPPY + + sort.Slice(payloads, func(i, j int) bool { + return payloads[i].FailedAt > payloads[j].FailedAt + }) + + for _, payload := range payloads { + if err = pw.Write(payload); err != nil { + return fmt.Errorf("writing to parquet writer: %v", err) + } + } + if err = pw.WriteStop(); err != nil { + return fmt.Errorf("stopping parquet writer: %v", err) + } + return nil +} + +// markJobsStatus marks the status of the jobs in the erridx jobsDB. +func (w *worker) markJobsStatus(statusList []*jobsdb.JobStatusT) error { + defer w.limiter.update.Begin(w.sourceID)() + + err := misc.RetryWithNotify( + w.lifecycle.ctx, + w.config.jobsDBCommandTimeout, + w.config.jobsDBMaxRetries.Load(), + func(ctx context.Context) error { + return w.jobsDB.UpdateJobStatus(ctx, statusList, nil, nil) + }, + func(attempt int) { + w.log.Warnw("failed to mark job's status", "attempt", attempt) + }, + ) + if err != nil { + return fmt.Errorf("updating job status: %w", err) + } + + tags := stats.Tags{ + "workspaceId": w.workspaceID, + "sourceId": w.sourceID, + "state": jobsdb.Succeeded.State, + } + w.statsFactory.NewTaggedStat("erridx_processed_jobs", stats.CountType, tags).Count(len(statusList)) + return nil +} + +func (w *worker) SleepDurations() (time.Duration, time.Duration) { + if w.lastUploadTime.IsZero() { + return w.config.minWorkerSleep, w.config.uploadFrequency + } + return w.config.minWorkerSleep, time.Until(w.lastUploadTime.Add(w.config.uploadFrequency)) +} + +func (w *worker) Stop() { + w.lifecycle.cancel() +} diff --git a/enterprise/reporting/error_index/worker_test.go b/enterprise/reporting/error_index/worker_test.go new file mode 100644 index 00000000000..dade052022c --- /dev/null +++ b/enterprise/reporting/error_index/worker_test.go @@ -0,0 +1,719 @@ +package error_index + +import ( + "bytes" + "context" + "database/sql" + "encoding/csv" + "encoding/json" + "fmt" + "io" + "os" + "path" + "slices" + "strconv" + "strings" + "sync" + "testing" + "time" + + kitsync "github.com/rudderlabs/rudder-go-kit/sync" + + "github.com/minio/minio-go/v7" + + "github.com/rudderlabs/rudder-go-kit/bytesize" + + "github.com/google/uuid" + "github.com/ory/dockertest/v3" + "github.com/samber/lo" + "github.com/stretchr/testify/require" + "github.com/xitongsys/parquet-go-source/buffer" + "github.com/xitongsys/parquet-go-source/local" + "github.com/xitongsys/parquet-go/reader" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + "github.com/rudderlabs/rudder-server/jobsdb" + "github.com/rudderlabs/rudder-server/utils/misc" + warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" + + _ "github.com/marcboeker/go-duckdb" +) + +func TestWorkerWriter(t *testing.T) { + const ( + sourceID = "test-source-id" + workspaceID = "test-workspace-id" + instanceID = "test-instance-id" + ) + + ctx := context.Background() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + t.Run("writer", func(t *testing.T) { + receivedAt := time.Date(2021, 1, 1, 1, 1, 1, 0, time.UTC) + failedAt := receivedAt.Add(time.Hour) + + count := 100 + factor := 10 + payloads := make([]payload, 0, count) + + for i := 0; i < count; i++ { + p := payload{ + MessageID: "messageId" + strconv.Itoa(i), + SourceID: "sourceId" + strconv.Itoa(i%5), + DestinationID: "destinationId" + strconv.Itoa(i%10), + TransformationID: "transformationId" + strconv.Itoa(i), + TrackingPlanID: "trackingPlanId" + strconv.Itoa(i), + FailedStage: "failedStage" + strconv.Itoa(i), + EventType: "eventType" + strconv.Itoa(i), + EventName: "eventName" + strconv.Itoa(i), + } + p.SetReceivedAt(receivedAt.Add(time.Duration(i) * time.Second)) + p.SetFailedAt(failedAt.Add(time.Duration(i) * time.Second)) + + payloads = append(payloads, p) + } + + t.Run("writes", func(t *testing.T) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + + w := worker{} + w.config.parquetRowGroupSize = misc.SingleValueLoader(512 * bytesize.MB) + w.config.parquetPageSize = misc.SingleValueLoader(8 * bytesize.KB) + w.config.parquetParallelWriters = misc.SingleValueLoader(int64(8)) + + require.NoError(t, w.encodeToParquet(buf, payloads)) + + pr, err := reader.NewParquetReader(buffer.NewBufferFileFromBytes(buf.Bytes()), new(payload), 8) + require.NoError(t, err) + require.EqualValues(t, len(payloads), pr.GetNumRows()) + + for i := 0; i < int(pr.GetNumRows())/factor; i++ { + expectedPayloads := make([]payload, factor) + + err := pr.Read(&expectedPayloads) + require.NoError(t, err) + + for j, expectedPayload := range expectedPayloads { + require.Equal(t, payloads[i*factor+j].MessageID, expectedPayload.MessageID) + require.Equal(t, payloads[i*factor+j].SourceID, expectedPayload.SourceID) + require.Equal(t, payloads[i*factor+j].DestinationID, expectedPayload.DestinationID) + require.Equal(t, payloads[i*factor+j].TransformationID, expectedPayload.TransformationID) + require.Equal(t, payloads[i*factor+j].TrackingPlanID, expectedPayload.TrackingPlanID) + require.Equal(t, payloads[i*factor+j].FailedStage, expectedPayload.FailedStage) + require.Equal(t, payloads[i*factor+j].EventType, expectedPayload.EventType) + require.Equal(t, payloads[i*factor+j].EventName, expectedPayload.EventName) + require.Equal(t, payloads[i*factor+j].ReceivedAt, expectedPayload.ReceivedAt) + require.Equal(t, payloads[i*factor+j].FailedAt, expectedPayload.FailedAt) + } + } + }) + + t.Run("filters", func(t *testing.T) { + filePath := path.Join(t.TempDir(), "payloads.parquet") + t.Cleanup(func() { + _ = os.Remove(filePath) + }) + + fw, err := local.NewLocalFileWriter(filePath) + require.NoError(t, err) + + w := worker{} + w.config.parquetRowGroupSize = misc.SingleValueLoader(512 * bytesize.MB) + w.config.parquetPageSize = misc.SingleValueLoader(8 * bytesize.KB) + w.config.parquetParallelWriters = misc.SingleValueLoader(int64(8)) + + require.NoError(t, w.encodeToParquet(fw, payloads)) + + t.Run("count all", func(t *testing.T) { + var count int64 + err := duckDB(t).QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM read_parquet('%s');", filePath)).Scan(&count) + require.NoError(t, err) + require.EqualValues(t, len(payloads), count) + }) + t.Run("count for sourceId, destinationId", func(t *testing.T) { + var count int64 + err := duckDB(t).QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM read_parquet('%s') WHERE source_id = $1 AND destination_id = $2;", filePath), "sourceId3", "destinationId3").Scan(&count) + require.NoError(t, err) + require.EqualValues(t, 10, count) + }) + t.Run("select all", func(t *testing.T) { + failedMessages := failedMessagesUsingDuckDB(t, ctx, nil, fmt.Sprintf("SELECT * FROM read_parquet('%s') ORDER BY failed_at DESC;", filePath)) + + for i, failedMessage := range failedMessages { + require.Equal(t, payloads[i].MessageID, failedMessage.MessageID) + require.Equal(t, payloads[i].SourceID, failedMessage.SourceID) + require.Equal(t, payloads[i].DestinationID, failedMessage.DestinationID) + require.Equal(t, payloads[i].TransformationID, failedMessage.TransformationID) + require.Equal(t, payloads[i].TrackingPlanID, failedMessage.TrackingPlanID) + require.Equal(t, payloads[i].FailedStage, failedMessage.FailedStage) + require.Equal(t, payloads[i].EventType, failedMessage.EventType) + require.Equal(t, payloads[i].EventName, failedMessage.EventName) + require.EqualValues(t, payloads[i].ReceivedAt, failedMessage.ReceivedAt) + require.EqualValues(t, payloads[i].FailedAt, failedMessage.FailedAt) + } + }) + }) + }) + + t.Run("workers work", func(t *testing.T) { + t.Run("same hours", func(t *testing.T) { + receivedAt := time.Date(2021, 1, 1, 1, 1, 1, 0, time.UTC) + failedAt := time.Date(2021, 1, 1, 1, 1, 1, 0, time.UTC) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + c := config.New() + c.Set("INSTANCE_ID", instanceID) + + errIndexDB := jobsdb.NewForReadWrite("err_idx", jobsdb.WithDBHandle(postgresContainer.DB), jobsdb.WithConfig(c)) + require.NoError(t, errIndexDB.Start()) + defer errIndexDB.TearDown() + + count := 100 + payloads := make([]payload, 0, count) + jobs := make([]*jobsdb.JobT, 0, count) + + for i := 0; i < count; i++ { + p := payload{ + MessageID: "message-id-" + strconv.Itoa(i), + SourceID: sourceID, + DestinationID: "destination-id-" + strconv.Itoa(i), + TransformationID: "transformation-id-" + strconv.Itoa(i), + TrackingPlanID: "tracking-plan-id-" + strconv.Itoa(i), + FailedStage: "failed-stage-" + strconv.Itoa(i), + EventType: "event-type-" + strconv.Itoa(i), + EventName: "event-name-" + strconv.Itoa(i), + } + p.SetReceivedAt(receivedAt) + p.SetFailedAt(failedAt.Add(time.Duration(i) * time.Second)) + payloads = append(payloads, p) + + epJSON, err := json.Marshal(p) + require.NoError(t, err) + + jobs = append(jobs, &jobsdb.JobT{ + UUID: uuid.New(), + Parameters: []byte(`{"source_id":"` + sourceID + `","workspaceId":"` + workspaceID + `"}`), + EventPayload: epJSON, + EventCount: 1, + WorkspaceId: workspaceID, + }) + } + + require.NoError(t, errIndexDB.Store(ctx, jobs)) + + cs := newMockConfigSubscriber() + cs.addWorkspaceIDForSourceID(sourceID, workspaceID) + + statsStore := memstats.New() + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.MINIO, + Config: map[string]any{ + "bucketName": minioResource.BucketName, + "accessKeyID": minioResource.AccessKeyID, + "secretAccessKey": minioResource.AccessKeySecret, + "endPoint": minioResource.Endpoint, + }, + }) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + limiterGroup := sync.WaitGroup{} + limiter := kitsync.NewLimiter(ctx, &limiterGroup, "erridx_test", 1000, statsStore) + defer func() { + cancel() + limiterGroup.Wait() + }() + + w := newWorker(sourceID, c, logger.NOP, statsStore, errIndexDB, cs, fm, limiter, limiter, limiter) + defer w.Stop() + + require.True(t, w.Work()) + require.EqualValues(t, len(jobs), statsStore.Get("erridx_uploaded_jobs", stats.Tags{ + "workspaceId": w.workspaceID, + "sourceId": w.sourceID, + }).LastValue()) + require.EqualValues(t, len(jobs), statsStore.Get("erridx_processed_jobs", stats.Tags{ + "workspaceId": w.workspaceID, + "sourceId": w.sourceID, + "state": jobsdb.Succeeded.State, + }).LastValue()) + require.False(t, w.Work()) + + lastFailedAt := failedAt.Add(time.Duration(len(jobs)-1) * time.Second) + filePath := fmt.Sprintf("s3://%s/%s/%s/%s/%d_%d_%s.parquet", + minioResource.BucketName, + w.sourceID, + failedAt.Format("2006-01-02"), + strconv.Itoa(failedAt.Hour()), + failedAt.Unix(), + lastFailedAt.Unix(), + instanceID, + ) + query := fmt.Sprintf("SELECT * FROM read_parquet('%s') ORDER BY failed_at ASC;", filePath) + failedMessages := failedMessagesUsingDuckDB(t, ctx, minioResource, query) + require.Len(t, failedMessages, len(jobs)) + require.EqualValues(t, payloads, failedMessages) + + s3SelectPath := fmt.Sprintf("%s/%s/%s/%d_%d_%s.parquet", + w.sourceID, + failedAt.Format("2006-01-02"), + strconv.Itoa(failedAt.Hour()), + failedAt.Unix(), + lastFailedAt.Unix(), + instanceID, + ) + s3SelectQuery := fmt.Sprint("SELECT message_id, source_id, destination_id, transformation_id, tracking_plan_id, failed_stage, event_type, event_name, received_at, failed_at FROM S3Object") + failedMessagesUsing3Select := failedMessagesUsingMinioS3Select(t, ctx, minioResource, s3SelectPath, s3SelectQuery) + slices.SortFunc(failedMessagesUsing3Select, func(a, b payload) int { + return a.FailedAtTime().Compare(b.FailedAtTime()) + }) + require.Equal(t, len(failedMessages), len(failedMessagesUsing3Select)) + require.Equal(t, failedMessages, failedMessagesUsing3Select) + + jr, err := errIndexDB.GetSucceeded(ctx, jobsdb.GetQueryParams{ + ParameterFilters: []jobsdb.ParameterFilterT{ + {Name: "source_id", Value: w.sourceID}, + }, + PayloadSizeLimit: w.config.payloadLimit.Load(), + EventsLimit: int(w.config.eventsLimit.Load()), + JobsLimit: int(w.config.eventsLimit.Load()), + }) + require.NoError(t, err) + require.Len(t, jr.Jobs, len(jobs)) + + lo.ForEach(jr.Jobs, func(item *jobsdb.JobT, index int) { + require.EqualValues(t, string(item.LastJobStatus.ErrorResponse), fmt.Sprintf(`{"location": "%s"}`, strings.Replace(filePath, "s3://", fmt.Sprintf("http://%s/", minioResource.Endpoint), 1))) + }) + }) + t.Run("multiple hours and days", func(t *testing.T) { + receivedAt := time.Date(2021, 1, 1, 1, 1, 1, 0, time.UTC) + failedAt := time.Date(2021, 1, 1, 1, 1, 1, 0, time.UTC) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + c := config.New() + c.Set("INSTANCE_ID", instanceID) + + errIndexDB := jobsdb.NewForReadWrite("err_idx", jobsdb.WithDBHandle(postgresContainer.DB), jobsdb.WithConfig(c)) + require.NoError(t, errIndexDB.Start()) + defer errIndexDB.TearDown() + + count := 100 + payloads := make([]payload, 0, count) + jobs := make([]*jobsdb.JobT, 0, count) + + for i := 0; i < count; i++ { + p := payload{ + MessageID: "message-id-" + strconv.Itoa(i), + SourceID: sourceID, + DestinationID: "destination-id-" + strconv.Itoa(i), + TransformationID: "transformation-id-" + strconv.Itoa(i), + TrackingPlanID: "tracking-plan-id-" + strconv.Itoa(i), + FailedStage: "failed-stage-" + strconv.Itoa(i), + EventType: "event-type-" + strconv.Itoa(i), + EventName: "event-name-" + strconv.Itoa(i), + } + p.SetReceivedAt(receivedAt) + p.SetFailedAt(failedAt.Add(time.Duration(i) * time.Hour)) + payloads = append(payloads, p) + + epJSON, err := json.Marshal(p) + require.NoError(t, err) + + jobs = append(jobs, &jobsdb.JobT{ + UUID: uuid.New(), + Parameters: []byte(`{"source_id":"` + sourceID + `","workspaceId":"` + workspaceID + `"}`), + EventPayload: epJSON, + EventCount: 1, + WorkspaceId: workspaceID, + }) + } + + require.NoError(t, errIndexDB.Store(ctx, jobs)) + + cs := newMockConfigSubscriber() + cs.addWorkspaceIDForSourceID(sourceID, workspaceID) + + statsStore := memstats.New() + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.MINIO, + Config: map[string]any{ + "bucketName": minioResource.BucketName, + "accessKeyID": minioResource.AccessKeyID, + "secretAccessKey": minioResource.AccessKeySecret, + "endPoint": minioResource.Endpoint, + }, + }) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + limiterGroup := sync.WaitGroup{} + limiter := kitsync.NewLimiter(ctx, &limiterGroup, "erridx_test", 1000, statsStore) + defer func() { + cancel() + limiterGroup.Wait() + }() + + w := newWorker(sourceID, c, logger.NOP, statsStore, errIndexDB, cs, fm, limiter, limiter, limiter) + defer w.Stop() + + require.True(t, w.Work()) + + for i := 0; i < count; i++ { + failedAt := failedAt.Add(time.Duration(i) * time.Hour) + query := fmt.Sprintf("SELECT * FROM read_parquet('%s') ORDER BY failed_at ASC;", fmt.Sprintf("s3://%s/%s/%s/%s/%d_%d_%s.parquet", + minioResource.BucketName, + w.sourceID, + failedAt.Format("2006-01-02"), + strconv.Itoa(failedAt.Hour()), + failedAt.Unix(), + failedAt.Unix(), + instanceID, + )) + + failedMessages := failedMessagesUsingDuckDB(t, ctx, minioResource, query) + require.EqualValues(t, []payload{payloads[i]}, failedMessages) + } + + jr, err := errIndexDB.GetSucceeded(ctx, jobsdb.GetQueryParams{ + ParameterFilters: []jobsdb.ParameterFilterT{ + {Name: "source_id", Value: w.sourceID}, + }, + PayloadSizeLimit: w.config.payloadLimit.Load(), + EventsLimit: int(w.config.eventsLimit.Load()), + JobsLimit: int(w.config.eventsLimit.Load()), + }) + require.NoError(t, err) + require.Len(t, jr.Jobs, len(jobs)) + + lo.ForEach(jr.Jobs, func(item *jobsdb.JobT, index int) { + failedAt := failedAt.Add(time.Duration(index) * time.Hour) + filePath := fmt.Sprintf("http://%s/%s/%s/%s/%s/%d_%d_%s.parquet", + minioResource.Endpoint, + minioResource.BucketName, + w.sourceID, + failedAt.Format("2006-01-02"), + strconv.Itoa(failedAt.Hour()), + failedAt.Unix(), + failedAt.Unix(), + instanceID, + ) + require.EqualValues(t, string(item.LastJobStatus.ErrorResponse), fmt.Sprintf(`{"location": "%s"}`, strings.Replace(filePath, "s3://", fmt.Sprintf("http://%s/", minioResource.Endpoint), 1))) + }) + }) + t.Run("limits reached but few left without crossing upload frequency", func(t *testing.T) { + receivedAt := time.Date(2021, 1, 1, 1, 1, 1, 0, time.UTC) + failedAt := time.Date(2021, 1, 1, 1, 1, 1, 0, time.UTC) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + eventsLimit := 24 + + c := config.New() + c.Set("INSTANCE_ID", instanceID) + c.Set("Reporting.errorIndexReporting.minWorkerSleep", "1s") + c.Set("Reporting.errorIndexReporting.uploadFrequency", "600s") + c.Set("Reporting.errorIndexReporting.eventsLimit", strconv.Itoa(eventsLimit)) + + errIndexDB := jobsdb.NewForReadWrite("err_idx", jobsdb.WithDBHandle(postgresContainer.DB), jobsdb.WithConfig(c)) + require.NoError(t, errIndexDB.Start()) + defer errIndexDB.TearDown() + + count := 100 + payloads := make([]payload, 0, count) + jobs := make([]*jobsdb.JobT, 0, count) + + for i := 0; i < count; i++ { + p := payload{ + MessageID: "message-id-" + strconv.Itoa(i), + SourceID: sourceID, + DestinationID: "destination-id-" + strconv.Itoa(i), + TransformationID: "transformation-id-" + strconv.Itoa(i), + TrackingPlanID: "tracking-plan-id-" + strconv.Itoa(i), + FailedStage: "failed-stage-" + strconv.Itoa(i), + EventType: "event-type-" + strconv.Itoa(i), + EventName: "event-name-" + strconv.Itoa(i), + } + p.SetReceivedAt(receivedAt) + p.SetFailedAt(failedAt.Add(time.Duration(i) * time.Second)) + payloads = append(payloads, p) + + epJSON, err := json.Marshal(p) + require.NoError(t, err) + + jobs = append(jobs, &jobsdb.JobT{ + UUID: uuid.New(), + Parameters: []byte(`{"source_id":"` + sourceID + `","workspaceId":"` + workspaceID + `"}`), + EventPayload: epJSON, + EventCount: 1, + WorkspaceId: workspaceID, + }) + } + require.NoError(t, errIndexDB.Store(ctx, jobs)) + + cs := newMockConfigSubscriber() + cs.addWorkspaceIDForSourceID(sourceID, workspaceID) + + statsStore := memstats.New() + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.MINIO, + Config: map[string]any{ + "bucketName": minioResource.BucketName, + "accessKeyID": minioResource.AccessKeyID, + "secretAccessKey": minioResource.AccessKeySecret, + "endPoint": minioResource.Endpoint, + }, + }) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + limiterGroup := sync.WaitGroup{} + limiter := kitsync.NewLimiter(ctx, &limiterGroup, "erridx_test", 1000, statsStore) + defer func() { + cancel() + limiterGroup.Wait() + }() + + w := newWorker(sourceID, c, logger.NOP, statsStore, errIndexDB, cs, fm, limiter, limiter, limiter) + defer w.Stop() + + for i := 0; i < count/eventsLimit; i++ { + require.True(t, w.Work()) + } + require.False(t, w.Work()) + + jr, err := errIndexDB.GetUnprocessed(ctx, jobsdb.GetQueryParams{ + ParameterFilters: []jobsdb.ParameterFilterT{ + {Name: "source_id", Value: w.sourceID}, + }, + PayloadSizeLimit: w.config.payloadLimit.Load(), + EventsLimit: int(w.config.eventsLimit.Load()), + JobsLimit: int(w.config.eventsLimit.Load()), + }) + require.NoError(t, err) + require.Len(t, jr.Jobs, 4) + }) + }) +} + +func failedMessagesUsingMinioS3Select(t testing.TB, ctx context.Context, mr *resource.MinioResource, filePath, query string) []payload { + t.Helper() + + r, err := mr.Client.SelectObjectContent(ctx, mr.BucketName, filePath, minio.SelectObjectOptions{ + Expression: query, + ExpressionType: minio.QueryExpressionTypeSQL, + InputSerialization: minio.SelectObjectInputSerialization{ + CompressionType: minio.SelectCompressionNONE, + Parquet: &minio.ParquetInputOptions{}, + }, + OutputSerialization: minio.SelectObjectOutputSerialization{ + CSV: &minio.CSVOutputOptions{ + RecordDelimiter: "\n", + FieldDelimiter: ",", + }, + }, + }) + require.NoError(t, err) + defer func() { _ = r.Close() }() + + buf := bytes.NewBuffer(make([]byte, 0, bytesize.MB)) + + _, err = io.Copy(buf, r) + require.NoError(t, err) + + c := csv.NewReader(buf) + records, err := c.ReadAll() + require.NoError(t, err) + + payloads := make([]payload, 0, len(records)) + for _, r := range records { + p := payload{ + MessageID: r[0], + SourceID: r[1], + DestinationID: r[2], + TransformationID: r[3], + TrackingPlanID: r[4], + FailedStage: r[5], + EventType: r[6], + EventName: r[7], + } + + receivedAt, err := strconv.Atoi(r[8]) + require.NoError(t, err) + failedAt, err := strconv.Atoi(r[9]) + require.NoError(t, err) + + p.SetReceivedAt(time.UnixMicro(int64(receivedAt))) + p.SetFailedAt(time.UnixMicro(int64(failedAt))) + + payloads = append(payloads, p) + } + return payloads +} + +func failedMessagesUsingDuckDB(t testing.TB, ctx context.Context, mr *resource.MinioResource, query string) []payload { + t.Helper() + + db := duckDB(t) + + if mr != nil { + _, err := db.Exec(fmt.Sprintf(`INSTALL httpfs; LOAD httpfs;SET s3_region='%s';SET s3_endpoint='%s';SET s3_access_key_id='%s';SET s3_secret_access_key='%s';SET s3_use_ssl= false;SET s3_url_style='path';`, + mr.Region, + mr.Endpoint, + mr.AccessKeyID, + mr.AccessKeySecret, + )) + require.NoError(t, err) + } + + rows, err := db.QueryContext(ctx, query) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + var expectedPayloads []payload + for rows.Next() { + var p payload + var receivedAt time.Time + var failedAt time.Time + require.NoError(t, rows.Scan( + &p.MessageID, &p.SourceID, &p.DestinationID, + &p.TransformationID, &p.TrackingPlanID, &p.FailedStage, + &p.EventType, &p.EventName, &receivedAt, + &failedAt, + )) + p.SetReceivedAt(receivedAt) + p.SetFailedAt(failedAt) + expectedPayloads = append(expectedPayloads, p) + } + require.NoError(t, rows.Err()) + return expectedPayloads +} + +func duckDB(t testing.TB) *sql.DB { + t.Helper() + + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec(`INSTALL parquet; LOAD parquet;`) + require.NoError(t, err) + return db +} + +func BenchmarkFileFormat(b *testing.B) { + now := time.Date(2021, 1, 1, 1, 1, 1, 0, time.UTC) + + entries := 1000000 + + b.Run("csv", func(b *testing.B) { + var records [][]string + + for i := 0; i < entries; i++ { + record := make([]string, 0, 10) + record = append(record, "messageId"+strconv.Itoa(i)) + record = append(record, "sourceId") + record = append(record, "destinationId"+strconv.Itoa(i%10)) + record = append(record, "transformationId"+strconv.Itoa(i%10)) + record = append(record, "trackingPlanId"+strconv.Itoa(i%10)) + record = append(record, "failedStage"+strconv.Itoa(i%10)) + record = append(record, "eventType"+strconv.Itoa(i%10)) + record = append(record, "eventName"+strconv.Itoa(i%10)) + record = append(record, now.Add(time.Duration(i)*time.Second).Format(time.RFC3339)) + record = append(record, now.Add(time.Duration(i)*time.Second).Format(time.RFC3339)) + + records = append(records, record) + } + + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + c := csv.NewWriter(buf) + + err := c.WriteAll(records) + require.NoError(b, err) + + b.Log("csv size:", buf.Len()) // csv size: 150 MB + }) + b.Run("json", func(b *testing.B) { + var records []payload + + for i := 0; i < entries; i++ { + records = append(records, payload{ + MessageID: "messageId" + strconv.Itoa(i), + SourceID: "sourceId", + DestinationID: "destinationId" + strconv.Itoa(i%10), + TransformationID: "transformationId" + strconv.Itoa(i%10), + TrackingPlanID: "trackingPlanId" + strconv.Itoa(i%10), + FailedStage: "failedStage" + strconv.Itoa(i%10), + EventType: "eventType" + strconv.Itoa(i%10), + EventName: "eventName" + strconv.Itoa(i%10), + ReceivedAt: now.Add(time.Duration(i) * time.Second).UnixMicro(), + FailedAt: now.Add(time.Duration(i) * time.Second).UnixMicro(), + }) + } + + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + e := json.NewEncoder(buf) + + for _, record := range records { + require.NoError(b, e.Encode(record)) + } + + b.Log("json size:", buf.Len()) // json size: 292 MB + }) + b.Run("parquet", func(b *testing.B) { + var records []payload + + for i := 0; i < entries; i++ { + records = append(records, payload{ + MessageID: "messageId" + strconv.Itoa(i), + SourceID: "sourceId", + DestinationID: "destinationId" + strconv.Itoa(i%10), + TransformationID: "transformationId" + strconv.Itoa(i%10), + TrackingPlanID: "trackingPlanId" + strconv.Itoa(i%10), + FailedStage: "failedStage" + strconv.Itoa(i%10), + EventType: "eventType" + strconv.Itoa(i%10), + EventName: "eventName" + strconv.Itoa(i%10), + ReceivedAt: now.Add(time.Duration(i) * time.Second).UnixMicro(), + FailedAt: now.Add(time.Duration(i) * time.Second).UnixMicro(), + }) + } + + w := worker{} + w.config.parquetRowGroupSize = misc.SingleValueLoader(512 * bytesize.MB) + w.config.parquetPageSize = misc.SingleValueLoader(8 * bytesize.KB) + w.config.parquetParallelWriters = misc.SingleValueLoader(int64(8)) + + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + + require.NoError(b, w.encodeToParquet(buf, records)) + + b.Log("parquet size:", buf.Len()) // parquet size: 13.8 MB + }) +} diff --git a/enterprise/reporting/error_index_reporting.go b/enterprise/reporting/error_index_reporting.go deleted file mode 100644 index 4690d25f68a..00000000000 --- a/enterprise/reporting/error_index_reporting.go +++ /dev/null @@ -1,201 +0,0 @@ -package reporting - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "sync" - "time" - - "github.com/google/uuid" - - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-server/jobsdb" - . "github.com/rudderlabs/rudder-server/utils/tx" //nolint:staticcheck - "github.com/rudderlabs/rudder-server/utils/types" -) - -type payload struct { - MessageID string `json:"messageId"` - SourceID string `json:"sourceId"` - DestinationID string `json:"destinationId"` - TransformationID string `json:"transformationId"` - TrackingPlanID string `json:"trackingPlanId"` - FailedStage string `json:"failedStage"` - EventType string `json:"eventType"` - EventName string `json:"eventName"` - ReceivedAt time.Time `json:"receivedAt"` - FailedAt time.Time `json:"failedAt"` -} - -type ErrorIndexReporter struct { - ctx context.Context - log logger.Logger - conf *config.Config - configSubscriber *configSubscriber - now func() time.Time - dbsMu sync.RWMutex - dbs map[string]*handleWithSqlDB -} - -type handleWithSqlDB struct { - *jobsdb.Handle - sqlDB *sql.DB -} - -func NewErrorIndexReporter( - ctx context.Context, - log logger.Logger, - configSubscriber *configSubscriber, - conf *config.Config, -) *ErrorIndexReporter { - eir := &ErrorIndexReporter{ - ctx: ctx, - log: log, - conf: conf, - configSubscriber: configSubscriber, - now: time.Now, - dbs: map[string]*handleWithSqlDB{}, - } - return eir -} - -// Report reports the metrics to the errorIndex JobsDB -func (eir *ErrorIndexReporter) Report(metrics []*types.PUReportedMetric, tx *Tx) error { - failedAt := eir.now() - - var jobs []*jobsdb.JobT - for _, metric := range metrics { - if metric.StatusDetail == nil { - continue - } - - for _, failedMessage := range metric.StatusDetail.FailedMessages { - workspaceID := eir.configSubscriber.WorkspaceIDFromSource(metric.SourceID) - - payload := payload{ - MessageID: failedMessage.MessageID, - SourceID: metric.SourceID, - DestinationID: metric.DestinationID, - TransformationID: metric.TransformationID, - TrackingPlanID: metric.TrackingPlanID, - FailedStage: metric.PUDetails.PU, - EventName: metric.StatusDetail.EventName, - EventType: metric.StatusDetail.EventType, - ReceivedAt: failedMessage.ReceivedAt, - FailedAt: failedAt, - } - payloadJSON, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("unable to marshal payload: %v", err) - } - - params := struct { - WorkspaceID string `json:"workspaceId"` - SourceID string `json:"source_id"` - }{ - WorkspaceID: workspaceID, - SourceID: metric.SourceID, - } - paramsJSON, err := json.Marshal(params) - if err != nil { - return fmt.Errorf("unable to marshal params: %v", err) - } - - jobs = append(jobs, &jobsdb.JobT{ - UUID: uuid.New(), - Parameters: paramsJSON, - EventPayload: payloadJSON, - EventCount: 1, - WorkspaceId: workspaceID, - }) - } - } - - if len(jobs) == 0 { - return nil - } - db, err := eir.resolveJobsDB(tx) - if err != nil { - return fmt.Errorf("failed to resolve jobsdb: %w", err) - } - if err := db.WithStoreSafeTxFromTx(eir.ctx, tx, func(tx jobsdb.StoreSafeTx) error { - return db.StoreInTx(eir.ctx, tx, jobs) - }); err != nil { - return fmt.Errorf("failed to store jobs: %w", err) - } - - return nil -} - -func (eir *ErrorIndexReporter) DatabaseSyncer(c types.SyncerConfig) types.ReportingSyncer { - eir.dbsMu.Lock() - defer eir.dbsMu.Unlock() - if _, ok := eir.dbs[c.ConnInfo]; !ok { - dbHandle, err := sql.Open("postgres", c.ConnInfo) - if err != nil { - panic(fmt.Errorf("failed to open error index db: %w", err)) - } - errIndexDB := jobsdb.NewForReadWrite( - "err_idx", - jobsdb.WithDBHandle(dbHandle), - jobsdb.WithDSLimit(eir.conf.GetReloadableIntVar(0, 1, "Reporting.errorIndexReporting.dsLimit")), - jobsdb.WithConfig(eir.conf), - jobsdb.WithSkipMaintenanceErr(eir.conf.GetBool("Reporting.errorIndexReporting.skipMaintenanceError", false)), - jobsdb.WithJobMaxAge( - func() time.Duration { - return eir.conf.GetDurationVar(24, time.Hour, "Reporting.errorIndexReporting.jobRetention") - }, - ), - ) - if err := errIndexDB.Start(); err != nil { - panic(fmt.Errorf("failed to start error index db: %w", err)) - } - eir.dbs[c.ConnInfo] = &handleWithSqlDB{ - Handle: errIndexDB, - sqlDB: dbHandle, - } - } - return func() { - } -} - -func (eir *ErrorIndexReporter) Stop() { - eir.dbsMu.RLock() - defer eir.dbsMu.RUnlock() - for _, db := range eir.dbs { - db.Handle.Stop() - } -} - -// resolveJobsDB returns the jobsdb that matches the current transaction (using system information functions) -// https://www.postgresql.org/docs/11/functions-info.html -func (eir *ErrorIndexReporter) resolveJobsDB(tx *Tx) (jobsdb.JobsDB, error) { - eir.dbsMu.RLock() - defer eir.dbsMu.RUnlock() - - if len(eir.dbs) == 1 { // optimisation, if there is only one jobsdb, return this. If it is the wrong one, it will fail anyway - for i := range eir.dbs { - return eir.dbs[i].Handle, nil - } - } - - dbIdentityQuery := `select inet_server_addr()::text || ':' || inet_server_port()::text || ':' || current_user || ':' || current_database() || ':' || current_schema || ':' || pg_postmaster_start_time()::text || ':' || version()` - var txDatabaseIdentity string - if err := tx.QueryRow(dbIdentityQuery).Scan(&txDatabaseIdentity); err != nil { - return nil, fmt.Errorf("failed to get current tx's db identity: %w", err) - } - - for key := range eir.dbs { - var databaseIdentity string - if err := eir.dbs[key].sqlDB.QueryRow(dbIdentityQuery).Scan(&databaseIdentity); err != nil { - return nil, fmt.Errorf("failed to get db identity for %q: %w", key, err) - } - if databaseIdentity == txDatabaseIdentity { - return eir.dbs[key].Handle, nil - } - } - return nil, fmt.Errorf("no jobsdb found matching the current transaction") -} diff --git a/enterprise/reporting/mediator.go b/enterprise/reporting/mediator.go index 46af6213b62..706db2a81b7 100644 --- a/enterprise/reporting/mediator.go +++ b/enterprise/reporting/mediator.go @@ -3,6 +3,10 @@ package reporting import ( "context" + "github.com/rudderlabs/rudder-go-kit/stats" + + erridx "github.com/rudderlabs/rudder-server/enterprise/reporting/error_index" + "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/config" @@ -55,7 +59,7 @@ func NewReportingMediator(ctx context.Context, log logger.Logger, enterpriseToke // error index reporting implementation if config.GetBool("Reporting.errorIndexReporting.enabled", false) { - errorIndexReporter := NewErrorIndexReporter(rm.ctx, rm.log, configSubscriber, config.Default) + errorIndexReporter := erridx.NewErrorIndexReporter(rm.ctx, rm.log, configSubscriber, config.Default, stats.Default) rm.reporters = append(rm.reporters, errorIndexReporter) } diff --git a/enterprise/reporting/reporting.go b/enterprise/reporting/reporting.go index 8b72dcea0af..8e73541bd9b 100644 --- a/enterprise/reporting/reporting.go +++ b/enterprise/reporting/reporting.go @@ -8,17 +8,18 @@ import ( "fmt" "io" "net/http" + "slices" "strconv" "strings" "sync" "time" + "go.uber.org/atomic" + "golang.org/x/sync/errgroup" + "github.com/cenkalti/backoff/v4" "github.com/lib/pq" "github.com/samber/lo" - "go.uber.org/atomic" - "golang.org/x/exp/slices" - "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" diff --git a/gateway/webhook/webhook.go b/gateway/webhook/webhook.go index 6b72033d0d7..3f352b01d31 100644 --- a/gateway/webhook/webhook.go +++ b/gateway/webhook/webhook.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "net/url" + "slices" "strconv" "strings" "sync" @@ -17,7 +18,6 @@ import ( "github.com/hashicorp/go-retryablehttp" "github.com/samber/lo" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" diff --git a/go.mod b/go.mod index 36651bad036..be1182facb9 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,7 @@ require ( github.com/golang/mock v1.6.0 github.com/gomodule/redigo v1.8.9 github.com/google/uuid v1.4.0 + github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 github.com/hashicorp/go-retryablehttp v0.7.4 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/yamux v0.1.1 @@ -65,6 +66,7 @@ require ( github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.12.0 github.com/manifoldco/promptui v0.9.0 + github.com/marcboeker/go-duckdb v1.5.1 github.com/minio/minio-go/v7 v7.0.63 github.com/mitchellh/mapstructure v1.5.0 github.com/mkmik/multierror v0.3.0 @@ -79,7 +81,7 @@ require ( github.com/rs/cors v1.10.1 github.com/rudderlabs/analytics-go v3.3.3+incompatible github.com/rudderlabs/compose-test v0.1.3 - github.com/rudderlabs/rudder-go-kit v0.16.2 + github.com/rudderlabs/rudder-go-kit v0.16.3 github.com/rudderlabs/sql-tunnels v0.1.5 github.com/samber/lo v1.38.1 github.com/segmentio/kafka-go v0.4.42 @@ -197,7 +199,6 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.1 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect - github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect diff --git a/go.sum b/go.sum index d0e1e80ab1d..f1fcffcc648 100644 --- a/go.sum +++ b/go.sum @@ -790,6 +790,8 @@ github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0V github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= +github.com/marcboeker/go-duckdb v1.5.1 h1:Mh6h0ke9EyM2XA9dWiNOawM+oUFXYOY5o2csJ32uxBw= +github.com/marcboeker/go-duckdb v1.5.1/go.mod h1:wm91jO2GNKa6iO9NTcjXIRsW+/ykPoJbQcHSXhdAl28= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= @@ -947,8 +949,8 @@ github.com/rudderlabs/compose-test v0.1.3 h1:uyep6jDCIF737sfv4zIaMsKRQKX95IDz5Xb github.com/rudderlabs/compose-test v0.1.3/go.mod h1:tuvS1eQdSfwOYv1qwyVAcpdJxPLQXJgy5xGDd/9XmMg= github.com/rudderlabs/parquet-go v0.0.2 h1:ZXRdZdimB0PdJtmxeSSxfI0fDQ3kZjwzBxRi6Ut1J8k= github.com/rudderlabs/parquet-go v0.0.2/go.mod h1:g6guum7o8uhj/uNhunnt7bw5Vabu/goI5i21/3fnxWQ= -github.com/rudderlabs/rudder-go-kit v0.16.2 h1:1zR0ivPT3Rp9bHmfq5k8VVfOy3bJrag2gHbjnqbUmtM= -github.com/rudderlabs/rudder-go-kit v0.16.2/go.mod h1:vRRTcYmAtYg87R4liGy24wO3452WlGHkFwtEopgme3k= +github.com/rudderlabs/rudder-go-kit v0.16.3 h1:IZIg7RjwbQN0GAHpiZgNLW388AwBmgVnh3bYPXP7SKQ= +github.com/rudderlabs/rudder-go-kit v0.16.3/go.mod h1:vRRTcYmAtYg87R4liGy24wO3452WlGHkFwtEopgme3k= github.com/rudderlabs/sql-tunnels v0.1.5 h1:L/e9GQtqJlTVMauAE+ym/XUqhg+Va6RZQiOvBgbhspY= github.com/rudderlabs/sql-tunnels v0.1.5/go.mod h1:ZwQkCLb/5hHm5U90juAj9idkkFGv2R2dzDHJoPbKIto= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= diff --git a/integration_test/reporting_error_index/reporting_error_index_test.go b/integration_test/reporting_error_index/reporting_error_index_test.go new file mode 100644 index 00000000000..28111cb66af --- /dev/null +++ b/integration_test/reporting_error_index/reporting_error_index_test.go @@ -0,0 +1,888 @@ +package reportingfailedmessages_test + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path" + "strconv" + "strings" + "testing" + "time" + + "github.com/samber/lo" + + "github.com/rudderlabs/rudder-server/jobsdb" + + "github.com/rudderlabs/rudder-server/processor/transformer" + + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/rudderlabs/rudder-go-kit/config" + kithttputil "github.com/rudderlabs/rudder-go-kit/httputil" + kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + "github.com/rudderlabs/rudder-go-kit/testhelper/rand" + "github.com/rudderlabs/rudder-server/runner" + "github.com/rudderlabs/rudder-server/testhelper/backendconfigtest" + "github.com/rudderlabs/rudder-server/testhelper/health" + "github.com/rudderlabs/rudder-server/testhelper/transformertest" + + _ "github.com/marcboeker/go-duckdb" +) + +func TestReportingErrorIndex(t *testing.T) { + t.Run("Events failed during tracking plan validation stage", func(t *testing.T) { + config.Reset() + defer config.Reset() + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-1"). + WithTrackingPlan("trackingplan-1", 1). + WithConnection( + backendconfigtest.NewDestinationBuilder("WEBHOOK"). + WithID("destination-1"). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + trServer := transformertest.NewBuilder(). + WithTrackingPlanHandler( + transformertest.ViolationErrorTransformerHandler( + http.StatusBadRequest, + "tracking plan validation failed", + []transformer.ValidationError{{Type: "Datatype-Mismatch", Message: "must be number"}}, + ), + ). + Build() + defer trServer.Close() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "identify", "writekey-1", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "tracking_plan_id", B: "trackingplan-1"}, + {A: "failed_stage", B: "tracking_plan_validator"}, + {A: "event_type", B: "identify"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + + t.Run("Events failed during user transformation stage", func(t *testing.T) { + config.Reset() + defer config.Reset() + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-2"). + WithConnection( + backendconfigtest.NewDestinationBuilder("WEBHOOK"). + WithID("destination-1"). + WithUserTransformation("transformation-1", "version-1"). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + trServer := transformertest.NewBuilder(). + WithUserTransformHandler( + transformertest.ErrorTransformerHandler( + http.StatusBadRequest, "TypeError: Cannot read property 'uuid' of undefined", + ), + ). + Build() + defer trServer.Close() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "identify", "writekey-2", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "destination_id", B: "destination-1"}, + {A: "transformation_id", B: "transformation-1"}, + {A: "failed_stage", B: "user_transformer"}, + {A: "event_type", B: "identify"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + + t.Run("Events failed during event filtering stage", func(t *testing.T) { + t.Run("empty message type", func(t *testing.T) { + config.Reset() + defer config.Reset() + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-1"). + WithConnection( + backendconfigtest.NewDestinationBuilder("WEBHOOK"). + WithID("destination-1"). + WithDefinitionConfigOption("supportedMessageTypes", []string{"track"}). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + trServer := transformertest.NewBuilder(). + WithUserTransformHandler(transformertest.EmptyTransformerHandler). + Build() + defer trServer.Close() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "", "writekey-1", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "destination_id", B: "destination-1"}, + {A: "failed_stage", B: "event_filter"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + + t.Run("empty message event", func(t *testing.T) { + config.Reset() + defer config.Reset() + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-1"). + WithConnection( + backendconfigtest.NewDestinationBuilder("WEBHOOK"). + WithID("destination-1"). + WithConfigOption("listOfConversions", []map[string]string{ + { + "conversions": "Test event", + }, + }). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + trServer := transformertest.NewBuilder(). + WithUserTransformHandler(transformertest.EmptyTransformerHandler). + Build() + defer trServer.Close() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "", "writekey-1", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "destination_id", B: "destination-1"}, + {A: "failed_stage", B: "event_filter"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + }) + + t.Run("Events failed during destination transformation stage", func(t *testing.T) { + config.Reset() + defer config.Reset() + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-1"). + WithConnection( + backendconfigtest.NewDestinationBuilder("WEBHOOK"). + WithID("destination-1"). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + trServer := transformertest.NewBuilder(). + WithDestTransformHandler( + "WEBHOOK", + transformertest.ErrorTransformerHandler(http.StatusBadRequest, "dest transformation failed"), + ). + Build() + defer trServer.Close() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "identify", "writekey-1", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "destination_id", B: "destination-1"}, + {A: "failed_stage", B: "dest_transformer"}, + {A: "event_type", B: "identify"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + + t.Run("Events failed during router delivery stage", func(t *testing.T) { + t.Run("rejected by destination itself", func(t *testing.T) { + config.Reset() + defer config.Reset() + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-1"). + WithConnection( + backendconfigtest.NewDestinationBuilder("WEBHOOK"). + WithID("destination-1"). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + webhook := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "aborted", http.StatusBadRequest) + })) + defer webhook.Close() + + trServer := transformertest.NewBuilder(). + WithDestTransformHandler( + "WEBHOOK", + transformertest.RESTJSONDestTransformerHandler(http.MethodPost, webhook.URL), + ). + Build() + defer trServer.Close() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "identify", "writekey-1", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "rt", jobsdb.Aborted.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "destination_id", B: "destination-1"}, + {A: "failed_stage", B: "router"}, + {A: "event_type", B: "identify"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + }) + + t.Run("Events failed during batch router delivery stage", func(t *testing.T) { + t.Run("destination id included in BatchRouter.toAbortDestinationIDs", func(t *testing.T) { + config.Reset() + defer config.Reset() + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-1"). + WithConnection( + backendconfigtest.NewDestinationBuilder("S3"). + WithID("destination-1"). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + trServer := transformertest.NewBuilder(). + WithDestTransformHandler( + "S3", + transformertest.MirroringTransformerHandler, + ). + Build() + defer trServer.Close() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + config.Set("BatchRouter.toAbortDestinationIDs", "destination-1") + + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "identify", "writekey-1", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "batch_rt", jobsdb.Aborted.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "destination_id", B: "destination-1"}, + {A: "failed_stage", B: "batch_router"}, + {A: "event_type", B: "identify"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + + t.Run("invalid object storage configuration", func(t *testing.T) { + config.Reset() + defer config.Reset() + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-1"). + WithConnection( + backendconfigtest.NewDestinationBuilder("S3"). + WithID("destination-1"). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + trServer := transformertest.NewBuilder(). + WithDestTransformHandler( + "S3", + transformertest.MirroringTransformerHandler, + ). + Build() + defer trServer.Close() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + config.Set("BatchRouter.S3.retryTimeWindow", "0s") + config.Set("BatchRouter.S3.maxFailedCountForJob", 0) + + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "identify", "writekey-1", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "batch_rt", jobsdb.Aborted.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "destination_id", B: "destination-1"}, + {A: "failed_stage", B: "batch_router"}, + {A: "event_type", B: "identify"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + + t.Run("unable to ping to warehouse", func(t *testing.T) { + config.Reset() + defer config.Reset() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + postgresContainer, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + minioResource, err := resource.SetupMinio(pool, t) + require.NoError(t, err) + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source-1"). + WithWriteKey("writekey-1"). + WithConnection( + backendconfigtest.NewDestinationBuilder("POSTGRES"). + WithID("destination-1"). + WithConfigOption("bucketProvider", "MINIO"). + WithConfigOption("bucketName", minioResource.BucketName). + WithConfigOption("accessKeyID", minioResource.AccessKeyID). + WithConfigOption("secretAccessKey", minioResource.AccessKeySecret). + WithConfigOption("endPoint", minioResource.Endpoint). + Build()). + Build()). + Build()). + Build() + defer bcServer.Close() + + trServer := transformertest.NewBuilder(). + WithDestTransformHandler( + "POSTGRES", + transformertest.WarehouseTransformerHandler( + "tracks", http.StatusOK, "", + ), + ). + Build() + defer trServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + config.Set("BatchRouter.warehouseServiceMaxRetryTime", "0s") + + err := runRudderServer(ctx, gwPort, postgresContainer, minioResource, bcServer.URL, trServer.URL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + eventsCount := 12 + + err = sendEvents(eventsCount, "identify", "writekey-1", url) + require.NoError(t, err) + + requireJobsCount(t, postgresContainer.DB, "gw", jobsdb.Succeeded.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "batch_rt", jobsdb.Aborted.State, eventsCount) + requireJobsCount(t, postgresContainer.DB, "err_idx", jobsdb.Succeeded.State, eventsCount) + requireMessagesCount(t, ctx, minioResource, eventsCount, []lo.Tuple2[string, string]{ + {A: "source_id", B: "source-1"}, + {A: "destination_id", B: "destination-1"}, + {A: "failed_stage", B: "batch_router"}, + {A: "event_type", B: "identify"}, + }...) + + cancel() + require.NoError(t, wg.Wait()) + }) + }) +} + +func runRudderServer( + ctx context.Context, + port int, + postgresContainer *resource.PostgresResource, + minioResource *resource.MinioResource, + cbURL, transformerURL, tmpDir string, +) (err error) { + config.Set("CONFIG_BACKEND_URL", cbURL) + config.Set("WORKSPACE_TOKEN", "token") + config.Set("DB.port", postgresContainer.Port) + config.Set("DB.user", postgresContainer.User) + config.Set("DB.name", postgresContainer.Database) + config.Set("DB.password", postgresContainer.Password) + config.Set("DEST_TRANSFORM_URL", transformerURL) + + config.Set("Warehouse.mode", "off") + config.Set("DestinationDebugger.disableEventDeliveryStatusUploads", true) + config.Set("SourceDebugger.disableEventUploads", true) + config.Set("TransformationDebugger.disableTransformationStatusUploads", true) + config.Set("JobsDB.backup.enabled", false) + config.Set("JobsDB.migrateDSLoopSleepDuration", "60m") + config.Set("archival.Enabled", false) + config.Set("Reporting.syncer.enabled", false) + config.Set("Reporting.errorIndexReporting.enabled", true) + config.Set("Reporting.errorIndexReporting.syncer.enabled", true) + config.Set("Reporting.errorIndexReporting.SleepDuration", "1s") + config.Set("Reporting.errorIndexReporting.minWorkerSleep", "1s") + config.Set("Reporting.errorIndexReporting.uploadFrequency", "1s") + config.Set("BatchRouter.mainLoopFreq", "1s") + config.Set("BatchRouter.uploadFreq", "1s") + config.Set("Gateway.webPort", strconv.Itoa(port)) + config.Set("RUDDER_TMPDIR", os.TempDir()) + config.Set("recovery.storagePath", path.Join(tmpDir, "/recovery_data.json")) + config.Set("recovery.enabled", false) + config.Set("Profiler.Enabled", false) + config.Set("Gateway.enableSuppressUserFeature", false) + + config.Set("ErrorIndex.storage.Bucket", minioResource.BucketName) + config.Set("ErrorIndex.storage.Endpoint", minioResource.Endpoint) + config.Set("ErrorIndex.storage.AccessKey", minioResource.AccessKeyID) + config.Set("ErrorIndex.storage.SecretAccessKey", minioResource.AccessKeySecret) + config.Set("ErrorIndex.storage.S3ForcePathStyle", true) + config.Set("ErrorIndex.storage.DisableSSL", true) + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panicked: %v", r) + } + }() + r := runner.New(runner.ReleaseInfo{EnterpriseToken: "DUMMY"}) + c := r.Run(ctx, []string{"rudder-error-reporting"}) + if c != 0 { + err = fmt.Errorf("rudder-server exited with a non-0 exit code: %d", c) + } + return +} + +// nolint: unparam +func requireJobsCount( + t *testing.T, + db *sql.DB, + queue, state string, + expectedCount int, +) { + t.Helper() + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, db.QueryRow(fmt.Sprintf(`SELECT count(*) FROM unionjobsdbmetadata('%s',1) WHERE job_state = '%s';`, queue, state)).Scan(&jobsCount)) + t.Logf("%s %sJobCount: %d", queue, state, jobsCount) + return jobsCount == expectedCount + }, + 20*time.Second, + 1*time.Second, + fmt.Sprintf("%d %s events should be in %s state", expectedCount, queue, state), + ) +} + +// nolint: unparam +func requireMessagesCount( + t *testing.T, + ctx context.Context, + mr *resource.MinioResource, + expectedCount int, + filters ...lo.Tuple2[string, string], +) { + t.Helper() + + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + _, err = db.Exec(fmt.Sprintf(`INSTALL parquet; LOAD parquet; INSTALL httpfs; LOAD httpfs;SET s3_region='%s';SET s3_endpoint='%s';SET s3_access_key_id='%s';SET s3_secret_access_key='%s';SET s3_use_ssl= false;SET s3_url_style='path';`, + mr.Region, + mr.Endpoint, + mr.AccessKeyID, + mr.AccessKeySecret, + )) + require.NoError(t, err) + + query := fmt.Sprintf("SELECT count(*) FROM read_parquet('%s') WHERE 1 = 1", fmt.Sprintf("s3://%s/**/**/**/*.parquet", mr.BucketName)) + query += strings.Join(lo.Map(filters, func(t lo.Tuple2[string, string], _ int) string { + return fmt.Sprintf(" AND %s = '%s'", t.A, t.B) + }), "") + + require.Eventually(t, func() bool { + var messagesCount int + require.NoError(t, db.QueryRowContext(ctx, query).Scan(&messagesCount)) + t.Logf("messagesCount: %d", messagesCount) + return messagesCount == expectedCount + }, + 10*time.Second, + 1*time.Second, + fmt.Sprintf("%d messages should be in the bucket", expectedCount), + ) +} + +// nolint: unparam +func sendEvents( + num int, + eventType, writeKey, + url string, +) error { + for i := 0; i < num; i++ { + payload := []byte(fmt.Sprintf(` + { + "batch": [ + { + "userId": %[1]q, + "type": %[2]q, + "context": { + "traits": { + "trait1": "new-val" + }, + "ip": "14.5.67.21", + "library": { + "name": "http" + } + }, + "timestamp": "2020-02-02T00:23:09.544Z" + } + ] + }`, + rand.String(10), + eventType, + )) + req, err := http.NewRequest(http.MethodPost, url+"/v1/batch", bytes.NewReader(payload)) + if err != nil { + return err + } + req.SetBasicAuth(writeKey, "password") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return fmt.Errorf("failed to send event to rudder server, status code: %d: %s", resp.StatusCode, string(b)) + } + func() { kithttputil.CloseResponse(resp) }() + } + return nil +} diff --git a/jobsdb/internal/cache/cache.go b/jobsdb/internal/cache/cache.go index 6f1e8b772a7..f87e8b3b133 100644 --- a/jobsdb/internal/cache/cache.go +++ b/jobsdb/internal/cache/cache.go @@ -2,12 +2,12 @@ package cache import ( "fmt" + "slices" "sync" "time" "github.com/google/uuid" "github.com/samber/lo" - "golang.org/x/exp/slices" ) const ( diff --git a/jobsdb/jobsdb.go b/jobsdb/jobsdb.go index 9cd75522d61..89fb4cf0687 100644 --- a/jobsdb/jobsdb.go +++ b/jobsdb/jobsdb.go @@ -28,6 +28,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "sort" "strconv" "strings" @@ -35,7 +36,6 @@ import ( "time" "unicode/utf8" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/samber/lo" diff --git a/processor/eventfilter/eventfilter.go b/processor/eventfilter/eventfilter.go index f34a3c516fb..a57eff9d3e4 100644 --- a/processor/eventfilter/eventfilter.go +++ b/processor/eventfilter/eventfilter.go @@ -1,10 +1,9 @@ package eventfilter import ( + "slices" "strings" - "golang.org/x/exp/slices" - "github.com/rudderlabs/rudder-go-kit/logger" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/processor/transformer" diff --git a/processor/processor.go b/processor/processor.go index 1a1862c2eda..2d7c95c4836 100644 --- a/processor/processor.go +++ b/processor/processor.go @@ -9,12 +9,12 @@ import ( "io" "net/http" "runtime/trace" + "slices" "strconv" "strings" "sync" "time" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" jsoniter "github.com/json-iterator/go" diff --git a/processor/transformer/transformer_test.go b/processor/transformer/transformer_test.go index 27416362014..e966db1ecec 100644 --- a/processor/transformer/transformer_test.go +++ b/processor/transformer/transformer_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/url" "os" + "slices" "strconv" "sync/atomic" "testing" @@ -17,8 +18,6 @@ import ( warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" - "golang.org/x/exp/slices" - "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/types" diff --git a/router/batchrouter/handle.go b/router/batchrouter/handle.go index 9420a002d4c..917365c4158 100644 --- a/router/batchrouter/handle.go +++ b/router/batchrouter/handle.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "path/filepath" + "slices" "sort" "strconv" "strings" @@ -20,7 +21,6 @@ import ( "github.com/samber/lo" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" diff --git a/router/batchrouter/handle_async.go b/router/batchrouter/handle_async.go index 7b4f4edcede..a556a0a8708 100644 --- a/router/batchrouter/handle_async.go +++ b/router/batchrouter/handle_async.go @@ -8,12 +8,12 @@ import ( "net/http" "os" "path/filepath" + "slices" "strconv" "time" "github.com/google/uuid" "github.com/tidwall/gjson" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-server/jobsdb" "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager" diff --git a/router/batchrouter/handle_lifecycle.go b/router/batchrouter/handle_lifecycle.go index f4488eb79e5..f23ab1f5a8e 100644 --- a/router/batchrouter/handle_lifecycle.go +++ b/router/batchrouter/handle_lifecycle.go @@ -8,14 +8,15 @@ import ( "net/http" "os" "path/filepath" + "slices" "strings" "sync" "time" + "golang.org/x/sync/errgroup" + "github.com/google/uuid" "github.com/tidwall/gjson" - "golang.org/x/exp/slices" - "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/bytesize" "github.com/rudderlabs/rudder-go-kit/config" diff --git a/router/batchrouter/util.go b/router/batchrouter/util.go index 9761be77011..ef23db27f32 100644 --- a/router/batchrouter/util.go +++ b/router/batchrouter/util.go @@ -3,12 +3,11 @@ package batchrouter import ( "context" "fmt" + "slices" "strings" "sync" "time" - "golang.org/x/exp/slices" - "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" diff --git a/router/batchrouter/worker.go b/router/batchrouter/worker.go index 1152d12e1ee..127f3a639b9 100644 --- a/router/batchrouter/worker.go +++ b/router/batchrouter/worker.go @@ -3,13 +3,13 @@ package batchrouter import ( "context" "fmt" + "slices" "strings" "sync" "time" "github.com/samber/lo" "github.com/tidwall/gjson" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" diff --git a/router/customdestinationmanager/customdestinationmanager.go b/router/customdestinationmanager/customdestinationmanager.go index f2877ea4549..2535d3ac9d7 100644 --- a/router/customdestinationmanager/customdestinationmanager.go +++ b/router/customdestinationmanager/customdestinationmanager.go @@ -6,11 +6,11 @@ import ( "errors" "fmt" "reflect" + "slices" "sync" "time" "github.com/sony/gobreaker" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" diff --git a/router/internal/partition/stats.go b/router/internal/partition/stats.go index 91fe989c53a..607af04ccd4 100644 --- a/router/internal/partition/stats.go +++ b/router/internal/partition/stats.go @@ -2,11 +2,11 @@ package partition import ( "math" + "slices" "sync" "time" "github.com/samber/lo" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-go-kit/stats/metric" ) diff --git a/router/manager/manager.go b/router/manager/manager.go index 7f04caed62e..0ad196b9844 100644 --- a/router/manager/manager.go +++ b/router/manager/manager.go @@ -3,8 +3,8 @@ package manager import ( "context" "fmt" + "slices" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/logger" diff --git a/router/utils/utils.go b/router/utils/utils.go index d2e5f0dada0..2ef478016e3 100644 --- a/router/utils/utils.go +++ b/router/utils/utils.go @@ -1,11 +1,10 @@ package utils import ( + "slices" "strings" "time" - "golang.org/x/exp/slices" - "github.com/tidwall/sjson" "github.com/rudderlabs/rudder-go-kit/config" diff --git a/router/worker.go b/router/worker.go index e0861d5f2a7..85affe34124 100644 --- a/router/worker.go +++ b/router/worker.go @@ -5,13 +5,12 @@ import ( "encoding/json" "fmt" "net/http" + "slices" "sort" "strconv" "strings" "time" - "golang.org/x/exp/slices" - "github.com/samber/lo" "github.com/tidwall/gjson" diff --git a/schema-forwarder/internal/forwarder/jobsforwarder.go b/schema-forwarder/internal/forwarder/jobsforwarder.go index 7135a35d861..c3cd1abb041 100644 --- a/schema-forwarder/internal/forwarder/jobsforwarder.go +++ b/schema-forwarder/internal/forwarder/jobsforwarder.go @@ -5,10 +5,10 @@ import ( "context" "encoding/json" "fmt" + "slices" "sync" "time" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" pulsarType "github.com/apache/pulsar-client-go/pulsar" diff --git a/services/debugger/source/eventUploader.go b/services/debugger/source/eventUploader.go index bf58570118a..ba32bf13943 100644 --- a/services/debugger/source/eventUploader.go +++ b/services/debugger/source/eventUploader.go @@ -4,11 +4,10 @@ import ( "context" "encoding/json" "fmt" + "slices" "sync" "time" - "golang.org/x/exp/slices" - "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" backendconfig "github.com/rudderlabs/rudder-server/backend-config" diff --git a/testhelper/clone.go b/testhelper/clone.go new file mode 100644 index 00000000000..8c37b67529a --- /dev/null +++ b/testhelper/clone.go @@ -0,0 +1,20 @@ +package testhelper + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Clone[T any](t testing.TB, v T) T { + t.Helper() + + buf, err := json.Marshal(v) + require.NoError(t, err) + + var clone T + require.NoError(t, json.Unmarshal(buf, &clone)) + + return clone +} diff --git a/testhelper/transformertest/handler_funcs.go b/testhelper/transformertest/handler_funcs.go index e6fc315915c..d517f221d17 100644 --- a/testhelper/transformertest/handler_funcs.go +++ b/testhelper/transformertest/handler_funcs.go @@ -92,3 +92,26 @@ func RESTJSONDestTransformerHandler(method, url string) func(request []transform } }) } + +// WarehouseTransformerHandler mirrors the request payload in the response but uses an error, status code along with warehouse compatible output +func WarehouseTransformerHandler(tableName string, code int, err string) TransformerHandler { + return func(request []transformer.TransformerEvent) (response []transformer.TransformerResponse) { + for i := range request { + req := request[i] + response = append(response, transformer.TransformerResponse{ + Metadata: req.Metadata, + Output: map[string]interface{}{ + "table": tableName, + "data": req.Message, + "metadata": map[string]interface{}{ + "table": tableName, + "columns": map[string]interface{}{}, + }, + }, + StatusCode: code, + Error: err, + }) + } + return + } +} diff --git a/warehouse/admin/admin.go b/warehouse/admin/admin.go index 29d78d218aa..ceb9f3e2560 100644 --- a/warehouse/admin/admin.go +++ b/warehouse/admin/admin.go @@ -31,34 +31,34 @@ type ConfigurationTestOutput struct { } type Admin struct { - csf connectionSourcesFetcher - suas startUploadAlwaysSetter - logger logger.Logger + connectionSources connectionSourcesFetcher + createUploadAlways createUploadAlwaysSetter + logger logger.Logger } type connectionSourcesFetcher interface { ConnectionSourcesMap(destID string) (map[string]model.Warehouse, bool) } -type startUploadAlwaysSetter interface { +type createUploadAlwaysSetter interface { Store(bool) } func New( - csf connectionSourcesFetcher, - suas startUploadAlwaysSetter, + connectionSources connectionSourcesFetcher, + createUploadAlways createUploadAlwaysSetter, logger logger.Logger, ) *Admin { return &Admin{ - csf: csf, - suas: suas, - logger: logger.Child("admin"), + connectionSources: connectionSources, + createUploadAlways: createUploadAlways, + logger: logger.Child("admin"), } } // TriggerUpload sets uploads to start without delay func (a *Admin) TriggerUpload(off bool, reply *string) error { - a.suas.Store(!off) + a.createUploadAlways.Store(!off) if off { *reply = "Turned off explicit warehouse upload triggers.\nWarehouse uploads will continue to be done as per schedule in control plane." } else { @@ -73,7 +73,7 @@ func (a *Admin) Query(s QueryInput, reply *warehouseutils.QueryResult) error { return errors.New("please specify the destination ID to query the warehouse") } - srcMap, ok := a.csf.ConnectionSourcesMap(s.DestID) + srcMap, ok := a.connectionSources.ConnectionSourcesMap(s.DestID) if !ok { return errors.New("please specify a valid and existing destination ID") } @@ -119,7 +119,7 @@ func (a *Admin) ConfigurationTest(s ConfigurationTestInput, reply *Configuration } var warehouse model.Warehouse - srcMap, ok := a.csf.ConnectionSourcesMap(s.DestID) + srcMap, ok := a.connectionSources.ConnectionSourcesMap(s.DestID) if !ok { return fmt.Errorf("please specify a valid and existing destinationID: %s", s.DestID) } diff --git a/warehouse/api/grpc.go b/warehouse/api/grpc.go index 37420d4cd97..a35182b5711 100644 --- a/warehouse/api/grpc.go +++ b/warehouse/api/grpc.go @@ -7,14 +7,19 @@ import ( "fmt" "net/http" "os" + "slices" + "strconv" "sync" "time" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-server/warehouse/bcm" "github.com/samber/lo" - "golang.org/x/exp/slices" "google.golang.org/genproto/googleapis/rpc/code" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -80,6 +85,7 @@ type GRPC struct { func NewGRPCServer( conf *config.Config, logger logger.Logger, + statsFactory stats.Stats, db *sqlmw.DB, tenantManager *multitenant.Manager, bcManager *bcm.BackendConfigManager, @@ -134,6 +140,9 @@ func NewGRPCServer( RetryInterval: 0, UseTLS: g.config.cpRouterUseTLS, Logger: g.logger, + Options: []grpc.ServerOption{ + grpc.UnaryInterceptor(statsInterceptor(statsFactory)), + }, RegisterService: func(srv *grpc.Server) { proto.RegisterWarehouseServer(srv, g) }, @@ -941,3 +950,20 @@ func (g *GRPC) RetryFailedBatches( } return resp, nil } + +func statsInterceptor(statsFactory stats.Stats) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + start := time.Now() + res, err := handler(ctx, req) + statusCode := codes.Unknown + if s, ok := status.FromError(err); ok { + statusCode = s.Code() + } + tags := stats.Tags{ + "reqType": info.FullMethod, + "code": strconv.Itoa(runtime.HTTPStatusFromCode(statusCode)), + } + statsFactory.NewTaggedStat("warehouse.grpc.response_time", stats.TimerType, tags).Since(start) + return res, err + } +} diff --git a/warehouse/api/grpc_test.go b/warehouse/api/grpc_test.go index 952e32a7b97..1d17e315e09 100644 --- a/warehouse/api/grpc_test.go +++ b/warehouse/api/grpc_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" "github.com/rudderlabs/rudder-server/warehouse/bcm" @@ -148,8 +148,8 @@ func TestGRPC(t *testing.T) { triggerStore := &sync.Map{} tenantManager := multitenant.New(c, mockBackendConfig) - bcManager := bcm.New(c, db, tenantManager, logger.NOP, stats.Default) - grpcServer, err := NewGRPCServer(c, logger.NOP, db, tenantManager, bcManager, triggerStore) + bcManager := bcm.New(c, db, tenantManager, logger.NOP, memstats.New()) + grpcServer, err := NewGRPCServer(c, logger.NOP, memstats.New(), db, tenantManager, bcManager, triggerStore) require.NoError(t, err) tcpPort, err := kithelper.GetFreePort() @@ -160,7 +160,7 @@ func TestGRPC(t *testing.T) { listener, err := net.Listen("tcp", tcpAddress) require.NoError(t, err) - server := grpc.NewServer(grpc.Creds(insecure.NewCredentials())) + server := grpc.NewServer(grpc.Creds(insecure.NewCredentials()), grpc.UnaryInterceptor(statsInterceptor(memstats.New()))) proto.RegisterWarehouseServer(server, grpcServer) g, gCtx := errgroup.WithContext(ctx) diff --git a/warehouse/api/http_test.go b/warehouse/api/http_test.go index e6a95545c96..75f83535586 100644 --- a/warehouse/api/http_test.go +++ b/warehouse/api/http_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/rudderlabs/rudder-server/warehouse/internal/mode" "github.com/rudderlabs/rudder-server/warehouse/bcm" @@ -37,7 +39,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" backendconfig "github.com/rudderlabs/rudder-server/backend-config" mocksBackendConfig "github.com/rudderlabs/rudder-server/mocks/backend-config" @@ -177,13 +178,13 @@ func TestHTTPApi(t *testing.T) { tenantManager := multitenant.New(c, mockBackendConfig) - bcManager := bcm.New(config.Default, db, tenantManager, logger.NOP, stats.Default) + bcManager := bcm.New(config.New(), db, tenantManager, logger.NOP, memstats.New()) triggerStore := &sync.Map{} ctx, stopTest := context.WithCancel(context.Background()) - n := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + n := notifier.New(config.New(), logger.NOP, memstats.New(), workspaceIdentifier) err = n.Setup(ctx, pgResource.DBDsn) require.NoError(t, err) @@ -192,7 +193,7 @@ func TestHTTPApi(t *testing.T) { db, n, ) - jobs.WithConfig(sourcesManager, config.Default) + jobs.WithConfig(sourcesManager, config.New()) g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { @@ -411,7 +412,7 @@ func TestHTTPApi(t *testing.T) { c := config.New() c.Set("Warehouse.runningMode", tc.runningMode) - a := NewApi(tc.mode, c, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(tc.mode, c, logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.healthHandler(resp, req) var healthBody map[string]string @@ -429,7 +430,7 @@ func TestHTTPApi(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events", bytes.NewReader([]byte(`"Invalid payload"`))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -447,7 +448,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -465,7 +466,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -483,7 +484,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusServiceUnavailable, resp.Code) @@ -501,7 +502,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -527,7 +528,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -557,7 +558,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -579,7 +580,7 @@ func TestHTTPApi(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/internal/v1/warehouse/fetch-tables", bytes.NewReader([]byte(`"Invalid payload"`))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.fetchTablesHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -596,7 +597,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.fetchTablesHandler(resp, req) require.Equal(t, http.StatusInternalServerError, resp.Code) @@ -618,7 +619,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.fetchTablesHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -641,7 +642,7 @@ func TestHTTPApi(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/trigger-upload", bytes.NewReader([]byte(`"Invalid payload"`))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -659,7 +660,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -677,7 +678,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusServiceUnavailable, resp.Code) @@ -695,7 +696,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -713,7 +714,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -735,7 +736,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, config.New(), logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -759,7 +760,7 @@ func TestHTTPApi(t *testing.T) { srvCtx, stopServer := context.WithCancel(ctx) - a := NewApi(config.MasterMode, c, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, c, logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) serverSetupCh := make(chan struct{}) go func() { @@ -957,7 +958,7 @@ func TestHTTPApi(t *testing.T) { srvCtx, stopServer := context.WithCancel(ctx) - a := NewApi(config.MasterMode, c, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) + a := NewApi(config.MasterMode, c, logger.NOP, memstats.New(), mockBackendConfig, db, n, tenantManager, bcManager, sourcesManager, triggerStore) serverSetupCh := make(chan struct{}) go func() { diff --git a/warehouse/app.go b/warehouse/app.go index cdabaf13788..05c8a4b3d37 100644 --- a/warehouse/app.go +++ b/warehouse/app.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "os" + "slices" "sync" + "sync/atomic" "time" "github.com/rudderlabs/rudder-server/warehouse/internal/mode" @@ -25,7 +27,6 @@ import ( "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/samber/lo" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/config" @@ -69,6 +70,7 @@ type App struct { sourcesManager *jobs.AsyncJobWh admin *whadmin.Admin triggerStore *sync.Map + createUploadAlways *atomic.Bool appName string @@ -132,6 +134,7 @@ func (a *App) Setup(ctx context.Context) error { return fmt.Errorf("setting up database: %w", err) } + a.createUploadAlways = &atomic.Bool{} a.triggerStore = &sync.Map{} a.tenantManager = multitenant.New( a.conf, @@ -181,6 +184,7 @@ func (a *App) Setup(ctx context.Context) error { a.grpcServer, err = api.NewGRPCServer( a.conf, a.logger, + a.statsFactory, a.db, a.tenantManager, a.bcManager, @@ -205,7 +209,7 @@ func (a *App) Setup(ctx context.Context) error { ) a.admin = whadmin.New( a.bcManager, - &router.StartUploadAlways, + a.createUploadAlways, a.logger, ) @@ -483,6 +487,7 @@ func (a *App) onConfigDataEvent( a.bcManager, a.encodingFactory, a.triggerStore, + a.createUploadAlways, ) if err != nil { return fmt.Errorf("setup warehouse %q: %w", destination.DestinationDefinition.Name, err) diff --git a/warehouse/app_test.go b/warehouse/app_test.go index 0faaa688d3c..d8a6d9746f7 100644 --- a/warehouse/app_test.go +++ b/warehouse/app_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/rudderlabs/rudder-server/warehouse/internal/mode" "github.com/ory/dockertest/v3" @@ -30,7 +32,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/postgres" @@ -131,7 +132,7 @@ func TestApp(t *testing.T) { c.Set("Warehouse.runningMode", subTC.runningMode) c.Set("Warehouse.webPort", webPort) - a := New(mockApp, c, logger.NOP, stats.Default, &bcConfig.NOOP{}, filemanager.New) + a := New(mockApp, c, logger.NOP, memstats.New(), &bcConfig.NOOP{}, filemanager.New) err = a.Setup(ctx) require.NoError(t, err) @@ -210,7 +211,7 @@ func TestApp(t *testing.T) { return ch }).AnyTimes() - a := New(mockApp, c, logger.NOP, stats.Default, mockBackendConfig, filemanager.New) + a := New(mockApp, c, logger.NOP, memstats.New(), mockBackendConfig, filemanager.New) err = a.Setup(ctx) require.NoError(t, err) @@ -278,7 +279,7 @@ func TestApp(t *testing.T) { c.Set("WAREHOUSE_JOBS_DB_PASSWORD", pgResource.Password) c.Set("WAREHOUSE_JOBS_DB_DB_NAME", pgResource.Database) - a := New(mockApp, c, logger.NOP, stats.Default, &bcConfig.NOOP{}, filemanager.New) + a := New(mockApp, c, logger.NOP, memstats.New(), &bcConfig.NOOP{}, filemanager.New) err = a.Setup(context.Background()) require.EqualError(t, err, "setting up database: warehouse Service needs postgres version >= 10. Exiting") }) @@ -290,7 +291,7 @@ func TestApp(t *testing.T) { c.Set("WAREHOUSE_JOBS_DB_PASSWORD", "ubuntu") c.Set("WAREHOUSE_JOBS_DB_DB_NAME", "ubuntu") - a := New(mockApp, c, logger.NOP, stats.Default, &bcConfig.NOOP{}, filemanager.New) + a := New(mockApp, c, logger.NOP, memstats.New(), &bcConfig.NOOP{}, filemanager.New) err = a.Setup(context.Background()) require.ErrorContains(t, err, "setting up database: could not check compatibility:") }) @@ -305,7 +306,7 @@ func TestApp(t *testing.T) { c.Set("DB.password", pgResource.Password) c.Set("DB.name", pgResource.Database) - a := New(mockApp, c, logger.NOP, stats.Default, &bcConfig.NOOP{}, filemanager.New) + a := New(mockApp, c, logger.NOP, memstats.New(), &bcConfig.NOOP{}, filemanager.New) err = a.Setup(context.Background()) require.NoError(t, err) }) @@ -386,7 +387,7 @@ func TestApp(t *testing.T) { return ch }).AnyTimes() - a := New(mockApp, c, logger.NOP, stats.Default, mockBackendConfig, filemanager.New) + a := New(mockApp, c, logger.NOP, memstats.New(), mockBackendConfig, filemanager.New) require.NoError(t, a.Setup(ctx)) require.NoError(t, a.monitorDestRouters(ctx)) }) @@ -417,7 +418,7 @@ func TestApp(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - a := New(mockApp, c, logger.NOP, stats.Default, &bcConfig.NOOP{}, filemanager.New) + a := New(mockApp, c, logger.NOP, memstats.New(), &bcConfig.NOOP{}, filemanager.New) err = a.Setup(ctx) require.NoError(t, err) @@ -475,7 +476,7 @@ func TestApp(t *testing.T) { mockLogger.EXPECT().Info(gomock.Any()).AnyTimes() mockLogger.EXPECT().Infof(gomock.Any()).AnyTimes() - a := New(mockApp, c, mockLogger, stats.Default, &bcConfig.NOOP{}, filemanager.New) + a := New(mockApp, c, mockLogger, memstats.New(), &bcConfig.NOOP{}, filemanager.New) err = a.Setup(ctx) require.NoError(t, err) diff --git a/warehouse/archive/archiver_test.go b/warehouse/archive/archiver_test.go index 431e06219b5..3801ca0f55f 100644 --- a/warehouse/archive/archiver_test.go +++ b/warehouse/archive/archiver_test.go @@ -159,7 +159,7 @@ func TestArchiver(t *testing.T) { db := sqlmw.New(pgResource.DB) archiver := archive.New( - config.Default, + config.New(), logger.NOP, mockStats, db, diff --git a/warehouse/bcm/backend_config_test.go b/warehouse/bcm/backend_config_test.go index 3bf8e7da2c3..171e1e9aa2f 100644 --- a/warehouse/bcm/backend_config_test.go +++ b/warehouse/bcm/backend_config_test.go @@ -8,7 +8,7 @@ import ( "os" "testing" - "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" "github.com/rudderlabs/rudder-server/warehouse/multitenant" @@ -112,10 +112,10 @@ func TestBackendConfigManager(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tenantManager := multitenant.New(config.Default, mockBackendConfig) + tenantManager := multitenant.New(config.New(), mockBackendConfig) t.Run("Subscriptions", func(t *testing.T) { - bcm := New(c, db, tenantManager, logger.NOP, stats.Default) + bcm := New(c, db, tenantManager, logger.NOP, memstats.New()) require.False(t, bcm.IsInitialized()) require.Equal(t, bcm.Connections(), map[string]map[string]model.Warehouse{}) @@ -191,7 +191,7 @@ func TestBackendConfigManager(t *testing.T) { }) t.Run("Tunnelling", func(t *testing.T) { - bcm := New(c, db, tenantManager, logger.NOP, stats.Default) + bcm := New(c, db, tenantManager, logger.NOP, memstats.New()) testCases := []struct { name string @@ -261,7 +261,7 @@ func TestBackendConfigManager(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) numSubscribers := 1000 - bcm := New(c, db, tenantManager, logger.NOP, stats.Default) + bcm := New(c, db, tenantManager, logger.NOP, memstats.New()) subscriptionsChs := make([]<-chan []model.Warehouse, numSubscribers) for i := 0; i < numSubscribers; i++ { @@ -459,11 +459,11 @@ func TestBackendConfigManager_Namespace(t *testing.T) { require.NoError(t, err) tenantManager := multitenant.New( - config.Default, + config.New(), backendconfig.DefaultBackendConfig, ) - bcm := New(c, db, tenantManager, logger.NOP, stats.Default) + bcm := New(c, db, tenantManager, logger.NOP, memstats.New()) namespace := bcm.namespace(context.Background(), tc.source, tc.destination) require.Equal(t, tc.expectedNamespace, namespace) diff --git a/warehouse/encoding/encoding_test.go b/warehouse/encoding/encoding_test.go index 7822fb9a2b4..d734df192a5 100644 --- a/warehouse/encoding/encoding_test.go +++ b/warehouse/encoding/encoding_test.go @@ -58,7 +58,7 @@ func TestReaderLoader(t *testing.T) { ) t.Log("Parquet", outputFilePath) - ef := encoding.NewFactory(config.Default) + ef := encoding.NewFactory(config.New()) writer, err := ef.NewLoadFileWriter(loadFileType, outputFilePath, schema, destinationType) require.NoError(t, err) @@ -205,7 +205,7 @@ func TestReaderLoader(t *testing.T) { lines = 100 ) - ef := encoding.NewFactory(config.Default) + ef := encoding.NewFactory(config.New()) writer, err := ef.NewLoadFileWriter(loadFileType, outputFilePath, nil, destinationType) require.NoError(t, err) @@ -267,7 +267,7 @@ func TestReaderLoader(t *testing.T) { lines = 100 ) - ef := encoding.NewFactory(config.Default) + ef := encoding.NewFactory(config.New()) writer, err := ef.NewLoadFileWriter(loadFileType, outputFilePath, nil, destinationType) require.NoError(t, err) @@ -320,7 +320,7 @@ func TestReaderLoader(t *testing.T) { }) t.Run("Empty files", func(t *testing.T) { - ef := encoding.NewFactory(config.Default) + ef := encoding.NewFactory(config.New()) t.Run("csv", func(t *testing.T) { destinationType := warehouseutils.RS diff --git a/warehouse/integrations/azure-synapse/azure-synapse.go b/warehouse/integrations/azure-synapse/azure-synapse.go index 2c61e9087e7..ef6a08e546b 100644 --- a/warehouse/integrations/azure-synapse/azure-synapse.go +++ b/warehouse/integrations/azure-synapse/azure-synapse.go @@ -11,6 +11,7 @@ import ( "net" "net/url" "os" + "slices" "strconv" "strings" "time" @@ -25,8 +26,6 @@ import ( sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/logfield" - "golang.org/x/exp/slices" - "github.com/rudderlabs/rudder-server/warehouse/internal/service/loadfiles/downloader" "github.com/rudderlabs/rudder-go-kit/config" diff --git a/warehouse/integrations/azure-synapse/azure_synapse_test.go b/warehouse/integrations/azure-synapse/azure_synapse_test.go index 785db12f035..0752bd9fa99 100644 --- a/warehouse/integrations/azure-synapse/azure_synapse_test.go +++ b/warehouse/integrations/azure-synapse/azure_synapse_test.go @@ -10,12 +10,13 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/golang/mock/gomock" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" azuresynapse "github.com/rudderlabs/rudder-server/warehouse/integrations/azure-synapse" mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -346,7 +347,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) err := az.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -362,7 +363,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) err := az.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -382,7 +383,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) err := az.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -429,7 +430,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) err := az.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -474,7 +475,7 @@ func TestIntegration(t *testing.T) { }} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) err := az.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -496,7 +497,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) err := az.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -518,7 +519,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) err := az.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -562,7 +563,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) err := az.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -677,7 +678,7 @@ func TestAzureSynapse_ProcessColumnValue(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + az := azuresynapse.New(config.New(), logger.NOP, memstats.New()) value, err := az.ProcessColumnValue(tc.data, tc.dataType) if tc.wantError { diff --git a/warehouse/integrations/bigquery/bigquery.go b/warehouse/integrations/bigquery/bigquery.go index 2f1bb5366d3..d252f17bec2 100644 --- a/warehouse/integrations/bigquery/bigquery.go +++ b/warehouse/integrations/bigquery/bigquery.go @@ -6,15 +6,12 @@ import ( "errors" "fmt" "regexp" + "slices" "strings" "time" - "github.com/rudderlabs/rudder-server/warehouse/integrations/types" - - "github.com/samber/lo" - "cloud.google.com/go/bigquery" - "golang.org/x/exp/slices" + "github.com/samber/lo" bqService "google.golang.org/api/bigquery/v2" "google.golang.org/api/googleapi" "google.golang.org/api/iterator" @@ -26,6 +23,7 @@ import ( "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/client" "github.com/rudderlabs/rudder-server/warehouse/integrations/bigquery/middleware" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/logfield" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -43,8 +41,7 @@ type BigQuery struct { config struct { setUsersLoadPartitionFirstEventFilter bool customPartitionsEnabled bool - isUsersTableDedupEnabled bool - isDedupEnabled bool + allowMerge bool enableDeleteByJobs bool customPartitionsEnabledWorkspaceIDs []string slowQueryThreshold time.Duration @@ -140,8 +137,7 @@ func New(conf *config.Config, log logger.Logger) *BigQuery { bq.config.setUsersLoadPartitionFirstEventFilter = conf.GetBool("Warehouse.bigquery.setUsersLoadPartitionFirstEventFilter", true) bq.config.customPartitionsEnabled = conf.GetBool("Warehouse.bigquery.customPartitionsEnabled", false) - bq.config.isUsersTableDedupEnabled = conf.GetBool("Warehouse.bigquery.isUsersTableDedupEnabled", false) - bq.config.isDedupEnabled = conf.GetBool("Warehouse.bigquery.isDedupEnabled", false) + bq.config.allowMerge = conf.GetBool("Warehouse.bigquery.allowMerge", true) bq.config.enableDeleteByJobs = conf.GetBool("Warehouse.bigquery.enableDeleteByJobs", false) bq.config.customPartitionsEnabledWorkspaceIDs = conf.GetStringSlice("Warehouse.bigquery.customPartitionsEnabledWorkspaceIDs", nil) bq.config.slowQueryThreshold = conf.GetDuration("Warehouse.bigquery.slowQueryThreshold", 5, time.Minute) @@ -163,6 +159,7 @@ func (bq *BigQuery) getMiddleware() *middleware.Client { logfield.DestinationType, bq.warehouse.Destination.DestinationDefinition.Name, logfield.WorkspaceID, bq.warehouse.WorkspaceID, logfield.Schema, bq.namespace, + logfield.ShouldMerge, bq.shouldMerge(), ), middleware.WithSlowQueryThreshold(bq.config.slowQueryThreshold), ) @@ -193,23 +190,21 @@ func (bq *BigQuery) CreateTable(ctx context.Context, tableName string, columnMap return fmt.Errorf("create table: %w", err) } - if !bq.dedupEnabled() { - if err = bq.createTableView(ctx, tableName, columnMap); err != nil { - return fmt.Errorf("create view: %w", err) - } + if err = bq.createTableView(ctx, tableName, columnMap); err != nil { + return fmt.Errorf("create view: %w", err) } + return nil } -func (bq *BigQuery) DropTable(ctx context.Context, tableName string) (err error) { - err = bq.DeleteTable(ctx, tableName) - if err != nil { - return +func (bq *BigQuery) DropTable(ctx context.Context, tableName string) error { + if err := bq.DeleteTable(ctx, tableName); err != nil { + return fmt.Errorf("cannot delete table %q: %w", tableName, err) } - if !bq.dedupEnabled() { - err = bq.DeleteTable(ctx, tableName+"_view") + if err := bq.DeleteTable(ctx, tableName+"_view"); err != nil { + return fmt.Errorf("cannot delete table %q: %w", tableName+"_view", err) } - return + return nil } func (bq *BigQuery) createTableView(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { @@ -225,9 +220,16 @@ func (bq *BigQuery) createTableView(ctx context.Context, tableName string, colum // assuming it has field named id upon which dedup is done in view viewQuery := `SELECT * EXCEPT (__row_number) FROM ( - SELECT *, ROW_NUMBER() OVER (PARTITION BY ` + partitionKey + viewOrderByStmt + `) AS __row_number FROM ` + "`" + bq.projectID + "." + bq.namespace + "." + tableName + "`" + ` WHERE _PARTITIONTIME BETWEEN TIMESTAMP_TRUNC(TIMESTAMP_MICROS(UNIX_MICROS(CURRENT_TIMESTAMP()) - 60 * 60 * 60 * 24 * 1000000), DAY, 'UTC') - AND TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, 'UTC') - ) + SELECT *, ROW_NUMBER() OVER (PARTITION BY ` + partitionKey + viewOrderByStmt + `) AS __row_number + FROM ` + "`" + bq.projectID + "." + bq.namespace + "." + tableName + "`" + ` + WHERE + _PARTITIONTIME BETWEEN TIMESTAMP_TRUNC( + TIMESTAMP_MICROS(UNIX_MICROS(CURRENT_TIMESTAMP()) - 60 * 60 * 60 * 24 * 1000000), + DAY, + 'UTC' + ) + AND TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, 'UTC') + ) WHERE __row_number = 1` metaData := &bigquery.TableMetadata{ ViewQuery: viewQuery, @@ -241,7 +243,8 @@ func (bq *BigQuery) schemaExists(ctx context.Context, _, _ string) (exists bool, ds := bq.db.Dataset(bq.namespace) _, err = ds.Metadata(ctx) if err != nil { - if e, ok := err.(*googleapi.Error); ok && e.Code == 404 { + var e *googleapi.Error + if errors.As(err, &e) && e.Code == 404 { bq.logger.Debugf("BQ: Dataset %s not found", bq.namespace) return false, nil } @@ -275,7 +278,8 @@ func (bq *BigQuery) CreateSchema(ctx context.Context) (err error) { bq.logger.Infof("BQ: Creating schema: %s ...", bq.namespace) err = ds.Create(ctx, meta) if err != nil { - if e, ok := err.(*googleapi.Error); ok && e.Code == 409 { + var e *googleapi.Error + if errors.As(err, &e) && e.Code == 409 { bq.logger.Infof("BQ: Create schema %s failed as schema already exists", bq.namespace) return nil } @@ -285,7 +289,8 @@ func (bq *BigQuery) CreateSchema(ctx context.Context) (err error) { func checkAndIgnoreAlreadyExistError(err error) bool { if err != nil { - if e, ok := err.(*googleapi.Error); ok { + var e *googleapi.Error + if errors.As(err, &e) { // 409 is returned when we try to create a table that already exists // 400 is returned for all kinds of invalid input - so we need to check the error message too if e.Code == 409 || (e.Code == 400 && strings.Contains(e.Message, "already exists in schema")) { @@ -384,14 +389,14 @@ func (bq *BigQuery) loadTable( gcsRef.MaxBadRecords = 0 gcsRef.IgnoreUnknownValues = false - if bq.dedupEnabled() { + if bq.shouldMerge() { return bq.loadTableByMerge(ctx, tableName, gcsRef, log, skipTempTableDelete) } return bq.loadTableByAppend(ctx, tableName, gcsRef, log) } func (bq *BigQuery) loadTableStrategy() string { - if bq.dedupEnabled() { + if bq.shouldMerge() { return "MERGE" } return "APPEND" @@ -599,8 +604,7 @@ func (bq *BigQuery) loadTableByMerge( SET %[6]s WHEN NOT MATCHED THEN INSERT (%[4]s) VALUES - (%[5]s); -`, + (%[5]s);`, bqTable(tableName), bqTable(stagingTableName), primaryJoinClause, @@ -646,7 +650,7 @@ func (bq *BigQuery) loadTableByMerge( func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]error) { errorMap = map[string]error{warehouseutils.IdentifiesTable: nil} - bq.logger.Infof("BQ: Starting load for identifies and users tables\n") + bq.logger.Infof("BQ: Starting load for identifies and users tables") _, identifyLoadTable, err := bq.loadTable(ctx, warehouseutils.IdentifiesTable, true) if err != nil { errorMap[warehouseutils.IdentifiesTable] = err @@ -704,10 +708,12 @@ func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]err bqIdentifiesTable := bqTable(warehouseutils.IdentifiesTable) partition := fmt.Sprintf("TIMESTAMP('%s')", identifyLoadTable.partitionDate) var identifiesFrom string - if bq.dedupEnabled() { - identifiesFrom = fmt.Sprintf(`%s WHERE user_id IS NOT NULL %s`, bqTable(identifyLoadTable.stagingTableName), loadedAtFilter()) + if bq.shouldMerge() { + identifiesFrom = fmt.Sprintf(`%s WHERE user_id IS NOT NULL %s`, + bqTable(identifyLoadTable.stagingTableName), loadedAtFilter()) } else { - identifiesFrom = fmt.Sprintf(`%s WHERE _PARTITIONTIME = %s AND user_id IS NOT NULL %s`, bqIdentifiesTable, partition, loadedAtFilter()) + identifiesFrom = fmt.Sprintf(`%s WHERE _PARTITIONTIME = %s AND user_id IS NOT NULL %s`, + bqIdentifiesTable, partition, loadedAtFilter()) } sqlStatement := fmt.Sprintf(`SELECT DISTINCT * FROM ( SELECT id, %[1]s FROM ( @@ -824,7 +830,7 @@ func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]err } } - if !bq.dedupEnabled() { + if !bq.shouldMerge() { loadUserTableByAppend() return } @@ -847,28 +853,24 @@ func Connect(context context.Context, cred *BQCredentials) (*bigquery.Client, er } opts = append(opts, option.WithCredentialsJSON(credBytes)) } - client, err := bigquery.NewClient(context, cred.ProjectID, opts...) - return client, err + c, err := bigquery.NewClient(context, cred.ProjectID, opts...) + return c, err } func (bq *BigQuery) connect(ctx context.Context, cred BQCredentials) (*bigquery.Client, error) { bq.logger.Infof("BQ: Connecting to BigQuery in project: %s", cred.ProjectID) - client, err := Connect(ctx, &cred) - return client, err + c, err := Connect(ctx, &cred) + return c, err } -func (bq *BigQuery) dedupEnabled() bool { - return bq.config.isDedupEnabled || bq.config.isUsersTableDedupEnabled +// shouldMerge returns true if: +// * the server config says we allow merging +// * the user opted in to merging +func (bq *BigQuery) shouldMerge() bool { + return bq.config.allowMerge && bq.warehouse.GetBoolDestinationConfig(model.EnableMergeSetting) } func (bq *BigQuery) CrashRecover(ctx context.Context) { - if !bq.dedupEnabled() { - return - } - bq.dropDanglingStagingTables(ctx) -} - -func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { sqlStatement := fmt.Sprintf(` SELECT table_name @@ -876,8 +878,7 @@ func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { %[1]s.INFORMATION_SCHEMA.TABLES WHERE table_schema = '%[1]s' - AND table_name LIKE '%[2]s'; - `, + AND table_name LIKE '%[2]s';`, bq.namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), ) @@ -885,7 +886,7 @@ func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { it, err := bq.getMiddleware().Read(ctx, query) if err != nil { bq.logger.Errorf("WH: BQ: Error dropping dangling staging tables in BQ: %v\nQuery: %s\n", err, sqlStatement) - return false + return } var stagingTableNames []string @@ -897,22 +898,20 @@ func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { break } bq.logger.Errorf("BQ: Error in processing fetched staging tables from information schema in dataset %v : %v", bq.namespace, err) - return false + return } if _, ok := values[0].(string); ok { stagingTableNames = append(stagingTableNames, values[0].(string)) } } bq.logger.Infof("WH: PG: Dropping dangling staging tables: %+v %+v\n", len(stagingTableNames), stagingTableNames) - delSuccess := true for _, stagingTableName := range stagingTableNames { err := bq.DeleteTable(ctx, stagingTableName) if err != nil { bq.logger.Errorf("WH: BQ: Error dropping dangling staging table: %s in BQ: %v", stagingTableName, err) - delSuccess = false + return } } - return delSuccess } func (bq *BigQuery) IsEmpty( @@ -1040,7 +1039,8 @@ func (bq *BigQuery) FetchSchema(ctx context.Context) (model.Schema, model.Schema it, err := bq.getMiddleware().Read(ctx, query) if err != nil { - if e, ok := err.(*googleapi.Error); ok && e.Code == 404 { + var e *googleapi.Error + if errors.As(err, &e) && e.Code == 404 { // if dataset resource is not found, return empty schema return schema, unrecognizedSchema, nil } @@ -1109,7 +1109,8 @@ func (bq *BigQuery) tableExists(ctx context.Context, tableName string) (exists b if err == nil { return true, nil } - if e, ok := err.(*googleapi.Error); ok { + var e *googleapi.Error + if errors.As(err, &e) { if e.Code == 404 { return false, nil } diff --git a/warehouse/integrations/bigquery/bigquery_test.go b/warehouse/integrations/bigquery/bigquery_test.go index fb384c9de7b..9508076d63d 100644 --- a/warehouse/integrations/bigquery/bigquery_test.go +++ b/warehouse/integrations/bigquery/bigquery_test.go @@ -5,39 +5,36 @@ import ( "encoding/json" "fmt" "os" + "slices" "strconv" "strings" "testing" "time" - "golang.org/x/exp/slices" - - "github.com/golang/mock/gomock" - - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/filemanager" - "github.com/rudderlabs/rudder-go-kit/logger" - mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "cloud.google.com/go/bigquery" - "google.golang.org/api/option" - + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/api/option" "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/compose-test/testcompose" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" + th "github.com/rudderlabs/rudder-server/testhelper" "github.com/rudderlabs/rudder-server/testhelper/health" "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/client" whbigquery "github.com/rudderlabs/rudder-server/warehouse/integrations/bigquery" bqHelper "github.com/rudderlabs/rudder-server/warehouse/integrations/bigquery/testhelper" - "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + whth "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/rudderlabs/rudder-server/warehouse/validations" ) @@ -69,11 +66,9 @@ func TestIntegration(t *testing.T) { sourcesSourceID := warehouseutils.RandHex() sourcesDestinationID := warehouseutils.RandHex() sourcesWriteKey := warehouseutils.RandHex() - destType := warehouseutils.BQ - - namespace := testhelper.RandSchema(destType) - sourcesNamespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) + sourcesNamespace := whth.RandSchema(destType) bqTestCredentials, err := bqHelper.GetBQTestCredentials() require.NoError(t, err) @@ -83,71 +78,63 @@ func TestIntegration(t *testing.T) { escapedCredentialsTrimmedStr := strings.Trim(string(escapedCredentials), `"`) - templateConfigurations := map[string]any{ - "workspaceID": workspaceID, - "sourceID": sourceID, - "destinationID": destinationID, - "writeKey": writeKey, - "sourcesSourceID": sourcesSourceID, - "sourcesDestinationID": sourcesDestinationID, - "sourcesWriteKey": sourcesWriteKey, - "namespace": namespace, - "project": bqTestCredentials.ProjectID, - "location": bqTestCredentials.Location, - "bucketName": bqTestCredentials.BucketName, - "credentials": escapedCredentialsTrimmedStr, - "sourcesNamespace": sourcesNamespace, - } - workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - - testhelper.EnhanceWithDefaultEnvs(t) - t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_MAX_PARALLEL_LOADS", "8") - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_ENABLE_DELETE_BY_JOBS", "true") - t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) - t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_SLOW_QUERY_THRESHOLD", "0s") - - svcDone := make(chan struct{}) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - r := runner.New(runner.ReleaseInfo{}) - _ = r.Run(ctx, []string{"bigquery-integration-test"}) - - close(svcDone) - }() - t.Cleanup(func() { <-svcDone }) + bootstrapSvc := func(t *testing.T, enableMerge bool) *bigquery.Client { + templateConfigurations := map[string]any{ + "workspaceID": workspaceID, + "sourceID": sourceID, + "destinationID": destinationID, + "writeKey": writeKey, + "sourcesSourceID": sourcesSourceID, + "sourcesDestinationID": sourcesDestinationID, + "sourcesWriteKey": sourcesWriteKey, + "namespace": namespace, + "project": bqTestCredentials.ProjectID, + "location": bqTestCredentials.Location, + "bucketName": bqTestCredentials.BucketName, + "credentials": escapedCredentialsTrimmedStr, + "sourcesNamespace": sourcesNamespace, + "enableMerge": enableMerge, + } + workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) + + whth.EnhanceWithDefaultEnvs(t) + t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) + t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) + t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_ENABLE_DELETE_BY_JOBS", "true") + t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) + t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) + t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_SLOW_QUERY_THRESHOLD", "0s") + + svcDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + r := runner.New(runner.ReleaseInfo{}) + _ = r.Run(ctx, []string{"bigquery-integration-test"}) + close(svcDone) + }() + + t.Cleanup(func() { <-svcDone }) + t.Cleanup(cancel) + + serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) + health.WaitUntilReady(ctx, t, + serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint", + ) - serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) - health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint") + db, err := bigquery.NewClient(ctx, + bqTestCredentials.ProjectID, + option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) - db, err := bigquery.NewClient( - ctx, - bqTestCredentials.ProjectID, option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), - ) - require.NoError(t, err) + return db + } t.Run("Event flow", func(t *testing.T) { - jobsDB := testhelper.JobsDB(t, jobsDBPort) - - t.Cleanup(func() { - for _, dataset := range []string{namespace, sourcesNamespace} { - require.Eventually(t, func() bool { - if err := db.Dataset(dataset).DeleteWithContents(ctx); err != nil { - t.Logf("error deleting dataset: %v", err) - return false - } - return true - }, - time.Minute, - time.Second, - ) - } - }) + jobsDB := whth.JobsDB(t, jobsDBPort) testcase := []struct { name string @@ -156,15 +143,15 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string tables []string - stagingFilesEventsMap testhelper.EventsCountMap - stagingFilesModifiedEventsMap testhelper.EventsCountMap - loadFilesEventsMap testhelper.EventsCountMap - tableUploadsEventsMap testhelper.EventsCountMap - warehouseEventsMap testhelper.EventsCountMap + stagingFilesEventsMap whth.EventsCountMap + stagingFilesModifiedEventsMap whth.EventsCountMap + loadFilesEventsMap whth.EventsCountMap + tableUploadsEventsMap whth.EventsCountMap + warehouseEventsMap whth.EventsCountMap asyncJob bool skipModifiedEvents bool - prerequisite func(t testing.TB) - isDedupEnabled bool + prerequisite func(context.Context, testing.TB, *bigquery.Client) + enableMerge bool customPartitionsEnabledWorkspaceIDs string stagingFilePrefix string }{ @@ -182,10 +169,9 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap: loadFilesEventsMap(), tableUploadsEventsMap: tableUploadsEventsMap(), warehouseEventsMap: mergeEventsMap(), - isDedupEnabled: true, - prerequisite: func(t testing.TB) { + enableMerge: true, + prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/upload-job-merge-mode", @@ -197,20 +183,19 @@ func TestIntegration(t *testing.T) { destinationID: sourcesDestinationID, schema: sourcesNamespace, tables: []string{"tracks", "google_sheet"}, - stagingFilesEventsMap: testhelper.EventsCountMap{ + stagingFilesEventsMap: whth.EventsCountMap{ "wh_staging_files": 9, // 8 + 1 (merge events because of ID resolution) }, - stagingFilesModifiedEventsMap: testhelper.EventsCountMap{ + stagingFilesModifiedEventsMap: whth.EventsCountMap{ "wh_staging_files": 8, // 8 (de-duped by encounteredMergeRuleMap) }, - loadFilesEventsMap: testhelper.SourcesLoadFilesEventsMap(), - tableUploadsEventsMap: testhelper.SourcesTableUploadsEventsMap(), - warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), + loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), + tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), + warehouseEventsMap: whth.SourcesWarehouseEventsMap(), asyncJob: true, - isDedupEnabled: false, - prerequisite: func(t testing.TB) { + enableMerge: false, + prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/sources-job", @@ -230,10 +215,9 @@ func TestIntegration(t *testing.T) { tableUploadsEventsMap: tableUploadsEventsMap(), warehouseEventsMap: appendEventsMap(), skipModifiedEvents: true, - isDedupEnabled: false, - prerequisite: func(t testing.TB) { + enableMerge: false, + prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/upload-job-append-mode", @@ -253,9 +237,9 @@ func TestIntegration(t *testing.T) { tableUploadsEventsMap: tableUploadsEventsMap(), warehouseEventsMap: appendEventsMap(), skipModifiedEvents: true, - isDedupEnabled: false, + enableMerge: false, customPartitionsEnabledWorkspaceIDs: workspaceID, - prerequisite: func(t testing.TB) { + prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() _ = db.Dataset(namespace).DeleteWithContents(ctx) @@ -285,13 +269,33 @@ func TestIntegration(t *testing.T) { for _, tc := range testcase { tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_IS_DEDUP_ENABLED", strconv.FormatBool(tc.isDedupEnabled)) - t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_CUSTOM_PARTITIONS_ENABLED_WORKSPACE_IDS", tc.customPartitionsEnabledWorkspaceIDs) + t.Setenv( + "RSERVER_WAREHOUSE_BIGQUERY_CUSTOM_PARTITIONS_ENABLED_WORKSPACE_IDS", + tc.customPartitionsEnabledWorkspaceIDs, + ) + db := bootstrapSvc(t, tc.enableMerge) + + t.Cleanup(func() { + for _, dataset := range []string{tc.schema} { + t.Logf("Cleaning up dataset %s.%s", tc.schema, dataset) + require.Eventually(t, + func() bool { + err := db.Dataset(dataset).DeleteWithContents(context.Background()) + if err != nil { + t.Logf("Error deleting dataset %s.%s: %v", tc.schema, dataset, err) + return false + } + return true + }, + time.Minute, + time.Second, + ) + } + }) if tc.prerequisite != nil { - tc.prerequisite(t) + tc.prerequisite(context.Background(), t, db) } sqlClient := &client.Client{ @@ -305,7 +309,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -324,7 +328,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) @@ -333,7 +337,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -353,7 +357,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } if tc.asyncJob { ts2.UserID = ts1.UserID @@ -364,14 +368,22 @@ func TestIntegration(t *testing.T) { }) t.Run("Validations", func(t *testing.T) { + ctx := context.Background() + db, err := bigquery.NewClient(ctx, + bqTestCredentials.ProjectID, + option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) t.Cleanup(func() { - require.Eventually(t, func() bool { - if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { - t.Logf("error deleting dataset: %v", err) - return false - } - return true - }, + require.Eventually(t, + func() bool { + if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { + t.Logf("error deleting dataset: %v", err) + return false + } + return true + }, time.Minute, time.Second, ) @@ -397,7 +409,7 @@ func TestIntegration(t *testing.T) { Enabled: true, RevisionID: destinationID, } - testhelper.VerifyConfigurationTest(t, dest) + whth.VerifyConfigurationTest(t, dest) }) t.Run("Load Table", func(t *testing.T) { @@ -407,8 +419,15 @@ func TestIntegration(t *testing.T) { workspaceID = "test_workspace_id" ) - namespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) + ctx := context.Background() + db, err := bigquery.NewClient(ctx, + bqTestCredentials.ProjectID, + option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) t.Cleanup(func() { require.Eventually(t, func() bool { if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { @@ -481,15 +500,15 @@ func TestIntegration(t *testing.T) { }) require.NoError(t, err) - t.Run("schema does not exists", func(t *testing.T) { + t.Run("schema does not exist", func(t *testing.T) { tableName := "schema_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -497,15 +516,15 @@ func TestIntegration(t *testing.T) { require.Error(t, err) require.Nil(t, loadTableStat) }) - t.Run("table does not exists", func(t *testing.T) { + t.Run("table does not exist", func(t *testing.T) { tableName := "table_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -516,40 +535,39 @@ func TestIntegration(t *testing.T) { require.Error(t, err) require.Nil(t, loadTableStat) }) - t.Run("merge", func(t *testing.T) { - tableName := "merge_test_table" + t.Run("merge with dedup", func(t *testing.T) { + tableName := "merge_with_dedup_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) - c := config.New() - c.Set("Warehouse.bigquery.isDedupEnabled", true) - - t.Run("without dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} - mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + dedupWarehouse := th.Clone(t, warehouse) + dedupWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true - bq := whbigquery.New(c, logger.NOP) - err := bq.Setup(ctx, warehouse, mockUploader) - require.NoError(t, err) + c := config.New() + bq := whbigquery.New(c, logger.NOP) + err := bq.Setup(ctx, dedupWarehouse, mockUploader) + require.NoError(t, err) - err = bq.CreateSchema(ctx) - require.NoError(t, err) + err = bq.CreateSchema(ctx) + require.NoError(t, err) - err = bq.CreateTable(ctx, tableName, schemaInWarehouse) - require.NoError(t, err) + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) - loadTableStat, err := bq.LoadTable(ctx, tableName) - require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(14)) - require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - loadTableStat, err = bq.LoadTable(ctx, tableName) - require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + loadTableStat, err = bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := bqHelper.RetrieveRecordsFromWarehouse(t, db, - fmt.Sprintf(` + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` SELECT id, received_at, @@ -558,67 +576,69 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s - ORDER BY - id; - `, - fmt.Sprintf("`%s`.`%s`", namespace, tableName), - ), - ) - require.Equal(t, records, testhelper.SampleTestRecords()) - }) - t.Run("with dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.json.gz", tableName) + FROM %s + ORDER BY id;`, + fmt.Sprintf("`%s`.`%s`", namespace, tableName), + ), + ) + require.Equal(t, records, whth.SampleTestRecords()) + }) + t.Run("merge without dedup", func(t *testing.T) { + tableName := "merge_without_dedup_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.json.gz", tableName) - loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} - mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - bq := whbigquery.New(c, logger.NOP) - err := bq.Setup(ctx, warehouse, mockUploader) - require.NoError(t, err) + c := config.New() + bq := whbigquery.New(c, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) - err = bq.CreateSchema(ctx) - require.NoError(t, err) + err = bq.CreateSchema(ctx) + require.NoError(t, err) - err = bq.CreateTable(ctx, tableName, schemaInWarehouse) - require.NoError(t, err) + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) - loadTableStat, err := bq.LoadTable(ctx, tableName) - require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := bqHelper.RetrieveRecordsFromWarehouse(t, db, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s - ORDER BY - id; - `, - fmt.Sprintf("`%s`.`%s`", namespace, tableName), - ), - ) - require.Equal(t, records, testhelper.DedupTestRecords()) - }) + retrieveRecordsSQL := fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s + ORDER BY id;`, + fmt.Sprintf("`%s`.`%s`", namespace, tableName), + ) + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, retrieveRecordsSQL) + require.Equal(t, records, whth.DedupTestRecords()) + + loadTableStat, err = bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records = bqHelper.RetrieveRecordsFromWarehouse(t, db, retrieveRecordsSQL) + require.Len(t, records, 28) }) t.Run("append", func(t *testing.T) { tableName := "append_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -648,20 +668,16 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - WHERE - _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') - ORDER BY - id; - `, + FROM %s.%s + WHERE _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') + ORDER BY id;`, namespace, tableName, time.Now().Add(-24*time.Hour).Format("2006-01-02"), time.Now().Add(+24*time.Hour).Format("2006-01-02"), ), ) - require.Equal(t, records, testhelper.AppendTestRecords()) + require.Equal(t, records, whth.AppendTestRecords()) }) t.Run("load file does not exists", func(t *testing.T) { tableName := "load_file_not_exists_test_table" @@ -671,7 +687,7 @@ func TestIntegration(t *testing.T) { }} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -688,12 +704,12 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in number of columns", func(t *testing.T) { tableName := "mismatch_columns_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-columns.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -710,12 +726,12 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in schema", func(t *testing.T) { tableName := "mismatch_schema_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-schema.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -732,12 +748,12 @@ func TestIntegration(t *testing.T) { t.Run("discards", func(t *testing.T) { tableName := warehouseutils.DiscardsTable - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/discards.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -753,27 +769,25 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsUpdated, int64(0)) records := bqHelper.RetrieveRecordsFromWarehouse(t, db, - fmt.Sprintf(` - SELECT - column_name, - column_value, - received_at, - row_id, - table_name, - uuid_ts - FROM - %s - ORDER BY row_id ASC; - `, + fmt.Sprintf( + `SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM %s + ORDER BY row_id ASC;`, fmt.Sprintf("`%s`.`%s`", namespace, tableName), ), ) - require.Equal(t, records, testhelper.DiscardTestRecords()) + require.Equal(t, records, whth.DiscardTestRecords()) }) t.Run("custom partition", func(t *testing.T) { tableName := "partition_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader( @@ -801,50 +815,52 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsUpdated, int64(0)) records := bqHelper.RetrieveRecordsFromWarehouse(t, db, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - WHERE - _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') - ORDER BY - id; - `, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + WHERE _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') + ORDER BY id;`, namespace, tableName, time.Now().Add(-24*time.Hour).Format("2006-01-02"), time.Now().Add(+24*time.Hour).Format("2006-01-02"), ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) }) t.Run("IsEmpty", func(t *testing.T) { - namespace := testhelper.RandSchema(warehouseutils.BQ) + ctx := context.Background() + db, err := bigquery.NewClient(ctx, + bqTestCredentials.ProjectID, + option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + namespace := whth.RandSchema(warehouseutils.BQ) t.Cleanup(func() { - require.Eventually(t, func() bool { - if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { - t.Logf("error deleting dataset: %v", err) - return false - } - return true - }, + require.Eventually(t, + func() bool { + if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { + t.Logf("error deleting dataset: %v", err) + return false + } + return true + }, time.Minute, time.Second, ) }) - ctx := context.Background() - credentials, err := bqHelper.GetBQTestCredentials() require.NoError(t, err) @@ -898,7 +914,7 @@ func TestIntegration(t *testing.T) { } t.Run("tables doesn't exists", func(t *testing.T) { - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -907,7 +923,7 @@ func TestIntegration(t *testing.T) { require.True(t, isEmpty) }) t.Run("tables empty", func(t *testing.T) { - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -933,7 +949,7 @@ func TestIntegration(t *testing.T) { require.True(t, isEmpty) }) t.Run("tables not empty", func(t *testing.T) { - bq := whbigquery.New(config.Default, logger.NOP) + bq := whbigquery.New(config.New(), logger.NOP) err := bq.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -970,8 +986,8 @@ func newMockUploader( return mockUploader } -func loadFilesEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func loadFilesEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 4, "users": 4, "tracks": 4, @@ -984,8 +1000,8 @@ func loadFilesEventsMap() testhelper.EventsCountMap { } } -func tableUploadsEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func tableUploadsEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 4, "users": 4, "tracks": 4, @@ -998,14 +1014,14 @@ func tableUploadsEventsMap() testhelper.EventsCountMap { } } -func stagingFilesEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func stagingFilesEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "wh_staging_files": 34, // Since extra 2 merge events because of ID resolution } } -func mergeEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func mergeEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 1, "users": 1, "tracks": 1, @@ -1018,8 +1034,8 @@ func mergeEventsMap() testhelper.EventsCountMap { } } -func appendEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func appendEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 4, "users": 1, "tracks": 4, diff --git a/warehouse/integrations/bigquery/testdata/template.json b/warehouse/integrations/bigquery/testdata/template.json index d865510ba73..c6eac2a5268 100644 --- a/warehouse/integrations/bigquery/testdata/template.json +++ b/warehouse/integrations/bigquery/testdata/template.json @@ -31,7 +31,8 @@ "credentials": "{{.credentials}}", "prefix": "", "namespace": "{{.namespace}}", - "syncFrequency": "30" + "syncFrequency": "30", + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, @@ -98,7 +99,7 @@ "id": "1dCzCUAtpWDzNxgGUYzq9sZdZZB", "name": "HTTP", "displayName": "HTTP", - "category": "", + "category": "singer-protocol", "createdAt": "2020-06-12T06:35:35.962Z", "updatedAt": "2020-06-12T06:35:35.962Z" }, @@ -159,7 +160,8 @@ "credentials": "{{.credentials}}", "prefix": "", "namespace": "{{.sourcesNamespace}}", - "syncFrequency": "30" + "syncFrequency": "30", + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, diff --git a/warehouse/integrations/clickhouse/clickhouse.go b/warehouse/integrations/clickhouse/clickhouse.go index 0b57498a997..604cabc3d5d 100644 --- a/warehouse/integrations/clickhouse/clickhouse.go +++ b/warehouse/integrations/clickhouse/clickhouse.go @@ -15,6 +15,7 @@ import ( "os" "path" "regexp" + "slices" "sort" "strconv" "strings" @@ -29,8 +30,6 @@ import ( "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "golang.org/x/exp/slices" - "github.com/cenkalti/backoff/v4" "github.com/rudderlabs/rudder-go-kit/stats" diff --git a/warehouse/integrations/clickhouse/clickhouse_test.go b/warehouse/integrations/clickhouse/clickhouse_test.go index 9775e04dec9..866594186d2 100644 --- a/warehouse/integrations/clickhouse/clickhouse_test.go +++ b/warehouse/integrations/clickhouse/clickhouse_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -20,7 +22,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" @@ -342,7 +343,7 @@ func TestClickhouse_UseS3CopyEngineForLoading(t *testing.T) { c := config.New() c.Set("Warehouse.clickhouse.s3EngineEnabledWorkspaceIDs", S3EngineEnabledWorkspaceIDs) - ch := clickhouse.New(c, logger.NOP, stats.Default) + ch := clickhouse.New(c, logger.NOP, memstats.New()) ch.Warehouse = model.Warehouse{ WorkspaceID: tc.workspaceID, } @@ -423,7 +424,7 @@ func TestClickhouse_LoadTableRoundTrip(t *testing.T) { c.Set("Warehouse.clickhouse.s3EngineEnabledWorkspaceIDs", tc.S3EngineEnabledWorkspaceIDs) c.Set("Warehouse.clickhouse.disableNullable", tc.disableNullable) - ch := clickhouse.New(c, logger.NOP, stats.Default) + ch := clickhouse.New(c, logger.NOP, memstats.New()) warehouse := model.Warehouse{ Namespace: fmt.Sprintf("test_namespace_%d", i), @@ -669,7 +670,7 @@ func TestClickhouse_TestConnection(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { - ch := clickhouse.New(config.Default, logger.NOP, stats.Default) + ch := clickhouse.New(config.New(), logger.NOP, memstats.New()) host := "localhost" if tc.host != "" { @@ -768,7 +769,7 @@ func TestClickhouse_LoadTestTable(t *testing.T) { i := i t.Run(tc.name, func(t *testing.T) { - ch := clickhouse.New(config.Default, logger.NOP, stats.Default) + ch := clickhouse.New(config.New(), logger.NOP, memstats.New()) warehouse := model.Warehouse{ Namespace: namespace, @@ -840,7 +841,7 @@ func TestClickhouse_FetchSchema(t *testing.T) { ctx := context.Background() t.Run("Success", func(t *testing.T) { - ch := clickhouse.New(config.Default, logger.NOP, stats.Default) + ch := clickhouse.New(config.New(), logger.NOP, memstats.New()) warehouse := model.Warehouse{ Namespace: fmt.Sprintf("%s_success", namespace), @@ -885,7 +886,7 @@ func TestClickhouse_FetchSchema(t *testing.T) { }) t.Run("Invalid host", func(t *testing.T) { - ch := clickhouse.New(config.Default, logger.NOP, stats.Default) + ch := clickhouse.New(config.New(), logger.NOP, memstats.New()) warehouse := model.Warehouse{ Namespace: fmt.Sprintf("%s_invalid_host", namespace), @@ -911,7 +912,7 @@ func TestClickhouse_FetchSchema(t *testing.T) { }) t.Run("Invalid database", func(t *testing.T) { - ch := clickhouse.New(config.Default, logger.NOP, stats.Default) + ch := clickhouse.New(config.New(), logger.NOP, memstats.New()) warehouse := model.Warehouse{ Namespace: fmt.Sprintf("%s_invalid_database", namespace), @@ -937,7 +938,7 @@ func TestClickhouse_FetchSchema(t *testing.T) { }) t.Run("Empty schema", func(t *testing.T) { - ch := clickhouse.New(config.Default, logger.NOP, stats.Default) + ch := clickhouse.New(config.New(), logger.NOP, memstats.New()) warehouse := model.Warehouse{ Namespace: fmt.Sprintf("%s_empty_schema", namespace), @@ -966,7 +967,7 @@ func TestClickhouse_FetchSchema(t *testing.T) { }) t.Run("Unrecognized schema", func(t *testing.T) { - ch := clickhouse.New(config.Default, logger.NOP, stats.Default) + ch := clickhouse.New(config.New(), logger.NOP, memstats.New()) warehouse := model.Warehouse{ Namespace: fmt.Sprintf("%s_unrecognized_schema", namespace), diff --git a/warehouse/integrations/deltalake/deltalake.go b/warehouse/integrations/deltalake/deltalake.go index 04aff63b117..eb6b02e122b 100644 --- a/warehouse/integrations/deltalake/deltalake.go +++ b/warehouse/integrations/deltalake/deltalake.go @@ -6,15 +6,13 @@ import ( "errors" "fmt" "regexp" + "slices" "strconv" "strings" "time" - "github.com/rudderlabs/rudder-server/warehouse/integrations/types" - dbsql "github.com/databricks/databricks-sql-go" dbsqllog "github.com/databricks/databricks-sql-go/logger" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" @@ -22,6 +20,7 @@ import ( "github.com/rudderlabs/rudder-server/utils/misc" warehouseclient "github.com/rudderlabs/rudder-server/warehouse/client" sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/logfield" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -46,9 +45,6 @@ const ( partitionNotFound = "SHOW PARTITIONS is not allowed on a table that is not partitioned" columnsAlreadyExists = "already exists in root" - mergeMode = "MERGE" - appendMode = "APPEND" - rudderStagingTableRegex = "^rudder_staging_.*$" // matches rudder_staging_* tables nonRudderStagingTableRegex = "^(?!rudder_staging_.*$).*" // matches tables that do not start with rudder_staging_ ) @@ -125,7 +121,7 @@ type Deltalake struct { stats stats.Stats config struct { - loadTableStrategy string + allowMerge bool enablePartitionPruning bool slowQueryThreshold time.Duration maxRetries int @@ -141,7 +137,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Deltalake { dl.logger = log.Child("integration").Child("deltalake") dl.stats = stat - dl.config.loadTableStrategy = conf.GetString("Warehouse.deltalake.loadTableStrategy", mergeMode) + dl.config.allowMerge = conf.GetBool("Warehouse.deltalake.allowMerge", true) dl.config.enablePartitionPruning = conf.GetBool("Warehouse.deltalake.enablePartitionPruning", true) dl.config.slowQueryThreshold = conf.GetDuration("Warehouse.deltalake.slowQueryThreshold", 5, time.Minute) dl.config.maxRetries = conf.GetInt("Warehouse.deltalake.maxRetries", 10) @@ -244,6 +240,7 @@ func (d *Deltalake) dropDanglingStagingTables(ctx context.Context) { logfield.DestinationType, d.Warehouse.Destination.DestinationDefinition.Name, logfield.WorkspaceID, d.Warehouse.WorkspaceID, logfield.Namespace, d.Namespace, + logfield.ShouldMerge, d.ShouldMerge(), logfield.Error, err.Error(), ) return @@ -586,7 +583,7 @@ func (d *Deltalake) loadTable( logfield.WorkspaceID, d.Warehouse.WorkspaceID, logfield.Namespace, d.Namespace, logfield.TableName, tableName, - logfield.LoadTableStrategy, d.config.loadTableStrategy, + logfield.ShouldMerge, d.ShouldMerge(), ) log.Infow("started loading") @@ -617,7 +614,7 @@ func (d *Deltalake) loadTable( } var loadTableStat *types.LoadTableStats - if d.ShouldAppend() { + if !d.ShouldMerge() { log.Infow("inserting data from staging table to main table") loadTableStat, err = d.insertIntoLoadTable( ctx, tableName, stagingTableName, @@ -1110,7 +1107,7 @@ func (d *Deltalake) LoadUserTables(ctx context.Context) map[string]error { columnKeys := append([]string{`id`}, userColNames...) - if d.ShouldAppend() { + if !d.ShouldMerge() { query = fmt.Sprintf(` INSERT INTO %[1]s.%[2]s (%[4]s) SELECT @@ -1172,7 +1169,7 @@ func (d *Deltalake) LoadUserTables(ctx context.Context) map[string]error { inserted int64 ) - if d.ShouldAppend() { + if !d.ShouldMerge() { err = row.Scan(&affected, &inserted) } else { err = row.Scan(&affected, &updated, &deleted, &inserted) @@ -1385,9 +1382,10 @@ func (*Deltalake) DeleteBy(context.Context, []string, warehouseutils.DeleteByPar return fmt.Errorf(warehouseutils.NotImplementedErrorCode) } -// ShouldAppend returns true if: -// * the load table strategy is "append" mode -// * the uploader says we can append -func (d *Deltalake) ShouldAppend() bool { - return d.config.loadTableStrategy == appendMode && d.Uploader.CanAppend() +// ShouldMerge returns true if: +// * the uploader says we cannot append +// * the user opted in to merging and we allow merging +func (d *Deltalake) ShouldMerge() bool { + return !d.Uploader.CanAppend() || + (d.config.allowMerge && d.Warehouse.GetBoolDestinationConfig(model.EnableMergeSetting)) } diff --git a/warehouse/integrations/deltalake/deltalake_test.go b/warehouse/integrations/deltalake/deltalake_test.go index 8ab75a4c054..2fc4ea20ab8 100644 --- a/warehouse/integrations/deltalake/deltalake_test.go +++ b/warehouse/integrations/deltalake/deltalake_test.go @@ -7,16 +7,12 @@ import ( "errors" "fmt" "os" + "slices" "strconv" "strings" "testing" "time" - "golang.org/x/exp/slices" - - "github.com/rudderlabs/rudder-go-kit/filemanager" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - dbsql "github.com/databricks/databricks-sql-go" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -24,18 +20,21 @@ import ( "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/compose-test/testcompose" "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" + th "github.com/rudderlabs/rudder-server/testhelper" "github.com/rudderlabs/rudder-server/testhelper/health" "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" "github.com/rudderlabs/rudder-server/utils/misc" warehouseclient "github.com/rudderlabs/rudder-server/warehouse/client" "github.com/rudderlabs/rudder-server/warehouse/integrations/deltalake" - "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + whth "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/rudderlabs/rudder-server/warehouse/validations" ) @@ -95,54 +94,12 @@ func TestIntegration(t *testing.T) { sourceID := warehouseutils.RandHex() destinationID := warehouseutils.RandHex() writeKey := warehouseutils.RandHex() - destType := warehouseutils.DELTALAKE - - namespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) deltaLakeCredentials, err := deltaLakeTestCredentials() require.NoError(t, err) - templateConfigurations := map[string]any{ - "workspaceID": workspaceID, - "sourceID": sourceID, - "destinationID": destinationID, - "writeKey": writeKey, - "host": deltaLakeCredentials.Host, - "port": deltaLakeCredentials.Port, - "path": deltaLakeCredentials.Path, - "token": deltaLakeCredentials.Token, - "namespace": namespace, - "containerName": deltaLakeCredentials.ContainerName, - "accountName": deltaLakeCredentials.AccountName, - "accountKey": deltaLakeCredentials.AccountKey, - } - workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - - testhelper.EnhanceWithDefaultEnvs(t) - t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_MAX_PARALLEL_LOADS", "8") - t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) - t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) - t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_SLOW_QUERY_THRESHOLD", "0s") - - svcDone := make(chan struct{}) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - r := runner.New(runner.ReleaseInfo{}) - _ = r.Run(ctx, []string{"deltalake-integration-test"}) - - close(svcDone) - }() - t.Cleanup(func() { <-svcDone }) - - serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) - health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint") - port, err := strconv.Atoi(deltaLakeCredentials.Port) require.NoError(t, err) @@ -160,17 +117,62 @@ func TestIntegration(t *testing.T) { db := sql.OpenDB(connector) require.NoError(t, db.Ping()) + bootstrapSvc := func(t *testing.T, enableMerge bool) { + templateConfigurations := map[string]any{ + "workspaceID": workspaceID, + "sourceID": sourceID, + "destinationID": destinationID, + "writeKey": writeKey, + "host": deltaLakeCredentials.Host, + "port": deltaLakeCredentials.Port, + "path": deltaLakeCredentials.Path, + "token": deltaLakeCredentials.Token, + "namespace": namespace, + "containerName": deltaLakeCredentials.ContainerName, + "accountName": deltaLakeCredentials.AccountName, + "accountKey": deltaLakeCredentials.AccountKey, + "enableMerge": enableMerge, + } + workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) + + whth.EnhanceWithDefaultEnvs(t) + t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) + t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) + t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) + t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) + t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_SLOW_QUERY_THRESHOLD", "0s") + + svcDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + r := runner.New(runner.ReleaseInfo{}) + _ = r.Run(ctx, []string{"deltalake-integration-test"}) + close(svcDone) + }() + + t.Cleanup(func() { <-svcDone }) + t.Cleanup(cancel) + + serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) + health.WaitUntilReady(ctx, t, + serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint", + ) + } + t.Run("Event flow", func(t *testing.T) { - jobsDB := testhelper.JobsDB(t, jobsDBPort) + jobsDB := whth.JobsDB(t, jobsDBPort) t.Cleanup(func() { - require.Eventually(t, func() bool { - if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %[1]s CASCADE;`, namespace)); err != nil { - t.Logf("error deleting schema: %v", err) - return false - } - return true - }, + require.Eventually(t, + func() bool { + if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %s CASCADE;`, namespace)); err != nil { + t.Logf("error deleting schema %q: %v", namespace, err) + return false + } + return true + }, time.Minute, time.Second, ) @@ -183,8 +185,8 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string messageID string - warehouseEventsMap testhelper.EventsCountMap - loadTableStrategy string + warehouseEventsMap whth.EventsCountMap + enableMerge bool useParquetLoadFiles bool stagingFilePrefix string jobRunID string @@ -196,7 +198,7 @@ func TestIntegration(t *testing.T) { sourceID: sourceID, destinationID: destinationID, warehouseEventsMap: mergeEventsMap(), - loadTableStrategy: "MERGE", + enableMerge: true, useParquetLoadFiles: false, stagingFilePrefix: "testdata/upload-job-merge-mode", jobRunID: misc.FastUUID().String(), @@ -208,7 +210,7 @@ func TestIntegration(t *testing.T) { sourceID: sourceID, destinationID: destinationID, warehouseEventsMap: appendEventsMap(), - loadTableStrategy: "APPEND", + enableMerge: false, useParquetLoadFiles: false, stagingFilePrefix: "testdata/upload-job-append-mode", // an empty jobRunID means that the source is not an ETL one @@ -222,7 +224,7 @@ func TestIntegration(t *testing.T) { sourceID: sourceID, destinationID: destinationID, warehouseEventsMap: mergeEventsMap(), - loadTableStrategy: "MERGE", + enableMerge: true, useParquetLoadFiles: true, stagingFilePrefix: "testdata/upload-job-parquet", jobRunID: misc.FastUUID().String(), @@ -231,9 +233,8 @@ func TestIntegration(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_LOAD_TABLE_STRATEGY", tc.loadTableStrategy) + bootstrapSvc(t, tc.enableMerge) t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_USE_PARQUET_LOAD_FILES", strconv.FormatBool(tc.useParquetLoadFiles)) sqlClient := &warehouseclient.Client{ @@ -253,14 +254,14 @@ func TestIntegration(t *testing.T) { tables := []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"} t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: writeKey, Schema: tc.schema, Tables: tables, SourceID: tc.sourceID, DestinationID: tc.destinationID, JobRunID: tc.jobRunID, - WarehouseEventsMap: testhelper.EventsCountMap{ + WarehouseEventsMap: whth.EventsCountMap{ "identifies": 1, "users": 1, "tracks": 1, @@ -277,12 +278,12 @@ func TestIntegration(t *testing.T) { HTTPPort: httpPort, Client: sqlClient, StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: writeKey, Schema: tc.schema, Tables: tables, @@ -306,13 +307,14 @@ func TestIntegration(t *testing.T) { t.Run("Validation", func(t *testing.T) { t.Cleanup(func() { - require.Eventually(t, func() bool { - if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %[1]s CASCADE;`, namespace)); err != nil { - t.Logf("error deleting schema: %v", err) - return false - } - return true - }, + require.Eventually(t, + func() bool { + if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %s CASCADE;`, namespace)); err != nil { + t.Logf("error deleting schema %q: %v", namespace, err) + return false + } + return true + }, time.Minute, time.Second, ) @@ -372,15 +374,17 @@ func TestIntegration(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Setenv("RSERVER_WAREHOUSE_DELTALAKE_USE_PARQUET_LOAD_FILES", strconv.FormatBool(tc.useParquetLoadFiles)) + t.Setenv( + "RSERVER_WAREHOUSE_DELTALAKE_USE_PARQUET_LOAD_FILES", + strconv.FormatBool(tc.useParquetLoadFiles), + ) for k, v := range tc.conf { dest.Config[k] = v } - testhelper.VerifyConfigurationTest(t, dest) + whth.VerifyConfigurationTest(t, dest) }) } }) @@ -392,20 +396,22 @@ func TestIntegration(t *testing.T) { workspaceID = "test_workspace_id" ) - namespace := testhelper.RandSchema(destType) - - t.Cleanup(func() { - require.Eventually(t, func() bool { - if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %[1]s CASCADE;`, namespace)); err != nil { - t.Logf("error deleting schema: %v", err) - return false - } - return true - }, + ctx := context.Background() + namespace := whth.RandSchema(destType) + cleanupSchema := func() { + require.Eventually(t, + func() bool { + _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %s CASCADE;`, namespace)) + if err != nil { + t.Logf("error deleting schema %q: %v", namespace, err) + return false + } + return true + }, time.Minute, time.Second, ) - }) + } schemaInUpload := model.TableSchema{ "test_bool": "boolean", @@ -470,12 +476,12 @@ func TestIntegration(t *testing.T) { t.Run("schema does not exists", func(t *testing.T) { tableName := "schema_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -486,17 +492,18 @@ func TestIntegration(t *testing.T) { t.Run("table does not exists", func(t *testing.T) { tableName := "table_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) loadTableStat, err := d.LoadTable(ctx, tableName) require.Error(t, err) @@ -506,17 +513,18 @@ func TestIntegration(t *testing.T) { tableName := "merge_test_table" t.Run("without dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -531,7 +539,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(0)) require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, fmt.Sprintf(` SELECT id, @@ -541,40 +549,40 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("with dedup use new record", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} - mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, true, "2022-12-15T06:53:49.640Z") + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, true, true, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) - err := d.Setup(ctx, warehouse, mockUploader) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + + d := deltalake.New(config.New(), logger.NOP, memstats.New()) + err := d.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) loadTableStat, err := d.LoadTable(ctx, tableName) require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` + retrieveRecordsSQL := fmt.Sprintf(` SELECT id, received_at, @@ -583,29 +591,38 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, - namespace, - tableName, - ), + FROM %s.%s + ORDER BY id;`, + namespace, + tableName, ) - require.Equal(t, records, testhelper.DedupTestRecords()) + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, retrieveRecordsSQL) + require.Equal(t, records, whth.DedupTestRecords()) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records = whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, retrieveRecordsSQL) + require.Equal(t, records, whth.DedupTestRecords()) }) t.Run("with no overlapping partition", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-11-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) - err := d.Setup(ctx, warehouse, mockUploader) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + + d := deltalake.New(config.New(), logger.NOP, memstats.New()) + err := d.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -615,45 +632,45 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.DedupTwiceTestRecords()) + require.Equal(t, records, whth.DedupTwiceTestRecords()) }) }) t.Run("append", func(t *testing.T) { tableName := "append_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, true, false, "2022-12-15T06:53:49.640Z") - c := config.New() - c.Set("Warehouse.deltalake.loadTableStrategy", "APPEND") - - d := deltalake.New(c, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -668,26 +685,23 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.AppendTestRecords()) + require.Equal(t, records, whth.AppendTestRecords()) }) t.Run("load file does not exists", func(t *testing.T) { tableName := "load_file_not_exists_test_table" @@ -697,12 +711,13 @@ func TestIntegration(t *testing.T) { }} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -714,17 +729,18 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in number of columns", func(t *testing.T) { tableName := "mismatch_columns_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -734,41 +750,39 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("mismatch in schema", func(t *testing.T) { tableName := "mismatch_schema_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -778,41 +792,39 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.MismatchSchemaTestRecords()) + require.Equal(t, records, whth.MismatchSchemaTestRecords()) }) t.Run("discards", func(t *testing.T) { tableName := warehouseutils.DiscardsTable - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, warehouseutils.DiscardsSchema) require.NoError(t, err) @@ -822,39 +834,38 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(6)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - column_name, - column_value, - received_at, - row_id, - table_name, - uuid_ts - FROM - %s.%s - ORDER BY row_id ASC; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM %s.%s + ORDER BY row_id ASC;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.DiscardTestRecords()) + require.Equal(t, records, whth.DiscardTestRecords()) }) t.Run("parquet", func(t *testing.T) { tableName := "parquet_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeParquet, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) + t.Cleanup(cleanupSchema) err = d.CreateTable(ctx, tableName, schemaInWarehouse) require.NoError(t, err) @@ -864,62 +875,59 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("partition pruning", func(t *testing.T) { t.Run("not partitioned", func(t *testing.T) { tableName := "not_partitioned_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) - - _, err = d.DB.QueryContext(ctx, ` - CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( - extra_test_bool BOOLEAN, - extra_test_datetime TIMESTAMP, - extra_test_float DOUBLE, - extra_test_int BIGINT, - extra_test_string STRING, - id STRING, - received_at TIMESTAMP, - event_date DATE GENERATED ALWAYS AS ( - CAST(received_at AS DATE) - ), - test_bool BOOLEAN, - test_datetime TIMESTAMP, - test_float DOUBLE, - test_int BIGINT, - test_string STRING - ) USING DELTA; - `) + t.Cleanup(cleanupSchema) + + _, err = d.DB.QueryContext(ctx, + `CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( + extra_test_bool BOOLEAN, + extra_test_datetime TIMESTAMP, + extra_test_float DOUBLE, + extra_test_int BIGINT, + extra_test_string STRING, + id STRING, + received_at TIMESTAMP, + event_date DATE GENERATED ALWAYS AS ( + CAST(received_at AS DATE) + ), + test_bool BOOLEAN, + test_datetime TIMESTAMP, + test_float DOUBLE, + test_int BIGINT, + test_string STRING + ) USING DELTA;`) require.NoError(t, err) loadTableStat, err := d.LoadTable(ctx, tableName) @@ -927,61 +935,58 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("event_date is not in partition", func(t *testing.T) { tableName := "not_event_date_partition_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") - d := deltalake.New(config.Default, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) require.NoError(t, err) - - _, err = d.DB.QueryContext(ctx, ` - CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( - extra_test_bool BOOLEAN, - extra_test_datetime TIMESTAMP, - extra_test_float DOUBLE, - extra_test_int BIGINT, - extra_test_string STRING, - id STRING, - received_at TIMESTAMP, - event_date DATE GENERATED ALWAYS AS ( - CAST(received_at AS DATE) - ), - test_bool BOOLEAN, - test_datetime TIMESTAMP, - test_float DOUBLE, - test_int BIGINT, - test_string STRING - ) USING DELTA PARTITIONED BY(id); - `) + t.Cleanup(cleanupSchema) + + _, err = d.DB.QueryContext(ctx, + `CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( + extra_test_bool BOOLEAN, + extra_test_datetime TIMESTAMP, + extra_test_float DOUBLE, + extra_test_int BIGINT, + extra_test_string STRING, + id STRING, + received_at TIMESTAMP, + event_date DATE GENERATED ALWAYS AS ( + CAST(received_at AS DATE) + ), + test_bool BOOLEAN, + test_datetime TIMESTAMP, + test_float DOUBLE, + test_int BIGINT, + test_string STRING + ) USING DELTA PARTITIONED BY(id);`) require.NoError(t, err) loadTableStat, err := d.LoadTable(ctx, tableName) @@ -989,26 +994,23 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT - id, - received_at, - test_bool, - test_datetime, - test_float, - test_int, - test_string - FROM - %s.%s - ORDER BY - id; - `, + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) }) }) @@ -1046,63 +1048,67 @@ func TestDeltalake_TrimErrorMessage(t *testing.T) { c := config.New() c.Set("Warehouse.deltalake.maxErrorLength", len(tempError.Error())*25) - d := deltalake.New(c, logger.NOP, stats.Default) - require.Equal(t, d.TrimErrorMessage(tc.inputError), tc.expectedError) + d := deltalake.New(c, logger.NOP, memstats.New()) + require.Equal(t, tc.expectedError, d.TrimErrorMessage(tc.inputError)) }) } } -func TestDeltalake_ShouldAppend(t *testing.T) { +func TestDeltalake_ShouldMerge(t *testing.T) { testCases := []struct { name string - loadTableStrategy string + enableMerge bool uploaderCanAppend bool uploaderExpectedCalls int expected bool }{ { - name: "uploader says we can append and we are in append mode", - loadTableStrategy: "APPEND", + name: "uploader says we can append and merge is not enabled", + enableMerge: false, uploaderCanAppend: true, uploaderExpectedCalls: 1, - expected: true, + expected: false, }, { - name: "uploader says we cannot append and we are in append mode", - loadTableStrategy: "APPEND", - uploaderCanAppend: false, + name: "uploader says we can append and merge is enabled", + enableMerge: true, + uploaderCanAppend: true, uploaderExpectedCalls: 1, - expected: false, + expected: true, }, { - name: "uploader says we can append and we are in merge mode", - loadTableStrategy: "MERGE", - uploaderCanAppend: true, - uploaderExpectedCalls: 0, - expected: false, + name: "uploader says we cannot append so enableMerge false is ignored", + enableMerge: false, + uploaderCanAppend: false, + uploaderExpectedCalls: 1, + expected: true, }, { - name: "uploader says we cannot append and we are in merge mode", - loadTableStrategy: "MERGE", + name: "uploader says we cannot append so enableMerge true is ignored", + enableMerge: true, uploaderCanAppend: false, - uploaderExpectedCalls: 0, - expected: false, + uploaderExpectedCalls: 1, + expected: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := config.New() - c.Set("Warehouse.deltalake.loadTableStrategy", tc.loadTableStrategy) - - d := deltalake.New(c, logger.NOP, stats.Default) + d := deltalake.New(config.New(), logger.NOP, memstats.New()) + d.Warehouse = model.Warehouse{ + Destination: backendconfig.DestinationT{ + Config: map[string]any{ + string(model.EnableMergeSetting): tc.enableMerge, + }, + }, + } mockCtrl := gomock.NewController(t) uploader := mockuploader.NewMockUploader(mockCtrl) uploader.EXPECT().CanAppend().Times(tc.uploaderExpectedCalls).Return(tc.uploaderCanAppend) d.Uploader = uploader - require.Equal(t, d.ShouldAppend(), tc.expected) + require.Equal(t, d.ShouldMerge(), tc.expected) }) } } @@ -1142,8 +1148,8 @@ func newMockUploader( return mockUploader } -func mergeEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func mergeEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 1, "users": 1, "tracks": 1, @@ -1155,8 +1161,8 @@ func mergeEventsMap() testhelper.EventsCountMap { } } -func appendEventsMap() testhelper.EventsCountMap { - return testhelper.EventsCountMap{ +func appendEventsMap() whth.EventsCountMap { + return whth.EventsCountMap{ "identifies": 2, "users": 2, "tracks": 2, diff --git a/warehouse/integrations/deltalake/testdata/template.json b/warehouse/integrations/deltalake/testdata/template.json index da90ff83657..2b796aa448f 100644 --- a/warehouse/integrations/deltalake/testdata/template.json +++ b/warehouse/integrations/deltalake/testdata/template.json @@ -38,7 +38,8 @@ "accountKey": "{{.accountKey}}", "syncFrequency": "30", "eventDelivery": false, - "eventDeliveryTS": 1648195480174 + "eventDeliveryTS": 1648195480174, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": { "eventDelivery": false, diff --git a/warehouse/integrations/mssql/mssql.go b/warehouse/integrations/mssql/mssql.go index 8071c9d6a9f..7538b8646c0 100644 --- a/warehouse/integrations/mssql/mssql.go +++ b/warehouse/integrations/mssql/mssql.go @@ -294,11 +294,9 @@ func (ms *MSSQL) loadTable( // - Refer to Microsoft's documentation on temporary tables at // https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175528(v=sql.105)?redirectedfrom=MSDN. log.Debugw("creating staging table") - createStagingTableStmt := fmt.Sprintf(` - SELECT - TOP 0 * INTO %[1]s.%[2]s - FROM - %[1]s.%[3]s;`, + createStagingTableStmt := fmt.Sprintf( + `SELECT TOP 0 * INTO %[1]s.%[2]s + FROM %[1]s.%[3]s;`, ms.Namespace, stagingTableName, tableName, @@ -471,7 +469,7 @@ func (ms *MSSQL) loadDataIntoStagingTable( return nil } -func (as *MSSQL) ProcessColumnValue( +func (ms *MSSQL) ProcessColumnValue( value string, valueType string, ) (interface{}, error) { diff --git a/warehouse/integrations/mssql/mssql_test.go b/warehouse/integrations/mssql/mssql_test.go index df2d2702282..73702680ce4 100644 --- a/warehouse/integrations/mssql/mssql_test.go +++ b/warehouse/integrations/mssql/mssql_test.go @@ -10,12 +10,13 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/golang/mock/gomock" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/warehouse/integrations/mssql" mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -385,7 +386,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) err := ms.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -401,7 +402,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) err := ms.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -421,7 +422,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) err := ms.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -468,7 +469,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) err := ms.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -513,7 +514,7 @@ func TestIntegration(t *testing.T) { }} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) err := ms.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -535,7 +536,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) err := ms.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -557,7 +558,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) err := ms.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -601,7 +602,7 @@ func TestIntegration(t *testing.T) { loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) err := ms.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -716,7 +717,7 @@ func TestMSSQL_ProcessColumnValue(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ms := mssql.New(config.Default, logger.NOP, stats.Default) + ms := mssql.New(config.New(), logger.NOP, memstats.New()) value, err := ms.ProcessColumnValue(tc.data, tc.dataType) if tc.wantError { diff --git a/warehouse/integrations/postgres/load.go b/warehouse/integrations/postgres/load.go index 4f9471a85cf..da76cfedff7 100644 --- a/warehouse/integrations/postgres/load.go +++ b/warehouse/integrations/postgres/load.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "os" + "slices" "strings" "github.com/rudderlabs/rudder-server/warehouse/integrations/types" @@ -19,7 +20,6 @@ import ( "github.com/rudderlabs/rudder-server/utils/misc" "github.com/lib/pq" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-server/warehouse/logfield" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -31,27 +31,15 @@ type loadUsersTableResponse struct { } func (pg *Postgres) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { - log := pg.logger.With( - logfield.SourceID, pg.Warehouse.Source.ID, - logfield.SourceType, pg.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, pg.Warehouse.Destination.ID, - logfield.DestinationType, pg.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, pg.Warehouse.WorkspaceID, - logfield.Namespace, pg.Namespace, - logfield.TableName, tableName, - logfield.LoadTableStrategy, pg.loadTableStrategy(), - ) - log.Infow("started loading") - var loadTableStats *types.LoadTableStats - var err error - - err = pg.DB.WithTx(ctx, func(tx *sqlmiddleware.Tx) error { + err := pg.DB.WithTx(ctx, func(tx *sqlmiddleware.Tx) error { + var err error loadTableStats, _, err = pg.loadTable( ctx, tx, tableName, pg.Uploader.GetTableSchemaInUpload(tableName), + false, ) return err }) @@ -59,8 +47,6 @@ func (pg *Postgres) LoadTable(ctx context.Context, tableName string) (*types.Loa return nil, fmt.Errorf("loading table: %w", err) } - log.Infow("completed loading") - return loadTableStats, err } @@ -69,6 +55,7 @@ func (pg *Postgres) loadTable( txn *sqlmiddleware.Tx, tableName string, tableSchemaInUpload model.TableSchema, + forceMerge bool, ) (*types.LoadTableStats, string, error) { log := pg.logger.With( logfield.SourceID, pg.Warehouse.Source.ID, @@ -78,9 +65,10 @@ func (pg *Postgres) loadTable( logfield.WorkspaceID, pg.Warehouse.WorkspaceID, logfield.Namespace, pg.Namespace, logfield.TableName, tableName, - logfield.LoadTableStrategy, pg.loadTableStrategy(), + logfield.ShouldMerge, pg.shouldMerge(), ) log.Infow("started loading") + defer log.Infow("completed loading") log.Debugw("setting search path") searchPathStmt := fmt.Sprintf(`SET search_path TO %q;`, @@ -105,10 +93,9 @@ func (pg *Postgres) loadTable( ) log.Debugw("creating staging table") - createStagingTableStmt := fmt.Sprintf(` - CREATE TEMPORARY TABLE %[2]s (LIKE %[1]q.%[3]q) - ON COMMIT PRESERVE ROWS; -`, + createStagingTableStmt := fmt.Sprintf( + `CREATE TEMPORARY TABLE %[2]s (LIKE %[1]q.%[3]q) + ON COMMIT PRESERVE ROWS;`, pg.Namespace, stagingTableName, tableName, @@ -143,7 +130,7 @@ func (pg *Postgres) loadTable( } var rowsDeleted int64 - if !slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID) { + if forceMerge || pg.shouldMerge() { log.Infow("deleting from load table") rowsDeleted, err = pg.deleteFromLoadTable( ctx, txn, tableName, @@ -225,13 +212,6 @@ func (pg *Postgres) loadDataIntoStagingTable( return nil } -func (pg *Postgres) loadTableStrategy() string { - if slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID) { - return "APPEND" - } - return "MERGE" -} - func (pg *Postgres) deleteFromLoadTable( ctx context.Context, txn *sqlmiddleware.Tx, @@ -387,7 +367,7 @@ func (pg *Postgres) loadUsersTable( usersSchemaInUpload, usersSchemaInWarehouse model.TableSchema, ) loadUsersTableResponse { - _, identifyStagingTable, err := pg.loadTable(ctx, tx, warehouseutils.IdentifiesTable, identifiesSchemaInUpload) + _, identifyStagingTable, err := pg.loadTable(ctx, tx, warehouseutils.IdentifiesTable, identifiesSchemaInUpload, false) if err != nil { return loadUsersTableResponse{ identifiesError: fmt.Errorf("loading identifies table: %w", err), @@ -398,9 +378,10 @@ func (pg *Postgres) loadUsersTable( return loadUsersTableResponse{} } - canSkipComputingLatestUserTraits := pg.config.skipComputingUserLatestTraits || slices.Contains(pg.config.skipComputingUserLatestTraitsWorkspaceIDs, pg.Warehouse.WorkspaceID) + canSkipComputingLatestUserTraits := pg.config.skipComputingUserLatestTraits || + slices.Contains(pg.config.skipComputingUserLatestTraitsWorkspaceIDs, pg.Warehouse.WorkspaceID) if canSkipComputingLatestUserTraits { - if _, _, err = pg.loadTable(ctx, tx, warehouseutils.UsersTable, usersSchemaInUpload); err != nil { + if _, _, err = pg.loadTable(ctx, tx, warehouseutils.UsersTable, usersSchemaInUpload, true); err != nil { return loadUsersTableResponse{ usersError: fmt.Errorf("loading users table: %w", err), } @@ -417,60 +398,40 @@ func (pg *Postgres) loadUsersTable( continue } userColNames = append(userColNames, fmt.Sprintf(`%q`, colName)) - caseSubQuery := fmt.Sprintf(` - CASE WHEN ( - SELECT - true + caseSubQuery := fmt.Sprintf( + `CASE WHEN ( + SELECT true ) THEN ( - SELECT - %[1]q - FROM - %[2]q AS staging_table - WHERE - x.id = staging_table.id AND - %[1]q IS NOT NULL - ORDER BY - received_at DESC - LIMIT - 1 - ) END AS %[1]q -`, + SELECT %[1]q + FROM %[2]q AS staging_table + WHERE x.id = staging_table.id AND %[1]q IS NOT NULL + ORDER BY received_at DESC + LIMIT 1 + ) END AS %[1]q`, colName, unionStagingTableName, ) firstValProps = append(firstValProps, caseSubQuery) } - query := fmt.Sprintf(` - CREATE TEMPORARY TABLE %[5]s AS ( - ( - SELECT - id, - %[4]s - FROM - %[1]q.%[2]q - WHERE - id IN ( - SELECT - user_id - FROM - %[3]q - WHERE - user_id IS NOT NULL - ) - ) - UNION + query := fmt.Sprintf( + `CREATE TEMPORARY TABLE %[5]s AS ( ( - SELECT - user_id, - %[4]s - FROM - %[3]q - WHERE - user_id IS NOT NULL + SELECT id, %[4]s + FROM %[1]q.%[2]q + WHERE id IN ( + SELECT user_id + FROM %[3]q + WHERE user_id IS NOT NULL + ) ) - ); -`, + UNION + ( + SELECT user_id, %[4]s + FROM %[3]q + WHERE user_id IS NOT NULL + ) + );`, pg.Namespace, warehouseutils.UsersTable, identifyStagingTable, @@ -534,13 +495,8 @@ func (pg *Postgres) loadUsersTable( // Delete from users table if the id is present in the staging table primaryKey := "id" query = fmt.Sprintf(` - DELETE FROM - %[1]q.%[2]q using %[3]q _source - WHERE - ( - _source.%[4]s = %[1]s.%[2]s.%[4]s - ); -`, + DELETE FROM %[1]q.%[2]q using %[3]q _source + WHERE _source.%[4]s = %[1]s.%[2]s.%[4]s;`, pg.Namespace, warehouseutils.UsersTable, usersStagingTableName, @@ -596,3 +552,10 @@ func (pg *Postgres) loadUsersTable( return loadUsersTableResponse{} } + +func (pg *Postgres) shouldMerge() bool { + return !pg.Uploader.CanAppend() || + (pg.config.allowMerge && + pg.Warehouse.GetBoolDestinationConfig(model.EnableMergeSetting) && + !slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID)) +} diff --git a/warehouse/integrations/postgres/load_test.go b/warehouse/integrations/postgres/load_test.go index 71e22197e48..969c650b5bc 100644 --- a/warehouse/integrations/postgres/load_test.go +++ b/warehouse/integrations/postgres/load_test.go @@ -247,6 +247,7 @@ func TestLoadUsersTable(t *testing.T) { mockUploader := mockuploader.NewMockUploader(ctrl) mockUploader.EXPECT().GetTableSchemaInUpload(gomock.Any()).AnyTimes().DoAndReturn(f) mockUploader.EXPECT().GetTableSchemaInWarehouse(gomock.Any()).AnyTimes().DoAndReturn(f) + mockUploader.EXPECT().CanAppend().Return(true).AnyTimes() pg.DB = db pg.Namespace = namespace diff --git a/warehouse/integrations/postgres/postgres.go b/warehouse/integrations/postgres/postgres.go index ba8eec34476..efe3a51814d 100644 --- a/warehouse/integrations/postgres/postgres.go +++ b/warehouse/integrations/postgres/postgres.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/tunnelling" + "github.com/rudderlabs/rudder-go-kit/stats" sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" @@ -22,7 +24,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/client" - "github.com/rudderlabs/rudder-server/warehouse/tunnelling" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -122,6 +123,7 @@ type Postgres struct { LoadFileDownloader downloader.Downloader config struct { + allowMerge bool enableDeleteByJobs bool numWorkersDownloadLoadFiles int slowQueryThreshold time.Duration @@ -162,6 +164,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Postgres { pg.logger = log.Child("integrations").Child("postgres") pg.stats = stat + pg.config.allowMerge = conf.GetBool("Warehouse.postgres.allowMerge", true) pg.config.enableDeleteByJobs = conf.GetBool("Warehouse.postgres.enableDeleteByJobs", false) pg.config.numWorkersDownloadLoadFiles = conf.GetInt("Warehouse.postgres.numWorkersDownloadLoadFiles", 1) pg.config.slowQueryThreshold = conf.GetDuration("Warehouse.postgres.slowQueryThreshold", 5, time.Minute) @@ -223,7 +226,7 @@ func (pg *Postgres) connect() (*sqlmiddleware.DB, error) { if cred.tunnelInfo != nil { - db, err = tunnelling.SQLConnectThroughTunnel(dsn.String(), cred.tunnelInfo.Config) + db, err = tunnelling.Connect(dsn.String(), cred.tunnelInfo.Config) if err != nil { return nil, fmt.Errorf("opening connection to postgres through tunnelling: %w", err) } @@ -248,7 +251,7 @@ func (pg *Postgres) getConnectionCredentials() credentials { sslMode: sslMode, sslDir: warehouseutils.GetSSLKeyDirPath(pg.Warehouse.Destination.ID), timeout: pg.connectTimeout, - tunnelInfo: warehouseutils.ExtractTunnelInfoFromDestinationConfig( + tunnelInfo: tunnelling.ExtractTunnelInfoFromDestinationConfig( pg.Warehouse.Destination.Config, ), } diff --git a/warehouse/integrations/postgres/postgres_test.go b/warehouse/integrations/postgres/postgres_test.go index 9faae121191..6a741bd8f25 100644 --- a/warehouse/integrations/postgres/postgres_test.go +++ b/warehouse/integrations/postgres/postgres_test.go @@ -11,36 +11,29 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "github.com/rudderlabs/compose-test/compose" + "github.com/rudderlabs/compose-test/testcompose" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" - "github.com/rudderlabs/rudder-server/warehouse/integrations/postgres" - mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - - "github.com/rudderlabs/compose-test/compose" - - "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" - - backendconfig "github.com/rudderlabs/rudder-server/backend-config" - "github.com/rudderlabs/rudder-server/warehouse/client" - "github.com/rudderlabs/rudder-server/warehouse/tunnelling" - - "github.com/rudderlabs/compose-test/testcompose" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" + th "github.com/rudderlabs/rudder-server/testhelper" "github.com/rudderlabs/rudder-server/testhelper/health" - - "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" - - "github.com/stretchr/testify/require" - + "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" "github.com/rudderlabs/rudder-server/utils/misc" - "github.com/rudderlabs/rudder-server/warehouse/validations" - + "github.com/rudderlabs/rudder-server/warehouse/client" + "github.com/rudderlabs/rudder-server/warehouse/integrations/postgres" + whth "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + "github.com/rudderlabs/rudder-server/warehouse/integrations/tunnelling" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" + "github.com/rudderlabs/rudder-server/warehouse/validations" ) func TestIntegration(t *testing.T) { @@ -81,9 +74,9 @@ func TestIntegration(t *testing.T) { destType := warehouseutils.POSTGRES - namespace := testhelper.RandSchema(destType) - sourcesNamespace := testhelper.RandSchema(destType) - tunnelledNamespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) + sourcesNamespace := whth.RandSchema(destType) + tunnelledNamespace := whth.RandSchema(destType) host := "localhost" database := "rudderdb" @@ -142,7 +135,7 @@ func TestIntegration(t *testing.T) { } workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - testhelper.EnhanceWithDefaultEnvs(t) + whth.EnhanceWithDefaultEnvs(t) t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("MINIO_ACCESS_KEY_ID", accessKeyID) @@ -183,7 +176,7 @@ func TestIntegration(t *testing.T) { require.NoError(t, err) require.NoError(t, db.Ping()) - jobsDB := testhelper.JobsDB(t, jobsDBPort) + jobsDB := whth.JobsDB(t, jobsDBPort) testCases := []struct { name string @@ -192,10 +185,10 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string tables []string - stagingFilesEventsMap testhelper.EventsCountMap - loadFilesEventsMap testhelper.EventsCountMap - tableUploadsEventsMap testhelper.EventsCountMap - warehouseEventsMap testhelper.EventsCountMap + stagingFilesEventsMap whth.EventsCountMap + loadFilesEventsMap whth.EventsCountMap + tableUploadsEventsMap whth.EventsCountMap + warehouseEventsMap whth.EventsCountMap asyncJob bool stagingFilePrefix string }{ @@ -217,10 +210,10 @@ func TestIntegration(t *testing.T) { tables: []string{"tracks", "google_sheet"}, sourceID: sourcesSourceID, destinationID: sourcesDestinationID, - stagingFilesEventsMap: testhelper.SourcesStagingFilesEventsMap(), - loadFilesEventsMap: testhelper.SourcesLoadFilesEventsMap(), - tableUploadsEventsMap: testhelper.SourcesTableUploadsEventsMap(), - warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), + stagingFilesEventsMap: whth.SourcesStagingFilesEventsMap(), + loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), + tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), + warehouseEventsMap: whth.SourcesWarehouseEventsMap(), asyncJob: true, stagingFilePrefix: "testdata/sources-job", }, @@ -248,7 +241,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -267,12 +260,12 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -292,7 +285,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } if tc.asyncJob { ts2.UserID = ts1.UserID @@ -325,11 +318,11 @@ func TestIntegration(t *testing.T) { }, } - db, err := tunnelling.SQLConnectThroughTunnel(dsn, tunnelInfo.Config) + db, err := tunnelling.Connect(dsn, tunnelInfo.Config) require.NoError(t, err) require.NoError(t, db.Ping()) - jobsDB := testhelper.JobsDB(t, jobsDBPort) + jobsDB := whth.JobsDB(t, jobsDBPort) testcases := []struct { name string @@ -338,10 +331,10 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string tables []string - stagingFilesEventsMap testhelper.EventsCountMap - loadFilesEventsMap testhelper.EventsCountMap - tableUploadsEventsMap testhelper.EventsCountMap - warehouseEventsMap testhelper.EventsCountMap + stagingFilesEventsMap whth.EventsCountMap + loadFilesEventsMap whth.EventsCountMap + tableUploadsEventsMap whth.EventsCountMap + warehouseEventsMap whth.EventsCountMap stagingFilePrefix string }{ { @@ -379,7 +372,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, SourceID: tc.sourceID, @@ -398,12 +391,12 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, SourceID: tc.sourceID, @@ -422,7 +415,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts2.VerifyEvents(t) }) @@ -458,7 +451,7 @@ func TestIntegration(t *testing.T) { Enabled: true, RevisionID: "29eeuu9kywWsRAybaXcxcnTVEl8", } - testhelper.VerifyConfigurationTest(t, dest) + whth.VerifyConfigurationTest(t, dest) }) t.Run("Load Table", func(t *testing.T) { @@ -468,7 +461,7 @@ func TestIntegration(t *testing.T) { workspaceID = "test_workspace_id" ) - namespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) schemaInUpload := model.TableSchema{ "test_bool": "boolean", @@ -545,12 +538,12 @@ func TestIntegration(t *testing.T) { t.Run("schema does not exists", func(t *testing.T) { tableName := "schema_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - pg := postgres.New(config.Default, logger.NOP, stats.Default) + pg := postgres.New(config.New(), logger.NOP, memstats.New()) err := pg.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -561,12 +554,12 @@ func TestIntegration(t *testing.T) { t.Run("table does not exists", func(t *testing.T) { tableName := "table_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - pg := postgres.New(config.Default, logger.NOP, stats.Default) + pg := postgres.New(config.New(), logger.NOP, memstats.New()) err := pg.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -581,7 +574,7 @@ func TestIntegration(t *testing.T) { tableName := "merge_test_table" t.Run("without dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -589,8 +582,11 @@ func TestIntegration(t *testing.T) { c := config.New() c.Set("Warehouse.postgres.EnableSQLStatementExecutionPlanWorkspaceIDs", workspaceID) - pg := postgres.New(c, logger.NOP, stats.Default) - err := pg.Setup(ctx, warehouse, mockUploader) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + + pg := postgres.New(c, logger.NOP, memstats.New()) + err := pg.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = pg.CreateSchema(ctx) @@ -609,7 +605,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(0)) require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, pg.DB.DB, fmt.Sprintf(` SELECT id, @@ -628,10 +624,10 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) t.Run("with dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -639,8 +635,11 @@ func TestIntegration(t *testing.T) { c := config.New() c.Set("Warehouse.postgres.EnableSQLStatementExecutionPlanWorkspaceIDs", workspaceID) - pg := postgres.New(config.Default, logger.NOP, stats.Default) - err := pg.Setup(ctx, warehouse, mockUploader) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + + pg := postgres.New(config.New(), logger.NOP, memstats.New()) + err := pg.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = pg.CreateSchema(ctx) @@ -654,7 +653,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(0)) require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, pg.DB.DB, fmt.Sprintf(` SELECT id, @@ -673,13 +672,13 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.DedupTestRecords()) + require.Equal(t, records, whth.DedupTestRecords()) }) }) t.Run("append", func(t *testing.T) { tableName := "append_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) @@ -687,7 +686,7 @@ func TestIntegration(t *testing.T) { c := config.New() c.Set("Warehouse.postgres.skipDedupDestinationIDs", destinationID) - pg := postgres.New(c, logger.NOP, stats.Default) + pg := postgres.New(c, logger.NOP, memstats.New()) err := pg.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -707,7 +706,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, pg.DB.DB, fmt.Sprintf(` SELECT id, @@ -726,7 +725,7 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.AppendTestRecords()) + require.Equal(t, records, whth.AppendTestRecords()) }) t.Run("load file does not exists", func(t *testing.T) { tableName := "load_file_not_exists_test_table" @@ -736,7 +735,7 @@ func TestIntegration(t *testing.T) { }} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - pg := postgres.New(config.Default, logger.NOP, stats.Default) + pg := postgres.New(config.New(), logger.NOP, memstats.New()) err := pg.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -753,12 +752,12 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in number of columns", func(t *testing.T) { tableName := "mismatch_columns_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - pg := postgres.New(config.Default, logger.NOP, stats.Default) + pg := postgres.New(config.New(), logger.NOP, memstats.New()) err := pg.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -775,12 +774,12 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in schema", func(t *testing.T) { tableName := "mismatch_schema_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) - pg := postgres.New(config.Default, logger.NOP, stats.Default) + pg := postgres.New(config.New(), logger.NOP, memstats.New()) err := pg.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -797,12 +796,12 @@ func TestIntegration(t *testing.T) { t.Run("discards", func(t *testing.T) { tableName := warehouseutils.DiscardsTable - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := mockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) - pg := postgres.New(config.Default, logger.NOP, stats.Default) + pg := postgres.New(config.New(), logger.NOP, memstats.New()) err := pg.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -817,7 +816,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(6)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, pg.DB.DB, fmt.Sprintf(` SELECT column_name, @@ -834,7 +833,7 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.DiscardTestRecords()) + require.Equal(t, records, whth.DiscardTestRecords()) }) }) } @@ -854,6 +853,7 @@ func mockUploader( mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).Return(loadFiles, nil).AnyTimes() // Try removing this mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + mockUploader.EXPECT().CanAppend().Return(true).AnyTimes() return mockUploader } diff --git a/warehouse/integrations/redshift/redshift.go b/warehouse/integrations/redshift/redshift.go index 58041ec119a..fedf6a15296 100644 --- a/warehouse/integrations/redshift/redshift.go +++ b/warehouse/integrations/redshift/redshift.go @@ -10,16 +10,17 @@ import ( "os" "path/filepath" "regexp" + "slices" "sort" "strings" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/tunnelling" + "github.com/samber/lo" "github.com/rudderlabs/rudder-server/warehouse/integrations/types" - "golang.org/x/exp/slices" - "github.com/lib/pq" "github.com/tidwall/gjson" @@ -32,7 +33,6 @@ import ( sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/logfield" - "github.com/rudderlabs/rudder-server/warehouse/tunnelling" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -157,6 +157,7 @@ type Redshift struct { stats stats.Stats config struct { + allowMerge bool slowQueryThreshold time.Duration dedupWindow bool dedupWindowInHours time.Duration @@ -180,7 +181,7 @@ type s3Manifest struct { Entries []s3ManifestEntry `json:"entries"` } -type RedshiftCredentials struct { +type connectionCredentials struct { Host string Port string DbName string @@ -196,6 +197,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Redshift { rs.logger = log.Child("integrations").Child("redshift") rs.stats = stat + rs.config.allowMerge = conf.GetBool("Warehouse.redshift.allowMerge", true) rs.config.dedupWindow = conf.GetBool("Warehouse.redshift.dedupWindow", false) rs.config.dedupWindowInHours = conf.GetDuration("Warehouse.redshift.dedupWindowInHours", 720, time.Hour) rs.config.skipDedupDestinationIDs = conf.GetStringSlice("Warehouse.redshift.skipDedupDestinationIDs", nil) @@ -300,7 +302,8 @@ func (rs *Redshift) AddColumns(ctx context.Context, tableName string, columnsInf func CheckAndIgnoreColumnAlreadyExistError(err error) bool { if err != nil { - if e, ok := err.(*pq.Error); ok { + var e *pq.Error + if errors.As(err, &e) { if e.Code == "42701" { return true } @@ -315,10 +318,10 @@ func (rs *Redshift) DeleteBy(ctx context.Context, tableNames []string, params wa rs.logger.Infof("RS: Flag for enableDeleteByJobs is %t", rs.config.enableDeleteByJobs) for _, tb := range tableNames { sqlStatement := fmt.Sprintf(`DELETE FROM "%[1]s"."%[2]s" WHERE - context_sources_job_run_id <> $1 AND - context_sources_task_run_id <> $2 AND - context_source_id = $3 AND - received_at < $4`, + context_sources_job_run_id <> $1 AND + context_sources_task_run_id <> $2 AND + context_source_id = $3 AND + received_at < $4`, rs.Namespace, tb, ) @@ -461,7 +464,7 @@ func (rs *Redshift) loadTable( logfield.WorkspaceID, rs.Warehouse.WorkspaceID, logfield.Namespace, rs.Namespace, logfield.TableName, tableName, - logfield.LoadTableStrategy, rs.loadTableStrategy(), + logfield.ShouldMerge, rs.shouldMerge(), ) log.Infow("started loading") @@ -516,7 +519,7 @@ func (rs *Redshift) loadTable( } var rowsDeleted int64 - if !slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID) { + if rs.shouldMerge() { log.Infow("deleting from load table") rowsDeleted, err = rs.deleteFromLoadTable( ctx, txn, tableName, @@ -549,13 +552,6 @@ func (rs *Redshift) loadTable( }, stagingTableName, nil } -func (rs *Redshift) loadTableStrategy() string { - if slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID) { - return "APPEND" - } - return "MERGE" -} - func (rs *Redshift) copyIntoLoadTable( ctx context.Context, txn *sqlmiddleware.Tx, @@ -579,15 +575,13 @@ func (rs *Redshift) copyIntoLoadTable( var copyStmt string if rs.Uploader.GetLoadFileType() == warehouseutils.LoadFileTypeParquet { - copyStmt = fmt.Sprintf(` - COPY %s - FROM - '%s' + copyStmt = fmt.Sprintf( + `COPY %s + FROM '%s' ACCESS_KEY_ID '%s' SECRET_ACCESS_KEY '%s' SESSION_TOKEN '%s' - MANIFEST FORMAT PARQUET; - `, + MANIFEST FORMAT PARQUET;`, fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName), manifestS3Location, tempAccessKeyId, @@ -595,10 +589,9 @@ func (rs *Redshift) copyIntoLoadTable( token, ) } else { - copyStmt = fmt.Sprintf(` - COPY %s(%s) - FROM - '%s' + copyStmt = fmt.Sprintf( + `COPY %s(%s) + FROM '%s' CSV GZIP ACCESS_KEY_ID '%s' SECRET_ACCESS_KEY '%s' @@ -608,8 +601,7 @@ func (rs *Redshift) copyIntoLoadTable( TIMEFORMAT 'auto' MANIFEST TRUNCATECOLUMNS EMPTYASNULL BLANKSASNULL FILLRECORD ACCEPTANYDATE TRIMBLANKS ACCEPTINVCHARS COMPUPDATE OFF - STATUPDATE OFF; - `, + STATUPDATE OFF;`, fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName), sortedColumnNames, manifestS3Location, @@ -638,14 +630,10 @@ func (rs *Redshift) deleteFromLoadTable( primaryKey = column } - deleteStmt := fmt.Sprintf(` - DELETE FROM - %[1]s.%[2]q - USING - %[1]s.%[3]q _source - WHERE - _source.%[4]s = %[1]s.%[2]q.%[4]s -`, + deleteStmt := fmt.Sprintf( + `DELETE FROM %[1]s.%[2]q + USING %[1]s.%[3]q _source + WHERE _source.%[4]s = %[1]s.%[2]q.%[4]s`, rs.Namespace, tableName, stagingTableName, @@ -653,9 +641,8 @@ func (rs *Redshift) deleteFromLoadTable( ) if rs.config.dedupWindow { if _, ok := tableSchemaAfterUpload["received_at"]; ok { - deleteStmt += fmt.Sprintf(` - AND %[1]s.%[2]q.received_at > GETDATE() - INTERVAL '%[3]d HOUR' -`, + deleteStmt += fmt.Sprintf( + ` AND %[1]s.%[2]q.received_at > GETDATE() - INTERVAL '%[3]d HOUR'`, rs.Namespace, tableName, rs.config.dedupWindowInHours/time.Hour, @@ -663,10 +650,8 @@ func (rs *Redshift) deleteFromLoadTable( } } if tableName == warehouseutils.DiscardsTable { - deleteStmt += fmt.Sprintf(` - AND _source.%[3]s = %[1]s.%[2]q.%[3]s - AND _source.%[4]s = %[1]s.%[2]q.%[4]s -`, + deleteStmt += fmt.Sprintf( + ` AND _source.%[3]s = %[1]s.%[2]q.%[3]s AND _source.%[4]s = %[1]s.%[2]q.%[4]s`, rs.Namespace, tableName, "table_name", @@ -697,10 +682,9 @@ func (rs *Redshift) insertIntoLoadTable( sortedColumnKeys, ) - insertStmt := fmt.Sprintf(` - INSERT INTO %[1]q.%[2]q (%[3]s) - SELECT - %[3]s + insertStmt := fmt.Sprintf( + `INSERT INTO %[1]q.%[2]q (%[3]s) + SELECT %[3]s FROM ( SELECT @@ -710,12 +694,9 @@ func (rs *Redshift) insertIntoLoadTable( ORDER BY received_at DESC ) AS _rudder_staging_row_number - FROM - %[1]q.%[4]q + FROM %[1]q.%[4]q ) AS _ - WHERE - _rudder_staging_row_number = 1; -`, + WHERE _rudder_staging_row_number = 1;`, rs.Namespace, tableName, quotedColumnNames, @@ -740,14 +721,17 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { firstValProps []string ) - rs.logger.Infow("started loading for identifies and users tables", + logFields := []any{ logfield.SourceID, rs.Warehouse.Source.ID, logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, logfield.DestinationID, rs.Warehouse.Destination.ID, logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, logfield.WorkspaceID, rs.Warehouse.WorkspaceID, logfield.Namespace, rs.Namespace, - ) + logfield.ShouldMerge, rs.shouldMerge(), + logfield.TableName, warehouseutils.UsersTable, + } + rs.logger.Infow("started loading for identifies and users tables", logFields...) _, identifyStagingTable, err = rs.loadTable(ctx, warehouseutils.IdentifiesTable, rs.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), rs.Uploader.GetTableSchemaInWarehouse(warehouseutils.IdentifiesTable), true) if err != nil { @@ -791,15 +775,12 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { stagingTableName := warehouseutils.StagingTableName(provider, warehouseutils.UsersTable, tableNameLimit) - query = fmt.Sprintf(` - CREATE TABLE %[1]q.%[2]q AS ( - SELECT - DISTINCT * + query = fmt.Sprintf( + `CREATE TABLE %[1]q.%[2]q AS ( + SELECT DISTINCT * FROM ( - SELECT - id, - %[3]s + SELECT id, %[3]s FROM ( ( @@ -820,18 +801,13 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { ) UNION ( - SELECT - user_id, - %[6]s - FROM - %[1]q.%[5]q - WHERE - user_id IS NOT NULL + SELECT user_id, %[6]s + FROM %[1]q.%[5]q + WHERE user_id IS NOT NULL ) ) ) - ); -`, + );`, rs.Namespace, stagingTableName, strings.Join(firstValProps, ","), @@ -849,16 +825,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { if _, err = txn.ExecContext(ctx, query); err != nil { _ = txn.Rollback() - rs.logger.Warnw("creating staging table for users", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, warehouseutils.UsersTable, - logfield.Error, err.Error(), - ) + rs.logger.Warnw("creating staging table for users", append(logFields, logfield.Error, err.Error())...) return map[string]error{ warehouseutils.IdentifiesTable: nil, warehouseutils.UsersTable: fmt.Errorf("creating staging table for users: %w", err), @@ -866,77 +833,45 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { } defer rs.dropStagingTables(ctx, []string{stagingTableName}) - primaryKey := "id" - query = fmt.Sprintf(` - DELETE FROM - %[1]s.%[2]q USING %[1]s.%[3]q _source - WHERE - ( - _source.%[4]s = %[1]s.%[2]s.%[4]s - ); -`, - rs.Namespace, - warehouseutils.UsersTable, - stagingTableName, - primaryKey, - ) + if rs.shouldMerge() { + primaryKey := "id" + query = fmt.Sprintf(`DELETE FROM %[1]s.%[2]q USING %[1]s.%[3]q _source + WHERE _source.%[4]s = %[1]s.%[2]s.%[4]s;`, + rs.Namespace, + warehouseutils.UsersTable, + stagingTableName, + primaryKey, + ) - if _, err = txn.ExecContext(ctx, query); err != nil { - _ = txn.Rollback() + if _, err = txn.ExecContext(ctx, query); err != nil { + _ = txn.Rollback() - rs.logger.Warnw("deleting from users table for dedup", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.Query, query, - logfield.TableName, warehouseutils.UsersTable, - logfield.Error, err.Error(), - ) - return map[string]error{ - warehouseutils.UsersTable: fmt.Errorf("deleting from main table for dedup: %w", normalizeError(err)), + rs.logger.Warnw("deleting from users table for dedup", append(logFields, + logfield.Query, query, + logfield.Error, err.Error(), + )...) + return map[string]error{ + warehouseutils.UsersTable: fmt.Errorf("deleting from main table for dedup: %w", normalizeError(err)), + } } } - query = fmt.Sprintf(` - INSERT INTO %[1]q.%[2]q (%[4]s) - SELECT - %[4]s - FROM - %[1]q.%[3]q; -`, + query = fmt.Sprintf( + `INSERT INTO %[1]q.%[2]q (%[4]s) + SELECT %[4]s + FROM %[1]q.%[3]q;`, rs.Namespace, warehouseutils.UsersTable, stagingTableName, warehouseutils.DoubleQuoteAndJoinByComma(append([]string{"id"}, userColNames...)), ) - rs.logger.Infow("inserting into users table", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, warehouseutils.UsersTable, - logfield.Query, query, - ) + rs.logger.Infow("inserting into users table", append(logFields, logfield.Query, query)...) if _, err = txn.ExecContext(ctx, query); err != nil { _ = txn.Rollback() - rs.logger.Warnw("failed inserting into users table", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, warehouseutils.UsersTable, - logfield.Error, err.Error(), - ) + rs.logger.Warnw("failed inserting into users table", append(logFields, logfield.Error, err.Error())...) return map[string]error{ warehouseutils.IdentifiesTable: nil, @@ -947,16 +882,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { if err = txn.Commit(); err != nil { _ = txn.Rollback() - rs.logger.Warnw("committing transaction for user table", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, warehouseutils.UsersTable, - logfield.Error, err.Error(), - ) + rs.logger.Warnw("committing transaction for user table", append(logFields, logfield.Error, err.Error())...) return map[string]error{ warehouseutils.IdentifiesTable: nil, @@ -964,14 +890,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { } } - rs.logger.Infow("completed loading for users and identifies tables", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - ) + rs.logger.Infow("completed loading for users and identifies tables", logFields...) return map[string]error{ warehouseutils.IdentifiesTable: nil, @@ -1003,7 +922,7 @@ func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) { ) if cred.TunnelInfo != nil { - if db, err = tunnelling.SQLConnectThroughTunnel(dsn.String(), cred.TunnelInfo.Config); err != nil { + if db, err = tunnelling.Connect(dsn.String(), cred.TunnelInfo.Config); err != nil { return nil, fmt.Errorf("connecting to redshift through tunnel: %w", err) } } else { @@ -1012,8 +931,7 @@ func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) { } } - stmt := `SET query_group to 'RudderStack'` - _, err = db.ExecContext(ctx, stmt) + _, err = db.ExecContext(ctx, `SET query_group to 'RudderStack'`) if err != nil { return nil, fmt.Errorf("redshift set query_group error : %v", err) } @@ -1040,15 +958,9 @@ func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) { } func (rs *Redshift) dropDanglingStagingTables(ctx context.Context) { - sqlStatement := ` - SELECT - table_name - FROM - information_schema.tables - WHERE - table_schema = $1 AND - table_name like $2; - ` + sqlStatement := `SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 AND table_name like $2;` rows, err := rs.DB.QueryContext(ctx, sqlStatement, rs.Namespace, @@ -1121,12 +1033,7 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu // creating staging column stagingColumnType = getRSDataType(columnType) stagingColumnName = fmt.Sprintf(`%s-staging-%s`, columnName, misc.FastUUID().String()) - query = fmt.Sprintf(` - ALTER TABLE - %q.%q - ADD - COLUMN %q %s; - `, + query = fmt.Sprintf(`ALTER TABLE %q.%q ADD COLUMN %q %s;`, rs.Namespace, tableName, stagingColumnName, @@ -1137,14 +1044,10 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu } // populating staging column - query = fmt.Sprintf(` - UPDATE - %[1]q.%[2]q - SET - %[3]q = CAST (%[4]q AS %[5]s) - WHERE - %[4]q IS NOT NULL; - `, + query = fmt.Sprintf( + `UPDATE %[1]q.%[2]q + SET %[3]q = CAST (%[4]q AS %[5]s) + WHERE %[4]q IS NOT NULL;`, rs.Namespace, tableName, stagingColumnName, @@ -1157,12 +1060,8 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu // renaming original column to deprecated column deprecatedColumnName = fmt.Sprintf(`%s-deprecated-%s`, columnName, misc.FastUUID().String()) - query = fmt.Sprintf(` - ALTER TABLE - %[1]q.%[2]q - RENAME COLUMN - %[3]q TO %[4]q; - `, + query = fmt.Sprintf( + `ALTER TABLE %[1]q.%[2]q RENAME COLUMN %[3]q TO %[4]q;`, rs.Namespace, tableName, columnName, @@ -1173,12 +1072,8 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu } // renaming staging column to original column - query = fmt.Sprintf(` - ALTER TABLE - %[1]q.%[2]q - RENAME COLUMN - %[3]q TO %[4]q; - `, + query = fmt.Sprintf( + `ALTER TABLE %[1]q.%[2]q RENAME COLUMN %[3]q TO %[4]q;`, rs.Namespace, tableName, stagingColumnName, @@ -1195,20 +1090,17 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu // dropping deprecated column // Since dropping the column can fail, we need to do it outside the transaction - // Because if will fail during the commit of the transaction + // Because if it will fail during the commit of the transaction // https://github.com/lib/pq/blob/d5affd5073b06f745459768de35356df2e5fd91d/conn.go#L600 - query = fmt.Sprintf(` - ALTER TABLE - %[1]q.%[2]q - DROP COLUMN - %[3]q; - `, + query = fmt.Sprintf( + `ALTER TABLE %[1]q.%[2]q DROP COLUMN %[3]q;`, rs.Namespace, tableName, deprecatedColumnName, ) if _, err = rs.DB.ExecContext(ctx, query); err != nil { - if pqError, ok := err.(*pq.Error); !ok || pqError.Code != "2BP01" { + var pqError *pq.Error + if !errors.As(err, &pqError) || pqError.Code != "2BP01" { return model.AlterTableResponse{}, fmt.Errorf("drop deprecated column: %w", err) } @@ -1228,15 +1120,15 @@ func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, colu return res, nil } -func (rs *Redshift) getConnectionCredentials() RedshiftCredentials { - creds := RedshiftCredentials{ +func (rs *Redshift) getConnectionCredentials() connectionCredentials { + creds := connectionCredentials{ Host: warehouseutils.GetConfigValue(RSHost, rs.Warehouse), Port: warehouseutils.GetConfigValue(RSPort, rs.Warehouse), DbName: warehouseutils.GetConfigValue(RSDbName, rs.Warehouse), Username: warehouseutils.GetConfigValue(RSUserName, rs.Warehouse), Password: warehouseutils.GetConfigValue(RSPassword, rs.Warehouse), timeout: rs.connectTimeout, - TunnelInfo: warehouseutils.ExtractTunnelInfoFromDestinationConfig(rs.Warehouse.Destination.Config), + TunnelInfo: tunnelling.ExtractTunnelInfoFromDestinationConfig(rs.Warehouse.Destination.Config), } return creds @@ -1247,18 +1139,13 @@ func (rs *Redshift) FetchSchema(ctx context.Context) (model.Schema, model.Schema schema := make(model.Schema) unrecognizedSchema := make(model.Schema) - sqlStatement := ` - SELECT + sqlStatement := `SELECT table_name, column_name, data_type, character_maximum_length - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - table_schema = $1 - and table_name not like $2; - ` + FROM INFORMATION_SCHEMA.COLUMNS + WHERE table_schema = $1 and table_name not like $2;` rows, err := rs.DB.QueryContext( ctx, @@ -1439,12 +1326,20 @@ func (rs *Redshift) SetConnectionTimeout(timeout time.Duration) { rs.connectTimeout = timeout } +func (rs *Redshift) shouldMerge() bool { + return !rs.Uploader.CanAppend() || + (rs.config.allowMerge && + rs.Warehouse.GetBoolDestinationConfig(model.EnableMergeSetting) && + !slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID)) +} + func (*Redshift) ErrorMappings() []model.JobError { return errorsMappings } func normalizeError(err error) error { - if pqErr, ok := err.(*pq.Error); ok { + var pqErr *pq.Error + if errors.As(err, &pqErr) { return fmt.Errorf("pq: message: %s, detail: %s", pqErr.Message, pqErr.Detail, diff --git a/warehouse/integrations/redshift/redshift_test.go b/warehouse/integrations/redshift/redshift_test.go index 483e18cd814..abdcf72fe21 100644 --- a/warehouse/integrations/redshift/redshift_test.go +++ b/warehouse/integrations/redshift/redshift_test.go @@ -7,52 +7,39 @@ import ( "errors" "fmt" "os" + "slices" "strconv" "strings" "testing" "time" - "golang.org/x/exp/slices" - "github.com/golang/mock/gomock" - - "github.com/rudderlabs/rudder-go-kit/filemanager" - mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - - "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" - - "github.com/rudderlabs/compose-test/compose" - - "github.com/rudderlabs/rudder-server/warehouse/integrations/redshift" - - "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" - - sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - "github.com/lib/pq" "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/compose-test/testcompose" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" + th "github.com/rudderlabs/rudder-server/testhelper" "github.com/rudderlabs/rudder-server/testhelper/health" - - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" - - "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" - + "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" "github.com/rudderlabs/rudder-server/utils/misc" - "github.com/rudderlabs/rudder-server/warehouse/validations" - - backendconfig "github.com/rudderlabs/rudder-server/backend-config" - - "github.com/stretchr/testify/require" - "github.com/rudderlabs/rudder-server/warehouse/client" + sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/integrations/redshift" + whth "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" + "github.com/rudderlabs/rudder-server/warehouse/validations" ) type testCredentials struct { @@ -117,8 +104,8 @@ func TestIntegration(t *testing.T) { destType := warehouseutils.RS - namespace := testhelper.RandSchema(destType) - sourcesNamespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) + sourcesNamespace := whth.RandSchema(destType) rsTestCredentials, err := rsTestCredentials() require.NoError(t, err) @@ -144,7 +131,7 @@ func TestIntegration(t *testing.T) { } workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - testhelper.EnhanceWithDefaultEnvs(t) + whth.EnhanceWithDefaultEnvs(t) t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) @@ -185,7 +172,7 @@ func TestIntegration(t *testing.T) { t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_DEDUP_WINDOW", "true") t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_DEDUP_WINDOW_IN_HOURS", "5") - jobsDB := testhelper.JobsDB(t, jobsDBPort) + jobsDB := whth.JobsDB(t, jobsDBPort) testcase := []struct { name string @@ -194,10 +181,10 @@ func TestIntegration(t *testing.T) { sourceID string destinationID string tables []string - stagingFilesEventsMap testhelper.EventsCountMap - loadFilesEventsMap testhelper.EventsCountMap - tableUploadsEventsMap testhelper.EventsCountMap - warehouseEventsMap testhelper.EventsCountMap + stagingFilesEventsMap whth.EventsCountMap + loadFilesEventsMap whth.EventsCountMap + tableUploadsEventsMap whth.EventsCountMap + warehouseEventsMap whth.EventsCountMap asyncJob bool stagingFilePrefix string }{ @@ -217,10 +204,10 @@ func TestIntegration(t *testing.T) { tables: []string{"tracks", "google_sheet"}, sourceID: sourcesSourceID, destinationID: sourcesDestinationID, - stagingFilesEventsMap: testhelper.SourcesStagingFilesEventsMap(), - loadFilesEventsMap: testhelper.SourcesLoadFilesEventsMap(), - tableUploadsEventsMap: testhelper.SourcesTableUploadsEventsMap(), - warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), + stagingFilesEventsMap: whth.SourcesStagingFilesEventsMap(), + loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), + tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), + warehouseEventsMap: whth.SourcesWarehouseEventsMap(), asyncJob: true, stagingFilePrefix: "testdata/sources-job", }, @@ -259,7 +246,7 @@ func TestIntegration(t *testing.T) { } t.Log("verifying test case 1") - ts1 := testhelper.TestConfig{ + ts1 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -278,12 +265,12 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } ts1.VerifyEvents(t) t.Log("verifying test case 2") - ts2 := testhelper.TestConfig{ + ts2 := whth.TestConfig{ WriteKey: tc.writeKey, Schema: tc.schema, Tables: tc.tables, @@ -303,7 +290,7 @@ func TestIntegration(t *testing.T) { JobRunID: misc.FastUUID().String(), TaskRunID: misc.FastUUID().String(), StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", - UserID: testhelper.GetUserId(destType), + UserID: whth.GetUserId(destType), } if tc.asyncJob { ts2.UserID = ts1.UserID @@ -352,7 +339,7 @@ func TestIntegration(t *testing.T) { Enabled: true, RevisionID: "29HgOWobrn0RYZLpaSwPIbN2987", } - testhelper.VerifyConfigurationTest(t, dest) + whth.VerifyConfigurationTest(t, dest) }) t.Run("Load Table", func(t *testing.T) { @@ -362,7 +349,7 @@ func TestIntegration(t *testing.T) { workspaceID = "test_workspace_id" ) - namespace := testhelper.RandSchema(destType) + namespace := whth.RandSchema(destType) t.Cleanup(func() { require.Eventually(t, func() bool { @@ -443,12 +430,12 @@ func TestIntegration(t *testing.T) { t.Run("schema does not exists", func(t *testing.T) { tableName := "schema_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) err := rs.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -459,12 +446,12 @@ func TestIntegration(t *testing.T) { t.Run("table does not exists", func(t *testing.T) { tableName := "table_not_exists_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) err := rs.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -476,15 +463,14 @@ func TestIntegration(t *testing.T) { require.Nil(t, loadTableStat) }) t.Run("merge", func(t *testing.T) { - tableName := "merge_test_table" - t.Run("without dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + tableName := "merge_without_dedup_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) - d := redshift.New(config.Default, logger.NOP, stats.Default) + d := redshift.New(config.New(), logger.NOP, memstats.New()) err := d.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -501,12 +487,12 @@ func TestIntegration(t *testing.T) { loadTableStat, err = d.LoadTable(ctx, tableName) require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT id, received_at, test_bool, @@ -514,25 +500,26 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, whth.DedupTwiceTestRecords(), records) }) t.Run("with dedup", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + tableName := "merge_with_dedup_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) - d := redshift.New(config.Default, logger.NOP, stats.Default) - err := d.Setup(ctx, warehouse, mockUploader) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + + d := redshift.New(config.New(), logger.NOP, memstats.New()) + err := d.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) @@ -543,12 +530,17 @@ func TestIntegration(t *testing.T) { loadTableStat, err := d.LoadTable(ctx, tableName) require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) require.Equal(t, loadTableStat.RowsInserted, int64(0)) require.Equal(t, loadTableStat.RowsUpdated, int64(14)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT id, received_at, test_bool, @@ -556,29 +548,82 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.DedupTestRecords()) + require.Equal(t, whth.DedupTestRecords(), records) }) t.Run("with dedup window", func(t *testing.T) { - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + tableName := "merge_with_dedup_window_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + + c := config.New() + c.Set("Warehouse.redshift.dedupWindow", true) + c.Set("Warehouse.redshift.dedupWindowInHours", 999999) + + d := redshift.New(c, logger.NOP, memstats.New()) + err := d.Setup(ctx, mergeWarehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + ORDER BY id;`, + namespace, + tableName, + ), + ) + require.Equal(t, whth.DedupTestRecords(), records) + }) + t.Run("with short dedup window", func(t *testing.T) { + tableName := "merge_with_short_dedup_window_test_table" + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + mergeWarehouse := th.Clone(t, warehouse) + mergeWarehouse.Destination.Config[string(model.EnableMergeSetting)] = true + c := config.New() c.Set("Warehouse.redshift.dedupWindow", true) c.Set("Warehouse.redshift.dedupWindowInHours", 0) - d := redshift.New(c, logger.NOP, stats.Default) - err := d.Setup(ctx, warehouse, mockUploader) + d := redshift.New(c, logger.NOP, memstats.New()) + err := d.Setup(ctx, mergeWarehouse, mockUploader) require.NoError(t, err) err = d.CreateSchema(ctx) @@ -592,9 +637,14 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, - fmt.Sprintf(` - SELECT + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := whth.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf( + `SELECT id, received_at, test_bool, @@ -602,22 +652,19 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %s.%s - ORDER BY - id; - `, + FROM %s.%s + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.DedupTwiceTestRecords()) + require.Equal(t, whth.DedupTwiceTestRecords(), records) }) }) t.Run("append", func(t *testing.T) { tableName := "append_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) @@ -625,7 +672,7 @@ func TestIntegration(t *testing.T) { c := config.New() c.Set("Warehouse.redshift.skipDedupDestinationIDs", []string{destinationID}) - rs := redshift.New(c, logger.NOP, stats.Default) + rs := redshift.New(c, logger.NOP, memstats.New()) err := rs.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -645,7 +692,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, rs.DB.DB, fmt.Sprintf(` SELECT id, @@ -664,7 +711,7 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.AppendTestRecords()) + require.Equal(t, records, whth.AppendTestRecords()) }) t.Run("load file does not exists", func(t *testing.T) { tableName := "load_file_not_exists_test_table" @@ -674,7 +721,7 @@ func TestIntegration(t *testing.T) { }} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) err := rs.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -691,12 +738,12 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in number of columns", func(t *testing.T) { tableName := "mismatch_columns_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) err := rs.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -713,12 +760,12 @@ func TestIntegration(t *testing.T) { t.Run("mismatch in schema", func(t *testing.T) { tableName := "mismatch_schema_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) err := rs.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -735,12 +782,12 @@ func TestIntegration(t *testing.T) { t.Run("discards", func(t *testing.T) { tableName := warehouseutils.DiscardsTable - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema, warehouseutils.LoadFileTypeCsv) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) err := rs.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -755,7 +802,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(6)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, rs.DB.DB, fmt.Sprintf(` SELECT column_name, @@ -772,12 +819,12 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.DiscardTestRecords()) + require.Equal(t, records, whth.DiscardTestRecords()) }) t.Run("parquet", func(t *testing.T) { tableName := "parquet_test_table" - uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) + uploadOutput := whth.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) fileStat, err := os.Stat("../testdata/load.parquet") require.NoError(t, err) @@ -788,7 +835,7 @@ func TestIntegration(t *testing.T) { }} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInUpload, warehouseutils.LoadFileTypeParquet) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) err = rs.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -803,7 +850,7 @@ func TestIntegration(t *testing.T) { require.Equal(t, loadTableStat.RowsInserted, int64(14)) require.Equal(t, loadTableStat.RowsUpdated, int64(0)) - records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + records := whth.RetrieveRecordsFromWarehouse(t, rs.DB.DB, fmt.Sprintf(` SELECT id, @@ -822,7 +869,7 @@ func TestIntegration(t *testing.T) { tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, records, whth.SampleTestRecords()) }) }) } @@ -897,7 +944,7 @@ func TestRedshift_AlterColumn(t *testing.T) { t.Log("db:", pgResource.DBDsn) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) rs.DB = sqlmiddleware.New(pgResource.DB) rs.Namespace = testNamespace @@ -991,6 +1038,7 @@ func newMockUploader( mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() mockUploader.EXPECT().GetLoadFileType().Return(loadFileType).AnyTimes() + mockUploader.EXPECT().CanAppend().Return(true).AnyTimes() return mockUploader } diff --git a/warehouse/integrations/snowflake/snowflake.go b/warehouse/integrations/snowflake/snowflake.go index eed8ab2aba4..d1c707a002a 100644 --- a/warehouse/integrations/snowflake/snowflake.go +++ b/warehouse/integrations/snowflake/snowflake.go @@ -13,8 +13,6 @@ import ( "strings" "time" - "github.com/rudderlabs/rudder-server/warehouse/integrations/types" - "github.com/samber/lo" snowflake "github.com/snowflakedb/gosnowflake" @@ -24,6 +22,7 @@ import ( "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/client" sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" "github.com/rudderlabs/rudder-server/warehouse/internal/model" lf "github.com/rudderlabs/rudder-server/warehouse/logfield" whutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -44,9 +43,6 @@ const ( role = "role" password = "password" application = "Rudderstack_Warehouse" - - loadTableStrategyMergeMode = "MERGE" - loadTableStrategyAppendMode = "APPEND" ) var primaryKeyMap = map[string]string{ @@ -169,7 +165,7 @@ type Snowflake struct { stats stats.Stats config struct { - loadTableStrategy string + allowMerge bool slowQueryThreshold time.Duration enableDeleteByJobs bool @@ -186,16 +182,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) (*Snowflake, sf.logger = log.Child("integrations").Child("snowflake") sf.stats = stat - loadTableStrategy := conf.GetString("Warehouse.snowflake.loadTableStrategy", loadTableStrategyMergeMode) - switch loadTableStrategy { - case loadTableStrategyMergeMode, loadTableStrategyAppendMode: - sf.config.loadTableStrategy = loadTableStrategy - default: - return nil, fmt.Errorf("loadTableStrategy out of the known domain [%+v]: %v", - []string{loadTableStrategyMergeMode, loadTableStrategyAppendMode}, loadTableStrategy, - ) - } - + sf.config.allowMerge = conf.GetBool("Warehouse.snowflake.allowMerge", true) sf.config.enableDeleteByJobs = conf.GetBool("Warehouse.snowflake.enableDeleteByJobs", false) sf.config.slowQueryThreshold = conf.GetDuration("Warehouse.snowflake.slowQueryThreshold", 5, time.Minute) sf.config.debugDuplicateWorkspaceIDs = conf.GetStringSlice("Warehouse.snowflake.debugDuplicateWorkspaceIDs", nil) @@ -285,7 +272,8 @@ func (sf *Snowflake) createSchema(ctx context.Context) (err error) { func checkAndIgnoreAlreadyExistError(err error) bool { if err != nil { // TODO: throw error if column already exists but of different type - if e, ok := err.(*snowflake.SnowflakeError); ok && e.SQLState == "42601" { + var e *snowflake.SnowflakeError + if errors.As(err, &e) && e.SQLState == "42601" { return true } return false @@ -312,9 +300,8 @@ func (sf *Snowflake) DeleteBy(ctx context.Context, tableNames []string, params w ) log.Infow("Cleaning up the following tables in snowflake") - sqlStatement := fmt.Sprintf(` - DELETE FROM - %[1]q.%[2]q + sqlStatement := fmt.Sprintf( + `DELETE FROM %[1]q.%[2]q WHERE context_sources_job_run_id <> '%[3]s' AND context_sources_task_run_id <> '%[4]s' AND @@ -360,7 +347,7 @@ func (sf *Snowflake) loadTable( lf.WorkspaceID, sf.Warehouse.WorkspaceID, lf.Namespace, sf.Namespace, lf.TableName, tableName, - lf.LoadTableStrategy, sf.config.loadTableStrategy, + lf.ShouldMerge, sf.ShouldMerge(), ) log.Infow("started loading") @@ -384,7 +371,7 @@ func (sf *Snowflake) loadTable( // Truncating the columns by default to avoid size limitation errors // https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions - if sf.ShouldAppend() { + if !sf.ShouldMerge() { log.Infow("copying data into main table") loadTableStats, err := sf.copyInto(ctx, db, schemaIdentifier, tableName, sortedColumnNames, tableName) if err != nil { @@ -551,24 +538,17 @@ func (sf *Snowflake) sampleDuplicateMessages( mainTable := fmt.Sprintf("%s.%q", identifier, mainTableName) stagingTable := fmt.Sprintf("%s.%q", identifier, stagingTableName) - rows, err := db.QueryContext(ctx, ` - SELECT - ID, - RECEIVED_AT - FROM - `+mainTable+` + rows, err := db.QueryContext(ctx, + `SELECT ID, RECEIVED_AT + FROM `+mainTable+` WHERE RECEIVED_AT > (SELECT DATEADD(day,-?,current_date())) AND ID IN ( - SELECT - ID - FROM - `+stagingTable+` + SELECT ID + FROM `+stagingTable+` ) ORDER BY RECEIVED_AT ASC - LIMIT - ?; -`, + LIMIT ?;`, sf.config.debugDuplicateIntervalInDays, sf.config.debugDuplicateLimit, ) @@ -843,11 +823,13 @@ func (sf *Snowflake) LoadIdentityMappingsTable(ctx context.Context) error { return nil } -// ShouldAppend returns true if: -// * the load table strategy is "append" mode -// * the uploader says we can append -func (sf *Snowflake) ShouldAppend() bool { - return sf.config.loadTableStrategy == loadTableStrategyAppendMode && sf.Uploader.CanAppend() +// ShouldMerge returns true if: +// * the uploader says we cannot append +// * the server configuration says we can merge +// * the user opted-in +func (sf *Snowflake) ShouldMerge() bool { + return !sf.Uploader.CanAppend() || + (sf.config.allowMerge && sf.Warehouse.GetBoolDestinationConfig(model.EnableMergeSetting)) } func (sf *Snowflake) LoadUserTables(ctx context.Context) map[string]error { @@ -872,7 +854,7 @@ func (sf *Snowflake) LoadUserTables(ctx context.Context) map[string]error { lf.WorkspaceID, sf.Warehouse.WorkspaceID, lf.Namespace, sf.Namespace, lf.TableName, whutils.UsersTable, - lf.LoadTableStrategy, sf.config.loadTableStrategy, + lf.ShouldMerge, sf.ShouldMerge(), ) log.Infow("started loading for identifies and users tables") @@ -889,7 +871,7 @@ func (sf *Snowflake) LoadUserTables(ctx context.Context) map[string]error { } schemaIdentifier := sf.schemaIdentifier() - if sf.ShouldAppend() { + if !sf.ShouldMerge() { tmpIdentifiesStagingTable := whutils.StagingTableName(provider, identifiesTable, tableNameLimit) sqlStatement := fmt.Sprintf( `CREATE TEMPORARY TABLE %[1]s.%[2]q LIKE %[1]s.%[3]q;`, diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index 3ff4a612040..6782aac3e18 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -7,27 +7,23 @@ import ( "errors" "fmt" "os" + "slices" "strconv" "strings" "testing" "time" - "golang.org/x/exp/slices" - - "github.com/samber/lo" - - "github.com/rudderlabs/rudder-go-kit/filemanager" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "github.com/golang/mock/gomock" + "github.com/samber/lo" sfdb "github.com/snowflakedb/gosnowflake" "github.com/stretchr/testify/require" "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/compose-test/testcompose" "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/runner" @@ -38,6 +34,7 @@ import ( "github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake" "github.com/rudderlabs/rudder-server/warehouse/integrations/testhelper" mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" whutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/rudderlabs/rudder-server/warehouse/validations" ) @@ -133,57 +130,53 @@ func TestIntegration(t *testing.T) { rbacCredentials, err := getSnowflakeTestCredentials(testRBACKey) require.NoError(t, err) - templateConfigurations := map[string]any{ - "workspaceID": workspaceID, - "sourceID": sourceID, - "destinationID": destinationID, - "writeKey": writeKey, - "caseSensitiveSourceID": caseSensitiveSourceID, - "caseSensitiveDestinationID": caseSensitiveDestinationID, - "caseSensitiveWriteKey": caseSensitiveWriteKey, - "rbacSourceID": rbacSourceID, - "rbacDestinationID": rbacDestinationID, - "rbacWriteKey": rbacWriteKey, - "sourcesSourceID": sourcesSourceID, - "sourcesDestinationID": sourcesDestinationID, - "sourcesWriteKey": sourcesWriteKey, - "account": credentials.Account, - "user": credentials.User, - "password": credentials.Password, - "role": credentials.Role, - "database": credentials.Database, - "caseSensitiveDatabase": strings.ToLower(credentials.Database), - "warehouse": credentials.Warehouse, - "bucketName": credentials.BucketName, - "accessKeyID": credentials.AccessKeyID, - "accessKey": credentials.AccessKey, - "namespace": namespace, - "sourcesNamespace": sourcesNamespace, - "caseSensitiveNamespace": caseSensitiveNamespace, - "rbacNamespace": rbacNamespace, - "rbacAccount": rbacCredentials.Account, - "rbacUser": rbacCredentials.User, - "rbacPassword": rbacCredentials.Password, - "rbacRole": rbacCredentials.Role, - "rbacDatabase": rbacCredentials.Database, - "rbacWarehouse": rbacCredentials.Warehouse, - "rbacBucketName": rbacCredentials.BucketName, - "rbacAccessKeyID": rbacCredentials.AccessKeyID, - "rbacAccessKey": rbacCredentials.AccessKey, - } - workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) - - bootstrap := func(t testing.TB, appendMode bool) func() { - loadTableStrategy := "MERGE" - if appendMode { - loadTableStrategy = "APPEND" + bootstrapSvc := func(t testing.TB, enableMerge bool) { + templateConfigurations := map[string]any{ + "workspaceID": workspaceID, + "sourceID": sourceID, + "destinationID": destinationID, + "writeKey": writeKey, + "caseSensitiveSourceID": caseSensitiveSourceID, + "caseSensitiveDestinationID": caseSensitiveDestinationID, + "caseSensitiveWriteKey": caseSensitiveWriteKey, + "rbacSourceID": rbacSourceID, + "rbacDestinationID": rbacDestinationID, + "rbacWriteKey": rbacWriteKey, + "sourcesSourceID": sourcesSourceID, + "sourcesDestinationID": sourcesDestinationID, + "sourcesWriteKey": sourcesWriteKey, + "account": credentials.Account, + "user": credentials.User, + "password": credentials.Password, + "role": credentials.Role, + "database": credentials.Database, + "caseSensitiveDatabase": strings.ToLower(credentials.Database), + "warehouse": credentials.Warehouse, + "bucketName": credentials.BucketName, + "accessKeyID": credentials.AccessKeyID, + "accessKey": credentials.AccessKey, + "namespace": namespace, + "sourcesNamespace": sourcesNamespace, + "caseSensitiveNamespace": caseSensitiveNamespace, + "rbacNamespace": rbacNamespace, + "rbacAccount": rbacCredentials.Account, + "rbacUser": rbacCredentials.User, + "rbacPassword": rbacCredentials.Password, + "rbacRole": rbacCredentials.Role, + "rbacDatabase": rbacCredentials.Database, + "rbacWarehouse": rbacCredentials.Warehouse, + "rbacBucketName": rbacCredentials.BucketName, + "rbacAccessKeyID": rbacCredentials.AccessKeyID, + "rbacAccessKey": rbacCredentials.AccessKey, + "enableMerge": enableMerge, } + workspaceConfigPath := workspaceConfig.CreateTempFile(t, "testdata/template.json", templateConfigurations) + testhelper.EnhanceWithDefaultEnvs(t) t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("RSERVER_WAREHOUSE_SNOWFLAKE_MAX_PARALLEL_LOADS", "8") t.Setenv("RSERVER_WAREHOUSE_SNOWFLAKE_ENABLE_DELETE_BY_JOBS", "true") - t.Setenv("RSERVER_WAREHOUSE_SNOWFLAKE_LOAD_TABLE_STRATEGY", loadTableStrategy) t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) t.Setenv("RSERVER_WAREHOUSE_SNOWFLAKE_SLOW_QUERY_THRESHOLD", "0s") @@ -196,20 +189,19 @@ func TestIntegration(t *testing.T) { )) ctx, cancel := context.WithCancel(context.Background()) - svcDone := make(chan struct{}) + go func() { r := runner.New(runner.ReleaseInfo{}) _ = r.Run(ctx, []string{"snowflake-integration-test"}) - close(svcDone) }() + t.Cleanup(func() { <-svcDone }) + t.Cleanup(cancel) serviceHealthEndpoint := fmt.Sprintf("http://localhost:%d/health", httpPort) health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, 100*time.Millisecond, "serviceHealthEndpoint") - - return cancel } t.Run("Event flow", func(t *testing.T) { @@ -235,7 +227,7 @@ func TestIntegration(t *testing.T) { asyncJob bool stagingFilePrefix string emptyJobRunID bool - appendMode bool + enableMerge bool customUserID string }{ { @@ -256,6 +248,7 @@ func TestIntegration(t *testing.T) { "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) }, stagingFilePrefix: "testdata/upload-job", + enableMerge: true, }, { name: "Upload Job with Role", @@ -275,6 +268,7 @@ func TestIntegration(t *testing.T) { "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) }, stagingFilePrefix: "testdata/upload-job-with-role", + enableMerge: true, }, { name: "Upload Job with Case Sensitive Database", @@ -294,6 +288,7 @@ func TestIntegration(t *testing.T) { "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) }, stagingFilePrefix: "testdata/upload-job-case-sensitive", + enableMerge: true, }, { name: "Async Job with Sources", @@ -315,6 +310,7 @@ func TestIntegration(t *testing.T) { warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), asyncJob: true, stagingFilePrefix: "testdata/sources-job", + enableMerge: true, }, { name: "Upload Job in append mode", @@ -335,7 +331,7 @@ func TestIntegration(t *testing.T) { // an empty jobRunID means that the source is not an ETL one // see Uploader.CanAppend() emptyJobRunID: true, - appendMode: true, + enableMerge: false, customUserID: testhelper.GetUserId("append_test"), }, } @@ -343,8 +339,7 @@ func TestIntegration(t *testing.T) { for _, tc := range testcase { tc := tc t.Run(tc.name, func(t *testing.T) { - cancel := bootstrap(t, tc.appendMode) - defer cancel() + bootstrapSvc(t, tc.enableMerge) urlConfig := sfdb.Config{ Account: tc.cred.Account, @@ -506,6 +501,7 @@ func TestIntegration(t *testing.T) { "syncFrequency": "30", "enableSSE": false, "useRudderStorage": false, + "enableMerge": true, }, DestinationDefinition: backendconfig.DestinationDefinitionT{ ID: "1XjvXnzw34UMAz1YOuKqL1kwzh6", @@ -628,7 +624,7 @@ func TestIntegration(t *testing.T) { loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, false, false) - sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -645,7 +641,7 @@ func TestIntegration(t *testing.T) { loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, false, false) - sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -664,17 +660,9 @@ func TestIntegration(t *testing.T) { uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} - mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, false, false) - - c := config.New() - c.Set("Warehouse.snowflake.debugDuplicateWorkspaceIDs", []string{workspaceID}) - c.Set("Warehouse.snowflake.debugDuplicateIntervalInDays", 1000) - c.Set("Warehouse.snowflake.debugDuplicateTables", []string{whutils.ToProviderCase( - whutils.SNOWFLAKE, - tableName, - )}) + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, true, false) - sf, err := snowflake.New(c, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -692,12 +680,14 @@ func TestIntegration(t *testing.T) { loadTableStat, err = sf.LoadTable(ctx, tableName) require.NoError(t, err) - require.Equal(t, loadTableStat.RowsInserted, int64(0)) - require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + require.Equal(t, loadTableStat.RowsInserted, int64(0), + "2nd copy on the same table with the same data should not have any 'rows_loaded'") + require.Equal(t, loadTableStat.RowsUpdated, int64(0), + "2nd copy on the same table with the same data should not have any 'rows_loaded'") records := testhelper.RetrieveRecordsFromWarehouse(t, sf.DB.DB, - fmt.Sprintf(` - SELECT + fmt.Sprintf( + `SELECT id, received_at, test_bool, @@ -705,16 +695,13 @@ func TestIntegration(t *testing.T) { test_float, test_int, test_string - FROM - %q.%q - ORDER BY - id; - `, + FROM %q.%q + ORDER BY id;`, namespace, tableName, ), ) - require.Equal(t, records, testhelper.SampleTestRecords()) + require.Equal(t, testhelper.SampleTestRecords(), records) }) t.Run("with dedup use new record", func(t *testing.T) { uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) @@ -722,7 +709,7 @@ func TestIntegration(t *testing.T) { loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, false, true) - sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -769,10 +756,7 @@ func TestIntegration(t *testing.T) { loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, true, false) - c := config.New() - c.Set("Warehouse.snowflake.loadTableStrategy", "APPEND") - - sf, err := snowflake.New(c, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -829,7 +813,7 @@ func TestIntegration(t *testing.T) { }} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, false, false) - sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -852,7 +836,7 @@ func TestIntegration(t *testing.T) { loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, false, false) - sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -897,7 +881,7 @@ func TestIntegration(t *testing.T) { loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, false, false) - sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -924,7 +908,7 @@ func TestIntegration(t *testing.T) { loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} mockUploader := newMockUploader(t, loadFiles, tableName, discardsSchema, discardsSchema, false, false) - sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) err = sf.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) @@ -962,58 +946,63 @@ func TestIntegration(t *testing.T) { }) } -func TestSnowflake_ShouldAppend(t *testing.T) { +func TestSnowflake_ShouldMerge(t *testing.T) { testCases := []struct { name string - loadTableStrategy string + enableMerge bool uploaderCanAppend bool uploaderExpectedCalls int expected bool }{ { - name: "uploader says we can append and we are in append mode", - loadTableStrategy: "APPEND", + name: "uploader says we can append and merge is not enabled", + enableMerge: false, uploaderCanAppend: true, uploaderExpectedCalls: 1, - expected: true, + expected: false, }, { - name: "uploader says we cannot append and we are in append mode", - loadTableStrategy: "APPEND", + name: "uploader says we cannot append and merge is not enabled", + enableMerge: false, uploaderCanAppend: false, uploaderExpectedCalls: 1, - expected: false, + expected: true, }, { - name: "uploader says we can append and we are in merge mode", - loadTableStrategy: "MERGE", + name: "uploader says we can append and merge is enabled", + enableMerge: true, uploaderCanAppend: true, - uploaderExpectedCalls: 0, - expected: false, + uploaderExpectedCalls: 1, + expected: true, }, { name: "uploader says we cannot append and we are in merge mode", - loadTableStrategy: "MERGE", + enableMerge: true, uploaderCanAppend: false, - uploaderExpectedCalls: 0, - expected: false, + uploaderExpectedCalls: 1, + expected: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := config.New() - c.Set("Warehouse.snowflake.loadTableStrategy", tc.loadTableStrategy) - - sf, err := snowflake.New(c, logger.NOP, stats.Default) + sf, err := snowflake.New(config.New(), logger.NOP, memstats.New()) require.NoError(t, err) + sf.Warehouse = model.Warehouse{ + Destination: backendconfig.DestinationT{ + Config: map[string]any{ + string(model.EnableMergeSetting): tc.enableMerge, + }, + }, + } + mockCtrl := gomock.NewController(t) uploader := mockuploader.NewMockUploader(mockCtrl) uploader.EXPECT().CanAppend().Times(tc.uploaderExpectedCalls).Return(tc.uploaderCanAppend) sf.Uploader = uploader - require.Equal(t, sf.ShouldAppend(), tc.expected) + require.Equal(t, sf.ShouldMerge(), tc.expected) }) } } diff --git a/warehouse/integrations/snowflake/testdata/template.json b/warehouse/integrations/snowflake/testdata/template.json index 2b7aab4b967..73d113e7647 100644 --- a/warehouse/integrations/snowflake/testdata/template.json +++ b/warehouse/integrations/snowflake/testdata/template.json @@ -39,7 +39,8 @@ "prefix": "snowflake-prefix", "syncFrequency": "30", "enableSSE": false, - "useRudderStorage": false + "useRudderStorage": false, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, @@ -163,7 +164,8 @@ "prefix": "snowflake-prefix", "syncFrequency": "30", "enableSSE": false, - "useRudderStorage": false + "useRudderStorage": false, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, @@ -288,7 +290,8 @@ "prefix": "snowflake-prefix", "syncFrequency": "30", "enableSSE": false, - "useRudderStorage": false + "useRudderStorage": false, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, @@ -438,7 +441,8 @@ "prefix": "snowflake-prefix", "syncFrequency": "30", "enableSSE": false, - "useRudderStorage": false + "useRudderStorage": false, + "enableMerge": {{.enableMerge}} }, "liveEventsConfig": {}, "secretConfig": {}, diff --git a/warehouse/tunnelling/connect.go b/warehouse/integrations/tunnelling/connect.go similarity index 54% rename from warehouse/tunnelling/connect.go rename to warehouse/integrations/tunnelling/connect.go index c91118e7ba5..b938904dee9 100644 --- a/warehouse/tunnelling/connect.go +++ b/warehouse/integrations/tunnelling/connect.go @@ -6,6 +6,8 @@ import ( "fmt" "strconv" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" + stunnel "github.com/rudderlabs/sql-tunnels/driver/ssh" ) @@ -22,30 +24,56 @@ const ( ) type ( - Type string - Config map[string]interface{} + Config map[string]interface{} + TunnelInfo struct { + Config Config + } ) -type TunnelInfo struct { - Config Config +// ExtractTunnelInfoFromDestinationConfig extracts TunnelInfo from destination config if tunnel is enabled for the destination. +func ExtractTunnelInfoFromDestinationConfig(config Config) *TunnelInfo { + if tunnelEnabled := whutils.ReadAsBool("useSSH", config); !tunnelEnabled { + return nil + } + + return &TunnelInfo{ + Config: config, + } } -func ReadSSHTunnelConfig(config Config) (conf *stunnel.Config, err error) { +// Connect establishes a database connection over an SSH tunnel. +func Connect(dsn string, config Config) (*sql.DB, error) { + tunnelConfig, err := extractTunnelConfig(config) + if err != nil { + return nil, fmt.Errorf("reading ssh tunnel config: %w", err) + } + + encodedDSN, err := tunnelConfig.EncodeWithDSN(dsn) + if err != nil { + return nil, fmt.Errorf("encoding with dsn: %w", err) + } + + db, err := sql.Open("sql+ssh", encodedDSN) + if err != nil { + return nil, fmt.Errorf("opening warehouse connection sql+ssh driver: %w", err) + } + return db, nil +} + +func extractTunnelConfig(config Config) (*stunnel.Config, error) { var user, host, port, privateKey *string + var err error - if user, err = ReadString(sshUser, config); err != nil { + if user, err = readString(sshUser, config); err != nil { return nil, err } - - if host, err = ReadString(sshHost, config); err != nil { + if host, err = readString(sshHost, config); err != nil { return nil, err } - - if port, err = ReadString(sshPort, config); err != nil { + if port, err = readString(sshPort, config); err != nil { return nil, err } - - if privateKey, err = ReadString(sshPrivateKey, config); err != nil { + if privateKey, err = readString(sshPrivateKey, config); err != nil { return nil, err } @@ -62,7 +90,7 @@ func ReadSSHTunnelConfig(config Config) (conf *stunnel.Config, err error) { }, nil } -func ReadString(key string, config Config) (*string, error) { +func readString(key string, config Config) (*string, error) { val, ok := config[key] if !ok { return nil, fmt.Errorf("%w: %s", ErrMissingKey, key) @@ -72,22 +100,5 @@ func ReadString(key string, config Config) (*string, error) { if !ok { return nil, fmt.Errorf("%w: %s expected string", ErrUnexpectedType, key) } - return &resp, nil } - -func SQLConnectThroughTunnel(dsn string, tunnelConfig Config) (*sql.DB, error) { - conf, err := ReadSSHTunnelConfig(tunnelConfig) - if err != nil { - return nil, fmt.Errorf("reading ssh tunnel config: %w", err) - } - encodedDSN, err := conf.EncodeWithDSN(dsn) - if err != nil { - return nil, fmt.Errorf("encoding with dsn: %w", err) - } - db, err := sql.Open("sql+ssh", encodedDSN) - if err != nil { - return nil, fmt.Errorf("opening warehouse connection sql+ssh driver: %w", err) - } - return db, nil -} diff --git a/warehouse/integrations/tunnelling/connect_test.go b/warehouse/integrations/tunnelling/connect_test.go new file mode 100644 index 00000000000..3104d6aa122 --- /dev/null +++ b/warehouse/integrations/tunnelling/connect_test.go @@ -0,0 +1,140 @@ +package tunnelling + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + + "github.com/rudderlabs/compose-test/compose" + "github.com/rudderlabs/compose-test/testcompose" + + "github.com/stretchr/testify/require" +) + +func TestConnect(t *testing.T) { + privateKey, err := os.ReadFile("testdata/test_key") + require.Nil(t, err) + + ctx := context.Background() + + c := testcompose.New(t, compose.FilePaths{"./testdata/docker-compose.yml"}) + c.Start(context.Background()) + + host := "0.0.0.0" + user := c.Env("openssh-server", "USER_NAME") + port := c.Port("openssh-server", 2222) + postgresPort := c.Port("postgres", 5432) + + testCases := []struct { + name string + dsn string + config Config + wantError error + }{ + { + name: "empty config", + dsn: "dsn", + config: Config{}, + wantError: ErrMissingKey, + }, + { + name: "invalid config", + dsn: "dsn", + config: Config{ + sshUser: "user", + sshHost: "host", + sshPort: 22, + sshPrivateKey: "privateKey", + }, + wantError: errors.New("invalid type"), + }, + { + name: "missing sshUser", + dsn: "dsn", + config: Config{ + sshHost: "host", + sshPort: "port", + sshPrivateKey: "privateKey", + }, + wantError: ErrMissingKey, + }, + { + name: "missing sshHost", + dsn: "dsn", + config: Config{ + sshUser: "user", + sshPort: "port", + sshPrivateKey: "privateKey", + }, + wantError: ErrMissingKey, + }, + { + name: "missing sshPort", + dsn: "dsn", + config: Config{ + sshUser: "user", + sshHost: "host", + sshPrivateKey: "privateKey", + }, + wantError: ErrMissingKey, + }, + { + name: "missing sshPrivateKey", + dsn: "dsn", + config: Config{ + sshUser: "user", + sshHost: "host", + sshPort: "port", + }, + wantError: ErrMissingKey, + }, + { + name: "invalid sshPort", + dsn: "dsn", + config: Config{ + sshUser: "user", + sshHost: "host", + sshPort: "port", + sshPrivateKey: "privateKey", + }, + wantError: errors.New("invalid port"), + }, + { + name: "invalid dsn", + dsn: "postgres://user:password@host:5439/db?query1=val1&query2=val2", + config: Config{ + sshUser: "user", + sshHost: "0.0.0.0", + sshPort: "22", + sshPrivateKey: "privateKey", + }, + wantError: errors.New("invalid dsn"), + }, + { + name: "valid dsn", + dsn: fmt.Sprintf("postgres://postgres:postgres@db_postgres:%d/postgres?sslmode=disable", postgresPort), + config: Config{ + sshUser: user, + sshHost: host, + sshPort: port, + sshPrivateKey: privateKey, + }, + wantError: errors.New("invalid dsn"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + db, err := Connect(tc.dsn, tc.config) + t.Log(err) + if tc.wantError != nil { + require.Error(t, err, tc.wantError) + return + } + require.NoError(t, err) + require.NoError(t, db.PingContext(ctx)) + }) + } +} diff --git a/warehouse/integrations/tunnelling/testdata/docker-compose.yml b/warehouse/integrations/tunnelling/testdata/docker-compose.yml new file mode 100644 index 00000000000..5da14098f6c --- /dev/null +++ b/warehouse/integrations/tunnelling/testdata/docker-compose.yml @@ -0,0 +1,37 @@ +version: "3.9" +services: + openssh-server: + image: lscr.io/linuxserver/openssh-server:latest + environment: + - PUBLIC_KEY_FILE=/test_key.pub + - SUDO_ACCESS=false + - PASSWORD_ACCESS=false + - USER_PASSWORD=password + - USER_NAME=linuxserver.io + - DOCKER_MODS=linuxserver/mods:openssh-server-ssh-tunnel + ports: + - 2222 + volumes: + - type: bind + source: ./test_key.pub + target: /test_key.pub + read_only: true + healthcheck: + test: [ "CMD", "grep", "Server listening on :: port 2222", "/config/logs/openssh/current" ] + interval: 1s + timeout: 1s + retries: 60 + depends_on: + - postgres + + postgres: + image: postgres:15-alpine + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + ports: + - "5432" + healthcheck: + test: [ "CMD-SHELL", "pg_isready" ] + interval: 1s + retries: 25 diff --git a/warehouse/integrations/tunnelling/testdata/test_key b/warehouse/integrations/tunnelling/testdata/test_key new file mode 100644 index 00000000000..c095f906c56 --- /dev/null +++ b/warehouse/integrations/tunnelling/testdata/test_key @@ -0,0 +1,39 @@ +-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn +NhAAAAAwEAAQAAAYEA0f/mqkkZ3c9qw8MTz5FoEO3PGecO/dtUFfJ4g1UBu9E7hi/pyVYY +fLfdsd5bqA2pXdU0ROymyVe683I1VzJcihUtwB1eQxP1mUhmoo0ixK0IUUGm4PRieCGv+r +0/gMvaYbVGUPCi5tAUVh02vZB7p2cTIaz872lvCnRhYbhGUHSbhNSSQOjnCtZfjuZZnE0l +PKjWV/wbJ7Pvoc/FZMlWOqL1AjAKuwFH5zs1RMrPDDv5PCZksq4a7DDxziEdq39jvA3sOm +pQXvzBBBLBOzu7rM3/MPJb6dvAGJcYxkptfL4YXTscIMINr0g24cn+Thvt9yqA93rkb9RB +kw6RIEwMlQKqserA+pfsaoW0SkvnlDKzS1DLwXioL4Uc1Jpr/9jTMEfR+W7v7gJPB1JDnV +gen5FBfiMqbsG1amUS+mjgNfC8I00tR+CUHxpqUWANtcWTinhSnLJ2skj/2QnciPHkHurR +EKyEwCVecgn+xVKyRgVDCGsJ+QnAdn51+i/kO3nvAAAFqENNbN9DTWzfAAAAB3NzaC1yc2 +EAAAGBANH/5qpJGd3PasPDE8+RaBDtzxnnDv3bVBXyeINVAbvRO4Yv6clWGHy33bHeW6gN +qV3VNETspslXuvNyNVcyXIoVLcAdXkMT9ZlIZqKNIsStCFFBpuD0Ynghr/q9P4DL2mG1Rl +DwoubQFFYdNr2Qe6dnEyGs/O9pbwp0YWG4RlB0m4TUkkDo5wrWX47mWZxNJTyo1lf8Gyez +76HPxWTJVjqi9QIwCrsBR+c7NUTKzww7+TwmZLKuGuww8c4hHat/Y7wN7DpqUF78wQQSwT +s7u6zN/zDyW+nbwBiXGMZKbXy+GF07HCDCDa9INuHJ/k4b7fcqgPd65G/UQZMOkSBMDJUC +qrHqwPqX7GqFtEpL55Qys0tQy8F4qC+FHNSaa//Y0zBH0flu7+4CTwdSQ51YHp+RQX4jKm +7BtWplEvpo4DXwvCNNLUfglB8aalFgDbXFk4p4UpyydrJI/9kJ3Ijx5B7q0RCshMAlXnIJ +/sVSskYFQwhrCfkJwHZ+dfov5Dt57wAAAAMBAAEAAAGAd9pxr+ag2LO0353LBMCcgGz5sn +LpX4F6cDw/A9XUc3lrW56k88AroaLe6NFbxoJlk6RHfL8EQg3MKX2Za/bWUgjcX7VjQy11 +EtL7oPKkUVPgV1/8+o8AVEgFxDmWsM+oB/QJ+dAdaVaBBNUPlQmNSXHOvX2ZrpqiQXlCyx +79IpYq3JjmEB3dH5ZSW6CkrExrYD+MdhLw/Kv5rISEyI0Qpc6zv1fkB+8nNpXYRTbrDLR9 +/xJ6jnBH9V3J5DeKU4MUQ39nrAp6iviyWydB973+MOygpy41fXO6hHyVZ2aSCysn1t6J/K +QdeEjqAOI/5CbdtiFGp06et799EFyzPItW0FKetW1UTOL2YHqdb+Q9sNjiNlUSzgxMbJWJ +RGO6g9B1mJsHl5mJZUiHQPsG/wgBER8VOP4bLOEB6gzVO2GE9HTJTOh5C+eEfrl52wPfXj +TqjtWAnhssxtgmWjkS0ibi+u1KMVXKHfaiqJ7nH0jMx+eu1RpMvuR8JqkU8qdMMGChAAAA +wHkQMfpCnjNAo6sllEB5FwjEdTBBOt7gu6nLQ2O3uGv0KNEEZ/BWJLQ5fKOfBtDHO+kl+5 +Qoxc0cE7cg64CyBF3+VjzrEzuX5Tuh4NwrsjT4vTTHhCIbIynxEPmKzvIyCMuglqd/nhu9 +6CXhghuTg8NrC7lY+cImiBfhxE32zqNITlpHW7exr95Gz1sML2TRJqxDN93oUFfrEuInx8 +HpXXnvMQxPRhcp9nDMU9/ahUamMabQqVVMwKDi8n3sPPzTiAAAAMEA+/hm3X/yNotAtMAH +y11parKQwPgEF4HYkSE0bEe+2MPJmEk4M4PGmmt/MQC5N5dXdUGxiQeVMR+Sw0kN9qZjM6 +SIz0YHQFMsxVmUMKFpAh4UI0GlsW49jSpVXs34Fg95AfhZOYZmOcGcYosp0huCeRlpLeIH +7Vv2bkfQaic3uNaVPg7+cXg7zdY6tZlzwa/4Fj0udfTjGQJOPSzIihdMLHnV81rZ2cUOZq +MSk6b02aMpVB4TV0l1w4j2mlF2eGD9AAAAwQDVW6p2VXKuPR7SgGGQgHXpAQCFZPGLYd8K +duRaCbxKJXzUnZBn53OX5fuLlFhmRmAMXE6ztHPN1/5JjwILn+O49qel1uUvzU8TaWioq7 +Are3SJR2ZucR4AKUvzUHGP3GWW96xPN8lq+rgb0th1eOSU2aVkaIdeTJhV1iPfaUUf+15S +YcJlSHLGgeqkok+VfuudZ73f3RFFhjoe1oAjlPB4leeMsBD9UBLx2U3xAevnfkecF4Lm83 +4sVswWATSFAFsAAAAsYWJoaW1hbnl1YmFiYmFyQEFiaGltYW55dXMtTWFjQm9vay1Qcm8u +bG9jYWwBAgMEBQYH +-----END OPENSSH PRIVATE KEY----- diff --git a/warehouse/integrations/tunnelling/testdata/test_key.pub b/warehouse/integrations/tunnelling/testdata/test_key.pub new file mode 100644 index 00000000000..63413c96a00 --- /dev/null +++ b/warehouse/integrations/tunnelling/testdata/test_key.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDR/+aqSRndz2rDwxPPkWgQ7c8Z5w7921QV8niDVQG70TuGL+nJVhh8t92x3luoDald1TRE7KbJV7rzcjVXMlyKFS3AHV5DE/WZSGaijSLErQhRQabg9GJ4Ia/6vT+Ay9phtUZQ8KLm0BRWHTa9kHunZxMhrPzvaW8KdGFhuEZQdJuE1JJA6OcK1l+O5lmcTSU8qNZX/Bsns++hz8VkyVY6ovUCMAq7AUfnOzVEys8MO/k8JmSyrhrsMPHOIR2rf2O8Dew6alBe/MEEEsE7O7uszf8w8lvp28AYlxjGSm18vhhdOxwgwg2vSDbhyf5OG+33KoD3euRv1EGTDpEgTAyVAqqx6sD6l+xqhbRKS+eUMrNLUMvBeKgvhRzUmmv/2NMwR9H5bu/uAk8HUkOdWB6fkUF+IypuwbVqZRL6aOA18LwjTS1H4JQfGmpRYA21xZOKeFKcsnaySP/ZCdyI8eQe6tEQrITAJV5yCf7FUrJGBUMIawn5CcB2fnX6L+Q7ee8= abc@abc-MacBook-Pro.local diff --git a/warehouse/internal/api/http_test.go b/warehouse/internal/api/http_test.go index 352347baa78..fe035a766f7 100644 --- a/warehouse/internal/api/http_test.go +++ b/warehouse/internal/api/http_test.go @@ -13,13 +13,14 @@ import ( "testing" "time" - "github.com/rudderlabs/rudder-go-kit/config" - backendConfig "github.com/rudderlabs/rudder-server/backend-config" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/stretchr/testify/require" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/warehouse/internal/api" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/multitenant" @@ -172,12 +173,12 @@ func TestAPI_Process(t *testing.T) { c := config.New() c.Set("Warehouse.degradedWorkspaceIDs", tc.degradedWorkspaceIDs) - m := multitenant.New(c, backendConfig.DefaultBackendConfig) + m := multitenant.New(c, backendconfig.DefaultBackendConfig) wAPI := api.WarehouseAPI{ Repo: r, Logger: logger.NOP, - Stats: stats.Default, // TODO: use a NOP stats + Stats: memstats.New(), // TODO: use a NOP stats Multitenant: m, } diff --git a/warehouse/internal/loadfiles/loadfiles.go b/warehouse/internal/loadfiles/loadfiles.go index ed88d4defcf..1fd303289c5 100644 --- a/warehouse/internal/loadfiles/loadfiles.go +++ b/warehouse/internal/loadfiles/loadfiles.go @@ -3,6 +3,7 @@ package loadfiles import ( "context" "fmt" + "slices" "strings" "time" @@ -12,7 +13,6 @@ import ( "github.com/samber/lo" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" jsoniter "github.com/json-iterator/go" diff --git a/warehouse/internal/loadfiles/mock_loadfile_repo_test.go b/warehouse/internal/loadfiles/mock_loadfile_repo_test.go index 199dd0a0e32..c50c9412119 100644 --- a/warehouse/internal/loadfiles/mock_loadfile_repo_test.go +++ b/warehouse/internal/loadfiles/mock_loadfile_repo_test.go @@ -2,8 +2,7 @@ package loadfiles_test import ( "context" - - "golang.org/x/exp/slices" + "slices" "github.com/rudderlabs/rudder-server/warehouse/internal/model" ) diff --git a/warehouse/internal/model/warehouse.go b/warehouse/internal/model/warehouse.go index 4608ffb241f..f398960e5e7 100644 --- a/warehouse/internal/model/warehouse.go +++ b/warehouse/internal/model/warehouse.go @@ -2,6 +2,17 @@ package model import backendconfig "github.com/rudderlabs/rudder-server/backend-config" +type DestinationConfigSetting interface{ string() string } + +type destConfSetting string + +func (s destConfSetting) string() string { return string(s) } + +const ( + EnableMergeSetting destConfSetting = "enableMerge" + UseRudderStorageSetting destConfSetting = "useRudderStorage" +) + type Warehouse struct { WorkspaceID string Source backendconfig.SourceT @@ -11,10 +22,10 @@ type Warehouse struct { Identifier string } -func (w *Warehouse) GetBoolDestinationConfig(key string) bool { +func (w *Warehouse) GetBoolDestinationConfig(key DestinationConfigSetting) bool { destConfig := w.Destination.Config - if destConfig[key] != nil { - if val, ok := destConfig[key].(bool); ok { + if destConfig[key.string()] != nil { + if val, ok := destConfig[key.string()].(bool); ok { return val } } diff --git a/warehouse/internal/repo/schema_test.go b/warehouse/internal/repo/schema_test.go index b893fa6b41c..c16ad8daaf5 100644 --- a/warehouse/internal/repo/schema_test.go +++ b/warehouse/internal/repo/schema_test.go @@ -3,11 +3,10 @@ package repo_test import ( "context" "errors" + "slices" "testing" "time" - "golang.org/x/exp/slices" - warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/stretchr/testify/require" diff --git a/warehouse/internal/service/recovery.go b/warehouse/internal/service/recovery.go index 502be55fa61..8d55525532f 100644 --- a/warehouse/internal/service/recovery.go +++ b/warehouse/internal/service/recovery.go @@ -3,12 +3,11 @@ package service import ( "context" "fmt" + "slices" "sync" "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "golang.org/x/exp/slices" - warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) diff --git a/warehouse/jobs/http_test.go b/warehouse/jobs/http_test.go index 47b8c7d1f2b..9f5c3e3750a 100644 --- a/warehouse/jobs/http_test.go +++ b/warehouse/jobs/http_test.go @@ -13,8 +13,9 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/services/notifier" "github.com/ory/dockertest/v3" @@ -61,7 +62,7 @@ func TestAsyncJobHandlers(t *testing.T) { ctx := context.Background() - n := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + n := notifier.New(config.New(), logger.NOP, memstats.New(), workspaceIdentifier) err = n.Setup(ctx, pgResource.DBDsn) require.NoError(t, err) diff --git a/warehouse/logfield/logfield.go b/warehouse/logfield/logfield.go index 5c084dc4f0d..a50cc33b448 100644 --- a/warehouse/logfield/logfield.go +++ b/warehouse/logfield/logfield.go @@ -36,4 +36,5 @@ const ( IntervalInHours = "intervalInHours" StartTime = "startTime" EndTime = "endTime" + ShouldMerge = "shouldMerge" ) diff --git a/warehouse/multitenant/manager_test.go b/warehouse/multitenant/manager_test.go index 6dfc6697a39..85851095ba4 100644 --- a/warehouse/multitenant/manager_test.go +++ b/warehouse/multitenant/manager_test.go @@ -150,7 +150,7 @@ func TestSourceToWorkspace(t *testing.T) { backendConfig[workspace] = entry } - m := multitenant.New(config.Default, &mockBackendConfig{ + m := multitenant.New(config.New(), &mockBackendConfig{ config: backendConfig, }) @@ -177,7 +177,7 @@ func TestSourceToWorkspace(t *testing.T) { require.NoError(t, g.Wait()) t.Run("context canceled", func(t *testing.T) { - m := multitenant.New(config.Default, &mockBackendConfig{ + m := multitenant.New(config.New(), &mockBackendConfig{ config: backendConfig, }) ctx, cancel := context.WithCancel(context.Background()) diff --git a/warehouse/router/errors_test.go b/warehouse/router/errors_test.go index 1d7fe33e6bc..099345c77a8 100644 --- a/warehouse/router/errors_test.go +++ b/warehouse/router/errors_test.go @@ -4,14 +4,14 @@ import ( "errors" "testing" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/rudderlabs/rudder-server/warehouse/router" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/stretchr/testify/require" @@ -179,7 +179,7 @@ func TestErrorHandler_MatchUploadJobErrorType(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - m, err := manager.New(tc.destType, config.Default, logger.NOP, stats.Default) + m, err := manager.New(tc.destType, config.New(), logger.NOP, memstats.New()) require.NoError(t, err) er := &router.ErrorHandler{Mapper: m} @@ -191,7 +191,7 @@ func TestErrorHandler_MatchUploadJobErrorType(t *testing.T) { }) t.Run("UnKnown errors", func(t *testing.T) { - m, err := manager.New(warehouseutils.RS, config.Default, logger.NOP, stats.Default) + m, err := manager.New(warehouseutils.RS, config.New(), logger.NOP, memstats.New()) require.NoError(t, err) er := &router.ErrorHandler{Mapper: m} @@ -206,7 +206,7 @@ func TestErrorHandler_MatchUploadJobErrorType(t *testing.T) { }) t.Run("Nil error: ", func(t *testing.T) { - m, err := manager.New(warehouseutils.RS, config.Default, logger.NOP, stats.Default) + m, err := manager.New(warehouseutils.RS, config.New(), logger.NOP, memstats.New()) require.NoError(t, err) er := &router.ErrorHandler{Mapper: m} diff --git a/warehouse/router/router.go b/warehouse/router/router.go index 0f258e3eaf9..1e209538c48 100644 --- a/warehouse/router/router.go +++ b/warehouse/router/router.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/rand" + "slices" "strconv" "sync" "sync/atomic" @@ -20,8 +21,6 @@ import ( "github.com/rudderlabs/rudder-server/warehouse/integrations/manager" - "golang.org/x/exp/slices" - "github.com/rudderlabs/rudder-server/services/controlplane" "github.com/rudderlabs/rudder-server/warehouse/multitenant" @@ -58,7 +57,9 @@ type Router struct { stagingRepo *repo.StagingFiles uploadRepo *repo.Uploads whSchemaRepo *repo.WHSchema - triggerStore *sync.Map + + triggerStore *sync.Map + createUploadAlways createUploadAlwaysLoader isEnabled atomic.Bool @@ -79,6 +80,9 @@ type Router struct { inProgressMap map[workerIdentifierMapKey][]jobID inProgressMapLock sync.RWMutex + scheduledTimesCache map[string][]int + scheduledTimesCacheLock sync.RWMutex + activeWorkerCount atomic.Int32 now func() time.Time nowSQL string @@ -134,6 +138,7 @@ func New( bcManager *bcm.BackendConfigManager, encodingFactory *encoding.Factory, triggerStore *sync.Map, + createUploadAlways createUploadAlwaysLoader, ) (*Router, error) { r := &Router{} @@ -155,6 +160,8 @@ func New( r.now = time.Now r.triggerStore = triggerStore r.createJobMarkerMap = make(map[string]time.Time) + r.createUploadAlways = createUploadAlways + r.scheduledTimesCache = make(map[string][]int) if err := r.uploadRepo.ResetInProgress(ctx, r.destType); err != nil { return nil, err @@ -473,7 +480,7 @@ func (r *Router) uploadsToProcess(ctx context.Context, availableWorkers int, ski }) r.configSubscriberLock.RUnlock() - upload.UseRudderStorage = warehouse.GetBoolDestinationConfig("useRudderStorage") + upload.UseRudderStorage = warehouse.GetBoolDestinationConfig(model.UseRudderStorageSetting) if !found { uploadJob := r.uploadJobFactory.NewUploadJob(ctx, &model.UploadJob{ diff --git a/warehouse/router/router_test.go b/warehouse/router/router_test.go index 4332779e75c..570e202bf9d 100644 --- a/warehouse/router/router_test.go +++ b/warehouse/router/router_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "os" "sync" + "sync/atomic" "testing" "time" @@ -91,14 +92,15 @@ func TestRouter(t *testing.T) { ctx := context.Background() - n := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + n := notifier.New(config.New(), logger.NOP, memstats.New(), workspaceIdentifier) err = n.Setup(ctx, pgResource.DBDsn) require.NoError(t, err) ctrl := gomock.NewController(t) + createUploadAlways := &atomic.Bool{} triggerStore := &sync.Map{} - tenantManager := multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) + tenantManager := multitenant.New(config.New(), mocksBackendConfig.NewMockBackendConfig(ctrl)) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -108,20 +110,20 @@ func TestRouter(t *testing.T) { cp := controlplane.NewClient(s.URL, &identity.Namespace{}, controlplane.WithHTTPClient(s.Client()), ) - backendConfigManager := bcm.New(config.Default, db, tenantManager, logger.NOP, stats.Default) + backendConfigManager := bcm.New(config.New(), db, tenantManager, logger.NOP, memstats.New()) ctx, cancel := context.WithCancel(ctx) defer cancel() - ef := encoding.NewFactory(config.Default) + ef := encoding.NewFactory(config.New()) r, err := New( ctx, &reporting.NOOP{}, destinationType, - config.Default, + config.New(), logger.NOP, - stats.Default, + memstats.New(), db, n, tenantManager, @@ -129,6 +131,7 @@ func TestRouter(t *testing.T) { backendConfigManager, ef, triggerStore, + createUploadAlways, ) require.NoError(t, err) @@ -186,7 +189,7 @@ func TestRouter(t *testing.T) { r.uploadRepo = repoUpload r.stagingRepo = repoStaging r.statsFactory = memstats.New() - r.conf = config.Default + r.conf = config.New() r.config.uploadFreqInS = misc.SingleValueLoader(int64(1800)) r.config.stagingFilesBatchSize = misc.SingleValueLoader(100) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) @@ -196,6 +199,8 @@ func TestRouter(t *testing.T) { r.triggerStore = &sync.Map{} r.inProgressMap = make(map[workerIdentifierMapKey][]jobID) r.createJobMarkerMap = make(map[string]time.Time) + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) t.Run("no staging files", func(t *testing.T) { err = r.createJobs(ctx, warehouse) @@ -342,7 +347,7 @@ func TestRouter(t *testing.T) { r.uploadRepo = repoUpload r.stagingRepo = repoStaging r.statsFactory = memstats.New() - r.conf = config.Default + r.conf = config.New() r.config.stagingFilesBatchSize = misc.SingleValueLoader(100) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) r.config.enableJitterForSyncs = misc.SingleValueLoader(true) @@ -350,6 +355,8 @@ func TestRouter(t *testing.T) { r.inProgressMap = make(map[workerIdentifierMapKey][]jobID) r.triggerStore = &sync.Map{} r.logger = logger.NOP + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) priority := 50 @@ -475,7 +482,7 @@ func TestRouter(t *testing.T) { r.statsFactory = statsStore r.uploadRepo = repoUpload r.stagingRepo = repoStaging - r.conf = config.Default + r.conf = config.New() r.config.uploadFreqInS = misc.SingleValueLoader(int64(1800)) r.config.stagingFilesBatchSize = misc.SingleValueLoader(100) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) @@ -497,6 +504,8 @@ func TestRouter(t *testing.T) { }) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) r.Enable() stagingFiles := createStagingFiles(t, ctx, repoStaging, workspaceID, sourceID, destinationID) @@ -581,17 +590,17 @@ func TestRouter(t *testing.T) { r.uploadRepo = repoUpload r.stagingRepo = repoStaging r.statsFactory = memstats.New() - r.conf = config.Default + r.conf = config.New() r.config.allowMultipleSourcesForJobsPickup = true r.config.stagingFilesBatchSize = misc.SingleValueLoader(100) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) r.destType = destinationType r.logger = logger.NOP - r.tenantManager = multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) + r.tenantManager = multitenant.New(config.New(), mocksBackendConfig.NewMockBackendConfig(ctrl)) r.warehouses = []model.Warehouse{warehouse} r.uploadJobFactory = UploadJobFactory{ reporting: &reporting.NOOP{}, - conf: config.Default, + conf: config.New(), logger: logger.NOP, statsFactory: r.statsFactory, db: r.db, @@ -610,6 +619,8 @@ func TestRouter(t *testing.T) { }) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) t.Run("no uploads", func(t *testing.T) { ujs, err := r.uploadsToProcess(ctx, 1, []string{}) @@ -714,7 +725,7 @@ func TestRouter(t *testing.T) { r.uploadRepo = repoUpload r.stagingRepo = repoStaging r.statsFactory = memstats.New() - r.conf = config.Default + r.conf = config.New() r.config.allowMultipleSourcesForJobsPickup = true r.config.stagingFilesBatchSize = misc.SingleValueLoader(100) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) @@ -723,12 +734,12 @@ func TestRouter(t *testing.T) { r.config.uploadAllocatorSleep = time.Millisecond * 100 r.destType = warehouseutils.RS r.logger = logger.NOP - r.tenantManager = multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) - r.bcManager = bcm.New(r.conf, r.db, r.tenantManager, r.logger, stats.Default) + r.tenantManager = multitenant.New(config.New(), mocksBackendConfig.NewMockBackendConfig(ctrl)) + r.bcManager = bcm.New(r.conf, r.db, r.tenantManager, r.logger, memstats.New()) r.warehouses = []model.Warehouse{warehouse} r.uploadJobFactory = UploadJobFactory{ reporting: &reporting.NOOP{}, - conf: config.Default, + conf: config.New(), logger: logger.NOP, statsFactory: r.statsFactory, db: r.db, @@ -751,6 +762,8 @@ func TestRouter(t *testing.T) { }) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) close(r.bcManager.InitialConfigFetched) @@ -863,7 +876,7 @@ func TestRouter(t *testing.T) { r.uploadRepo = repoUpload r.stagingRepo = repoStaging r.statsFactory = memstats.New() - r.conf = config.Default + r.conf = config.New() r.config.allowMultipleSourcesForJobsPickup = true r.config.stagingFilesBatchSize = misc.SingleValueLoader(100) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) @@ -872,12 +885,12 @@ func TestRouter(t *testing.T) { r.config.uploadAllocatorSleep = time.Millisecond * 100 r.destType = warehouseutils.RS r.logger = logger.NOP - r.tenantManager = multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) - r.bcManager = bcm.New(r.conf, r.db, r.tenantManager, r.logger, stats.Default) + r.tenantManager = multitenant.New(config.New(), mocksBackendConfig.NewMockBackendConfig(ctrl)) + r.bcManager = bcm.New(r.conf, r.db, r.tenantManager, r.logger, memstats.New()) r.warehouses = []model.Warehouse{warehouse} r.uploadJobFactory = UploadJobFactory{ reporting: &reporting.NOOP{}, - conf: config.Default, + conf: config.New(), logger: logger.NOP, statsFactory: r.statsFactory, db: r.db, @@ -900,6 +913,8 @@ func TestRouter(t *testing.T) { }) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) close(r.bcManager.InitialConfigFetched) @@ -952,7 +967,7 @@ func TestRouter(t *testing.T) { r.statsFactory = statsStore r.uploadRepo = repoUpload r.stagingRepo = repoStaging - r.conf = config.Default + r.conf = config.New() r.config.stagingFilesBatchSize = misc.SingleValueLoader(100) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) @@ -972,6 +987,8 @@ func TestRouter(t *testing.T) { }) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) closeCh := make(chan struct{}) @@ -1178,14 +1195,16 @@ func TestRouter(t *testing.T) { r.uploadRepo = repoUpload r.stagingRepo = repoStaging r.statsFactory = memstats.New() - r.conf = config.Default + r.conf = config.New() r.logger = logger.NOP r.destType = warehouseutils.RS r.config.maxConcurrentUploadJobs = 1 - r.tenantManager = multitenant.New(config.Default, mockBackendConfig) - r.bcManager = bcm.New(r.conf, r.db, r.tenantManager, r.logger, stats.Default) + r.tenantManager = multitenant.New(config.New(), mockBackendConfig) + r.bcManager = bcm.New(r.conf, r.db, r.tenantManager, r.logger, memstats.New()) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) go func() { r.bcManager.Start(ctx) diff --git a/warehouse/router/scheduling.go b/warehouse/router/scheduling.go index 2eaec845496..761306d2be9 100644 --- a/warehouse/router/scheduling.go +++ b/warehouse/router/scheduling.go @@ -4,33 +4,23 @@ import ( "context" "fmt" "strconv" - "sync" - "sync/atomic" "time" "github.com/samber/lo" "github.com/rudderlabs/rudder-server/utils/timeutil" "github.com/rudderlabs/rudder-server/warehouse/internal/model" - warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) -// TODO: Move this to router struct instead of exposing it as globals. -var ( - scheduledTimesCache map[string][]int - scheduledTimesCacheLock sync.RWMutex - - StartUploadAlways atomic.Bool -) - -func init() { - scheduledTimesCache = map[string][]int{} +type createUploadAlwaysLoader interface { + Load() bool } // canCreateUpload indicates if an upload can be started now for the warehouse based on its configured schedule func (r *Router) canCreateUpload(ctx context.Context, warehouse model.Warehouse) (bool, error) { // can be set from rudder-cli to force uploads always - if StartUploadAlways.Load() { + if r.createUploadAlways.Load() { return true, nil } @@ -47,15 +37,15 @@ func (r *Router) canCreateUpload(ctx context.Context, warehouse model.Warehouse) } // gets exclude window start time and end time - excludeWindow := warehouseutils.GetConfigValueAsMap(warehouseutils.ExcludeWindow, warehouse.Destination.Config) + excludeWindow := whutils.GetConfigValueAsMap(whutils.ExcludeWindow, warehouse.Destination.Config) excludeWindowStartTime, excludeWindowEndTime := excludeWindowStartEndTimes(excludeWindow) if checkCurrentTimeExistsInExcludeWindow(r.now().UTC(), excludeWindowStartTime, excludeWindowEndTime) { return false, fmt.Errorf("exclude window: current time exists in exclude window") } - syncFrequency := warehouseutils.GetConfigValue(warehouseutils.SyncFrequency, warehouse) - syncStartAt := warehouseutils.GetConfigValue(warehouseutils.SyncStartAt, warehouse) + syncFrequency := whutils.GetConfigValue(whutils.SyncFrequency, warehouse) + syncStartAt := whutils.GetConfigValue(whutils.SyncStartAt, warehouse) if syncFrequency == "" || syncStartAt == "" { if r.uploadFrequencyExceeded(warehouse, syncFrequency) { return true, nil @@ -63,7 +53,7 @@ func (r *Router) canCreateUpload(ctx context.Context, warehouse model.Warehouse) return false, fmt.Errorf("upload frequency exceeded") } - prevScheduledTime := prevScheduledTime(syncFrequency, syncStartAt, r.now()) + prevScheduledTime := r.prevScheduledTime(syncFrequency, syncStartAt, r.now()) lastUploadCreatedAt, err := r.uploadRepo.LastCreatedAt(ctx, warehouse.Source.ID, warehouse.Destination.ID) if err != nil { return false, err @@ -80,14 +70,12 @@ func (r *Router) canCreateUpload(ctx context.Context, warehouse model.Warehouse) func excludeWindowStartEndTimes(excludeWindow map[string]interface{}) (string, string) { var startTime, endTime string - if st, ok := excludeWindow[warehouseutils.ExcludeWindowStartTime].(string); ok { + if st, ok := excludeWindow[whutils.ExcludeWindowStartTime].(string); ok { startTime = st } - - if et, ok := excludeWindow[warehouseutils.ExcludeWindowEndTime].(string); ok { + if et, ok := excludeWindow[whutils.ExcludeWindowEndTime].(string); ok { endTime = et } - return startTime, endTime } @@ -96,34 +84,31 @@ func checkCurrentTimeExistsInExcludeWindow(currentTime time.Time, windowStartTim return false } - startTimeMins := timeutil.MinsOfDay(windowStartTime) - endTimeMins := timeutil.MinsOfDay(windowEndTime) + startTimeInMin := timeutil.MinsOfDay(windowStartTime) + endTimeInMin := timeutil.MinsOfDay(windowEndTime) currentTimeMins := timeutil.GetElapsedMinsInThisDay(currentTime) // startTime, currentTime, endTime: 05:09, 06:19, 09:07 - > window between this day 05:09 and 09:07 - if startTimeMins < currentTimeMins && currentTimeMins < endTimeMins { + if startTimeInMin < currentTimeMins && currentTimeMins < endTimeInMin { return true } - // startTime, currentTime, endTime: 22:09, 06:19, 09:07 -> window between this day 22:09 and tomorrow 09:07 - if startTimeMins > currentTimeMins && currentTimeMins < endTimeMins && startTimeMins > endTimeMins { + if startTimeInMin > currentTimeMins && currentTimeMins < endTimeInMin && startTimeInMin > endTimeInMin { return true } - // startTime, currentTime, endTime: 22:09, 23:19, 09:07 -> window between this day 22:09 and tomorrow 09:07 - if startTimeMins < currentTimeMins && currentTimeMins > endTimeMins && startTimeMins > endTimeMins { + if startTimeInMin < currentTimeMins && currentTimeMins > endTimeInMin && startTimeInMin > endTimeInMin { return true } - return false } // prevScheduledTime returns the closest previous scheduled time // e.g. Syncing every 3hrs starting at 13:00 (scheduled times: 13:00, 16:00, 19:00, 22:00, 01:00, 04:00, 07:00, 10:00) // prev scheduled time for current time (e.g. 18:00 -> 16:00 same day, 00:30 -> 22:00 prev day) -func prevScheduledTime(syncFrequency, syncStartAt string, currTime time.Time) time.Time { - allStartTimes := scheduledTimes(syncFrequency, syncStartAt) +func (r *Router) prevScheduledTime(syncFrequency, syncStartAt string, currTime time.Time) time.Time { + allStartTimes := r.scheduledTimes(syncFrequency, syncStartAt) loc, _ := time.LoadLocation("UTC") now := currTime.In(loc) @@ -154,10 +139,10 @@ func prevScheduledTime(syncFrequency, syncStartAt string, currTime time.Time) ti // scheduledTimes returns all possible start times (minutes from start of day) as per schedule // e.g. Syncing every 3hrs starting at 13:00 (scheduled times: 13:00, 16:00, 19:00, 22:00, 01:00, 04:00, 07:00, 10:00) -func scheduledTimes(syncFrequency, syncStartAt string) []int { - scheduledTimesCacheLock.RLock() - cachedTimes, ok := scheduledTimesCache[fmt.Sprintf(`%s-%s`, syncFrequency, syncStartAt)] - scheduledTimesCacheLock.RUnlock() +func (r *Router) scheduledTimes(syncFrequency, syncStartAt string) []int { + r.scheduledTimesCacheLock.RLock() + cachedTimes, ok := r.scheduledTimesCache[fmt.Sprintf(`%s-%s`, syncFrequency, syncStartAt)] + r.scheduledTimesCacheLock.RUnlock() if ok { return cachedTimes @@ -191,9 +176,9 @@ func scheduledTimes(syncFrequency, syncStartAt string) []int { times = append(lo.Reverse(prependTimes), times...) - scheduledTimesCacheLock.Lock() - scheduledTimesCache[fmt.Sprintf(`%s-%s`, syncFrequency, syncStartAt)] = times - scheduledTimesCacheLock.Unlock() + r.scheduledTimesCacheLock.Lock() + r.scheduledTimesCache[fmt.Sprintf(`%s-%s`, syncFrequency, syncStartAt)] = times + r.scheduledTimesCacheLock.Unlock() return times } diff --git a/warehouse/router/scheduling_test.go b/warehouse/router/scheduling_test.go index 5e6e7ec3a09..e2c6668ce09 100644 --- a/warehouse/router/scheduling_test.go +++ b/warehouse/router/scheduling_test.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" "sync" + "sync/atomic" "testing" "time" @@ -91,7 +92,10 @@ func TestRouter_CanCreateUpload(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.expectedPrevScheduledTime, prevScheduledTime(tc.syncFrequency, tc.syncStartAt, tc.currTime)) + r := Router{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) + require.Equal(t, tc.expectedPrevScheduledTime, r.prevScheduledTime(tc.syncFrequency, tc.syncStartAt, tc.currTime)) }) } }) @@ -194,6 +198,8 @@ func TestRouter_CanCreateUpload(t *testing.T) { r := Router{} r.triggerStore = &sync.Map{} r.triggerStore.Store(w.Identifier, struct{}{}) + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) canCreate, err := r.canCreateUpload(context.Background(), w) require.NoError(t, err) @@ -210,6 +216,8 @@ func TestRouter_CanCreateUpload(t *testing.T) { r.config.uploadFreqInS = misc.SingleValueLoader(int64(1800)) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) canCreate, err := r.canCreateUpload(context.Background(), w) require.NoError(t, err) @@ -231,6 +239,8 @@ func TestRouter_CanCreateUpload(t *testing.T) { r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) r.updateCreateJobMarker(w, now.Add(-time.Hour)) @@ -254,6 +264,8 @@ func TestRouter_CanCreateUpload(t *testing.T) { r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(true) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) r.updateCreateJobMarker(w, now) @@ -278,6 +290,8 @@ func TestRouter_CanCreateUpload(t *testing.T) { r := Router{} r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(false) r.now = func() time.Time { return time.Date(2009, time.November, 10, 5, 30, 0, 0, time.UTC) @@ -306,6 +320,8 @@ func TestRouter_CanCreateUpload(t *testing.T) { r.config.uploadFreqInS = misc.SingleValueLoader(int64(1800)) r.createJobMarkerMap = make(map[string]time.Time) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) r.updateCreateJobMarker(w, now) @@ -331,6 +347,8 @@ func TestRouter_CanCreateUpload(t *testing.T) { r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(false) r.config.uploadFreqInS = misc.SingleValueLoader(int64(1800)) r.triggerStore = &sync.Map{} + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) r.createJobMarkerMap = make(map[string]time.Time) r.updateCreateJobMarker(w, now.Add(-time.Hour)) @@ -421,6 +439,8 @@ func TestRouter_CanCreateUpload(t *testing.T) { r.triggerStore = &sync.Map{} r.config.warehouseSyncFreqIgnore = misc.SingleValueLoader(false) r.createJobMarkerMap = make(map[string]time.Time) + r.createUploadAlways = &atomic.Bool{} + r.scheduledTimesCache = make(map[string][]int) r.uploadRepo = repoUpload r.now = func() time.Time { return tc.now diff --git a/warehouse/router/tracker_test.go b/warehouse/router/tracker_test.go index 473cb1aaf2b..c14af1fde85 100644 --- a/warehouse/router/tracker_test.go +++ b/warehouse/router/tracker_test.go @@ -270,7 +270,7 @@ func TestRouter_CronTracker(t *testing.T) { statsFactory: memstats.New(), db: sqlquerywrapper.New(pgResource.DB), logger: logger.NOP, - conf: config.Default, + conf: config.New(), } r.warehouses = append(r.warehouses, warehouse) diff --git a/warehouse/router/upload.go b/warehouse/router/upload.go index 9638e01acd8..223af89cb3b 100644 --- a/warehouse/router/upload.go +++ b/warehouse/router/upload.go @@ -6,14 +6,13 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strconv" "strings" "sync" "sync/atomic" "time" - "golang.org/x/exp/slices" - "github.com/cenkalti/backoff/v4" "github.com/samber/lo" @@ -709,6 +708,7 @@ func (job *UploadJob) exportRegularTables(specialTables []string, loadFilesTable // * the source is not an ETL source // * the source is not a replay source // * the source category is not in "mergeSourceCategoryMap" +// * the job is not a retry func (job *UploadJob) CanAppend() bool { if isSourceETL := job.upload.SourceJobRunID != ""; isSourceETL { return false @@ -719,6 +719,9 @@ func (job *UploadJob) CanAppend() bool { if _, isMergeCategory := mergeSourceCategoryMap[job.warehouse.Source.SourceDefinition.Category]; isMergeCategory { return false } + if job.upload.Retried { + return false + } return true } diff --git a/warehouse/router/upload_stats.go b/warehouse/router/upload_stats.go index 2b2ef87051d..7134f79bd77 100644 --- a/warehouse/router/upload_stats.go +++ b/warehouse/router/upload_stats.go @@ -2,11 +2,10 @@ package router import ( "fmt" + "slices" "strings" "time" - "golang.org/x/exp/slices" - "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/utils/misc" diff --git a/warehouse/router/upload_stats_test.go b/warehouse/router/upload_stats_test.go index e81125d12f4..5bebaba8c40 100644 --- a/warehouse/router/upload_stats_test.go +++ b/warehouse/router/upload_stats_test.go @@ -7,7 +7,8 @@ import ( "testing" "time" - "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/warehouse/internal/repo" @@ -41,7 +42,7 @@ func TestUploadJob_Stats(t *testing.T) { mockMeasurement.EXPECT().Count(1).Times(1) ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, statsFactory: mockStats, db: sqlmiddleware.New(db), @@ -71,7 +72,7 @@ func TestUploadJob_Stats(t *testing.T) { mockMeasurement.EXPECT().Count(4).Times(2) ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, statsFactory: mockStats, db: sqlmiddleware.New(db), @@ -102,7 +103,7 @@ func TestUploadJob_Stats(t *testing.T) { mockMeasurement.EXPECT().Since(gomock.Any()).Times(1) ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, statsFactory: mockStats, db: sqlmiddleware.New(db), @@ -131,7 +132,7 @@ func TestUploadJob_Stats(t *testing.T) { mockMeasurement.EXPECT().SendTiming(gomock.Any()).Times(1) ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, statsFactory: mockStats, db: sqlmiddleware.New(db), @@ -165,9 +166,9 @@ func TestUploadJob_MatchRows(t *testing.T) { t.Run("Total rows in load files", func(t *testing.T) { ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, - statsFactory: stats.Default, + statsFactory: memstats.New(), db: sqlmiddleware.New(db), } job := ujf.NewUploadJob(context.Background(), &model.UploadJob{ @@ -205,9 +206,9 @@ func TestUploadJob_MatchRows(t *testing.T) { t.Run("Total rows in staging files", func(t *testing.T) { ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, - statsFactory: stats.Default, + statsFactory: memstats.New(), db: sqlmiddleware.New(db), } job := ujf.NewUploadJob(context.Background(), &model.UploadJob{ @@ -246,9 +247,9 @@ func TestUploadJob_MatchRows(t *testing.T) { t.Run("Get uploads timings", func(t *testing.T) { ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, - statsFactory: stats.Default, + statsFactory: memstats.New(), db: sqlmiddleware.New(db), } job := ujf.NewUploadJob(context.Background(), &model.UploadJob{ @@ -333,7 +334,7 @@ func TestUploadJob_MatchRows(t *testing.T) { mockMeasurement.EXPECT().Gauge(gomock.Any()).Times(tc.statsCount) ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, statsFactory: mockStats, db: sqlmiddleware.New(db), diff --git a/warehouse/router/upload_test.go b/warehouse/router/upload_test.go index 2556063a4c4..9655b058c57 100644 --- a/warehouse/router/upload_test.go +++ b/warehouse/router/upload_test.go @@ -13,7 +13,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-go-kit/stats/memstats" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" backendconfig "github.com/rudderlabs/rudder-server/backend-config" @@ -241,14 +240,14 @@ func TestUploadJobT_UpdateTableSchema(t *testing.T) { t.Log("db:", pgResource.DBDsn) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) rs.DB = sqlmiddleware.New(pgResource.DB) rs.Namespace = testNamespace ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, - statsFactory: stats.Default, + statsFactory: memstats.New(), db: sqlmiddleware.New(pgResource.DB), } @@ -317,14 +316,14 @@ func TestUploadJobT_UpdateTableSchema(t *testing.T) { t.Log("db:", pgResource.DBDsn) - rs := redshift.New(config.Default, logger.NOP, stats.Default) + rs := redshift.New(config.New(), logger.NOP, memstats.New()) rs.DB = sqlmiddleware.New(pgResource.DB) rs.Namespace = testNamespace ujf := &UploadJobFactory{ - conf: config.Default, + conf: config.New(), logger: logger.NOP, - statsFactory: stats.Default, + statsFactory: memstats.New(), db: sqlmiddleware.New(pgResource.DB), } diff --git a/warehouse/schema/schema.go b/warehouse/schema/schema.go index ae5e93996b3..b8644214581 100644 --- a/warehouse/schema/schema.go +++ b/warehouse/schema/schema.go @@ -5,10 +5,10 @@ import ( "fmt" "reflect" "regexp" + "slices" "sync" "github.com/samber/lo" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" diff --git a/warehouse/slave/slave_test.go b/warehouse/slave/slave_test.go index 25012262913..948c480f76b 100644 --- a/warehouse/slave/slave_test.go +++ b/warehouse/slave/slave_test.go @@ -8,6 +8,8 @@ import ( "os" "testing" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" "github.com/rudderlabs/rudder-server/warehouse/bcm" @@ -29,7 +31,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/internal/model" ) @@ -95,18 +96,18 @@ func TestSlave(t *testing.T) { workerJobs := 25 tenantManager := multitenant.New( - config.Default, + config.New(), backendconfig.DefaultBackendConfig, ) slave := New( - config.Default, + config.New(), logger.NOP, - stats.Default, + memstats.New(), slaveNotifier, - bcm.New(config.Default, nil, tenantManager, logger.NOP, stats.Default), - constraints.New(config.Default), - encoding.NewFactory(config.Default), + bcm.New(config.New(), nil, tenantManager, logger.NOP, memstats.New()), + constraints.New(config.New()), + encoding.NewFactory(config.New()), ) slave.config.noOfSlaveWorkerRoutines = workers diff --git a/warehouse/slave/worker_job.go b/warehouse/slave/worker_job.go index 5e7b1dfeb07..21bc1814fc9 100644 --- a/warehouse/slave/worker_job.go +++ b/warehouse/slave/worker_job.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "slices" "sort" "strconv" "strings" @@ -19,7 +20,6 @@ import ( "github.com/rudderlabs/rudder-server/utils/timeutil" "go.uber.org/atomic" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/config" diff --git a/warehouse/slave/worker_job_test.go b/warehouse/slave/worker_job_test.go index cf7eed69eff..69dbe117f48 100644 --- a/warehouse/slave/worker_job_test.go +++ b/warehouse/slave/worker_job_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "os" + "slices" "strings" "testing" "time" @@ -19,7 +20,6 @@ import ( "github.com/google/uuid" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/require" - "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" @@ -198,7 +198,7 @@ func TestSlaveJob(t *testing.T) { StagingFileLocation: uf.ObjectName, } - jr := newJobRun(p, config.Default, logger.NOP, stats.Default, encoding.NewFactory(config.Default)) + jr := newJobRun(p, config.New(), logger.NOP, memstats.New(), encoding.NewFactory(config.New())) defer jr.cleanup() @@ -226,7 +226,7 @@ func TestSlaveJob(t *testing.T) { statsStore := memstats.New() - jr := newJobRun(p, config.Default, logger.NOP, statsStore, encoding.NewFactory(config.Default)) + jr := newJobRun(p, config.New(), logger.NOP, statsStore, encoding.NewFactory(config.New())) defer jr.cleanup() @@ -263,7 +263,7 @@ func TestSlaveJob(t *testing.T) { statsStore := memstats.New() - jr := newJobRun(p, config.Default, logger.NOP, statsStore, encoding.NewFactory(config.Default)) + jr := newJobRun(p, config.New(), logger.NOP, statsStore, encoding.NewFactory(config.New())) defer jr.cleanup() @@ -296,7 +296,7 @@ func TestSlaveJob(t *testing.T) { StagingDestinationRevisionID: uuid.New().String(), } - jr := newJobRun(p, config.Default, logger.NOP, stats.Default, encoding.NewFactory(config.Default)) + jr := newJobRun(p, config.New(), logger.NOP, memstats.New(), encoding.NewFactory(config.New())) defer jr.cleanup() @@ -323,7 +323,7 @@ func TestSlaveJob(t *testing.T) { DestinationType: destType, } - jr := newJobRun(p, config.Default, logger.NOP, stats.Default, encoding.NewFactory(config.Default)) + jr := newJobRun(p, config.New(), logger.NOP, memstats.New(), encoding.NewFactory(config.New())) defer jr.cleanup() @@ -364,7 +364,7 @@ func TestSlaveJob(t *testing.T) { now := time.Date(2020, 4, 27, 20, 0, 0, 0, time.UTC) - jr := newJobRun(p, config.Default, logger.NOP, stats.Default, encoding.NewFactory(config.Default)) + jr := newJobRun(p, config.New(), logger.NOP, memstats.New(), encoding.NewFactory(config.New())) jr.uuidTS = now jr.now = func() time.Time { return now @@ -515,7 +515,7 @@ func TestSlaveJob(t *testing.T) { c.Set("Warehouse.slaveUploadTimeout", "5m") c.Set("WAREHOUSE_BUCKET_LOAD_OBJECTS_FOLDER_NAME", loadObjectFolder) - jr := newJobRun(job, c, logger.NOP, store, encoding.NewFactory(config.Default)) + jr := newJobRun(job, c, logger.NOP, store, encoding.NewFactory(config.New())) jr.since = func(t time.Time) time.Duration { return time.Second } diff --git a/warehouse/slave/worker_test.go b/warehouse/slave/worker_test.go index 71bd09fa275..279e22ac8f6 100644 --- a/warehouse/slave/worker_test.go +++ b/warehouse/slave/worker_test.go @@ -9,6 +9,8 @@ import ( "os" "testing" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/rudderlabs/rudder-server/warehouse/bcm" "github.com/rudderlabs/rudder-server/warehouse/constraints" @@ -25,7 +27,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" backendconfig "github.com/rudderlabs/rudder-server/backend-config" mocksBackendConfig "github.com/rudderlabs/rudder-server/mocks/backend-config" @@ -77,7 +78,7 @@ func TestSlaveWorker(t *testing.T) { jobLocation := uploadFile(t, ctx, destConf, "testdata/staging.json.gz") schemaMap := stagingSchema(t) - ef := encoding.NewFactory(config.Default) + ef := encoding.NewFactory(config.New()) t.Run("success", func(t *testing.T) { subscribeCh := make(chan *notifier.ClaimJobResponse) @@ -87,15 +88,15 @@ func TestSlaveWorker(t *testing.T) { subscribeCh: subscribeCh, } - tenantManager := multitenant.New(config.Default, backendconfig.DefaultBackendConfig) + tenantManager := multitenant.New(config.New(), backendconfig.DefaultBackendConfig) slaveWorker := newWorker( - config.Default, + config.New(), logger.NOP, - stats.Default, + memstats.New(), slaveNotifier, - bcm.New(config.Default, nil, tenantManager, logger.NOP, stats.Default), - constraints.New(config.Default), + bcm.New(config.New(), nil, tenantManager, logger.NOP, memstats.New()), + constraints.New(config.New()), ef, workerIdx, ) @@ -189,15 +190,15 @@ func TestSlaveWorker(t *testing.T) { subscribeCh: subscribeCh, } - tenantManager := multitenant.New(config.Default, backendconfig.DefaultBackendConfig) + tenantManager := multitenant.New(config.New(), backendconfig.DefaultBackendConfig) slaveWorker := newWorker( - config.Default, + config.New(), logger.NOP, - stats.Default, + memstats.New(), slaveNotifier, - bcm.New(config.Default, nil, tenantManager, logger.NOP, stats.Default), - constraints.New(config.Default), + bcm.New(config.New(), nil, tenantManager, logger.NOP, memstats.New()), + constraints.New(config.New()), ef, workerIdx, ) @@ -318,15 +319,15 @@ func TestSlaveWorker(t *testing.T) { c := config.New() c.Set("Warehouse.s3_datalake.columnCountLimit", 10) - tenantManager := multitenant.New(config.Default, backendconfig.DefaultBackendConfig) + tenantManager := multitenant.New(config.New(), backendconfig.DefaultBackendConfig) slaveWorker := newWorker( c, logger.NOP, - stats.Default, + memstats.New(), slaveNotifier, - bcm.New(config.Default, nil, tenantManager, logger.NOP, stats.Default), - constraints.New(config.Default), + bcm.New(config.New(), nil, tenantManager, logger.NOP, memstats.New()), + constraints.New(config.New()), ef, workerIdx, ) @@ -387,15 +388,15 @@ func TestSlaveWorker(t *testing.T) { subscribeCh: subscribeCh, } - tenantManager := multitenant.New(config.Default, backendconfig.DefaultBackendConfig) + tenantManager := multitenant.New(config.New(), backendconfig.DefaultBackendConfig) slaveWorker := newWorker( - config.Default, + config.New(), logger.NOP, - stats.Default, + memstats.New(), slaveNotifier, - bcm.New(config.Default, nil, tenantManager, logger.NOP, stats.Default), - constraints.New(config.Default), + bcm.New(config.New(), nil, tenantManager, logger.NOP, memstats.New()), + constraints.New(config.New()), ef, workerIdx, ) @@ -540,9 +541,9 @@ func TestSlaveWorker(t *testing.T) { return ch }).AnyTimes() - tenantManager := multitenant.New(config.Default, mockBackendConfig) - bcm := bcm.New(config.Default, nil, tenantManager, logger.NOP, stats.Default) - ef := encoding.NewFactory(config.Default) + tenantManager := multitenant.New(config.New(), mockBackendConfig) + bcm := bcm.New(config.New(), nil, tenantManager, logger.NOP, memstats.New()) + ef := encoding.NewFactory(config.New()) setupCh := make(chan struct{}) go func() { @@ -568,10 +569,10 @@ func TestSlaveWorker(t *testing.T) { slaveWorker := newWorker( c, logger.NOP, - stats.Default, + memstats.New(), slaveNotifier, bcm, - constraints.New(config.Default), + constraints.New(config.New()), ef, workerIdx, ) @@ -634,10 +635,10 @@ func TestSlaveWorker(t *testing.T) { slaveWorker := newWorker( c, logger.NOP, - stats.Default, + memstats.New(), slaveNotifier, bcm, - constraints.New(config.Default), + constraints.New(config.New()), ef, workerIdx, ) diff --git a/warehouse/utils/querytype.go b/warehouse/utils/querytype.go index a439f3fee94..9fb994ef14d 100644 --- a/warehouse/utils/querytype.go +++ b/warehouse/utils/querytype.go @@ -28,6 +28,7 @@ func init() { "(?P(?:IF.*)*DROP.*TABLE)", "(?PSHOW.*TABLES)", "(?PSHOW.*PARTITIONS)", + "(?PSHOW.*SCHEMAS)", "(?PDESCRIBE.*(?:QUERY.*)*TABLE)", "(?PSET.*TO)", } diff --git a/warehouse/utils/querytype_test.go b/warehouse/utils/querytype_test.go index 55ce137bb04..96f24702341 100644 --- a/warehouse/utils/querytype_test.go +++ b/warehouse/utils/querytype_test.go @@ -38,6 +38,7 @@ func TestGetQueryType(t *testing.T) { {"drop table 2", "\t\n\n \t\n\n IF OBJECT_ID ('foo.qux','X') IS NOT NULL DROP TABLE foo.bar", "DROP_TABLE", true}, {"show tables", "\t\n\n \t\n\n sHoW tAbLes FROM some_table", "SHOW_TABLES", true}, {"show partitions", "\t\n\n \t\n\n sHoW pArtItiOns billing.tracks_t1", "SHOW_PARTITIONS", true}, + {"show schemas", "\t\n\n \t\n\n sHoW schemaS like foobar", "SHOW_SCHEMAS", true}, {"describe table 1", "\t\n\n \t\n\n dEscrIbe tABLE t1", "DESCRIBE_TABLE", true}, {"describe table 2", "\t\n\n \t\n\n dEscrIbe qUeRy tABLE t1", "DESCRIBE_TABLE", true}, {"set", "\t\n\n \t\n\n sEt something TO something_else", "SET_TO", true}, diff --git a/warehouse/utils/utils.go b/warehouse/utils/utils.go index 4897ba0c42a..de961279023 100644 --- a/warehouse/utils/utils.go +++ b/warehouse/utils/utils.go @@ -35,7 +35,6 @@ import ( backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/utils/awsutils" "github.com/rudderlabs/rudder-server/utils/misc" - "github.com/rudderlabs/rudder-server/warehouse/tunnelling" ) const ( @@ -935,16 +934,6 @@ func RandHex() string { return string(buf[:]) } -func ExtractTunnelInfoFromDestinationConfig(config map[string]interface{}) *tunnelling.TunnelInfo { - if tunnelEnabled := ReadAsBool("useSSH", config); !tunnelEnabled { - return nil - } - - return &tunnelling.TunnelInfo{ - Config: config, - } -} - func ReadAsBool(key string, config map[string]interface{}) bool { if _, ok := config[key]; ok { if val, ok := config[key].(bool); ok { diff --git a/warehouse/utils/utils_test.go b/warehouse/utils/utils_test.go index bcf81dcecb3..06b5ad6be1e 100644 --- a/warehouse/utils/utils_test.go +++ b/warehouse/utils/utils_test.go @@ -979,7 +979,7 @@ func TestWarehouseT_GetBoolDestinationConfig(t *testing.T) { warehouse: model.Warehouse{ Destination: backendconfig.DestinationT{ Config: map[string]interface{}{ - "k1": "true", + "useRudderStorage": "true", }, }, }, @@ -989,7 +989,7 @@ func TestWarehouseT_GetBoolDestinationConfig(t *testing.T) { warehouse: model.Warehouse{ Destination: backendconfig.DestinationT{ Config: map[string]interface{}{ - "k1": false, + "useRudderStorage": false, }, }, }, @@ -999,7 +999,7 @@ func TestWarehouseT_GetBoolDestinationConfig(t *testing.T) { warehouse: model.Warehouse{ Destination: backendconfig.DestinationT{ Config: map[string]interface{}{ - "k1": true, + "useRudderStorage": true, }, }, }, @@ -1007,7 +1007,7 @@ func TestWarehouseT_GetBoolDestinationConfig(t *testing.T) { }, } for idx, input := range inputs { - got := input.warehouse.GetBoolDestinationConfig("k1") + got := input.warehouse.GetBoolDestinationConfig(model.UseRudderStorageSetting) want := input.expected if got != want { t.Errorf("got %t expected %t input %d", got, want, idx)