From 33513c9f6e4d765d3f41ce729fed6c7eaa9667aa Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Wed, 1 Nov 2023 14:19:08 +0530 Subject: [PATCH 1/5] feat: introducing warehouse repo withTx (#4042) --- warehouse/internal/repo/load.go | 86 +++++++++++-------------- warehouse/internal/repo/repo.go | 21 ++++++ warehouse/internal/repo/table_upload.go | 77 +++++++++------------- warehouse/internal/repo/upload.go | 3 +- 4 files changed, 89 insertions(+), 98 deletions(-) diff --git a/warehouse/internal/repo/load.go b/warehouse/internal/repo/load.go index 71e2fb0da54..f10a7e118e3 100644 --- a/warehouse/internal/repo/load.go +++ b/warehouse/internal/repo/load.go @@ -2,7 +2,6 @@ package repo import ( "context" - "database/sql" jsonstd "encoding/json" "fmt" @@ -44,14 +43,14 @@ func NewLoadFiles(db *sqlmiddleware.DB, opts ...Opt) *LoadFiles { } // DeleteByStagingFiles deletes load files associated with stagingFileIDs. -func (repo *LoadFiles) DeleteByStagingFiles(ctx context.Context, stagingFileIDs []int64) error { +func (lf *LoadFiles) DeleteByStagingFiles(ctx context.Context, stagingFileIDs []int64) error { sqlStatement := ` DELETE FROM ` + loadTableName + ` WHERE staging_file_id = ANY($1);` - _, err := repo.db.ExecContext(ctx, sqlStatement, pq.Array(stagingFileIDs)) + _, err := lf.db.ExecContext(ctx, sqlStatement, pq.Array(stagingFileIDs)) if err != nil { return fmt.Errorf(`deleting load files: %w`, err) } @@ -60,58 +59,47 @@ func (repo *LoadFiles) DeleteByStagingFiles(ctx context.Context, stagingFileIDs } // Insert loadFiles into the database. -func (repo *LoadFiles) Insert(ctx context.Context, loadFiles []model.LoadFile) (err error) { - // Using transactions for bulk copying - txn, err := repo.db.BeginTx(ctx, &sql.TxOptions{}) - if err != nil { - return - } - - stmt, err := txn.PrepareContext( - ctx, - pq.CopyIn( - "wh_load_files", - "staging_file_id", - "location", - "source_id", - "destination_id", - "destination_type", - "table_name", - "total_events", - "created_at", - "metadata", - ), - ) - if err != nil { - return fmt.Errorf(`inserting load files: CopyIn: %w`, err) - } - defer func() { _ = stmt.Close() }() - - for _, loadFile := range loadFiles { - metadata := fmt.Sprintf(`{"content_length": %d, "destination_revision_id": %q, "use_rudder_storage": %t}`, loadFile.ContentLength, loadFile.DestinationRevisionID, loadFile.UseRudderStorage) - _, err = stmt.ExecContext(ctx, loadFile.StagingFileID, loadFile.Location, loadFile.SourceID, loadFile.DestinationID, loadFile.DestinationType, loadFile.TableName, loadFile.TotalRows, timeutil.Now(), metadata) +func (lf *LoadFiles) Insert(ctx context.Context, loadFiles []model.LoadFile) error { + return (*repo)(lf).WithTx(ctx, func(tx *sqlmiddleware.Tx) error { + stmt, err := tx.PrepareContext( + ctx, + pq.CopyIn( + "wh_load_files", + "staging_file_id", + "location", + "source_id", + "destination_id", + "destination_type", + "table_name", + "total_events", + "created_at", + "metadata", + ), + ) if err != nil { - _ = txn.Rollback() - return fmt.Errorf(`inserting load files: CopyIn exec: %w`, err) + return fmt.Errorf(`inserting load files: CopyIn: %w`, err) } - } - - _, err = stmt.ExecContext(ctx) - if err != nil { - _ = txn.Rollback() - return fmt.Errorf(`inserting load files: CopyIn final exec: %w`, err) - } - err = txn.Commit() - if err != nil { - return fmt.Errorf(`inserting load files: commit: %w`, err) - } - return + defer func() { _ = stmt.Close() }() + + for _, loadFile := range loadFiles { + metadata := fmt.Sprintf(`{"content_length": %d, "destination_revision_id": %q, "use_rudder_storage": %t}`, loadFile.ContentLength, loadFile.DestinationRevisionID, loadFile.UseRudderStorage) + _, err = stmt.ExecContext(ctx, loadFile.StagingFileID, loadFile.Location, loadFile.SourceID, loadFile.DestinationID, loadFile.DestinationType, loadFile.TableName, loadFile.TotalRows, timeutil.Now(), metadata) + if err != nil { + return fmt.Errorf(`inserting load files: CopyIn exec: %w`, err) + } + } + _, err = stmt.ExecContext(ctx) + if err != nil { + return fmt.Errorf(`inserting load files: CopyIn final exec: %w`, err) + } + return nil + }) } // GetByStagingFiles returns all load files matching the staging file ids. // // Ordered by id ascending. -func (repo *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []int64) ([]model.LoadFile, error) { +func (lf *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []int64) ([]model.LoadFile, error) { sqlStatement := ` WITH row_numbered_load_files as ( SELECT @@ -136,7 +124,7 @@ func (repo *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []i ORDER BY id ASC ` - rows, err := repo.db.QueryContext(ctx, sqlStatement, pq.Array(stagingFileIDs)) + rows, err := lf.db.QueryContext(ctx, sqlStatement, pq.Array(stagingFileIDs)) if err != nil { return nil, fmt.Errorf("query staging ids: %w", err) } diff --git a/warehouse/internal/repo/repo.go b/warehouse/internal/repo/repo.go index caf027d2d31..dd7ca9f67c1 100644 --- a/warehouse/internal/repo/repo.go +++ b/warehouse/internal/repo/repo.go @@ -1,6 +1,9 @@ package repo import ( + "context" + "database/sql" + "fmt" "time" sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" @@ -18,3 +21,21 @@ func WithNow(now func() time.Time) Opt { r.now = now } } + +func (r *repo) WithTx(ctx context.Context, f func(tx *sqlmiddleware.Tx) error) error { + tx, err := r.db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + + if err := f(tx); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("rollback transaction for %w: %w", err, rollbackErr) + } + return err + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + return nil +} diff --git a/warehouse/internal/repo/table_upload.go b/warehouse/internal/repo/table_upload.go index f6f5cca7739..df28c57f349 100644 --- a/warehouse/internal/repo/table_upload.go +++ b/warehouse/internal/repo/table_upload.go @@ -58,23 +58,9 @@ func NewTableUploads(db *sqlmiddleware.DB, opts ...Opt) *TableUploads { return r } -func (repo *TableUploads) Insert(ctx context.Context, uploadID int64, tableNames []string) error { - var ( - txn *sqlmiddleware.Tx - stmt *sql.Stmt - err error - ) - - if txn, err = repo.db.BeginTx(ctx, &sql.TxOptions{}); err != nil { - return fmt.Errorf(`begin transaction: %w`, err) - } - defer func() { - if err != nil { - _ = txn.Rollback() - } - }() - - stmt, err = txn.PrepareContext(ctx, ` +func (tu *TableUploads) Insert(ctx context.Context, uploadID int64, tableNames []string) error { + return (*repo)(tu).WithTx(ctx, func(tx *sqlmiddleware.Tx) error { + stmt, err := tx.PrepareContext(ctx, ` INSERT INTO `+tableUploadTableName+` ( wh_upload_id, table_name, status, error, created_at, updated_at @@ -85,38 +71,35 @@ func (repo *TableUploads) Insert(ctx context.Context, uploadID int64, tableNames ON CONSTRAINT `+tableUploadUniqueConstraintName+` DO NOTHING; `) - if err != nil { - return fmt.Errorf(`prepared statement: %w`, err) - } - defer func() { _ = stmt.Close() }() - - for _, tableName := range tableNames { - _, err = stmt.ExecContext(ctx, uploadID, tableName, model.TableUploadWaiting, "{}", repo.now(), repo.now()) if err != nil { - return fmt.Errorf(`stmt exec: %w`, err) + return fmt.Errorf(`prepared statement: %w`, err) } - } - if err = txn.Commit(); err != nil { - return fmt.Errorf(`commit: %w`, err) - } + defer func() { _ = stmt.Close() }() - return nil + for _, tableName := range tableNames { + _, err = stmt.ExecContext(ctx, uploadID, tableName, model.TableUploadWaiting, "{}", tu.now(), tu.now()) + if err != nil { + return fmt.Errorf(`stmt exec: %w`, err) + } + } + return nil + }) } -func (repo *TableUploads) GetByUploadID(ctx context.Context, uploadID int64) ([]model.TableUpload, error) { +func (tu *TableUploads) GetByUploadID(ctx context.Context, uploadID int64) ([]model.TableUpload, error) { query := `SELECT ` + tableUploadColumns + ` FROM ` + tableUploadTableName + ` WHERE wh_upload_id = $1;` - rows, err := repo.db.QueryContext(ctx, query, uploadID) + rows, err := tu.db.QueryContext(ctx, query, uploadID) if err != nil { return nil, fmt.Errorf("querying table uploads: %w", err) } - return repo.parseRows(rows) + return tu.parseRows(rows) } -func (repo *TableUploads) GetByUploadIDAndTableName(ctx context.Context, uploadID int64, tableName string) (model.TableUpload, error) { +func (tu *TableUploads) GetByUploadIDAndTableName(ctx context.Context, uploadID int64, tableName string) (model.TableUpload, error) { query := `SELECT ` + tableUploadColumns + ` FROM ` + tableUploadTableName + ` WHERE wh_upload_id = $1 AND @@ -124,12 +107,12 @@ func (repo *TableUploads) GetByUploadIDAndTableName(ctx context.Context, uploadI LIMIT 1; ` - rows, err := repo.db.QueryContext(ctx, query, uploadID, tableName) + rows, err := tu.db.QueryContext(ctx, query, uploadID, tableName) if err != nil { return model.TableUpload{}, fmt.Errorf("querying table uploads: %w", err) } - entries, err := repo.parseRows(rows) + entries, err := tu.parseRows(rows) if err != nil { return model.TableUpload{}, fmt.Errorf("parsing rows: %w", err) } @@ -191,7 +174,7 @@ func (*TableUploads) parseRows(rows *sqlmiddleware.Rows) ([]model.TableUpload, e return tableUploads, nil } -func (repo *TableUploads) PopulateTotalEventsFromStagingFileIDs(ctx context.Context, uploadId int64, tableName string, stagingFileIDs []int64) error { +func (tu *TableUploads) PopulateTotalEventsFromStagingFileIDs(ctx context.Context, uploadId int64, tableName string, stagingFileIDs []int64) error { subQuery := ` WITH row_numbered_load_files as ( SELECT @@ -231,7 +214,7 @@ func (repo *TableUploads) PopulateTotalEventsFromStagingFileIDs(ctx context.Cont tableName, pq.Array(stagingFileIDs), } - result, err := repo.db.ExecContext( + result, err := tu.db.ExecContext( ctx, query, queryArgs..., @@ -251,7 +234,7 @@ func (repo *TableUploads) PopulateTotalEventsFromStagingFileIDs(ctx context.Cont return nil } -func (repo *TableUploads) TotalExportedEvents(ctx context.Context, uploadId int64, skipTables []string) (int64, error) { +func (tu *TableUploads) TotalExportedEvents(ctx context.Context, uploadId int64, skipTables []string) (int64, error) { var ( count sql.NullInt64 err error @@ -261,7 +244,7 @@ func (repo *TableUploads) TotalExportedEvents(ctx context.Context, uploadId int6 skipTables = []string{} } - err = repo.db.QueryRowContext(ctx, ` + err = tu.db.QueryRowContext(ctx, ` SELECT COALESCE(sum(total_events), 0) AS total FROM @@ -285,7 +268,7 @@ func (repo *TableUploads) TotalExportedEvents(ctx context.Context, uploadId int6 return 0, errors.New(`count is not valid`) } -func (repo *TableUploads) Set(ctx context.Context, uploadId int64, tableName string, options TableUploadSetOptions) error { +func (tu *TableUploads) Set(ctx context.Context, uploadId int64, tableName string, options TableUploadSetOptions) error { var ( query string queryArgs []any @@ -325,7 +308,7 @@ func (repo *TableUploads) Set(ctx context.Context, uploadId int64, tableName str } setQuery.WriteString(fmt.Sprintf(`updated_at = $%d,`, len(queryArgs)+1)) - queryArgs = append(queryArgs, repo.now()) + queryArgs = append(queryArgs, tu.now()) // remove trailing comma setQueryString := strings.TrimSuffix(setQuery.String(), ",") @@ -339,7 +322,7 @@ func (repo *TableUploads) Set(ctx context.Context, uploadId int64, tableName str wh_upload_id = $1 AND table_name = $2; ` - result, err := repo.db.ExecContext( + result, err := tu.db.ExecContext( ctx, query, queryArgs..., @@ -359,12 +342,12 @@ func (repo *TableUploads) Set(ctx context.Context, uploadId int64, tableName str return nil } -func (repo *TableUploads) ExistsForUploadID(ctx context.Context, uploadId int64) (bool, error) { +func (tu *TableUploads) ExistsForUploadID(ctx context.Context, uploadId int64) (bool, error) { var ( count int64 err error ) - err = repo.db.QueryRowContext(ctx, + err = tu.db.QueryRowContext(ctx, ` SELECT COUNT(*) @@ -381,8 +364,8 @@ func (repo *TableUploads) ExistsForUploadID(ctx context.Context, uploadId int64) return count > 0, nil } -func (repo *TableUploads) SyncsInfo(ctx context.Context, uploadID int64) ([]model.TableUploadInfo, error) { - tableUploads, err := repo.GetByUploadID(ctx, uploadID) +func (tu *TableUploads) SyncsInfo(ctx context.Context, uploadID int64) ([]model.TableUploadInfo, error) { + tableUploads, err := tu.GetByUploadID(ctx, uploadID) if err != nil { return nil, fmt.Errorf("table uploads for upload id: %w", err) } diff --git a/warehouse/internal/repo/upload.go b/warehouse/internal/repo/upload.go index b5637eaf9eb..f0171d00a94 100644 --- a/warehouse/internal/repo/upload.go +++ b/warehouse/internal/repo/upload.go @@ -30,8 +30,7 @@ var syncStatusMap = map[string]string{ const ( uploadsTableName = warehouseutils.WarehouseUploadsTable - - uploadColumns = ` + uploadColumns = ` id, status, schema, From b80d27394b42116f3b0f7e86c71d1ed126230ae4 Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Wed, 1 Nov 2023 14:34:56 +0530 Subject: [PATCH 2/5] fix: don't send error in stats (#4055) --- warehouse/router/router.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/warehouse/router/router.go b/warehouse/router/router.go index 1e209538c48..efdbe00ed2f 100644 --- a/warehouse/router/router.go +++ b/warehouse/router/router.go @@ -11,6 +11,8 @@ import ( "sync/atomic" "time" + "github.com/rudderlabs/rudder-server/warehouse/logfield" + "github.com/rudderlabs/rudder-server/warehouse/bcm" "github.com/lib/pq" @@ -594,10 +596,9 @@ func (r *Router) createJobs(ctx context.Context, warehouse model.Warehouse) (err "workspaceId": warehouse.WorkspaceID, "destinationID": warehouse.Destination.ID, "destType": warehouse.Destination.DestinationDefinition.Name, - "reason": err.Error(), }).Count(1) - r.logger.Debugf("[WH]: Skipping upload loop since %s upload freq not exceeded: %v", warehouse.Identifier, err) + r.logger.Debugw("Skipping upload loop since upload freq not exceeded", logfield.Error, err.Error()) return nil } From a9199aeeee599725d561cbb5e215fa56348aa156 Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Wed, 1 Nov 2023 23:16:58 +0530 Subject: [PATCH 3/5] refactor: sources async job (#4008) --- warehouse/api/http.go | 12 +- warehouse/api/http_test.go | 13 +- warehouse/app.go | 12 +- .../integrations/bigquery/bigquery_test.go | 10 +- warehouse/integrations/mssql/mssql_test.go | 10 +- .../integrations/postgres/postgres_test.go | 10 +- .../integrations/redshift/redshift_test.go | 10 +- .../integrations/snowflake/snowflake_test.go | 10 +- warehouse/integrations/testhelper/setup.go | 14 +- warehouse/integrations/testhelper/verify.go | 26 +- warehouse/internal/errors/errors.go | 2 - warehouse/internal/model/source.go | 83 +++ warehouse/internal/model/source_test.go | 114 ++++ warehouse/internal/model/upload.go | 5 +- warehouse/internal/repo/source.go | 303 +++++++++++ warehouse/internal/repo/source_test.go | 485 +++++++++++++++++ warehouse/internal/repo/table_upload.go | 150 ++++-- warehouse/internal/repo/table_upload_test.go | 76 +++ warehouse/internal/repo/upload_test.go | 1 - warehouse/jobs/http.go | 157 ------ warehouse/jobs/http_test.go | 357 ------------- warehouse/jobs/jobs.go | 64 --- warehouse/jobs/runner.go | 425 --------------- warehouse/jobs/types.go | 89 ---- warehouse/jobs/utils.go | 41 -- warehouse/slave/worker.go | 58 +- warehouse/slave/worker_test.go | 47 +- warehouse/source/http.go | 130 +++++ warehouse/source/http_test.go | 364 +++++++++++++ warehouse/source/source.go | 306 +++++++++++ warehouse/source/source_test.go | 504 ++++++++++++++++++ warehouse/source/types.go | 67 +++ 32 files changed, 2640 insertions(+), 1315 deletions(-) create mode 100644 warehouse/internal/model/source.go create mode 100644 warehouse/internal/model/source_test.go create mode 100644 warehouse/internal/repo/source.go create mode 100644 warehouse/internal/repo/source_test.go delete mode 100644 warehouse/jobs/http.go delete mode 100644 warehouse/jobs/http_test.go delete mode 100644 warehouse/jobs/jobs.go delete mode 100644 warehouse/jobs/runner.go delete mode 100644 warehouse/jobs/types.go delete mode 100644 warehouse/jobs/utils.go create mode 100644 warehouse/source/http.go create mode 100644 warehouse/source/http_test.go create mode 100644 warehouse/source/source.go create mode 100644 warehouse/source/source_test.go create mode 100644 warehouse/source/types.go diff --git a/warehouse/api/http.go b/warehouse/api/http.go index 5df66deea44..56977e9e567 100644 --- a/warehouse/api/http.go +++ b/warehouse/api/http.go @@ -34,8 +34,8 @@ import ( sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/internal/repo" - "github.com/rudderlabs/rudder-server/warehouse/jobs" "github.com/rudderlabs/rudder-server/warehouse/multitenant" + "github.com/rudderlabs/rudder-server/warehouse/source" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -75,7 +75,7 @@ type Api struct { bcConfig backendconfig.BackendConfig tenantManager *multitenant.Manager bcManager *bcm.BackendConfigManager - asyncManager *jobs.AsyncJobWh + sourceManager *source.Manager stagingRepo *repo.StagingFiles uploadRepo *repo.Uploads schemaRepo *repo.WHSchema @@ -100,7 +100,7 @@ func NewApi( notifier *notifier.Notifier, tenantManager *multitenant.Manager, bcManager *bcm.BackendConfigManager, - asyncManager *jobs.AsyncJobWh, + sourceManager *source.Manager, triggerStore *sync.Map, ) *Api { a := &Api{ @@ -112,7 +112,7 @@ func NewApi( statsFactory: statsFactory, tenantManager: tenantManager, bcManager: bcManager, - asyncManager: asyncManager, + sourceManager: sourceManager, triggerStore: triggerStore, stagingRepo: repo.NewStagingFiles(db), uploadRepo: repo.NewUploads(db), @@ -170,8 +170,8 @@ func (a *Api) addMasterEndpoints(ctx context.Context, r chi.Router) { r.Post("/pending-events", a.logMiddleware(a.pendingEventsHandler)) r.Post("/trigger-upload", a.logMiddleware(a.triggerUploadHandler)) - r.Post("/jobs", a.logMiddleware(a.asyncManager.InsertJobHandler)) // TODO: add degraded mode - r.Get("/jobs/status", a.logMiddleware(a.asyncManager.StatusJobHandler)) // TODO: add degraded mode + r.Post("/jobs", a.logMiddleware(a.sourceManager.InsertJobHandler)) // TODO: add degraded mode + r.Get("/jobs/status", a.logMiddleware(a.sourceManager.StatusJobHandler)) // TODO: add degraded mode r.Get("/fetch-tables", a.logMiddleware(a.fetchTablesHandler)) // TODO: Remove this endpoint once sources change is released }) diff --git a/warehouse/api/http_test.go b/warehouse/api/http_test.go index 75f83535586..64425df08e7 100644 --- a/warehouse/api/http_test.go +++ b/warehouse/api/http_test.go @@ -27,8 +27,8 @@ import ( kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - "github.com/rudderlabs/rudder-server/warehouse/jobs" "github.com/rudderlabs/rudder-server/warehouse/multitenant" + "github.com/rudderlabs/rudder-server/warehouse/source" "github.com/golang/mock/gomock" "github.com/ory/dockertest/v3" @@ -188,12 +188,12 @@ func TestHTTPApi(t *testing.T) { err = n.Setup(ctx, pgResource.DBDsn) require.NoError(t, err) - sourcesManager := jobs.New( - ctx, + sourcesManager := source.New( + c, + logger.NOP, db, n, ) - jobs.WithConfig(sourcesManager, config.New()) g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { @@ -205,7 +205,7 @@ func TestHTTPApi(t *testing.T) { return nil }) g.Go(func() error { - return sourcesManager.Run() + return sourcesManager.Run(gCtx) }) setupCh := make(chan struct{}) @@ -906,7 +906,8 @@ func TestHTTPApi(t *testing.T) { "source_id": "test_source_id", "destination_id": "test_destination_id", "job_run_id": "test_source_job_run_id", - "task_run_id": "test_source_task_run_id" + "task_run_id": "test_source_task_run_id", + "async_job_type": "deletebyjobrunid" } `))) require.NoError(t, err) diff --git a/warehouse/app.go b/warehouse/app.go index 05c8a4b3d37..e332c83ca20 100644 --- a/warehouse/app.go +++ b/warehouse/app.go @@ -45,8 +45,8 @@ import ( "github.com/rudderlabs/rudder-server/utils/types" whadmin "github.com/rudderlabs/rudder-server/warehouse/admin" "github.com/rudderlabs/rudder-server/warehouse/archive" - "github.com/rudderlabs/rudder-server/warehouse/jobs" "github.com/rudderlabs/rudder-server/warehouse/multitenant" + "github.com/rudderlabs/rudder-server/warehouse/source" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -67,7 +67,7 @@ type App struct { constraintsManager *constraints.Manager encodingFactory *encoding.Factory fileManagerFactory filemanager.Factory - sourcesManager *jobs.AsyncJobWh + sourcesManager *source.Manager admin *whadmin.Admin triggerStore *sync.Map createUploadAlways *atomic.Bool @@ -174,12 +174,12 @@ func (a *App) Setup(ctx context.Context) error { return fmt.Errorf("cannot setup notifier: %w", err) } - a.sourcesManager = jobs.New( - ctx, + a.sourcesManager = source.New( + a.conf, + a.logger, a.db, a.notifier, ) - jobs.WithConfig(a.sourcesManager, a.conf) a.grpcServer, err = api.NewGRPCServer( a.conf, @@ -413,7 +413,7 @@ func (a *App) Run(ctx context.Context) error { return nil }) g.Go(misc.WithBugsnagForWarehouse(func() error { - return a.sourcesManager.Run() + return a.sourcesManager.Run(gCtx) })) } diff --git a/warehouse/integrations/bigquery/bigquery_test.go b/warehouse/integrations/bigquery/bigquery_test.go index 9508076d63d..43595b63e8f 100644 --- a/warehouse/integrations/bigquery/bigquery_test.go +++ b/warehouse/integrations/bigquery/bigquery_test.go @@ -148,7 +148,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap whth.EventsCountMap tableUploadsEventsMap whth.EventsCountMap warehouseEventsMap whth.EventsCountMap - asyncJob bool + sourceJob bool skipModifiedEvents bool prerequisite func(context.Context, testing.TB, *bigquery.Client) enableMerge bool @@ -177,7 +177,7 @@ func TestIntegration(t *testing.T) { stagingFilePrefix: "testdata/upload-job-merge-mode", }, { - name: "Async Job", + name: "Source Job", writeKey: sourcesWriteKey, sourceID: sourcesSourceID, destinationID: sourcesDestinationID, @@ -192,7 +192,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), warehouseEventsMap: whth.SourcesWarehouseEventsMap(), - asyncJob: true, + sourceJob: true, enableMerge: false, prerequisite: func(ctx context.Context, t testing.TB, db *bigquery.Client) { t.Helper() @@ -347,7 +347,7 @@ func TestIntegration(t *testing.T) { LoadFilesEventsMap: tc.loadFilesEventsMap, TableUploadsEventsMap: tc.tableUploadsEventsMap, WarehouseEventsMap: tc.warehouseEventsMap, - AsyncJob: tc.asyncJob, + SourceJob: tc.sourceJob, Config: conf, WorkspaceID: workspaceID, DestinationType: destType, @@ -359,7 +359,7 @@ func TestIntegration(t *testing.T) { StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", UserID: whth.GetUserId(destType), } - if tc.asyncJob { + if tc.sourceJob { ts2.UserID = ts1.UserID } ts2.VerifyEvents(t) diff --git a/warehouse/integrations/mssql/mssql_test.go b/warehouse/integrations/mssql/mssql_test.go index 73702680ce4..37ebac8946d 100644 --- a/warehouse/integrations/mssql/mssql_test.go +++ b/warehouse/integrations/mssql/mssql_test.go @@ -161,7 +161,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap testhelper.EventsCountMap tableUploadsEventsMap testhelper.EventsCountMap warehouseEventsMap testhelper.EventsCountMap - asyncJob bool + sourceJob bool stagingFilePrefix string }{ { @@ -174,7 +174,7 @@ func TestIntegration(t *testing.T) { stagingFilePrefix: "testdata/upload-job", }, { - name: "Async Job", + name: "Source Job", writeKey: sourcesWriteKey, schema: sourcesNamespace, tables: []string{"tracks", "google_sheet"}, @@ -184,7 +184,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap: testhelper.SourcesLoadFilesEventsMap(), tableUploadsEventsMap: testhelper.SourcesTableUploadsEventsMap(), warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), - asyncJob: true, + sourceJob: true, stagingFilePrefix: "testdata/sources-job", }, } @@ -245,7 +245,7 @@ func TestIntegration(t *testing.T) { LoadFilesEventsMap: tc.loadFilesEventsMap, TableUploadsEventsMap: tc.tableUploadsEventsMap, WarehouseEventsMap: tc.warehouseEventsMap, - AsyncJob: tc.asyncJob, + SourceJob: tc.sourceJob, Config: conf, WorkspaceID: workspaceID, DestinationType: destType, @@ -257,7 +257,7 @@ func TestIntegration(t *testing.T) { StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", UserID: testhelper.GetUserId(destType), } - if tc.asyncJob { + if tc.sourceJob { ts2.UserID = ts1.UserID } ts2.VerifyEvents(t) diff --git a/warehouse/integrations/postgres/postgres_test.go b/warehouse/integrations/postgres/postgres_test.go index 6a741bd8f25..c225bba767a 100644 --- a/warehouse/integrations/postgres/postgres_test.go +++ b/warehouse/integrations/postgres/postgres_test.go @@ -189,7 +189,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap whth.EventsCountMap tableUploadsEventsMap whth.EventsCountMap warehouseEventsMap whth.EventsCountMap - asyncJob bool + sourceJob bool stagingFilePrefix string }{ { @@ -204,7 +204,7 @@ func TestIntegration(t *testing.T) { stagingFilePrefix: "testdata/upload-job", }, { - name: "Async Job", + name: "Source Job", writeKey: sourcesWriteKey, schema: sourcesNamespace, tables: []string{"tracks", "google_sheet"}, @@ -214,7 +214,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), warehouseEventsMap: whth.SourcesWarehouseEventsMap(), - asyncJob: true, + sourceJob: true, stagingFilePrefix: "testdata/sources-job", }, } @@ -275,7 +275,7 @@ func TestIntegration(t *testing.T) { LoadFilesEventsMap: tc.loadFilesEventsMap, TableUploadsEventsMap: tc.tableUploadsEventsMap, WarehouseEventsMap: tc.warehouseEventsMap, - AsyncJob: tc.asyncJob, + SourceJob: tc.sourceJob, Config: conf, WorkspaceID: workspaceID, DestinationType: destType, @@ -287,7 +287,7 @@ func TestIntegration(t *testing.T) { StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", UserID: whth.GetUserId(destType), } - if tc.asyncJob { + if tc.sourceJob { ts2.UserID = ts1.UserID } ts2.VerifyEvents(t) diff --git a/warehouse/integrations/redshift/redshift_test.go b/warehouse/integrations/redshift/redshift_test.go index abdcf72fe21..3de39467bac 100644 --- a/warehouse/integrations/redshift/redshift_test.go +++ b/warehouse/integrations/redshift/redshift_test.go @@ -185,7 +185,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap whth.EventsCountMap tableUploadsEventsMap whth.EventsCountMap warehouseEventsMap whth.EventsCountMap - asyncJob bool + sourceJob bool stagingFilePrefix string }{ { @@ -198,7 +198,7 @@ func TestIntegration(t *testing.T) { stagingFilePrefix: "testdata/upload-job", }, { - name: "Async Job", + name: "Source Job", writeKey: sourcesWriteKey, schema: sourcesNamespace, tables: []string{"tracks", "google_sheet"}, @@ -208,7 +208,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap: whth.SourcesLoadFilesEventsMap(), tableUploadsEventsMap: whth.SourcesTableUploadsEventsMap(), warehouseEventsMap: whth.SourcesWarehouseEventsMap(), - asyncJob: true, + sourceJob: true, stagingFilePrefix: "testdata/sources-job", }, } @@ -280,7 +280,7 @@ func TestIntegration(t *testing.T) { LoadFilesEventsMap: tc.loadFilesEventsMap, TableUploadsEventsMap: tc.tableUploadsEventsMap, WarehouseEventsMap: tc.warehouseEventsMap, - AsyncJob: tc.asyncJob, + SourceJob: tc.sourceJob, Config: conf, WorkspaceID: workspaceID, DestinationType: destType, @@ -292,7 +292,7 @@ func TestIntegration(t *testing.T) { StagingFilePath: tc.stagingFilePrefix + ".staging-1.json", UserID: whth.GetUserId(destType), } - if tc.asyncJob { + if tc.sourceJob { ts2.UserID = ts1.UserID } ts2.VerifyEvents(t) diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index 6782aac3e18..d053c9359ed 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -224,7 +224,7 @@ func TestIntegration(t *testing.T) { warehouseEventsMap2 testhelper.EventsCountMap cred *testCredentials database string - asyncJob bool + sourceJob bool stagingFilePrefix string emptyJobRunID bool enableMerge bool @@ -291,7 +291,7 @@ func TestIntegration(t *testing.T) { enableMerge: true, }, { - name: "Async Job with Sources", + name: "Source Job with Sources", writeKey: sourcesWriteKey, schema: sourcesNamespace, tables: []string{"tracks", "google_sheet"}, @@ -308,7 +308,7 @@ func TestIntegration(t *testing.T) { loadFilesEventsMap: testhelper.SourcesLoadFilesEventsMap(), tableUploadsEventsMap: testhelper.SourcesTableUploadsEventsMap(), warehouseEventsMap: testhelper.SourcesWarehouseEventsMap(), - asyncJob: true, + sourceJob: true, stagingFilePrefix: "testdata/sources-job", enableMerge: true, }, @@ -438,7 +438,7 @@ func TestIntegration(t *testing.T) { LoadFilesEventsMap: tc.loadFilesEventsMap, TableUploadsEventsMap: tc.tableUploadsEventsMap, WarehouseEventsMap: whEventsMap, - AsyncJob: tc.asyncJob, + SourceJob: tc.sourceJob, Config: conf, WorkspaceID: workspaceID, DestinationType: destType, @@ -450,7 +450,7 @@ func TestIntegration(t *testing.T) { StagingFilePath: tc.stagingFilePrefix + ".staging-2.json", UserID: userID, } - if tc.asyncJob { + if tc.sourceJob { ts2.UserID = ts1.UserID } ts2.VerifyEvents(t) diff --git a/warehouse/integrations/testhelper/setup.go b/warehouse/integrations/testhelper/setup.go index 91523e0b7f5..247a9e44743 100644 --- a/warehouse/integrations/testhelper/setup.go +++ b/warehouse/integrations/testhelper/setup.go @@ -29,10 +29,10 @@ import ( ) const ( - WaitFor2Minute = 2 * time.Minute - WaitFor10Minute = 10 * time.Minute - DefaultQueryFrequency = 100 * time.Millisecond - AsyncJOBQueryFrequency = 1000 * time.Millisecond + WaitFor2Minute = 2 * time.Minute + WaitFor10Minute = 10 * time.Minute + DefaultQueryFrequency = 100 * time.Millisecond + SourceJobQueryFrequency = 1000 * time.Millisecond ) const ( @@ -64,7 +64,7 @@ type TestConfig struct { TableUploadsEventsMap EventsCountMap WarehouseEventsMap EventsCountMap JobsDB *sql.DB - AsyncJob bool + SourceJob bool SkipWarehouse bool HTTPPort int } @@ -80,8 +80,8 @@ func (w *TestConfig) VerifyEvents(t testing.TB) { verifyEventsInLoadFiles(t, w) verifyEventsInTableUploads(t, w) - if w.AsyncJob { - verifyAsyncJob(t, w) + if w.SourceJob { + verifySourceJob(t, w) } if !w.SkipWarehouse { verifyEventsInWareHouse(t, w) diff --git a/warehouse/integrations/testhelper/verify.go b/warehouse/integrations/testhelper/verify.go index c7ce9d122da..fe1d4c7691c 100644 --- a/warehouse/integrations/testhelper/verify.go +++ b/warehouse/integrations/testhelper/verify.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/stretchr/testify/require" backendconfig "github.com/rudderlabs/rudder-server/backend-config" @@ -236,14 +238,14 @@ func queryCount(cl *whclient.Client, statement string) (int64, error) { return strconv.ParseInt(result.Values[0][0], 10, 64) } -func verifyAsyncJob(t testing.TB, tc *TestConfig) { +func verifySourceJob(t testing.TB, tc *TestConfig) { t.Helper() - t.Logf("Creating async job for sourceID: %s, jobRunID: %s, taskRunID: %s, destinationID: %s, workspaceID: %s", + t.Logf("Creating source job for sourceID: %s, jobRunID: %s, taskRunID: %s, destinationID: %s, workspaceID: %s", tc.SourceID, tc.JobRunID, tc.TaskRunID, tc.DestinationID, tc.WorkspaceID, ) - asyncPayload := fmt.Sprintf( + payload := fmt.Sprintf( `{ "source_id":"%s","job_run_id":"%s","task_run_id":"%s","channel":"sources", "async_job_type":"deletebyjobrunid","destination_id":"%s","start_time":"%s","workspace_id":"%s" @@ -257,7 +259,7 @@ func verifyAsyncJob(t testing.TB, tc *TestConfig) { ) url := fmt.Sprintf("http://localhost:%d/v1/warehouse/jobs", tc.HTTPPort) - req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(asyncPayload)) + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(payload)) require.NoError(t, err) req.Header.Add("Content-Type", "application/json") @@ -275,11 +277,11 @@ func verifyAsyncJob(t testing.TB, tc *TestConfig) { require.NoError(t, err) require.Equal(t, res.Status, "200 OK") - t.Logf("Verify async job status for sourceID: %s, jobRunID: %s, taskRunID: %s, destID: %s, workspaceID: %s", + t.Logf("Verify source job status for sourceID: %s, jobRunID: %s, taskRunID: %s, destID: %s, workspaceID: %s", tc.SourceID, tc.JobRunID, tc.TaskRunID, tc.DestinationID, tc.WorkspaceID, ) - type asyncResponse struct { + type jobResponse struct { Status string `json:"status"` Error string `json:"error"` } @@ -309,14 +311,14 @@ func verifyAsyncJob(t testing.TB, tc *TestConfig) { return false } - var asyncRes asyncResponse - if err = json.NewDecoder(res.Body).Decode(&asyncRes); err != nil { + var jr jobResponse + if err = json.NewDecoder(res.Body).Decode(&jr); err != nil { return false } - return asyncRes.Status == "succeeded" + return jr.Status == model.SourceJobStatusSucceeded.String() } - require.Eventuallyf(t, operation, WaitFor10Minute, AsyncJOBQueryFrequency, - "Failed to get async job status for job_run_id: %s, task_run_id: %s, source_id: %s, destination_id: %s: %v", + require.Eventuallyf(t, operation, WaitFor10Minute, SourceJobQueryFrequency, + "Failed to get source job status for job_run_id: %s, task_run_id: %s, source_id: %s, destination_id: %s: %v", tc.JobRunID, tc.TaskRunID, tc.SourceID, @@ -324,7 +326,7 @@ func verifyAsyncJob(t testing.TB, tc *TestConfig) { err, ) - t.Logf("Completed verifying async job") + t.Logf("Completed verifying source job") } func VerifyConfigurationTest(t testing.TB, destination backendconfig.DestinationT) { diff --git a/warehouse/internal/errors/errors.go b/warehouse/internal/errors/errors.go index 8edf2d57e82..0fa4ecec9d3 100644 --- a/warehouse/internal/errors/errors.go +++ b/warehouse/internal/errors/errors.go @@ -9,6 +9,4 @@ var ( ErrNoWarehouseFound = errors.New("no warehouse found") ErrWorkspaceFromSourceNotFound = errors.New("workspace from source not found") ErrMarshallResponse = errors.New("can't marshall response") - ErrInvalidRequest = errors.New("invalid request") - ErrJobsApiNotInitialized = errors.New("warehouse jobs api not initialized") ) diff --git a/warehouse/internal/model/source.go b/warehouse/internal/model/source.go new file mode 100644 index 00000000000..c271eb5224a --- /dev/null +++ b/warehouse/internal/model/source.go @@ -0,0 +1,83 @@ +package model + +import ( + "encoding/json" + "fmt" + "time" +) + +type SourceJobType interface { + String() string + sourceJobTypeProtected() +} + +type sourceJobType string + +func (s sourceJobType) String() string { return string(s) } +func (s sourceJobType) sourceJobTypeProtected() {} + +var SourceJobTypeDeleteByJobRunID SourceJobType = sourceJobType("deletebyjobrunid") + +func FromSourceJobType(jobType string) (SourceJobType, error) { + switch jobType { + case SourceJobTypeDeleteByJobRunID.String(): + return SourceJobTypeDeleteByJobRunID, nil + default: + return nil, fmt.Errorf("invalid job type %s", jobType) + } +} + +type SourceJobStatus interface { + String() string + sourceJobStatusProtected() +} + +type sourceJobStatus string + +func (s sourceJobStatus) String() string { return string(s) } +func (s sourceJobStatus) sourceJobStatusProtected() {} + +var ( + SourceJobStatusWaiting SourceJobStatus = sourceJobStatus("waiting") + SourceJobStatusExecuting SourceJobStatus = sourceJobStatus("executing") + SourceJobStatusFailed SourceJobStatus = sourceJobStatus("failed") + SourceJobStatusAborted SourceJobStatus = sourceJobStatus("aborted") + SourceJobStatusSucceeded SourceJobStatus = sourceJobStatus("succeeded") +) + +func FromSourceJobStatus(status string) (SourceJobStatus, error) { + switch status { + case SourceJobStatusWaiting.String(): + return SourceJobStatusWaiting, nil + case SourceJobStatusExecuting.String(): + return SourceJobStatusExecuting, nil + case SourceJobStatusFailed.String(): + return SourceJobStatusFailed, nil + case SourceJobStatusAborted.String(): + return SourceJobStatusAborted, nil + case SourceJobStatusSucceeded.String(): + return SourceJobStatusSucceeded, nil + default: + return nil, fmt.Errorf("invalid job status %s", status) + } +} + +type SourceJob struct { + ID int64 + + SourceID string + DestinationID string + WorkspaceID string + + TableName string + + Status SourceJobStatus + Error error + JobType SourceJobType + + Metadata json.RawMessage + Attempts int64 + + CreatedAt time.Time + UpdatedAt time.Time +} diff --git a/warehouse/internal/model/source_test.go b/warehouse/internal/model/source_test.go new file mode 100644 index 00000000000..49896f9999f --- /dev/null +++ b/warehouse/internal/model/source_test.go @@ -0,0 +1,114 @@ +package model + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFromSourceJobType(t *testing.T) { + testCases := []struct { + name string + jobType string + expected SourceJobType + wantErr error + }{ + { + name: "delete bv job run id", + jobType: "deletebyjobrunid", + expected: SourceJobTypeDeleteByJobRunID, + wantErr: nil, + }, + { + name: "invalid", + jobType: "invalid", + expected: nil, + wantErr: fmt.Errorf("invalid job type %s", "invalid"), + }, + { + name: "empty", + jobType: "", + expected: nil, + wantErr: fmt.Errorf("invalid job type %s", ""), + }, + { + name: "nil", + jobType: "", + expected: nil, + wantErr: fmt.Errorf("invalid job type %s", ""), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + jobType, err := FromSourceJobType(tc.jobType) + if tc.wantErr != nil { + require.Equal(t, tc.wantErr, err) + require.Nil(t, jobType) + return + } + require.NoError(t, err) + require.Equal(t, tc.expected, jobType) + }) + } +} + +func TestFromSourceJobStatus(t *testing.T) { + testCases := []struct { + name string + status string + expected SourceJobStatus + wantError error + }{ + { + name: "waiting", + status: "waiting", + expected: SourceJobStatusWaiting, + wantError: nil, + }, + { + name: "executing", + status: "executing", + expected: SourceJobStatusExecuting, + wantError: nil, + }, + { + name: "failed", + status: "failed", + expected: SourceJobStatusFailed, + wantError: nil, + }, + { + name: "aborted", + status: "aborted", + expected: SourceJobStatusAborted, + wantError: nil, + }, + { + name: "succeeded", + status: "succeeded", + expected: SourceJobStatusSucceeded, + wantError: nil, + }, + { + name: "invalid", + status: "invalid", + expected: nil, + wantError: fmt.Errorf("invalid job status %s", "invalid"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + jobStatus, err := FromSourceJobStatus(tc.status) + if tc.wantError != nil { + require.Equal(t, tc.wantError, err) + require.Nil(t, jobStatus) + return + } + require.NoError(t, err) + require.Equal(t, tc.expected, jobStatus) + }) + } +} diff --git a/warehouse/internal/model/upload.go b/warehouse/internal/model/upload.go index 9a77fcb0fd6..a49f6b14d6f 100644 --- a/warehouse/internal/model/upload.go +++ b/warehouse/internal/model/upload.go @@ -64,8 +64,9 @@ func GetUserFriendlyJobErrorCategory(errorType JobErrorType) string { } var ( - ErrUploadNotFound = errors.New("upload not found") - ErrNoUploadsFound = errors.New("no uploads found") + ErrUploadNotFound = errors.New("upload not found") + ErrSourcesJobNotFound = errors.New("sources job not found") + ErrNoUploadsFound = errors.New("no uploads found") ) type Upload struct { diff --git a/warehouse/internal/repo/source.go b/warehouse/internal/repo/source.go new file mode 100644 index 00000000000..45c4fd07f04 --- /dev/null +++ b/warehouse/internal/repo/source.go @@ -0,0 +1,303 @@ +package repo + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lib/pq" + + "github.com/rudderlabs/rudder-server/utils/timeutil" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" +) + +const ( + sourceJobTableName = whutils.WarehouseAsyncJobTable + sourceJobColumns = ` + id, + source_id, + destination_id, + status, + created_at, + updated_at, + tablename, + error, + async_job_type, + metadata, + attempt, + workspace_id + ` +) + +type Source repo + +func NewSource(db *sqlmw.DB, opts ...Opt) *Source { + r := &Source{ + db: db, + now: timeutil.Now, + } + for _, opt := range opts { + opt((*repo)(r)) + } + return r +} + +func (s *Source) Insert(ctx context.Context, sourceJobs []model.SourceJob) ([]int64, error) { + var ids []int64 + + err := (*repo)(s).WithTx(ctx, func(tx *sqlmw.Tx) error { + stmt, err := tx.PrepareContext( + ctx, ` + INSERT INTO `+sourceJobTableName+` ( + source_id, destination_id, tablename, + status, created_at, updated_at, async_job_type, + workspace_id, metadata + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id; +`, + ) + if err != nil { + return fmt.Errorf(`preparing statement: %w`, err) + } + defer func() { _ = stmt.Close() }() + + for _, sourceJob := range sourceJobs { + var id int64 + err = stmt.QueryRowContext( + ctx, + sourceJob.SourceID, + sourceJob.DestinationID, + sourceJob.TableName, + model.SourceJobStatusWaiting.String(), + s.now(), + s.now(), + sourceJob.JobType.String(), + sourceJob.WorkspaceID, + sourceJob.Metadata, + ).Scan(&id) + if err != nil { + return fmt.Errorf(`executing: %w`, err) + } + + ids = append(ids, id) + } + return nil + }) + if err != nil { + return nil, err + } + return ids, nil +} + +func (s *Source) Reset(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, ` + UPDATE + `+sourceJobTableName+` + SET + status = $1 + WHERE + status = $2 OR status = $3; + `, + model.SourceJobStatusWaiting.String(), + model.SourceJobStatusExecuting.String(), + model.SourceJobStatusFailed.String(), + ) + if err != nil { + return fmt.Errorf("executing: %w", err) + } + return nil +} + +func (s *Source) GetToProcess(ctx context.Context, limit int64) ([]model.SourceJob, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT + `+sourceJobColumns+` + FROM + `+sourceJobTableName+` + WHERE + status = $1 OR status = $2 + LIMIT $3; + `, + model.SourceJobStatusWaiting.String(), + model.SourceJobStatusFailed.String(), + limit, + ) + if err != nil { + return nil, fmt.Errorf("querying: %w", err) + } + defer func() { _ = rows.Close() }() + + sourceJobs, err := scanSourceJobs(rows) + if err != nil { + return nil, fmt.Errorf("scanning source jobs: %w", err) + } + return sourceJobs, nil +} + +func scanSourceJobs(rows *sqlmw.Rows) ([]model.SourceJob, error) { + var sourceJobs []model.SourceJob + for rows.Next() { + var sourceJob model.SourceJob + err := scanSourceJob(rows.Scan, &sourceJob) + if err != nil { + return nil, fmt.Errorf("scanning source job: %w", err) + } + sourceJobs = append(sourceJobs, sourceJob) + } + if err := rows.Err(); err != nil { + return nil, err + } + return sourceJobs, nil +} + +func scanSourceJob(scan scanFn, sourceJob *model.SourceJob) error { + var errorRaw sql.NullString + var jobType, status string + + if err := scan( + &sourceJob.ID, + &sourceJob.SourceID, + &sourceJob.DestinationID, + &status, + &sourceJob.CreatedAt, + &sourceJob.UpdatedAt, + &sourceJob.TableName, + &errorRaw, + &jobType, + &sourceJob.Metadata, + &sourceJob.Attempts, + &sourceJob.WorkspaceID, + ); err != nil { + return fmt.Errorf("scanning row: %w", err) + } + + sourceJobStatus, err := model.FromSourceJobStatus(status) + if err != nil { + return fmt.Errorf("getting sourceJobStatus %w", err) + } + sourceJobType, err := model.FromSourceJobType(jobType) + if err != nil { + return fmt.Errorf("getting sourceJobType: %w", err) + } + if errorRaw.Valid && errorRaw.String != "" { + sourceJob.Error = errors.New(errorRaw.String) + } + + sourceJob.Status = sourceJobStatus + sourceJob.JobType = sourceJobType + sourceJob.CreatedAt = sourceJob.CreatedAt.UTC() + sourceJob.UpdatedAt = sourceJob.UpdatedAt.UTC() + return nil +} + +func (s *Source) GetByJobRunTaskRun(ctx context.Context, jobRunID, taskRunID string) (*model.SourceJob, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT + `+sourceJobColumns+` + FROM + `+sourceJobTableName+` + WHERE + metadata->>'job_run_id' = $1 AND + metadata->>'task_run_id' = $2 + LIMIT 1; + `, + jobRunID, + taskRunID, + ) + + var sourceJob model.SourceJob + err := scanSourceJob(row.Scan, &sourceJob) + if errors.Is(err, sql.ErrNoRows) { + return nil, model.ErrSourcesJobNotFound + } + if err != nil { + return nil, fmt.Errorf("scanning source job: %w", err) + } + return &sourceJob, nil +} + +func (s *Source) OnUpdateSuccess(ctx context.Context, id int64) error { + r, err := s.db.ExecContext(ctx, ` + UPDATE + `+sourceJobTableName+` + SET + status = $1, + updated_at = $2 + WHERE + id = $3; +`, + model.SourceJobStatusSucceeded.String(), + s.now(), + id, + ) + if err != nil { + return fmt.Errorf("executing: %w", err) + } + rowsAffected, err := r.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if rowsAffected == 0 { + return model.ErrSourcesJobNotFound + } + return nil +} + +func (s *Source) OnUpdateFailure(ctx context.Context, id int64, error error, maxAttempt int) error { + r, err := s.db.ExecContext(ctx, ` + UPDATE + `+sourceJobTableName+` + SET + status = ( + CASE WHEN attempt > $1 THEN $2 + ELSE $3 END + ), + attempt = attempt + 1, + updated_at = $4, + error = $5 + WHERE + id = $6; +`, + maxAttempt, + model.SourceJobStatusAborted.String(), + model.SourceJobStatusFailed.String(), + s.now(), + error.Error(), + id, + ) + if err != nil { + return fmt.Errorf("executing: %w", err) + } + rowsAffected, err := r.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if rowsAffected == 0 { + return model.ErrSourcesJobNotFound + } + return nil +} + +func (s *Source) MarkExecuting(ctx context.Context, ids []int64) error { + _, err := s.db.ExecContext(ctx, ` + UPDATE + `+sourceJobTableName+` + SET + status = $1, + updated_at = $2 + WHERE + id = ANY($3); +`, + model.SourceJobStatusExecuting, + s.now(), + pq.Array(ids), + ) + if err != nil { + return fmt.Errorf("executing: %w", err) + } + return nil +} diff --git a/warehouse/internal/repo/source_test.go b/warehouse/internal/repo/source_test.go new file mode 100644 index 00000000000..4115cea6213 --- /dev/null +++ b/warehouse/internal/repo/source_test.go @@ -0,0 +1,485 @@ +package repo_test + +import ( + "context" + "encoding/json" + "errors" + "strconv" + "testing" + "time" + + "github.com/samber/lo" + + warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/rudder-server/warehouse/internal/repo" +) + +func TestSource_Insert(t *testing.T) { + const ( + sourceId = "test_source_id" + destinationId = "test_destination_id" + workspaceId = "test_workspace_id" + ) + + db, ctx := setupDB(t), context.Background() + + now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + repoSource := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + t.Run("success", func(t *testing.T) { + ids, err := repoSource.Insert(ctx, lo.RepeatBy(10, func(i int) model.SourceJob { + return model.SourceJob{ + SourceID: sourceId, + DestinationID: destinationId, + TableName: "table-" + strconv.Itoa(i), + WorkspaceID: workspaceId, + Metadata: json.RawMessage(`{"key": "value"}`), + JobType: model.SourceJobTypeDeleteByJobRunID, + } + })) + require.NoError(t, err) + require.Len(t, ids, 10) + }) + t.Run("context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + cancel() + + ids, err := repoSource.Insert(ctx, lo.RepeatBy(1, func(i int) model.SourceJob { + return model.SourceJob{ + SourceID: sourceId, + DestinationID: destinationId, + TableName: "table-" + strconv.Itoa(i), + WorkspaceID: workspaceId, + Metadata: json.RawMessage(`{"key": "value"}`), + JobType: model.SourceJobTypeDeleteByJobRunID, + } + })) + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, ids) + }) +} + +func TestSource_Reset(t *testing.T) { + const ( + sourceId = "test_source_id" + destinationId = "test_destination_id" + workspaceId = "test_workspace_id" + ) + + db, ctx := setupDB(t), context.Background() + + now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + repoSource := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + t.Run("success", func(t *testing.T) { + ids, err := repoSource.Insert(ctx, lo.RepeatBy(10, func(i int) model.SourceJob { + return model.SourceJob{ + SourceID: sourceId, + DestinationID: destinationId, + TableName: "table-" + strconv.Itoa(i), + WorkspaceID: workspaceId, + Metadata: json.RawMessage(`{"key": "value"}`), + JobType: model.SourceJobTypeDeleteByJobRunID, + } + })) + require.NoError(t, err) + require.Len(t, ids, 10) + + for _, id := range ids[0:3] { + _, err = db.ExecContext(ctx, `UPDATE `+warehouseutils.WarehouseAsyncJobTable+` SET status = $1 WHERE id = $2;`, model.SourceJobStatusSucceeded, id) + require.NoError(t, err) + } + for _, id := range ids[3:6] { + _, err = db.ExecContext(ctx, `UPDATE `+warehouseutils.WarehouseAsyncJobTable+` SET status = $1 WHERE id = $2;`, model.SourceJobStatusExecuting, id) + require.NoError(t, err) + } + for _, id := range ids[6:10] { + _, err = db.ExecContext(ctx, `UPDATE `+warehouseutils.WarehouseAsyncJobTable+` SET status = $1 WHERE id = $2;`, model.SourceJobStatusFailed, id) + require.NoError(t, err) + } + + err = repoSource.Reset(ctx) + require.NoError(t, err) + + for _, id := range ids[0:3] { + var status string + err = db.QueryRowContext(ctx, `SELECT status FROM `+warehouseutils.WarehouseAsyncJobTable+` WHERE id = $1`, id).Scan(&status) + require.NoError(t, err) + require.Equal(t, model.SourceJobStatusSucceeded.String(), status) + } + + for _, id := range ids[3:10] { + var status string + err = db.QueryRowContext(ctx, `SELECT status FROM `+warehouseutils.WarehouseAsyncJobTable+` WHERE id = $1`, id).Scan(&status) + require.NoError(t, err) + require.Equal(t, model.SourceJobStatusWaiting.String(), status) + } + }) + t.Run("context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + cancel() + + err := repoSource.Reset(ctx) + require.ErrorIs(t, err, context.Canceled) + }) +} + +func TestSource_GetToProcess(t *testing.T) { + const ( + sourceId = "test_source_id" + destinationId = "test_destination_id" + workspaceId = "test_workspace_id" + ) + + db, ctx := setupDB(t), context.Background() + + now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + repoSource := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + t.Run("nothing to process", func(t *testing.T) { + jobs, err := repoSource.GetToProcess(ctx, 20) + require.NoError(t, err) + require.Zero(t, jobs) + }) + t.Run("few to process", func(t *testing.T) { + ids, err := repoSource.Insert(ctx, lo.RepeatBy(25, func(i int) model.SourceJob { + return model.SourceJob{ + SourceID: sourceId, + DestinationID: destinationId, + TableName: "table-" + strconv.Itoa(i), + WorkspaceID: workspaceId, + Metadata: json.RawMessage(`{"key": "value"}`), + JobType: model.SourceJobTypeDeleteByJobRunID, + } + })) + require.NoError(t, err) + require.Len(t, ids, 25) + + for _, id := range ids[6:10] { + _, err = db.ExecContext(ctx, `UPDATE `+warehouseutils.WarehouseAsyncJobTable+` SET status = $1 WHERE id = $2;`, model.SourceJobStatusExecuting, id) + require.NoError(t, err) + } + for _, id := range ids[10:14] { + _, err = db.ExecContext(ctx, `UPDATE `+warehouseutils.WarehouseAsyncJobTable+` SET status = $1 WHERE id = $2;`, model.SourceJobStatusSucceeded, id) + require.NoError(t, err) + } + for _, id := range ids[14:18] { + _, err = db.ExecContext(ctx, `UPDATE `+warehouseutils.WarehouseAsyncJobTable+` SET status = $1 WHERE id = $2;`, model.SourceJobStatusFailed, id) + require.NoError(t, err) + } + for _, id := range ids[18:22] { + _, err = db.ExecContext(ctx, `UPDATE `+warehouseutils.WarehouseAsyncJobTable+` SET status = $1 WHERE id = $2;`, model.SourceJobStatusAborted, id) + require.NoError(t, err) + } + + jobs, err := repoSource.GetToProcess(ctx, 20) + require.NoError(t, err) + require.Len(t, jobs, 13) + + lo.ForEach(jobs, func(job model.SourceJob, index int) { + require.Equal(t, sourceId, job.SourceID) + require.Equal(t, destinationId, job.DestinationID) + require.Equal(t, workspaceId, job.WorkspaceID) + require.Contains(t, []model.SourceJobStatus{model.SourceJobStatusWaiting, model.SourceJobStatusFailed}, job.Status) + require.Equal(t, model.SourceJobTypeDeleteByJobRunID, job.JobType) + require.Equal(t, json.RawMessage(`{"key": "value"}`), job.Metadata) + require.EqualValues(t, now.UTC(), job.CreatedAt.UTC()) + require.EqualValues(t, now.UTC(), job.UpdatedAt.UTC()) + }) + }) + t.Run("context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + cancel() + + jobs, err := repoSource.GetToProcess(ctx, -1) + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, jobs) + }) +} + +func TestSource_GetByJobRunTaskRun(t *testing.T) { + const ( + sourceId = "test_source_id" + destinationId = "test_destination_id" + workspaceId = "test_workspace_id" + jobRun = "test-job-run" + taskRun = "test-task-run" + otherJobRun = "other-job-run" + otherTaskRun = "other-task-run" + ) + + db, ctx := setupDB(t), context.Background() + + now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + repoSource := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + t.Run("source job available", func(t *testing.T) { + ids, err := repoSource.Insert(ctx, lo.RepeatBy(1, func(i int) model.SourceJob { + return model.SourceJob{ + SourceID: sourceId, + DestinationID: destinationId, + TableName: "table-" + strconv.Itoa(i), + WorkspaceID: workspaceId, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + JobType: model.SourceJobTypeDeleteByJobRunID, + } + })) + require.NoError(t, err) + require.Len(t, ids, 1) + + job, err := repoSource.GetByJobRunTaskRun(ctx, jobRun, taskRun) + require.NoError(t, err) + require.Equal(t, job, &model.SourceJob{ + ID: 1, + SourceID: sourceId, + DestinationID: destinationId, + WorkspaceID: workspaceId, + TableName: "table-0", + Status: model.SourceJobStatusWaiting, + Error: nil, + JobType: model.SourceJobTypeDeleteByJobRunID, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + CreatedAt: now.UTC(), + UpdatedAt: now.UTC(), + Attempts: 0, + }) + }) + t.Run("source job not available", func(t *testing.T) { + job, err := repoSource.GetByJobRunTaskRun(ctx, otherJobRun, otherTaskRun) + require.ErrorIs(t, err, model.ErrSourcesJobNotFound) + require.Nil(t, job) + }) + t.Run("context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + cancel() + + job, err := repoSource.GetByJobRunTaskRun(ctx, jobRun, taskRun) + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, job) + }) +} + +func TestSource_OnUpdateSuccess(t *testing.T) { + const ( + sourceId = "test_source_id" + destinationId = "test_destination_id" + workspaceId = "test_workspace_id" + jobRun = "test-job-run" + taskRun = "test-task-run" + ) + + db, ctx := setupDB(t), context.Background() + + now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + repoSource := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + t.Run("source job found", func(t *testing.T) { + ids, err := repoSource.Insert(ctx, lo.RepeatBy(1, func(i int) model.SourceJob { + return model.SourceJob{ + SourceID: sourceId, + DestinationID: destinationId, + TableName: "table-" + strconv.Itoa(i), + WorkspaceID: workspaceId, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + JobType: model.SourceJobTypeDeleteByJobRunID, + } + })) + require.NoError(t, err) + require.Len(t, ids, 1) + + err = repoSource.OnUpdateSuccess(ctx, int64(1)) + require.NoError(t, err) + + job, err := repoSource.GetByJobRunTaskRun(ctx, jobRun, taskRun) + require.NoError(t, err) + + require.Equal(t, job, &model.SourceJob{ + ID: 1, + SourceID: sourceId, + DestinationID: destinationId, + WorkspaceID: workspaceId, + TableName: "table-0", + Status: model.SourceJobStatusSucceeded, + Error: nil, + JobType: model.SourceJobTypeDeleteByJobRunID, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + CreatedAt: now.UTC(), + UpdatedAt: now.UTC(), + Attempts: 0, + }) + }) + t.Run("source job not found", func(t *testing.T) { + err := repoSource.OnUpdateSuccess(ctx, int64(-1)) + require.ErrorIs(t, err, model.ErrSourcesJobNotFound) + }) + t.Run("context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + cancel() + + err := repoSource.OnUpdateSuccess(ctx, int64(1)) + require.ErrorIs(t, err, context.Canceled) + }) +} + +func TestSource_OnUpdateFailure(t *testing.T) { + const ( + sourceId = "test_source_id" + destinationId = "test_destination_id" + workspaceId = "test_workspace_id" + jobRun = "test-job-run" + taskRun = "test-task-run" + testError = "test-error" + ) + + db, ctx := setupDB(t), context.Background() + + now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + repoSource := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + t.Run("source job found", func(t *testing.T) { + ids, err := repoSource.Insert(ctx, lo.RepeatBy(1, func(i int) model.SourceJob { + return model.SourceJob{ + SourceID: sourceId, + DestinationID: destinationId, + TableName: "table-" + strconv.Itoa(i), + WorkspaceID: workspaceId, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + JobType: model.SourceJobTypeDeleteByJobRunID, + } + })) + require.NoError(t, err) + require.Len(t, ids, 1) + + t.Run("not crossed max attempt", func(t *testing.T) { + err = repoSource.OnUpdateFailure(ctx, int64(1), errors.New(testError), 1) + require.NoError(t, err) + + job, err := repoSource.GetByJobRunTaskRun(ctx, jobRun, taskRun) + require.NoError(t, err) + + require.Equal(t, job, &model.SourceJob{ + ID: 1, + SourceID: sourceId, + DestinationID: destinationId, + WorkspaceID: workspaceId, + TableName: "table-0", + Status: model.SourceJobStatusFailed, + Error: errors.New(testError), + JobType: model.SourceJobTypeDeleteByJobRunID, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + CreatedAt: now.UTC(), + UpdatedAt: now.UTC(), + Attempts: 1, + }) + }) + t.Run("crossed max attempt", func(t *testing.T) { + err = repoSource.OnUpdateFailure(ctx, int64(1), errors.New(testError), -1) + require.NoError(t, err) + + job, err := repoSource.GetByJobRunTaskRun(ctx, jobRun, taskRun) + require.NoError(t, err) + + require.Equal(t, job, &model.SourceJob{ + ID: 1, + SourceID: sourceId, + DestinationID: destinationId, + WorkspaceID: workspaceId, + TableName: "table-0", + Status: model.SourceJobStatusAborted, + Error: errors.New(testError), + JobType: model.SourceJobTypeDeleteByJobRunID, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + CreatedAt: now.UTC(), + UpdatedAt: now.UTC(), + Attempts: 2, + }) + }) + }) + t.Run("source job not found", func(t *testing.T) { + err := repoSource.OnUpdateFailure(ctx, int64(-1), errors.New(testError), 1) + require.ErrorIs(t, err, model.ErrSourcesJobNotFound) + }) + t.Run("context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + cancel() + + err := repoSource.OnUpdateFailure(ctx, int64(1), errors.New(testError), 1) + require.ErrorIs(t, err, context.Canceled) + }) +} + +func TestSource_MarkExecuting(t *testing.T) { + const ( + sourceId = "test_source_id" + destinationId = "test_destination_id" + workspaceId = "test_workspace_id" + jobRun = "test-job-run" + taskRun = "test-task-run" + ) + + db, ctx := setupDB(t), context.Background() + + now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + repoSource := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + t.Run("success", func(t *testing.T) { + ids, err := repoSource.Insert(ctx, lo.RepeatBy(1, func(i int) model.SourceJob { + return model.SourceJob{ + SourceID: sourceId, + DestinationID: destinationId, + TableName: "table-" + strconv.Itoa(i), + WorkspaceID: workspaceId, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + JobType: model.SourceJobTypeDeleteByJobRunID, + } + })) + require.NoError(t, err) + require.Len(t, ids, 1) + + err = repoSource.MarkExecuting(ctx, []int64{1}) + require.NoError(t, err) + + job, err := repoSource.GetByJobRunTaskRun(ctx, jobRun, taskRun) + require.NoError(t, err) + + require.Equal(t, job, &model.SourceJob{ + ID: 1, + SourceID: sourceId, + DestinationID: destinationId, + WorkspaceID: workspaceId, + TableName: "table-0", + Status: model.SourceJobStatusExecuting, + Error: nil, + JobType: model.SourceJobTypeDeleteByJobRunID, + Metadata: json.RawMessage(`{"job_run_id": "test-job-run", "task_run_id": "test-task-run"}`), + CreatedAt: now.UTC(), + UpdatedAt: now.UTC(), + Attempts: 0, + }) + }) + t.Run("context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + cancel() + + err := repoSource.MarkExecuting(ctx, []int64{1}) + require.ErrorIs(t, err, context.Canceled) + }) +} diff --git a/warehouse/internal/repo/table_upload.go b/warehouse/internal/repo/table_upload.go index df28c57f349..24101fee15e 100644 --- a/warehouse/internal/repo/table_upload.go +++ b/warehouse/internal/repo/table_upload.go @@ -95,8 +95,13 @@ func (tu *TableUploads) GetByUploadID(ctx context.Context, uploadID int64) ([]mo if err != nil { return nil, fmt.Errorf("querying table uploads: %w", err) } + defer func() { _ = rows.Close() }() - return tu.parseRows(rows) + tableUploads, err := scanTableUploads(rows) + if err != nil { + return nil, fmt.Errorf("scanning table uploads: %w", err) + } + return tableUploads, nil } func (tu *TableUploads) GetByUploadIDAndTableName(ctx context.Context, uploadID int64, tableName string) (model.TableUpload, error) { @@ -107,73 +112,72 @@ func (tu *TableUploads) GetByUploadIDAndTableName(ctx context.Context, uploadID LIMIT 1; ` - rows, err := tu.db.QueryContext(ctx, query, uploadID, tableName) - if err != nil { - return model.TableUpload{}, fmt.Errorf("querying table uploads: %w", err) - } + row := tu.db.QueryRowContext(ctx, query, uploadID, tableName) - entries, err := tu.parseRows(rows) - if err != nil { - return model.TableUpload{}, fmt.Errorf("parsing rows: %w", err) + var tableUpload model.TableUpload + err := scanTableUpload(row.Scan, &tableUpload) + if errors.Is(err, sql.ErrNoRows) { + return tableUpload, fmt.Errorf("no table upload found with uploadID: %d, tableName: %s", uploadID, tableName) } - if len(entries) == 0 { - return model.TableUpload{}, fmt.Errorf("no table upload found with uploadID: %d, tableName: %s", uploadID, tableName) + if err != nil { + return tableUpload, fmt.Errorf("scanning table upload: %w", err) } - - return entries[0], err + return tableUpload, err } -func (*TableUploads) parseRows(rows *sqlmiddleware.Rows) ([]model.TableUpload, error) { +func scanTableUploads(rows *sqlmiddleware.Rows) ([]model.TableUpload, error) { var tableUploads []model.TableUpload - - defer func() { _ = rows.Close() }() - for rows.Next() { - var ( - tableUpload model.TableUpload - locationRaw sql.NullString - lastExecTimeRaw sql.NullTime - totalEvents sql.NullInt64 - ) - err := rows.Scan( - &tableUpload.ID, - &tableUpload.UploadID, - &tableUpload.TableName, - &tableUpload.Status, - &tableUpload.Error, - &lastExecTimeRaw, - &totalEvents, - &tableUpload.CreatedAt, - &tableUpload.UpdatedAt, - &locationRaw, - ) + var tableUpload model.TableUpload + err := scanTableUpload(rows.Scan, &tableUpload) if err != nil { - return nil, fmt.Errorf("scanning row: %w", err) - } - - tableUpload.CreatedAt = tableUpload.CreatedAt.UTC() - tableUpload.UpdatedAt = tableUpload.UpdatedAt.UTC() - - if lastExecTimeRaw.Valid { - tableUpload.LastExecTime = lastExecTimeRaw.Time.UTC() - } - if locationRaw.Valid { - tableUpload.Location = locationRaw.String + return nil, fmt.Errorf("scanning table upload: %w", err) } - if totalEvents.Valid { - tableUpload.TotalEvents = totalEvents.Int64 - } - tableUploads = append(tableUploads, tableUpload) } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("iterating rows: %w", err) + return nil, err } - return tableUploads, nil } +func scanTableUpload(scan scanFn, tableUpload *model.TableUpload) error { + var ( + locationRaw sql.NullString + lastExecTimeRaw sql.NullTime + totalEvents sql.NullInt64 + ) + err := scan( + &tableUpload.ID, + &tableUpload.UploadID, + &tableUpload.TableName, + &tableUpload.Status, + &tableUpload.Error, + &lastExecTimeRaw, + &totalEvents, + &tableUpload.CreatedAt, + &tableUpload.UpdatedAt, + &locationRaw, + ) + if err != nil { + return fmt.Errorf("scanning row: %w", err) + } + + tableUpload.CreatedAt = tableUpload.CreatedAt.UTC() + tableUpload.UpdatedAt = tableUpload.UpdatedAt.UTC() + + if lastExecTimeRaw.Valid { + tableUpload.LastExecTime = lastExecTimeRaw.Time.UTC() + } + if locationRaw.Valid { + tableUpload.Location = locationRaw.String + } + if totalEvents.Valid { + tableUpload.TotalEvents = totalEvents.Int64 + } + return nil +} + func (tu *TableUploads) PopulateTotalEventsFromStagingFileIDs(ctx context.Context, uploadId int64, tableName string, stagingFileIDs []int64) error { subQuery := ` WITH row_numbered_load_files as ( @@ -384,3 +388,45 @@ func (tu *TableUploads) SyncsInfo(ctx context.Context, uploadID int64) ([]model. }) return tableUploadInfos, nil } + +func (tu *TableUploads) GetByJobRunTaskRun( + ctx context.Context, + sourceID, + destinationID, + jobRunID, + taskRunID string, +) ([]model.TableUpload, error) { + rows, err := tu.db.QueryContext(ctx, ` + SELECT + `+tableUploadColumns+` + FROM + `+tableUploadTableName+` + WHERE + wh_upload_id IN ( + SELECT + id + FROM + `+uploadsTableName+` + WHERE + source_id=$1 AND + destination_id=$2 AND + metadata->>'source_job_run_id'=$3 AND + metadata->>'source_task_run_id'=$4 + ); + `, + sourceID, + destinationID, + jobRunID, + taskRunID, + ) + if err != nil { + return nil, fmt.Errorf("querying: %w", err) + } + defer func() { _ = rows.Close() }() + + tableUploads, err := scanTableUploads(rows) + if err != nil { + return nil, fmt.Errorf("scanning table uploads: %w", err) + } + return tableUploads, nil +} diff --git a/warehouse/internal/repo/table_upload_test.go b/warehouse/internal/repo/table_upload_test.go index 2cc4a992f5a..80f06360482 100644 --- a/warehouse/internal/repo/table_upload_test.go +++ b/warehouse/internal/repo/table_upload_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/samber/lo" + "github.com/stretchr/testify/require" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -391,3 +393,77 @@ func TestTableUploadRepo(t *testing.T) { }) }) } + +func TestTableUploads_GetByJobRunTaskRun(t *testing.T) { + const ( + sourceID = "test_source_id" + destinationID = "test_destination_id" + destType = "test_destination_type" + workspaceID = "test_workspace_id" + taskRunID = "test_task_run_id" + jobRunID = "test_job_run_id" + ) + + db, ctx := setupDB(t), context.Background() + + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + now := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + repoUpload := repo.NewUploads(db, repo.WithNow(func() time.Time { + return now + })) + repoStaging := repo.NewStagingFiles(db, repo.WithNow(func() time.Time { + return now + })) + repoTableUpload := repo.NewTableUploads(db, repo.WithNow(func() time.Time { + return now + })) + + upload := model.Upload{ + WorkspaceID: workspaceID, + Namespace: "namespace", + SourceID: sourceID, + DestinationID: destinationID, + DestinationType: destType, + Status: model.ExportedData, + SourceTaskRunID: taskRunID, + SourceJobRunID: jobRunID, + } + + stagingID, err := repoStaging.Insert(ctx, &model.StagingFileWithSchema{}) + require.NoError(t, err) + + uploadID, err := repoUpload.CreateWithStagingFiles(ctx, upload, []*model.StagingFile{{ + ID: stagingID, + SourceID: sourceID, + DestinationID: destinationID, + SourceTaskRunID: taskRunID, + SourceJobRunID: jobRunID, + }}) + require.NoError(t, err) + + tables := []string{"table1", "table2", "table3"} + + err = repoTableUpload.Insert(ctx, uploadID, tables) + require.NoError(t, err) + + t.Run("known", func(t *testing.T) { + tableUploads, err := repoTableUpload.GetByJobRunTaskRun(ctx, sourceID, destinationID, jobRunID, taskRunID) + require.NoError(t, err) + require.Len(t, tableUploads, len(tables)) + require.Equal(t, tables, lo.Map(tableUploads, func(item model.TableUpload, index int) string { + return item.TableName + })) + }) + t.Run("unknown", func(t *testing.T) { + tableUploads, err := repoTableUpload.GetByJobRunTaskRun(ctx, sourceID, destinationID, "some-other-job-run-id", "some-other-task-run-id") + require.NoError(t, err) + require.Empty(t, tableUploads) + }) + t.Run("cancelled context", func(t *testing.T) { + tableUploads, err := repoTableUpload.GetByJobRunTaskRun(cancelledCtx, sourceID, destinationID, jobRunID, taskRunID) + require.ErrorIs(t, err, context.Canceled) + require.Empty(t, tableUploads) + }) +} diff --git a/warehouse/internal/repo/upload_test.go b/warehouse/internal/repo/upload_test.go index 7093a103b3b..d857cbd1809 100644 --- a/warehouse/internal/repo/upload_test.go +++ b/warehouse/internal/repo/upload_test.go @@ -92,7 +92,6 @@ func TestUploads_Count(t *testing.T) { DestinationID: uploads[i].DestinationID, SourceTaskRunID: uploads[i].SourceTaskRunID, }}) - require.NoError(t, err) uploads[i].ID = id diff --git a/warehouse/jobs/http.go b/warehouse/jobs/http.go deleted file mode 100644 index 323d3568cad..00000000000 --- a/warehouse/jobs/http.go +++ /dev/null @@ -1,157 +0,0 @@ -package jobs - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "strings" - - "github.com/rudderlabs/rudder-server/services/notifier" - - ierrors "github.com/rudderlabs/rudder-server/warehouse/internal/errors" - lf "github.com/rudderlabs/rudder-server/warehouse/logfield" - - "github.com/samber/lo" -) - -type insertJobResponse struct { - JobIds []int64 `json:"jobids"` - Err error `json:"error"` -} - -// InsertJobHandler adds a job to the warehouse_jobs table -func (a *AsyncJobWh) InsertJobHandler(w http.ResponseWriter, r *http.Request) { - defer func() { _ = r.Body.Close() }() - - if !a.enabled { - a.logger.Errorw("jobs api not initialized for inserting async job") - http.Error(w, ierrors.ErrJobsApiNotInitialized.Error(), http.StatusInternalServerError) - return - } - - var payload StartJobReqPayload - if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { - a.logger.Warnw("invalid JSON in request body for inserting async jobs", lf.Error, err.Error()) - http.Error(w, ierrors.ErrInvalidJSONRequestBody.Error(), http.StatusBadRequest) - return - } - - if err := validatePayload(&payload); err != nil { - a.logger.Warnw("invalid payload for inserting async job", lf.Error, err.Error()) - http.Error(w, fmt.Sprintf("invalid payload: %s", err.Error()), http.StatusBadRequest) - return - } - - // TODO: Move to repository - tableNames, err := a.tableNamesBy(payload.SourceID, payload.DestinationID, payload.JobRunID, payload.TaskRunID) - if err != nil { - a.logger.Errorw("extracting tableNames for inserting async job", lf.Error, err.Error()) - http.Error(w, "can't extract tableNames", http.StatusInternalServerError) - return - } - - tableNames = lo.Filter(tableNames, func(tableName string, i int) bool { - switch strings.ToLower(tableName) { - case "rudder_discards", "rudder_identity_mappings", "rudder_identity_merge_rules": - return false - default: - return true - } - }) - - jobIds := make([]int64, 0, len(tableNames)) - for _, table := range tableNames { - metadataJson, err := json.Marshal(WhJobsMetaData{ - JobRunID: payload.JobRunID, - TaskRunID: payload.TaskRunID, - StartTime: payload.StartTime, - JobType: string(notifier.JobTypeAsync), - }) - if err != nil { - a.logger.Errorw("marshalling metadata for inserting async job", lf.Error, err.Error()) - http.Error(w, "can't marshall metadata", http.StatusInternalServerError) - return - } - - // TODO: Move to repository - id, err := a.addJobsToDB(&AsyncJobPayload{ - SourceID: payload.SourceID, - DestinationID: payload.DestinationID, - TableName: table, - AsyncJobType: payload.AsyncJobType, - MetaData: metadataJson, - WorkspaceID: payload.WorkspaceID, - }) - if err != nil { - a.logger.Errorw("inserting async job", lf.Error, err.Error()) - http.Error(w, "can't insert async job", http.StatusInternalServerError) - return - } - - jobIds = append(jobIds, id) - } - - resBody, err := json.Marshal(insertJobResponse{ - JobIds: jobIds, - Err: nil, - }) - if err != nil { - a.logger.Errorw("marshalling response for inserting async job", lf.Error, err.Error()) - http.Error(w, ierrors.ErrMarshallResponse.Error(), http.StatusInternalServerError) - return - } - - _, _ = w.Write(resBody) -} - -// StatusJobHandler The following handler gets called for getting the status of the async job -func (a *AsyncJobWh) StatusJobHandler(w http.ResponseWriter, r *http.Request) { - defer func() { _ = r.Body.Close() }() - - if !a.enabled { - a.logger.Errorw("jobs api not initialized for async job status") - http.Error(w, ierrors.ErrJobsApiNotInitialized.Error(), http.StatusInternalServerError) - return - } - - queryParams := r.URL.Query() - payload := StartJobReqPayload{ - TaskRunID: queryParams.Get("task_run_id"), - JobRunID: queryParams.Get("job_run_id"), - SourceID: queryParams.Get("source_id"), - DestinationID: queryParams.Get("destination_id"), - WorkspaceID: queryParams.Get("workspace_id"), - } - if err := validatePayload(&payload); err != nil { - a.logger.Warnw("invalid payload for async job status", lf.Error, err.Error()) - http.Error(w, fmt.Sprintf("invalid request: %s", err.Error()), http.StatusBadRequest) - return - } - - // TODO: Move to repository - jobStatus := a.jobStatus(&payload) - resBody, err := json.Marshal(jobStatus) - if err != nil { - a.logger.Errorw("marshalling response for async job status", lf.Error, err.Error()) - http.Error(w, ierrors.ErrMarshallResponse.Error(), http.StatusInternalServerError) - return - } - - _, _ = w.Write(resBody) -} - -func validatePayload(payload *StartJobReqPayload) error { - switch true { - case payload.SourceID == "": - return errors.New("source_id is required") - case payload.DestinationID == "": - return errors.New("destination_id is required") - case payload.JobRunID == "": - return errors.New("job_run_id is required") - case payload.TaskRunID == "": - return errors.New("task_run_id is required") - default: - return nil - } -} diff --git a/warehouse/jobs/http_test.go b/warehouse/jobs/http_test.go deleted file mode 100644 index 9f5c3e3750a..00000000000 --- a/warehouse/jobs/http_test.go +++ /dev/null @@ -1,357 +0,0 @@ -package jobs - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" - - "github.com/rudderlabs/rudder-go-kit/stats/memstats" - - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-server/services/notifier" - - "github.com/ory/dockertest/v3" - - "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" - migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" - sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "github.com/rudderlabs/rudder-server/warehouse/internal/repo" - whutils "github.com/rudderlabs/rudder-server/warehouse/utils" - - "github.com/stretchr/testify/require" -) - -func TestAsyncJobHandlers(t *testing.T) { - const ( - workspaceID = "test_workspace_id" - sourceID = "test_source_id" - destinationID = "test_destination_id" - workspaceIdentifier = "test_workspace-identifier" - namespace = "test_namespace" - destinationType = "test_destination_type" - sourceTaskRunID = "test_source_task_run_id" - sourceJobID = "test_source_job_id" - sourceJobRunID = "test_source_job_run_id" - ) - - pool, err := dockertest.NewPool("") - require.NoError(t, err) - - pgResource, err := resource.SetupPostgres(pool, t) - require.NoError(t, err) - - t.Log("db:", pgResource.DBDsn) - - err = (&migrator.Migrator{ - Handle: pgResource.DB, - MigrationsTable: "wh_schema_migrations", - }).Migrate("warehouse") - require.NoError(t, err) - - db := sqlmiddleware.New(pgResource.DB) - - ctx := context.Background() - - n := notifier.New(config.New(), logger.NOP, memstats.New(), workspaceIdentifier) - err = n.Setup(ctx, pgResource.DBDsn) - require.NoError(t, err) - - now := time.Now().Truncate(time.Second).UTC() - - uploadsRepo := repo.NewUploads(db, repo.WithNow(func() time.Time { - return now - })) - tableUploadsRepo := repo.NewTableUploads(db, repo.WithNow(func() time.Time { - return now - })) - stagingRepo := repo.NewStagingFiles(db, repo.WithNow(func() time.Time { - return now - })) - - stagingFile := model.StagingFile{ - WorkspaceID: workspaceID, - Location: "s3://bucket/path/to/file", - SourceID: sourceID, - DestinationID: destinationID, - Status: whutils.StagingFileWaitingState, - Error: fmt.Errorf("dummy error"), - FirstEventAt: now.Add(time.Second), - UseRudderStorage: true, - DestinationRevisionID: "destination_revision_id", - TotalEvents: 100, - SourceTaskRunID: sourceTaskRunID, - SourceJobID: sourceJobID, - SourceJobRunID: sourceJobRunID, - TimeWindow: time.Date(1993, 8, 1, 3, 0, 0, 0, time.UTC), - }.WithSchema([]byte(`{"type": "object"}`)) - - stagingID, err := stagingRepo.Insert(ctx, &stagingFile) - require.NoError(t, err) - - uploadID, err := uploadsRepo.CreateWithStagingFiles(ctx, model.Upload{ - WorkspaceID: workspaceID, - Namespace: namespace, - SourceID: sourceID, - DestinationID: destinationID, - DestinationType: destinationType, - Status: model.Aborted, - SourceJobRunID: sourceJobRunID, - SourceTaskRunID: sourceTaskRunID, - }, []*model.StagingFile{{ - ID: stagingID, - SourceID: sourceID, - DestinationID: destinationID, - SourceJobRunID: sourceJobRunID, - SourceTaskRunID: sourceTaskRunID, - }}) - require.NoError(t, err) - - err = tableUploadsRepo.Insert(ctx, uploadID, []string{ - "test_table_1", - "test_table_2", - "test_table_3", - "test_table_4", - "test_table_5", - - "rudder_discards", - "rudder_identity_mappings", - "rudder_identity_merge_rules", - }) - require.NoError(t, err) - - t.Run("validate payload", func(t *testing.T) { - testCases := []struct { - name string - payload StartJobReqPayload - expectedError error - }{ - { - name: "invalid source", - payload: StartJobReqPayload{ - JobRunID: "job_run_id", - TaskRunID: "task_run_id", - SourceID: "", - DestinationID: "destination_id", - WorkspaceID: "workspace_id", - }, - expectedError: errors.New("source_id is required"), - }, - { - name: "invalid destination", - payload: StartJobReqPayload{ - JobRunID: "job_run_id", - TaskRunID: "task_run_id", - SourceID: "source_id", - DestinationID: "", - WorkspaceID: "workspace_id", - }, - expectedError: errors.New("destination_id is required"), - }, - { - name: "invalid task run", - payload: StartJobReqPayload{ - JobRunID: "job_run_id", - TaskRunID: "", - SourceID: "source_id", - DestinationID: "destination_id", - WorkspaceID: "workspace_id", - }, - expectedError: errors.New("task_run_id is required"), - }, - { - name: "invalid job run", - payload: StartJobReqPayload{ - JobRunID: "", - TaskRunID: "task_run_id", - SourceID: "source_id", - DestinationID: "destination_id", - WorkspaceID: "workspace_id", - }, - expectedError: errors.New("job_run_id is required"), - }, - { - name: "valid payload", - payload: StartJobReqPayload{ - JobRunID: "job_run_id", - TaskRunID: "task_run_id", - SourceID: "source_id", - DestinationID: "destination_id", - WorkspaceID: "workspace_id", - }, - }, - } - for _, tc := range testCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.expectedError, validatePayload(&tc.payload)) - }) - } - }) - - t.Run("InsertJobHandler", func(t *testing.T) { - t.Run("Not enabled", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", nil) - resp := httptest.NewRecorder() - - jobsManager := AsyncJobWh{ - db: db, - enabled: false, - logger: logger.NOP, - context: ctx, - notifier: n, - } - jobsManager.InsertJobHandler(resp, req) - require.Equal(t, http.StatusInternalServerError, resp.Code) - - b, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "warehouse jobs api not initialized\n", string(b)) - }) - t.Run("invalid payload", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(`"Invalid payload"`))) - resp := httptest.NewRecorder() - - jobsManager := AsyncJobWh{ - db: db, - enabled: true, - logger: logger.NOP, - context: ctx, - notifier: n, - } - jobsManager.InsertJobHandler(resp, req) - require.Equal(t, http.StatusBadRequest, resp.Code) - - b, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "invalid JSON in request body\n", string(b)) - }) - t.Run("invalid request", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(`{}`))) - resp := httptest.NewRecorder() - - jobsManager := AsyncJobWh{ - db: db, - enabled: true, - logger: logger.NOP, - context: ctx, - notifier: n, - } - jobsManager.InsertJobHandler(resp, req) - require.Equal(t, http.StatusBadRequest, resp.Code) - - b, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "invalid payload: source_id is required\n", string(b)) - }) - t.Run("success", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(` - { - "source_id": "test_source_id", - "destination_id": "test_destination_id", - "job_run_id": "test_source_job_run_id", - "task_run_id": "test_source_task_run_id" - } - `))) - resp := httptest.NewRecorder() - - jobsManager := AsyncJobWh{ - db: db, - enabled: true, - logger: logger.NOP, - context: ctx, - notifier: n, - } - jobsManager.InsertJobHandler(resp, req) - require.Equal(t, http.StatusOK, resp.Code) - - var insertResponse insertJobResponse - err = json.NewDecoder(resp.Body).Decode(&insertResponse) - require.NoError(t, err) - require.Nil(t, insertResponse.Err) - require.Len(t, insertResponse.JobIds, 5) - }) - }) - - t.Run("StatusJobHandler", func(t *testing.T) { - t.Run("Not enabled", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status", nil) - resp := httptest.NewRecorder() - - jobsManager := AsyncJobWh{ - db: db, - enabled: false, - logger: logger.NOP, - context: ctx, - notifier: n, - } - jobsManager.StatusJobHandler(resp, req) - require.Equal(t, http.StatusInternalServerError, resp.Code) - - b, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "warehouse jobs api not initialized\n", string(b)) - }) - t.Run("invalid payload", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status", nil) - resp := httptest.NewRecorder() - - jobsManager := AsyncJobWh{ - db: db, - enabled: true, - logger: logger.NOP, - context: ctx, - notifier: n, - } - jobsManager.StatusJobHandler(resp, req) - require.Equal(t, http.StatusBadRequest, resp.Code) - - b, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "invalid request: source_id is required\n", string(b)) - }) - t.Run("success", func(t *testing.T) { - _, err := db.ExecContext(ctx, ` - INSERT INTO `+whutils.WarehouseAsyncJobTable+` (source_id, destination_id, status, created_at, updated_at, tablename, error, async_job_type, metadata, workspace_id) - VALUES ('test_source_id', 'test_destination_id', 'aborted', NOW(), NOW(), 'test_table_name', 'test_error', 'deletebyjobrunid', '{"job_run_id": "test_source_job_run_id", "task_run_id": "test_source_task_run_id"}', 'test_workspace_id') - `) - require.NoError(t, err) - - qp := url.Values{} - qp.Add("task_run_id", sourceTaskRunID) - qp.Add("job_run_id", sourceJobRunID) - qp.Add("source_id", sourceID) - qp.Add("destination_id", destinationID) - qp.Add("workspace_id", workspaceID) - - req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status?"+qp.Encode(), nil) - resp := httptest.NewRecorder() - - jobsManager := AsyncJobWh{ - db: db, - enabled: true, - logger: logger.NOP, - context: ctx, - notifier: n, - } - jobsManager.StatusJobHandler(resp, req) - require.Equal(t, http.StatusOK, resp.Code) - - var statusResponse WhStatusResponse - err = json.NewDecoder(resp.Body).Decode(&statusResponse) - require.NoError(t, err) - require.Equal(t, statusResponse.Status, "aborted") - require.Equal(t, statusResponse.Err, "test_error") - }) - }) -} diff --git a/warehouse/jobs/jobs.go b/warehouse/jobs/jobs.go deleted file mode 100644 index 51427553bee..00000000000 --- a/warehouse/jobs/jobs.go +++ /dev/null @@ -1,64 +0,0 @@ -package jobs - -import ( - "context" - "time" - - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - - warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" -) - -type WhAsyncJob struct{} - -func (*WhAsyncJob) IsWarehouseSchemaEmpty() bool { return true } - -func (*WhAsyncJob) GetLocalSchema(context.Context) (model.Schema, error) { - return model.Schema{}, nil -} - -func (*WhAsyncJob) UpdateLocalSchema(context.Context, model.Schema) error { - return nil -} - -func (*WhAsyncJob) GetTableSchemaInWarehouse(string) model.TableSchema { - return model.TableSchema{} -} - -func (*WhAsyncJob) GetTableSchemaInUpload(string) model.TableSchema { - return model.TableSchema{} -} - -func (*WhAsyncJob) GetLoadFilesMetadata(context.Context, warehouseutils.GetLoadFilesOptions) ([]warehouseutils.LoadFile, error) { - return []warehouseutils.LoadFile{}, nil -} - -func (*WhAsyncJob) GetSampleLoadFileLocation(context.Context, string) (string, error) { - return "", nil -} - -func (*WhAsyncJob) GetSingleLoadFile(context.Context, string) (warehouseutils.LoadFile, error) { - return warehouseutils.LoadFile{}, nil -} - -func (*WhAsyncJob) ShouldOnDedupUseNewRecord() bool { - return false -} - -func (*WhAsyncJob) UseRudderStorage() bool { - return false -} - -func (*WhAsyncJob) GetLoadFileGenStartTIme() time.Time { - return time.Time{} -} - -func (*WhAsyncJob) GetLoadFileType() string { - return "" -} - -func (*WhAsyncJob) GetFirstLastEvent() (time.Time, time.Time) { - return time.Now(), time.Now() -} - -func (*WhAsyncJob) CanAppend() bool { return false } diff --git a/warehouse/jobs/runner.go b/warehouse/jobs/runner.go deleted file mode 100644 index 3ae446c2293..00000000000 --- a/warehouse/jobs/runner.go +++ /dev/null @@ -1,425 +0,0 @@ -package jobs - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "time" - - sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - - "github.com/rudderlabs/rudder-server/services/notifier" - - "github.com/lib/pq" - "github.com/samber/lo" - "golang.org/x/sync/errgroup" - - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-server/utils/misc" - "github.com/rudderlabs/rudder-server/utils/timeutil" - warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" -) - -// New Initializes AsyncJobWh structure with appropriate variabless -func New( - ctx context.Context, - db *sqlmw.DB, - notifier *notifier.Notifier, -) *AsyncJobWh { - return &AsyncJobWh{ - db: db, - enabled: false, - notifier: notifier, - context: ctx, - logger: logger.NewLogger().Child("asyncjob"), - } -} - -func WithConfig(a *AsyncJobWh, config *config.Config) { - a.maxBatchSizeToProcess = config.GetInt("Warehouse.jobs.maxBatchSizeToProcess", 10) - a.maxCleanUpRetries = config.GetInt("Warehouse.jobs.maxCleanUpRetries", 5) - a.maxQueryRetries = config.GetInt("Warehouse.jobs.maxQueryRetries", 3) - a.maxAttemptsPerJob = config.GetInt("Warehouse.jobs.maxAttemptsPerJob", 3) - a.retryTimeInterval = config.GetDuration("Warehouse.jobs.retryTimeInterval", 10, time.Second) - a.asyncJobTimeOut = config.GetDuration("Warehouse.jobs.asyncJobTimeOut", 300, time.Second) -} - -func (a *AsyncJobWh) tableNamesBy(sourceID, destinationID, jobRunID, taskRunID string) ([]string, error) { - a.logger.Infof("[WH-Jobs]: Extracting table names for the job run id %s", jobRunID) - var tableNames []string - var err error - - query := `SELECT table_name FROM ` + warehouseutils.WarehouseTableUploadsTable + ` WHERE wh_upload_id IN ` + - ` (SELECT id FROM ` + warehouseutils.WarehouseUploadsTable + ` WHERE metadata->>'source_job_run_id'=$1 - AND metadata->>'source_task_run_id'=$2 - AND source_id=$3 - AND destination_id=$4)` - a.logger.Debugf("[WH-Jobs]: Query is %s\n", query) - rows, err := a.db.QueryContext(a.context, query, jobRunID, taskRunID, sourceID, destinationID) - if err != nil { - a.logger.Errorf("[WH-Jobs]: Error executing the query %s with error %v", query, err) - return nil, err - } - defer func() { _ = rows.Close() }() - for rows.Next() { - var tableName string - err = rows.Scan(&tableName) - if err != nil { - a.logger.Errorf("[WH-Jobs]: Error scanning the rows %s", err.Error()) - return nil, err - } - tableNames = append(tableNames, tableName) - } - if err = rows.Err(); err != nil { - a.logger.Errorf("[WH-Jobs]: Error iterating the rows %s", err.Error()) - return nil, err - } - a.logger.Infof("Got the TableNames as %s", tableNames) - return lo.Uniq(tableNames), nil -} - -// Takes AsyncJobPayload and adds rows to table wh_async_jobs -func (a *AsyncJobWh) addJobsToDB(payload *AsyncJobPayload) (int64, error) { - a.logger.Infof("[WH-Jobs]: Adding job to the wh_async_jobs %s for tableName: %s", payload.MetaData, payload.TableName) - var jobId int64 - sqlStatement := `INSERT INTO ` + warehouseutils.WarehouseAsyncJobTable + ` ( - source_id, destination_id, tablename, - status, created_at, updated_at, async_job_type, - workspace_id, metadata - ) - VALUES - ($1, $2, $3, $4, $5, $6 ,$7, $8, $9 ) RETURNING id` - - stmt, err := a.db.Prepare(sqlStatement) - if err != nil { - a.logger.Errorf("[WH-Jobs]: Error preparing out the query %s ", sqlStatement) - err = fmt.Errorf("error preparing out the query, while addJobsToDB %v", err) - return 0, err - } - - defer func() { _ = stmt.Close() }() - now := timeutil.Now() - row := stmt.QueryRowContext(a.context, payload.SourceID, payload.DestinationID, payload.TableName, WhJobWaiting, now, now, payload.AsyncJobType, payload.WorkspaceID, payload.MetaData) - err = row.Scan(&jobId) - if err != nil { - a.logger.Errorf("[WH-Jobs]: Error processing the %s, %s ", sqlStatement, err.Error()) - return 0, err - } - return jobId, nil -} - -// Run Async Job runner's main job is to -// 1. Scan the database for entries into wh_async_jobs -// 2. Publish data to pg_notifier queue -// 3. Move any executing jobs to waiting -func (a *AsyncJobWh) Run() error { - // Start the asyncJobRunner - a.logger.Info("[WH-Jobs]: Initializing async job runner") - g, ctx := errgroup.WithContext(a.context) - a.context = ctx - err := misc.RetryWith(a.context, a.retryTimeInterval, a.maxCleanUpRetries, func(ctx context.Context) error { - err := a.cleanUpAsyncTable(ctx) - if err != nil { - a.logger.Errorf("[WH-Jobs]: unable to cleanup asynctable with error %s", err.Error()) - return err - } - a.enabled = true - return nil - }) - if err != nil { - a.logger.Errorf("[WH-Jobs]: unable to cleanup asynctable with error %s", err.Error()) - return err - } - if a.enabled { - g.Go(func() error { - return a.startAsyncJobRunner(ctx) - }) - } - return g.Wait() -} - -func (a *AsyncJobWh) cleanUpAsyncTable(ctx context.Context) error { - a.logger.Info("[WH-Jobs]: Cleaning up the zombie asyncjobs") - sqlStatement := fmt.Sprintf( - `UPDATE %s SET status=$1 WHERE status=$2 or status=$3`, - pq.QuoteIdentifier(warehouseutils.WarehouseAsyncJobTable), - ) - a.logger.Debugf("[WH-Jobs]: resetting up async jobs table query %s", sqlStatement) - _, err := a.db.ExecContext(ctx, sqlStatement, WhJobWaiting, WhJobExecuting, WhJobFailed) - return err -} - -/* -startAsyncJobRunner is the main runner that -1) Periodically queries the db for any pending async jobs -2) Groups them together -3) Publishes them to the notifier -4) Spawns a subroutine that periodically checks for responses from Notifier/slave worker post trackBatch -*/ -func (a *AsyncJobWh) startAsyncJobRunner(ctx context.Context) error { - a.logger.Info("[WH-Jobs]: Starting async job runner") - defer a.logger.Info("[WH-Jobs]: Stopping AsyncJobRunner") - - for { - a.logger.Debug("[WH-Jobs]: Scanning for waiting async job") - - select { - case <-ctx.Done(): - return nil - case <-time.After(a.retryTimeInterval): - } - - pendingAsyncJobs, err := a.getPendingAsyncJobs(ctx) - if err != nil { - a.logger.Errorf("[WH-Jobs]: unable to get pending async jobs with error %s", err.Error()) - continue - } - if len(pendingAsyncJobs) == 0 { - continue - } - - a.logger.Infof("[WH-Jobs]: Number of async wh jobs left = %d", len(pendingAsyncJobs)) - - notifierClaims, err := getMessagePayloadsFromAsyncJobPayloads(pendingAsyncJobs) - if err != nil { - a.logger.Errorf("Error converting the asyncJobType to notifier payload %s ", err) - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - continue - } - ch, err := a.notifier.Publish(ctx, ¬ifier.PublishRequest{ - Payloads: notifierClaims, - JobType: notifier.JobTypeAsync, - Priority: 100, - }) - if err != nil { - a.logger.Errorf("[WH-Jobs]: unable to get publish async jobs to notifier. Task failed with error %s", err.Error()) - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - continue - } - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobExecuting, err) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - - select { - case <-ctx.Done(): - a.logger.Infof("[WH-Jobs]: Context cancelled for async job runner") - return nil - case responses, ok := <-ch: - if !ok { - a.logger.Error("[WH-Jobs]: Notifier track batch channel closed") - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, fmt.Errorf("receiving channel closed")) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - continue - } - if responses.Err != nil { - a.logger.Errorf("[WH-Jobs]: Error received from the notifier track batch %s", responses.Err.Error()) - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, responses.Err) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - continue - } - a.logger.Info("[WH-Jobs]: Response received from the notifier track batch") - asyncJobsStatusMap := getAsyncStatusMapFromAsyncPayloads(pendingAsyncJobs) - a.updateStatusJobPayloadsFromNotifierResponse(responses, asyncJobsStatusMap) - _ = a.updateAsyncJobs(ctx, asyncJobsStatusMap) - case <-time.After(a.asyncJobTimeOut): - a.logger.Errorf("Go Routine timed out waiting for a response from notifier", pendingAsyncJobs[0].Id) - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - } - } -} - -func (a *AsyncJobWh) updateStatusJobPayloadsFromNotifierResponse(r *notifier.PublishResponse, m map[string]AsyncJobStatus) { - for _, resp := range r.Jobs { - var response NotifierResponse - err := json.Unmarshal(resp.Payload, &response) - if err != nil { - a.logger.Errorf("error unmarshalling notifier payload to AsyncJobStatusMa for Id: %s", response.Id) - continue - } - - if output, ok := m[response.Id]; ok { - output.Status = string(resp.Status) - if resp.Error != nil { - output.Error = fmt.Errorf(resp.Error.Error()) - } - m[response.Id] = output - } - } -} - -// Queries the jobsDB and gets active async job and returns it in a -func (a *AsyncJobWh) getPendingAsyncJobs(ctx context.Context) ([]AsyncJobPayload, error) { - asyncJobPayloads := make([]AsyncJobPayload, 0) - a.logger.Debug("[WH-Jobs]: Get pending wh async jobs") - // Filter to get most recent row for the sourceId/destinationID combo and remaining ones should relegate to abort. - var attempt int - query := fmt.Sprintf( - `SELECT - id, - source_id, - destination_id, - tablename, - async_job_type, - metadata, - attempt - FROM %s WHERE (status=$1 OR status=$2) LIMIT $3`, warehouseutils.WarehouseAsyncJobTable) - rows, err := a.db.QueryContext(ctx, query, WhJobWaiting, WhJobFailed, a.maxBatchSizeToProcess) - if err != nil { - a.logger.Errorf("[WH-Jobs]: Error in getting pending wh async jobs with error %s", err.Error()) - return asyncJobPayloads, err - } - defer func() { _ = rows.Close() }() - for rows.Next() { - var asyncJobPayload AsyncJobPayload - err = rows.Scan( - &asyncJobPayload.Id, - &asyncJobPayload.SourceID, - &asyncJobPayload.DestinationID, - &asyncJobPayload.TableName, - &asyncJobPayload.AsyncJobType, - &asyncJobPayload.MetaData, - &attempt, - ) - if err != nil { - a.logger.Errorf("[WH-Jobs]: Error scanning rows %s\n", err) - return asyncJobPayloads, err - } - asyncJobPayloads = append(asyncJobPayloads, asyncJobPayload) - a.logger.Infof("Adding row with Id = %s & attempt no %d", asyncJobPayload.Id, attempt) - } - if err := rows.Err(); err != nil { - a.logger.Errorf("[WH-Jobs]: Error in getting pending wh async jobs with error %s", rows.Err().Error()) - return asyncJobPayloads, err - } - return asyncJobPayloads, nil -} - -// Updates the warehouse async jobs with the status sent as a parameter -func (a *AsyncJobWh) updateAsyncJobs(ctx context.Context, payloads map[string]AsyncJobStatus) error { - a.logger.Info("[WH-Jobs]: Updating wh async jobs to Executing") - var err error - for _, payload := range payloads { - if payload.Error != nil { - err = a.updateAsyncJobStatus(ctx, payload.Id, payload.Status, payload.Error.Error()) - continue - } - err = a.updateAsyncJobStatus(ctx, payload.Id, payload.Status, "") - } - return err -} - -func (a *AsyncJobWh) updateAsyncJobStatus(ctx context.Context, Id, status, errMessage string) error { - a.logger.Infof("[WH-Jobs]: Updating status of wh async jobs to %s", status) - sqlStatement := fmt.Sprintf(`UPDATE %s SET status=(CASE - WHEN attempt >= $1 - THEN $2 - ELSE $3 - END) , - error=$4 WHERE id=$5 AND status!=$6 AND status!=$7 `, - warehouseutils.WarehouseAsyncJobTable, - ) - var err error - for retryCount := 0; retryCount < a.maxQueryRetries; retryCount++ { - a.logger.Debugf("[WH-Jobs]: updating async jobs table query %s, retry no : %d", sqlStatement, retryCount) - _, err := a.db.ExecContext(ctx, sqlStatement, - a.maxAttemptsPerJob, WhJobAborted, status, errMessage, Id, WhJobAborted, WhJobSucceeded, - ) - if err == nil { - a.logger.Info("Update successful") - a.logger.Debugf("query: %s successfully executed", sqlStatement) - if status == WhJobFailed { - return a.updateAsyncJobAttempt(ctx, Id) - } - return err - } - } - - a.logger.Errorf("Query: %s failed with error: %s", sqlStatement, err.Error()) - return err -} - -func (a *AsyncJobWh) updateAsyncJobAttempt(ctx context.Context, Id string) error { - a.logger.Info("[WH-Jobs]: Incrementing wh async jobs attempt") - sqlStatement := fmt.Sprintf(`UPDATE %s SET attempt=attempt+1 WHERE id=$1 AND status!=$2 AND status!=$3 `, warehouseutils.WarehouseAsyncJobTable) - var err error - for queryRetry := 0; queryRetry < a.maxQueryRetries; queryRetry++ { - a.logger.Debugf("[WH-Jobs]: updating async jobs table query %s, retry no : %d", sqlStatement, queryRetry) - row, err := a.db.QueryContext(ctx, sqlStatement, Id, WhJobAborted, WhJobSucceeded) - if err == nil { - a.logger.Info("Update successful") - a.logger.Debugf("query: %s successfully executed", sqlStatement) - return nil - } - _ = row.Err() - } - a.logger.Errorf("query: %s failed with Error : %s", sqlStatement, err.Error()) - return err -} - -// returns status and errMessage -// Only succeeded, executing & waiting states should have empty errMessage -// Rest of the states failed, aborted should send an error message conveying a message -func (a *AsyncJobWh) jobStatus(payload *StartJobReqPayload) WhStatusResponse { - var statusResponse WhStatusResponse - a.logger.Info("[WH-Jobs]: Getting status for wh async jobs %v", payload) - // Need to check for count first and see if there are any rows matching the job_run_id and task_run_id. If none, then raise an error instead of showing complete - sqlStatement := fmt.Sprintf(`SELECT status,error FROM %s WHERE metadata->>'job_run_id'=$1 AND metadata->>'task_run_id'=$2`, warehouseutils.WarehouseAsyncJobTable) - a.logger.Debugf("Query inside getStatusAsync function is %s", sqlStatement) - rows, err := a.db.QueryContext(a.context, sqlStatement, payload.JobRunID, payload.TaskRunID) - if err != nil { - a.logger.Errorf("[WH-Jobs]: Error executing the query %s", err.Error()) - return WhStatusResponse{ - Status: WhJobFailed, - Err: err.Error(), - } - } - defer func() { _ = rows.Close() }() - for rows.Next() { - var status string - var errMessage sql.NullString - err = rows.Scan(&status, &errMessage) - if err != nil { - a.logger.Errorf("[WH-Jobs]: Error scanning rows %s\n", err) - return WhStatusResponse{ - Status: WhJobFailed, - Err: err.Error(), - } - } - - switch status { - case WhJobFailed: - a.logger.Infof("[WH-Jobs] Async Job with job_run_id: %s, task_run_id: %s is failed", payload.JobRunID, payload.TaskRunID) - statusResponse.Status = WhJobFailed - if !errMessage.Valid { - statusResponse.Err = "Failed while scanning" - } - statusResponse.Err = errMessage.String - case WhJobAborted: - a.logger.Infof("[WH-Jobs] Async Job with job_run_id: %s, task_run_id: %s is aborted", payload.JobRunID, payload.TaskRunID) - statusResponse.Status = WhJobAborted - if !errMessage.Valid { - statusResponse.Err = "Failed while scanning" - } - statusResponse.Err = errMessage.String - case WhJobSucceeded: - a.logger.Infof("[WH-Jobs] Async Job with job_run_id: %s, task_run_id: %s is complete", payload.JobRunID, payload.TaskRunID) - statusResponse.Status = WhJobSucceeded - default: - a.logger.Infof("[WH-Jobs] Async Job with job_run_id: %s, task_run_id: %s is under processing", payload.JobRunID, payload.TaskRunID) - statusResponse.Status = WhJobExecuting - } - } - - if err = rows.Err(); err != nil { - a.logger.Errorf("[WH-Jobs]: Error scanning rows %s\n", err) - return WhStatusResponse{ - Status: WhJobFailed, - Err: err.Error(), - } - } - return statusResponse -} diff --git a/warehouse/jobs/types.go b/warehouse/jobs/types.go deleted file mode 100644 index f85fafab495..00000000000 --- a/warehouse/jobs/types.go +++ /dev/null @@ -1,89 +0,0 @@ -package jobs - -import ( - "context" - "encoding/json" - "time" - - "github.com/rudderlabs/rudder-server/services/notifier" - - sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - - "github.com/rudderlabs/rudder-go-kit/logger" -) - -// StartJobReqPayload For processing requests payload in handlers.go -type StartJobReqPayload struct { - SourceID string `json:"source_id"` - Type string `json:"type"` - Channel string `json:"channel"` - DestinationID string `json:"destination_id"` - StartTime string `json:"start_time"` - JobRunID string `json:"job_run_id"` - TaskRunID string `json:"task_run_id"` - AsyncJobType string `json:"async_job_type"` - WorkspaceID string `json:"workspace_id"` -} - -type AsyncJobWh struct { - db *sqlmw.DB - enabled bool - notifier *notifier.Notifier - context context.Context - logger logger.Logger - maxBatchSizeToProcess int - maxCleanUpRetries int - maxQueryRetries int - retryTimeInterval time.Duration - maxAttemptsPerJob int - asyncJobTimeOut time.Duration -} - -type WhJobsMetaData struct { - JobRunID string `json:"job_run_id"` - TaskRunID string `json:"task_run_id"` - JobType string `json:"jobtype"` - StartTime string `json:"start_time"` -} - -// AsyncJobPayload For creating job payload to wh_async_jobs table -type AsyncJobPayload struct { - Id string `json:"id"` - SourceID string `json:"source_id"` - DestinationID string `json:"destination_id"` - TableName string `json:"tablename"` - AsyncJobType string `json:"async_job_type"` - WorkspaceID string `json:"workspace_id"` - MetaData json.RawMessage `json:"metadata"` -} - -const ( - WhJobWaiting string = "waiting" - WhJobExecuting string = "executing" - WhJobSucceeded string = "succeeded" - WhJobAborted string = "aborted" - WhJobFailed string = "failed" -) - -type NotifierResponse struct { - Id string `json:"id"` -} - -type WhStatusResponse struct { - Status string - Err string -} - -type WhAsyncJobRunner interface { - startAsyncJobRunner(context.Context) - getTableNamesBy(context.Context, string, string) - getPendingAsyncJobs(context.Context) ([]AsyncJobPayload, error) - getStatusAsyncJob(*StartJobReqPayload) (string, error) - updateMultipleAsyncJobs(*[]AsyncJobPayload, string, string) -} - -type AsyncJobStatus struct { - Id string - Status string - Error error -} diff --git a/warehouse/jobs/utils.go b/warehouse/jobs/utils.go deleted file mode 100644 index 017797a8efa..00000000000 --- a/warehouse/jobs/utils.go +++ /dev/null @@ -1,41 +0,0 @@ -package jobs - -import ( - "encoding/json" -) - -func convertToPayloadStatusStructWithSingleStatus(payloads []AsyncJobPayload, status string, err error) map[string]AsyncJobStatus { - asyncJobStatusMap := make(map[string]AsyncJobStatus) - for _, payload := range payloads { - asyncJobStatusMap[payload.Id] = AsyncJobStatus{ - Id: payload.Id, - Status: status, - Error: err, - } - } - return asyncJobStatusMap -} - -// convert to notifier Payload and return the array of payloads -func getMessagePayloadsFromAsyncJobPayloads(asyncJobPayloads []AsyncJobPayload) ([]json.RawMessage, error) { - var messages []json.RawMessage - for _, job := range asyncJobPayloads { - message, err := json.Marshal(job) - if err != nil { - return messages, err - } - messages = append(messages, message) - } - return messages, nil -} - -func getAsyncStatusMapFromAsyncPayloads(payloads []AsyncJobPayload) map[string]AsyncJobStatus { - asyncJobStatusMap := make(map[string]AsyncJobStatus) - for _, payload := range payloads { - asyncJobStatusMap[payload.Id] = AsyncJobStatus{ - Id: payload.Id, - Status: WhJobFailed, - } - } - return asyncJobStatusMap -} diff --git a/warehouse/slave/worker.go b/warehouse/slave/worker.go index baf81195bbb..e5b86d5f517 100644 --- a/warehouse/slave/worker.go +++ b/warehouse/slave/worker.go @@ -27,7 +27,7 @@ import ( integrationsconfig "github.com/rudderlabs/rudder-server/warehouse/integrations/config" "github.com/rudderlabs/rudder-server/warehouse/integrations/manager" "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "github.com/rudderlabs/rudder-server/warehouse/jobs" + "github.com/rudderlabs/rudder-server/warehouse/source" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -51,11 +51,6 @@ type uploadResult struct { UseRudderStorage bool } -type asyncJobRunResult struct { - Result bool `json:"Result"` - ID string `json:"Id"` -} - type worker struct { conf *config.Config log logger.Logger @@ -134,7 +129,7 @@ func (w *worker) start(ctx context.Context, notificationChan <-chan *notifier.Cl switch claimedJob.Job.Type { case notifier.JobTypeAsync: - w.processClaimedAsyncJob(ctx, claimedJob) + w.processClaimedSourceJob(ctx, claimedJob) default: w.processClaimedUploadJob(ctx, claimedJob) } @@ -424,7 +419,7 @@ func (w *worker) processStagingFile(ctx context.Context, job payload) ([]uploadR return uploadsResults, err } -func (w *worker) processClaimedAsyncJob(ctx context.Context, claimedJob *notifier.ClaimJob) { +func (w *worker) processClaimedSourceJob(ctx context.Context, claimedJob *notifier.ClaimJob) { handleErr := func(err error, claimedJob *notifier.ClaimJob) { w.log.Errorf("Error processing claim: %v", err) @@ -434,7 +429,7 @@ func (w *worker) processClaimedAsyncJob(ctx context.Context, claimedJob *notifie } var ( - job jobs.AsyncJobPayload + job source.NotifierRequest err error ) @@ -443,13 +438,15 @@ func (w *worker) processClaimedAsyncJob(ctx context.Context, claimedJob *notifie return } - jobResult, err := w.runAsyncJob(ctx, job) + err = w.runSourceJob(ctx, job) if err != nil { handleErr(err, claimedJob) return } - jobResultJSON, err := json.Marshal(jobResult) + jobResultJSON, err := json.Marshal(source.NotifierResponse{ + ID: job.ID, + }) if err != nil { handleErr(err, claimedJob) return @@ -460,20 +457,15 @@ func (w *worker) processClaimedAsyncJob(ctx context.Context, claimedJob *notifie }) } -func (w *worker) runAsyncJob(ctx context.Context, asyncjob jobs.AsyncJobPayload) (asyncJobRunResult, error) { - result := asyncJobRunResult{ - ID: asyncjob.Id, - Result: false, - } - - warehouse, err := w.destinationFromSlaveConnectionMap(asyncjob.DestinationID, asyncjob.SourceID) +func (w *worker) runSourceJob(ctx context.Context, sourceJob source.NotifierRequest) error { + warehouse, err := w.destinationFromSlaveConnectionMap(sourceJob.DestinationID, sourceJob.SourceID) if err != nil { - return result, err + return fmt.Errorf("getting warehouse: %w", err) } integrationsManager, err := manager.NewWarehouseOperations(warehouse.Destination.DestinationDefinition.Name, w.conf, w.log, w.statsFactory) if err != nil { - return result, err + return fmt.Errorf("getting integrations manager: %w", err) } integrationsManager.SetConnectionTimeout(warehouseutils.GetConnectionTimeout( @@ -481,35 +473,29 @@ func (w *worker) runAsyncJob(ctx context.Context, asyncjob jobs.AsyncJobPayload) warehouse.Destination.ID, )) - err = integrationsManager.Setup(ctx, warehouse, &jobs.WhAsyncJob{}) + err = integrationsManager.Setup(ctx, warehouse, &source.Uploader{}) if err != nil { - return result, err + return fmt.Errorf("setting up integrations manager: %w", err) } defer integrationsManager.Cleanup(ctx) var metadata warehouseutils.DeleteByMetaData - if err = json.Unmarshal(asyncjob.MetaData, &metadata); err != nil { - return result, err + if err = json.Unmarshal(sourceJob.MetaData, &metadata); err != nil { + return fmt.Errorf("unmarshalling metadata: %w", err) } - switch asyncjob.AsyncJobType { - case "deletebyjobrunid": - err = integrationsManager.DeleteBy(ctx, []string{asyncjob.TableName}, warehouseutils.DeleteByParams{ - SourceId: asyncjob.SourceID, + switch sourceJob.JobType { + case model.SourceJobTypeDeleteByJobRunID.String(): + err = integrationsManager.DeleteBy(ctx, []string{sourceJob.TableName}, warehouseutils.DeleteByParams{ + SourceId: sourceJob.SourceID, TaskRunId: metadata.TaskRunId, JobRunId: metadata.JobRunId, StartTime: metadata.StartTime, }) default: - err = errors.New("invalid asyncJob type") + err = errors.New("invalid sourceJob type") } - if err != nil { - return result, err - } - - result.Result = true - - return result, nil + return err } func (w *worker) destinationFromSlaveConnectionMap(destinationId, sourceId string) (model.Warehouse, error) { diff --git a/warehouse/slave/worker_test.go b/warehouse/slave/worker_test.go index 279e22ac8f6..7e46e12fdf1 100644 --- a/warehouse/slave/worker_test.go +++ b/warehouse/slave/worker_test.go @@ -9,6 +9,8 @@ import ( "os" "testing" + "github.com/rudderlabs/rudder-server/warehouse/source" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" "github.com/rudderlabs/rudder-server/warehouse/bcm" @@ -33,7 +35,6 @@ import ( "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/pubsub" "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "github.com/rudderlabs/rudder-server/warehouse/jobs" "github.com/rudderlabs/rudder-server/warehouse/multitenant" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -577,13 +578,13 @@ func TestSlaveWorker(t *testing.T) { workerIdx, ) - p := jobs.AsyncJobPayload{ - Id: "1", + p := source.NotifierRequest{ + ID: 1, SourceID: sourceID, DestinationID: destinationID, TableName: "test_table_name", WorkspaceID: workspaceID, - AsyncJobType: "deletebyjobrunid", + JobType: model.SourceJobTypeDeleteByJobRunID.String(), MetaData: []byte(`{"job_run_id": "1", "task_run_id": "1", "start_time": "2020-01-01T00:00:00Z"}`), } @@ -605,18 +606,17 @@ func TestSlaveWorker(t *testing.T) { go func() { defer close(claimedJobDone) - slaveWorker.processClaimedAsyncJob(ctx, claim) + slaveWorker.processClaimedSourceJob(ctx, claim) }() response := <-subscribeCh require.NoError(t, response.Err) - var asyncResult asyncJobRunResult - err = json.Unmarshal(response.Payload, &asyncResult) + var notifierResponse source.NotifierResponse + err = json.Unmarshal(response.Payload, ¬ifierResponse) require.NoError(t, err) - require.Equal(t, "1", asyncResult.ID) - require.True(t, asyncResult.Result) + require.Equal(t, int64(1), notifierResponse.ID) <-claimedJobDone }) @@ -647,34 +647,27 @@ func TestSlaveWorker(t *testing.T) { name string sourceID string destinationID string - jobType string + jobType model.SourceJobType expectedError error }{ - { - name: "invalid job type", - sourceID: sourceID, - destinationID: destinationID, - jobType: "invalid_job_type", - expectedError: errors.New("invalid asyncJob type"), - }, { name: "invalid parameters", - jobType: "deletebyjobrunid", - expectedError: errors.New("invalid Parameters"), + jobType: model.SourceJobTypeDeleteByJobRunID, + expectedError: errors.New("getting warehouse: invalid Parameters"), }, { name: "invalid source id", sourceID: "invalid_source_id", destinationID: destinationID, - jobType: "deletebyjobrunid", - expectedError: errors.New("invalid Source Id"), + jobType: model.SourceJobTypeDeleteByJobRunID, + expectedError: errors.New("getting warehouse: invalid Source Id"), }, { name: "invalid destination id", sourceID: sourceID, destinationID: "invalid_destination_id", - jobType: "deletebyjobrunid", - expectedError: errors.New("invalid Destination Id"), + jobType: model.SourceJobTypeDeleteByJobRunID, + expectedError: errors.New("getting warehouse: invalid Destination Id"), }, } @@ -682,13 +675,13 @@ func TestSlaveWorker(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { - p := jobs.AsyncJobPayload{ - Id: "1", + p := source.NotifierRequest{ + ID: 1, SourceID: tc.sourceID, DestinationID: tc.destinationID, TableName: "test_table_name", WorkspaceID: workspaceID, - AsyncJobType: tc.jobType, + JobType: tc.jobType.String(), MetaData: []byte(`{"job_run_id": "1", "task_run_id": "1", "start_time": "2020-01-01T00:00:00Z"}`), } @@ -711,7 +704,7 @@ func TestSlaveWorker(t *testing.T) { go func() { defer close(claimedJobDone) - slaveWorker.processClaimedAsyncJob(ctx, claim) + slaveWorker.processClaimedSourceJob(ctx, claim) }() response := <-subscribeCh diff --git a/warehouse/source/http.go b/warehouse/source/http.go new file mode 100644 index 00000000000..9c4ddb83e29 --- /dev/null +++ b/warehouse/source/http.go @@ -0,0 +1,130 @@ +package source + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "regexp" + + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + + ierrors "github.com/rudderlabs/rudder-server/warehouse/internal/errors" + lf "github.com/rudderlabs/rudder-server/warehouse/logfield" +) + +// emptyRegex matches empty strings +var emptyRegex = regexp.MustCompile(`^\s*$`) + +// InsertJobHandler adds a job to the warehouse_jobs table +func (m *Manager) InsertJobHandler(w http.ResponseWriter, r *http.Request) { + defer func() { _ = r.Body.Close() }() + + var payload insertJobRequest + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + m.logger.Warnw("invalid JSON in request body for inserting source jobs", lf.Error, err.Error()) + http.Error(w, ierrors.ErrInvalidJSONRequestBody.Error(), http.StatusBadRequest) + return + } + + if err := validatePayload(&payload); err != nil { + m.logger.Warnw("invalid payload for inserting source job", lf.Error, err.Error()) + http.Error(w, fmt.Sprintf("invalid payload: %s", err.Error()), http.StatusBadRequest) + return + } + + jobIds, err := m.InsertJobs(r.Context(), payload) + if err != nil { + if errors.Is(r.Context().Err(), context.Canceled) { + http.Error(w, ierrors.ErrRequestCancelled.Error(), http.StatusBadRequest) + return + } + m.logger.Errorw("inserting source jobs", lf.Error, err.Error()) + http.Error(w, "can't insert source jobs", http.StatusInternalServerError) + return + } + + resBody, err := json.Marshal(insertJobResponse{ + JobIds: jobIds, + Err: nil, + }) + if err != nil { + m.logger.Errorw("marshalling response for inserting source job", lf.Error, err.Error()) + http.Error(w, ierrors.ErrMarshallResponse.Error(), http.StatusInternalServerError) + return + } + + _, _ = w.Write(resBody) +} + +// StatusJobHandler The following handler gets called for getting the status of the async job +func (m *Manager) StatusJobHandler(w http.ResponseWriter, r *http.Request) { + defer func() { _ = r.Body.Close() }() + + queryParams := r.URL.Query() + payload := insertJobRequest{ + TaskRunID: queryParams.Get("task_run_id"), + JobRunID: queryParams.Get("job_run_id"), + SourceID: queryParams.Get("source_id"), + DestinationID: queryParams.Get("destination_id"), + WorkspaceID: queryParams.Get("workspace_id"), + } + if err := validatePayload(&payload); err != nil { + m.logger.Warnw("invalid payload for source job status", lf.Error, err.Error()) + http.Error(w, fmt.Sprintf("invalid request: %s", err.Error()), http.StatusBadRequest) + return + } + + sourceJob, err := m.sourceRepo.GetByJobRunTaskRun(r.Context(), payload.JobRunID, payload.TaskRunID) + if err != nil { + if errors.Is(r.Context().Err(), context.Canceled) { + http.Error(w, ierrors.ErrRequestCancelled.Error(), http.StatusBadRequest) + return + } + if errors.Is(err, model.ErrSourcesJobNotFound) { + http.Error(w, model.ErrSourcesJobNotFound.Error(), http.StatusNotFound) + return + } + m.logger.Warnw("unable to get source job status", lf.Error, err.Error()) + http.Error(w, fmt.Sprintf("can't get source job status: %s", err.Error()), http.StatusBadRequest) + return + } + + var statusResponse jobStatusResponse + switch sourceJob.Status { + case model.SourceJobStatusFailed, model.SourceJobStatusAborted: + errorMessage := "source job failed" + if sourceJob.Error != nil { + errorMessage = sourceJob.Error.Error() + } + statusResponse.Status = sourceJob.Status.String() + statusResponse.Err = errorMessage + default: + statusResponse.Status = sourceJob.Status.String() + } + + resBody, err := json.Marshal(statusResponse) + if err != nil { + m.logger.Errorw("marshalling response for source job status", lf.Error, err.Error()) + http.Error(w, ierrors.ErrMarshallResponse.Error(), http.StatusInternalServerError) + return + } + + _, _ = w.Write(resBody) +} + +func validatePayload(payload *insertJobRequest) error { + switch true { + case emptyRegex.MatchString(payload.SourceID): + return errors.New("source_id is required") + case emptyRegex.MatchString(payload.DestinationID): + return errors.New("destination_id is required") + case emptyRegex.MatchString(payload.JobRunID): + return errors.New("job_run_id is required") + case emptyRegex.MatchString(payload.TaskRunID): + return errors.New("task_run_id is required") + default: + return nil + } +} diff --git a/warehouse/source/http_test.go b/warehouse/source/http_test.go new file mode 100644 index 00000000000..d01bb48e8b2 --- /dev/null +++ b/warehouse/source/http_test.go @@ -0,0 +1,364 @@ +package source + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/ory/dockertest/v3" + + "github.com/rudderlabs/rudder-go-kit/config" + + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/rudder-server/warehouse/internal/repo" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" + + "github.com/stretchr/testify/require" +) + +func TestValidatePayload(t *testing.T) { + testCases := []struct { + name string + payload insertJobRequest + expectedError error + }{ + { + name: "invalid source (empty string)", + payload: insertJobRequest{ + JobRunID: "job_run_id", + TaskRunID: "task_run_id", + SourceID: "", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("source_id is required"), + }, + + { + name: "invalid source (empty string with spaces)", + payload: insertJobRequest{ + JobRunID: "job_run_id", + TaskRunID: "task_run_id", + SourceID: " ", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("source_id is required"), + }, + { + name: "invalid destination", + payload: insertJobRequest{ + JobRunID: "job_run_id", + TaskRunID: "task_run_id", + SourceID: "source_id", + DestinationID: "", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("destination_id is required"), + }, + { + name: "invalid task run", + payload: insertJobRequest{ + JobRunID: "job_run_id", + TaskRunID: "", + SourceID: "source_id", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("task_run_id is required"), + }, + { + name: "invalid job run", + payload: insertJobRequest{ + JobRunID: "", + TaskRunID: "task_run_id", + SourceID: "source_id", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("job_run_id is required"), + }, + { + name: "valid payload", + payload: insertJobRequest{ + JobRunID: "job_run_id", + TaskRunID: "task_run_id", + SourceID: "source_id", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expectedError, validatePayload(&tc.payload)) + }) + } +} + +func TestManager_InsertJobHandler(t *testing.T) { + const ( + workspaceID = "test_workspace_id" + sourceID = "test_source_id" + destinationID = "test_destination_id" + namespace = "test_namespace" + destinationType = "test_destination_type" + sourceTaskRunID = "test_source_task_run_id" + sourceJobID = "test_source_job_id" + sourceJobRunID = "test_source_job_run_id" + ) + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + db, ctx := setupDB(t, pool), context.Background() + + now := time.Now().Truncate(time.Second).UTC() + + uploadsRepo := repo.NewUploads(db, repo.WithNow(func() time.Time { + return now + })) + tableUploadsRepo := repo.NewTableUploads(db, repo.WithNow(func() time.Time { + return now + })) + stagingRepo := repo.NewStagingFiles(db, repo.WithNow(func() time.Time { + return now + })) + + stagingFile := model.StagingFile{ + WorkspaceID: workspaceID, + Location: "s3://bucket/path/to/file", + SourceID: sourceID, + DestinationID: destinationID, + Status: whutils.StagingFileWaitingState, + Error: fmt.Errorf("dummy error"), + FirstEventAt: now.Add(time.Second), + UseRudderStorage: true, + DestinationRevisionID: "destination_revision_id", + TotalEvents: 100, + SourceTaskRunID: sourceTaskRunID, + SourceJobID: sourceJobID, + SourceJobRunID: sourceJobRunID, + TimeWindow: time.Date(1993, 8, 1, 3, 0, 0, 0, time.UTC), + }.WithSchema([]byte(`{"type": "object"}`)) + + stagingID, err := stagingRepo.Insert(ctx, &stagingFile) + require.NoError(t, err) + + uploadID, err := uploadsRepo.CreateWithStagingFiles(ctx, model.Upload{ + WorkspaceID: workspaceID, + Namespace: namespace, + SourceID: sourceID, + DestinationID: destinationID, + DestinationType: destinationType, + Status: model.Aborted, + SourceJobRunID: sourceJobRunID, + SourceTaskRunID: sourceTaskRunID, + }, []*model.StagingFile{{ + ID: stagingID, + SourceID: sourceID, + DestinationID: destinationID, + SourceJobRunID: sourceJobRunID, + SourceTaskRunID: sourceTaskRunID, + }}) + require.NoError(t, err) + + err = tableUploadsRepo.Insert(ctx, uploadID, []string{ + "test_table_1", + "test_table_2", + "test_table_3", + "test_table_4", + "test_table_5", + + "rudder_discards", + "rudder_identity_mappings", + "rudder_identity_merge_rules", + }) + require.NoError(t, err) + + t.Run("invalid payload", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(`"Invalid payload"`))) + resp := httptest.NewRecorder() + + sourceManager := New(config.New(), logger.NOP, db, &mockPublisher{}) + sourceManager.InsertJobHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid JSON in request body\n", string(b)) + }) + t.Run("invalid request", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(`{}`))) + resp := httptest.NewRecorder() + + sourceManager := New(config.New(), logger.NOP, db, &mockPublisher{}) + sourceManager.InsertJobHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid payload: source_id is required\n", string(b)) + }) + t.Run("success", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "destination_id": "test_destination_id", + "job_run_id": "test_source_job_run_id", + "task_run_id": "test_source_task_run_id", + "async_job_type": "deletebyjobrunid" + } + `))) + resp := httptest.NewRecorder() + + sourceManager := New(config.New(), logger.NOP, db, &mockPublisher{}) + sourceManager.InsertJobHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var insertResponse insertJobResponse + err = json.NewDecoder(resp.Body).Decode(&insertResponse) + require.NoError(t, err) + require.Nil(t, insertResponse.Err) + require.Len(t, insertResponse.JobIds, 5) + }) + t.Run("exclude tables", func(t *testing.T) { + // discards, merge rules and mapping tables should be excluded + }) +} + +func TestManager_StatusJobHandler(t *testing.T) { + const ( + workspaceID = "test_workspace_id" + sourceID = "test_source_id" + destinationID = "test_destination_id" + sourceJobID = "test_source_job_id" + sourceTaskRunID = "test_source_task_run_id" + ) + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + db, ctx := setupDB(t, pool), context.Background() + + now := time.Now().Truncate(time.Second).UTC() + + sourceRepo := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + createSourceJob := func(jobRunID, taskRunID, tableName string) []int64 { + metadata := fmt.Sprintf(`{"job_run_id":"%s","task_run_id":"%s","jobtype":"%s","start_time":"%s"}`, + jobRunID, + taskRunID, + model.SourceJobTypeDeleteByJobRunID, + now.Format(time.RFC3339), + ) + + ids, err := sourceRepo.Insert(ctx, []model.SourceJob{ + { + SourceID: sourceID, + DestinationID: destinationID, + WorkspaceID: workspaceID, + TableName: tableName, + JobType: model.SourceJobTypeDeleteByJobRunID, + Metadata: []byte(metadata), + }, + }) + require.NoError(t, err) + + return ids + } + + t.Run("invalid payload", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status", nil) + resp := httptest.NewRecorder() + + sourceManager := New(config.New(), logger.NOP, db, &mockPublisher{}) + sourceManager.StatusJobHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid request: source_id is required\n", string(b)) + }) + t.Run("status waiting", func(t *testing.T) { + _ = createSourceJob(sourceJobID+"-1", sourceTaskRunID+"-1", "test_table-1") + + qp := url.Values{} + qp.Add("task_run_id", sourceTaskRunID+"-1") + qp.Add("job_run_id", sourceJobID+"-1") + qp.Add("source_id", sourceID) + qp.Add("destination_id", destinationID) + qp.Add("workspace_id", workspaceID) + + req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status?"+qp.Encode(), nil) + resp := httptest.NewRecorder() + + sourceManager := New(config.New(), logger.NOP, db, &mockPublisher{}) + sourceManager.StatusJobHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var statusResponse jobStatusResponse + err = json.NewDecoder(resp.Body).Decode(&statusResponse) + require.NoError(t, err) + require.Equal(t, statusResponse.Status, model.SourceJobStatusWaiting.String()) + require.Empty(t, statusResponse.Err) + }) + t.Run("status aborted", func(t *testing.T) { + ids := createSourceJob(sourceJobID+"-2", sourceTaskRunID+"-2", "test_table-2") + + for _, id := range ids { + err := sourceRepo.OnUpdateFailure(ctx, id, errors.New("test error"), -1) + require.NoError(t, err) + } + + qp := url.Values{} + qp.Add("task_run_id", sourceTaskRunID+"-2") + qp.Add("job_run_id", sourceJobID+"-2") + qp.Add("source_id", sourceID) + qp.Add("destination_id", destinationID) + qp.Add("workspace_id", workspaceID) + + req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status?"+qp.Encode(), nil) + resp := httptest.NewRecorder() + + sourceManager := New(config.New(), logger.NOP, db, &mockPublisher{}) + sourceManager.StatusJobHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var statusResponse jobStatusResponse + err = json.NewDecoder(resp.Body).Decode(&statusResponse) + require.NoError(t, err) + require.Equal(t, statusResponse.Status, model.SourceJobStatusAborted.String()) + require.Equal(t, statusResponse.Err, errors.New("test error").Error()) + }) + t.Run("job not found", func(t *testing.T) { + qp := url.Values{} + qp.Add("task_run_id", sourceTaskRunID+"-unknown") + qp.Add("job_run_id", sourceJobID+"-unknown") + qp.Add("source_id", sourceID) + qp.Add("destination_id", destinationID) + qp.Add("workspace_id", workspaceID) + + req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status?"+qp.Encode(), nil) + resp := httptest.NewRecorder() + + sourceManager := New(config.New(), logger.NOP, db, &mockPublisher{}) + sourceManager.StatusJobHandler(resp, req) + require.Equal(t, http.StatusNotFound, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "sources job not found\n", string(b)) + }) +} diff --git a/warehouse/source/source.go b/warehouse/source/source.go new file mode 100644 index 00000000000..4d0a11488b4 --- /dev/null +++ b/warehouse/source/source.go @@ -0,0 +1,306 @@ +package source + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/lib/pq" + + "github.com/samber/lo" + + "github.com/rudderlabs/rudder-server/services/notifier" + + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/rudder-server/warehouse/internal/repo" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" + + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" +) + +type Manager struct { + logger logger.Logger + sourceRepo sourceRepo + tableUploadsRepo tableUploadsRepo + publisher publisher + + config struct { + maxBatchSizeToProcess int64 + maxAttemptsPerJob int + } + + trigger struct { + processingTimeout func() <-chan time.Time + processingSleepInterval func() <-chan time.Time + } +} + +func New(conf *config.Config, log logger.Logger, db *sqlmw.DB, publisher publisher) *Manager { + m := &Manager{ + logger: log.Child("source"), + tableUploadsRepo: repo.NewTableUploads(db), + sourceRepo: repo.NewSource(db), + publisher: publisher, + } + + m.config.maxBatchSizeToProcess = conf.GetInt64("Warehouse.jobs.maxBatchSizeToProcess", 10) + m.config.maxAttemptsPerJob = conf.GetInt("Warehouse.jobs.maxAttemptsPerJob", 3) + + m.trigger.processingTimeout = func() <-chan time.Time { + return time.After(conf.GetDuration("Warehouse.jobs.processingTimeout", 300, time.Second)) + } + m.trigger.processingSleepInterval = func() <-chan time.Time { + return time.After(conf.GetDuration("Warehouse.jobs.processingSleepInterval", 10, time.Second)) + } + return m +} + +func (m *Manager) InsertJobs(ctx context.Context, payload insertJobRequest) ([]int64, error) { + jobType, err := model.FromSourceJobType(payload.JobType) + if err != nil { + return nil, fmt.Errorf("invalid job type %s", payload.JobType) + } + + tableUploads, err := m.tableUploadsRepo.GetByJobRunTaskRun( + ctx, + payload.SourceID, + payload.DestinationID, + payload.JobRunID, + payload.TaskRunID, + ) + if err != nil { + return nil, fmt.Errorf("getting table uploads: %w", err) + } + + tableNames := lo.Map(tableUploads, func(item model.TableUpload, index int) string { + return item.TableName + }) + tableNames = lo.Filter(lo.Uniq(tableNames), func(tableName string, i int) bool { + switch strings.ToLower(tableName) { + case whutils.DiscardsTable, whutils.IdentityMappingsTable, whutils.IdentityMergeRulesTable: + return false + default: + return true + } + }) + + type metadata struct { + JobRunID string `json:"job_run_id"` + TaskRunID string `json:"task_run_id"` + JobType string `json:"jobtype"` + StartTime string `json:"start_time"` + } + metadataJson, err := json.Marshal(metadata{ + JobRunID: payload.JobRunID, + TaskRunID: payload.TaskRunID, + StartTime: payload.StartTime, + JobType: string(notifier.JobTypeAsync), + }) + if err != nil { + return nil, fmt.Errorf("marshalling metadata: %w", err) + } + jobIds, err := m.sourceRepo.Insert(ctx, lo.Map(tableNames, func(tableName string, _ int) model.SourceJob { + return model.SourceJob{ + SourceID: payload.SourceID, + DestinationID: payload.DestinationID, + WorkspaceID: payload.WorkspaceID, + TableName: tableName, + JobType: jobType, + Metadata: metadataJson, + } + })) + if err != nil { + return nil, fmt.Errorf("inserting source jobs: %w", err) + } + return jobIds, nil +} + +func (m *Manager) Run(ctx context.Context) error { + if err := m.sourceRepo.Reset(ctx); err != nil { + return fmt.Errorf("resetting source jobs with error %w", err) + } + + if err := m.process(ctx); err != nil { + var pqErr *pq.Error + + switch { + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.As(err, &pqErr) && pqErr.Code == "57014": + return nil + default: + return fmt.Errorf("processing source jobs with error %w", err) + } + } + return nil +} + +func (m *Manager) process(ctx context.Context) error { + m.logger.Infow("starting source jobs processing") + + for { + pendingJobs, err := m.sourceRepo.GetToProcess(ctx, m.config.maxBatchSizeToProcess) + if err != nil { + return fmt.Errorf("getting pending source jobs with error %w", err) + } + if len(pendingJobs) == 0 { + continue + } + + if err = m.processPendingJobs(ctx, pendingJobs); err != nil { + return fmt.Errorf("process pending source jobs with error %w", err) + } + + select { + case <-ctx.Done(): + m.logger.Infow("source jobs processing stopped due to context cancelled") + return nil + case <-m.trigger.processingSleepInterval(): + } + } +} + +// processPendingJobs +// 1. Prepare and publish claims to notifier +// 2. Mark source jobs as executing +// 3. Mark source jobs as failed if notifier returned timeout +// 4. Mark source jobs as failed if notifier returned error else mark as succeeded +func (m *Manager) processPendingJobs(ctx context.Context, pendingJobs []model.SourceJob) error { + claims := make([]json.RawMessage, 0, len(pendingJobs)) + for _, job := range pendingJobs { + message, err := json.Marshal(NotifierRequest{ + ID: job.ID, + SourceID: job.SourceID, + DestinationID: job.DestinationID, + WorkspaceID: job.WorkspaceID, + TableName: job.TableName, + JobType: job.JobType.String(), + MetaData: job.Metadata, + }) + if err != nil { + return fmt.Errorf("marshalling source job %d: %w", job.ID, err) + } + claims = append(claims, message) + } + + ch, err := m.publisher.Publish(ctx, ¬ifier.PublishRequest{ + Payloads: claims, + JobType: notifier.JobTypeAsync, + Priority: 100, + }) + if err != nil { + return fmt.Errorf("publishing source jobs: %w", err) + } + + pendingJobsMap := lo.SliceToMap(pendingJobs, func(item model.SourceJob) (int64, *model.SourceJob) { + return item.ID, &item + }) + pendingJobIDs := lo.Map(pendingJobs, func(item model.SourceJob, index int) int64 { + return item.ID + }) + + if err = m.sourceRepo.MarkExecuting(ctx, pendingJobIDs); err != nil { + return fmt.Errorf("marking status executing: %w", err) + } + + select { + case <-ctx.Done(): + m.logger.Infow("pending jobs process stopped due to context cancelled", "ids", pendingJobIDs) + return nil + case responses, ok := <-ch: + if !ok { + if err := m.markFailed(ctx, pendingJobIDs, ErrReceivingChannelClosed); err != nil { + return fmt.Errorf("marking status failed for receiving channel closed: %w", err) + } + return ErrReceivingChannelClosed + } + if responses.Err != nil { + if err := m.markFailed(ctx, pendingJobIDs, responses.Err); err != nil { + return fmt.Errorf("marking status failed for publishing source jobs: %w", err) + } + return fmt.Errorf("publishing source jobs: %w", responses.Err) + } + + for _, job := range responses.Jobs { + var response NotifierResponse + var jobStatus model.SourceJobStatus + + if err = json.Unmarshal(job.Payload, &response); err != nil { + return fmt.Errorf("unmarshalling notifier response for source job %d: %w", job.ID, err) + } + if jobStatus, err = model.FromSourceJobStatus(string(job.Status)); err != nil { + return fmt.Errorf("invalid job status %s for source job %d: %w", job.Status, job.ID, err) + } + if pendingJob, ok := pendingJobsMap[response.ID]; ok { + pendingJob.Status = jobStatus + pendingJob.Error = job.Error + } + } + + for _, job := range pendingJobsMap { + if job.Error != nil { + err = m.sourceRepo.OnUpdateFailure( + ctx, + job.ID, + job.Error, + m.config.maxAttemptsPerJob, + ) + if err != nil { + return fmt.Errorf("on update failure for source job %d: %w", job.ID, err) + } + continue + } + + if err = m.sourceRepo.OnUpdateSuccess(ctx, job.ID); err != nil { + return fmt.Errorf("marking status success for source job %d: %w", job.ID, err) + } + } + case <-m.trigger.processingTimeout(): + if err = m.markFailed(ctx, pendingJobIDs, ErrProcessingTimedOut); err != nil { + return fmt.Errorf("marking status failed for processing timed out: %w", err) + } + return ErrProcessingTimedOut + } + return nil +} + +func (m *Manager) markFailed(ctx context.Context, ids []int64, failError error) error { + for _, id := range ids { + err := m.sourceRepo.OnUpdateFailure( + ctx, + id, + failError, + m.config.maxAttemptsPerJob, + ) + if err != nil { + return fmt.Errorf("updating failure for source job %d: %w", id, err) + } + } + return nil +} + +type Uploader struct{} + +func (*Uploader) IsWarehouseSchemaEmpty() bool { return true } +func (*Uploader) UpdateLocalSchema(context.Context, model.Schema) error { return nil } +func (*Uploader) GetTableSchemaInUpload(string) model.TableSchema { return model.TableSchema{} } +func (*Uploader) ShouldOnDedupUseNewRecord() bool { return false } +func (*Uploader) UseRudderStorage() bool { return false } +func (*Uploader) CanAppend() bool { return false } +func (*Uploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } +func (*Uploader) GetLoadFileType() string { return "" } +func (*Uploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Now(), time.Now() } +func (*Uploader) GetLocalSchema(context.Context) (model.Schema, error) { return model.Schema{}, nil } +func (*Uploader) GetTableSchemaInWarehouse(string) model.TableSchema { return model.TableSchema{} } +func (*Uploader) GetSampleLoadFileLocation(context.Context, string) (string, error) { return "", nil } +func (*Uploader) GetLoadFilesMetadata(context.Context, whutils.GetLoadFilesOptions) ([]whutils.LoadFile, error) { + return []whutils.LoadFile{}, nil +} + +func (*Uploader) GetSingleLoadFile(context.Context, string) (whutils.LoadFile, error) { + return whutils.LoadFile{}, nil +} diff --git a/warehouse/source/source_test.go b/warehouse/source/source_test.go new file mode 100644 index 00000000000..9aa96a61b1e --- /dev/null +++ b/warehouse/source/source_test.go @@ -0,0 +1,504 @@ +package source + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/samber/lo" + + "golang.org/x/sync/errgroup" + + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/rudder-server/warehouse/internal/repo" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + "github.com/rudderlabs/rudder-server/services/notifier" + migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" +) + +type jobRunTaskRun struct { + jobRunID string + taskRunID string +} + +func newMockPublisher(mockResponse <-chan chan *notifier.PublishResponse, mockError error) *mockPublisher { + return &mockPublisher{ + mockResponse: mockResponse, + mockError: mockError, + } +} + +type mockPublisher struct { + mockResponse <-chan chan *notifier.PublishResponse + mockError error +} + +func (m *mockPublisher) Publish(context.Context, *notifier.PublishRequest) (<-chan *notifier.PublishResponse, error) { + if m.mockError != nil { + return nil, m.mockError + } + return <-m.mockResponse, nil +} + +func TestSource(t *testing.T) { + const ( + workspaceID = "test_workspace_id" + sourceID = "test_source_id" + destinationID = "test_destination_id" + sourceTaskRunID = "test_source_task_run_id" + sourceJobRunID = "test_source_job_run_id" + ) + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + ctx := context.Background() + now := time.Now().Truncate(time.Second).UTC() + + createSourceJob := func(sourceRepo *repo.Source, jobRunID, taskRunID, tableName string) []int64 { + metadata := fmt.Sprintf(`{"job_run_id":"%s","task_run_id":"%s","jobtype":"%s","start_time":"%s"}`, + jobRunID, + taskRunID, + model.SourceJobTypeDeleteByJobRunID, + now.Format(time.RFC3339), + ) + + ids, err := sourceRepo.Insert(ctx, []model.SourceJob{ + { + SourceID: sourceID, + DestinationID: destinationID, + WorkspaceID: workspaceID, + TableName: tableName, + JobType: model.SourceJobTypeDeleteByJobRunID, + Metadata: []byte(metadata), + }, + }) + require.NoError(t, err) + return ids + } + + t.Run("channel closed", func(t *testing.T) { + db := setupDB(t, pool) + + sr := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + sourceJobs := createSourceJob(sr, sourceJobRunID, sourceTaskRunID, "test_table") + + response := make(chan *notifier.PublishResponse) + close(response) + publishResponse := make(chan chan *notifier.PublishResponse, 1) + publishResponse <- response + close(publishResponse) + + m := New(config.New(), logger.NOP, db, newMockPublisher(publishResponse, nil)) + require.Error(t, m.Run(ctx)) + + job, err := m.sourceRepo.GetByJobRunTaskRun(ctx, sourceJobRunID, sourceTaskRunID) + require.NoError(t, err) + + require.Equal(t, sourceJobs, []int64{job.ID}) + require.Equal(t, model.SourceJobStatusFailed, job.Status) + require.Equal(t, int64(1), job.Attempts) + require.EqualError(t, ErrReceivingChannelClosed, job.Error.Error()) + }) + t.Run("publishing error", func(t *testing.T) { + db := setupDB(t, pool) + + sr := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + sourceJobs := createSourceJob(sr, sourceJobRunID, sourceTaskRunID, "test_table") + + m := New(config.New(), logger.NOP, db, newMockPublisher(nil, errors.New("test error"))) + require.Error(t, m.Run(ctx)) + + job, err := m.sourceRepo.GetByJobRunTaskRun(ctx, sourceJobRunID, sourceTaskRunID) + require.NoError(t, err) + + require.Equal(t, sourceJobs, []int64{job.ID}) + require.Equal(t, model.SourceJobStatusWaiting, job.Status) + require.Zero(t, job.Attempts) + require.NoError(t, job.Error) + }) + t.Run("publishing response error", func(t *testing.T) { + db := setupDB(t, pool) + + sr := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + sourceJobs := createSourceJob(sr, sourceJobRunID, sourceTaskRunID, "test_table") + + response := make(chan *notifier.PublishResponse, 1) + response <- ¬ifier.PublishResponse{ + Err: errors.New("test error"), + } + close(response) + publishResponse := make(chan chan *notifier.PublishResponse, 1) + publishResponse <- response + close(publishResponse) + + m := New(config.New(), logger.NOP, db, newMockPublisher(publishResponse, nil)) + require.Error(t, m.Run(ctx)) + + job, err := m.sourceRepo.GetByJobRunTaskRun(ctx, sourceJobRunID, sourceTaskRunID) + require.NoError(t, err) + + require.Equal(t, sourceJobs, []int64{job.ID}) + require.Equal(t, model.SourceJobStatusFailed, job.Status) + require.Equal(t, int64(1), job.Attempts) + require.Error(t, job.Error) + }) + t.Run("timeout", func(t *testing.T) { + db := setupDB(t, pool) + + sr := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + sourceJobs := createSourceJob(sr, sourceJobRunID, sourceTaskRunID, "test_table") + + response := make(chan *notifier.PublishResponse) + defer close(response) + publishResponse := make(chan chan *notifier.PublishResponse, 1) + publishResponse <- response + close(publishResponse) + + c := config.New() + c.Set("Warehouse.jobs.processingTimeout", "1s") + c.Set("Warehouse.jobs.processingSleepInterval", "1s") + + m := New(c, logger.NOP, db, newMockPublisher(publishResponse, nil)) + require.Error(t, m.Run(ctx)) + + job, err := m.sourceRepo.GetByJobRunTaskRun(ctx, sourceJobRunID, sourceTaskRunID) + require.NoError(t, err) + + require.Equal(t, sourceJobs, []int64{job.ID}) + require.Equal(t, model.SourceJobStatusFailed, job.Status) + require.Equal(t, int64(1), job.Attempts) + require.EqualError(t, ErrProcessingTimedOut, job.Error.Error()) + }) + t.Run("some succeeded, some failed", func(t *testing.T) { + db := setupDB(t, pool) + + sr := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + sourceJobs1 := createSourceJob(sr, sourceJobRunID+"-1", sourceTaskRunID+"-1", "test_table-1") + sourceJobs2 := createSourceJob(sr, sourceJobRunID+"-2", sourceTaskRunID+"-2", "test_table-1") + sourceJobs3 := createSourceJob(sr, sourceJobRunID+"-3", sourceTaskRunID+"-3", "test_table-1") + sourceJobs4 := createSourceJob(sr, sourceJobRunID+"-4", sourceTaskRunID+"-4", "test_table-1") + + response := make(chan *notifier.PublishResponse, 1) + response <- ¬ifier.PublishResponse{ + Jobs: []notifier.Job{ + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs1[0])), + Status: notifier.Succeeded, + }, + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs2[0])), + Status: notifier.Failed, + Error: errors.New("test error"), + }, + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs3[0])), + Status: notifier.Failed, + Error: errors.New("test error"), + }, + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs4[0])), + Status: notifier.Succeeded, + }, + }, + } + close(response) + publishResponse := make(chan chan *notifier.PublishResponse, 1) + publishResponse <- response + close(publishResponse) + + c := config.New() + c.Set("Warehouse.jobs.processingSleepInterval", "1ms") + c.Set("Warehouse.jobs.maxAttemptsPerJob", -1) + + m := New(c, logger.NOP, db, newMockPublisher(publishResponse, nil)) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + g, gCtx := errgroup.WithContext(ctx) + g.Go(func() error { + return m.Run(gCtx) + }) + g.Go(func() error { + runs := []jobRunTaskRun{ + {jobRunID: sourceJobRunID + "-1", taskRunID: sourceTaskRunID + "-1"}, + {jobRunID: sourceJobRunID + "-2", taskRunID: sourceTaskRunID + "-2"}, + {jobRunID: sourceJobRunID + "-3", taskRunID: sourceTaskRunID + "-3"}, + {jobRunID: sourceJobRunID + "-4", taskRunID: sourceTaskRunID + "-4"}, + } + require.Eventually(t, func() bool { + defer cancel() + + jobs := getAll(t, context.Background(), m.sourceRepo, runs...) + filteredJobs := lo.Filter(jobs, func(j *model.SourceJob, index int) bool { + return j.Status == model.SourceJobStatusAborted || j.Status == model.SourceJobStatusSucceeded + }) + return len(filteredJobs) == len(runs) + }, + 60*time.Second, + 100*time.Millisecond, + ) + return nil + }) + require.NoError(t, g.Wait()) + + job1, err := m.sourceRepo.GetByJobRunTaskRun(context.Background(), sourceJobRunID+"-1", sourceTaskRunID+"-1") + require.NoError(t, err) + job2, err := m.sourceRepo.GetByJobRunTaskRun(context.Background(), sourceJobRunID+"-2", sourceTaskRunID+"-2") + require.NoError(t, err) + job3, err := m.sourceRepo.GetByJobRunTaskRun(context.Background(), sourceJobRunID+"-3", sourceTaskRunID+"-3") + require.NoError(t, err) + job4, err := m.sourceRepo.GetByJobRunTaskRun(context.Background(), sourceJobRunID+"-4", sourceTaskRunID+"-4") + require.NoError(t, err) + + require.Equal(t, sourceJobs1, []int64{job1.ID}) + require.Equal(t, model.SourceJobStatusSucceeded, job1.Status) + require.Equal(t, int64(0), job1.Attempts) + require.NoError(t, job1.Error) + require.Equal(t, sourceJobs2, []int64{job2.ID}) + require.Equal(t, model.SourceJobStatusAborted, job2.Status) + require.Equal(t, int64(1), job2.Attempts) + require.Error(t, job2.Error) + require.Equal(t, sourceJobs3, []int64{job3.ID}) + require.Equal(t, model.SourceJobStatusAborted, job3.Status) + require.Equal(t, int64(1), job3.Attempts) + require.Error(t, job3.Error) + require.Equal(t, sourceJobs4, []int64{job4.ID}) + require.Equal(t, model.SourceJobStatusSucceeded, job4.Status) + require.Equal(t, int64(0), job4.Attempts) + require.NoError(t, job4.Error) + }) + t.Run("failed and then aborted", func(t *testing.T) { + db := setupDB(t, pool) + + sr := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + sourceJobs1 := createSourceJob(sr, sourceJobRunID+"-1", sourceTaskRunID+"-1", "test_table-1") + sourceJobs2 := createSourceJob(sr, sourceJobRunID+"-2", sourceTaskRunID+"-2", "test_table-2") + + publishResponse := make(chan chan *notifier.PublishResponse, 10) + for i := 0; i < 10; i++ { + response := make(chan *notifier.PublishResponse, 1) + response <- ¬ifier.PublishResponse{ + Jobs: []notifier.Job{ + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs1[0])), + Status: notifier.Succeeded, + }, + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs2[0])), + Status: notifier.Failed, + Error: errors.New("test error"), + }, + }, + Err: nil, + } + close(response) + publishResponse <- response + } + close(publishResponse) + + c := config.New() + c.Set("Warehouse.jobs.processingSleepInterval", "1ms") + c.Set("Warehouse.jobs.maxAttemptsPerJob", 5) + + m := New(c, logger.NOP, db, newMockPublisher(publishResponse, nil)) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + g, gCtx := errgroup.WithContext(ctx) + g.Go(func() error { + return m.Run(gCtx) + }) + g.Go(func() error { + runs := []jobRunTaskRun{ + {jobRunID: sourceJobRunID + "-1", taskRunID: sourceTaskRunID + "-1"}, + {jobRunID: sourceJobRunID + "-2", taskRunID: sourceTaskRunID + "-2"}, + } + require.Eventually(t, func() bool { + defer cancel() + + jobs := getAll(t, context.Background(), m.sourceRepo, runs...) + filteredJobs := lo.Filter(jobs, func(j *model.SourceJob, index int) bool { + return j.Status == model.SourceJobStatusAborted || j.Status == model.SourceJobStatusSucceeded + }) + return len(filteredJobs) == len(runs) + }, + 60*time.Second, + 100*time.Millisecond, + ) + return nil + }) + require.NoError(t, g.Wait()) + + job1, err := m.sourceRepo.GetByJobRunTaskRun(context.Background(), sourceJobRunID+"-1", sourceTaskRunID+"-1") + require.NoError(t, err) + job2, err := m.sourceRepo.GetByJobRunTaskRun(context.Background(), sourceJobRunID+"-2", sourceTaskRunID+"-2") + require.NoError(t, err) + + require.Equal(t, sourceJobs1, []int64{job1.ID}) + require.Equal(t, model.SourceJobStatusSucceeded, job1.Status) + require.Equal(t, int64(0), job1.Attempts) + require.NoError(t, job1.Error) + require.Equal(t, sourceJobs2, []int64{job2.ID}) + require.Equal(t, model.SourceJobStatusAborted, job2.Status) + require.Equal(t, int64(7), job2.Attempts) + require.Error(t, job2.Error) + }) + t.Run("failed then succeeded", func(t *testing.T) { + db := setupDB(t, pool) + + sr := repo.NewSource(db, repo.WithNow(func() time.Time { + return now + })) + + sourceJobs1 := createSourceJob(sr, sourceJobRunID+"-1", sourceTaskRunID+"-1", "test_table-1") + sourceJobs2 := createSourceJob(sr, sourceJobRunID+"-2", sourceTaskRunID+"-2", "test_table-2") + + failedResponse := ¬ifier.PublishResponse{ + Jobs: []notifier.Job{ + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs1[0])), + Status: notifier.Failed, + Error: errors.New("test error"), + }, + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs2[0])), + Status: notifier.Failed, + Error: errors.New("test error"), + }, + }, + Err: nil, + } + succeededResponse := ¬ifier.PublishResponse{ + Jobs: []notifier.Job{ + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs1[0])), + Status: notifier.Succeeded, + }, + { + Payload: []byte(fmt.Sprintf(`{"id": %d}`, sourceJobs2[0])), + Status: notifier.Succeeded, + }, + }, + Err: nil, + } + publishResponse := make(chan chan *notifier.PublishResponse, 10) + for i := 0; i < 10; i++ { + response := make(chan *notifier.PublishResponse, 1) + if i < 5 { + response <- failedResponse + } else { + response <- succeededResponse + } + close(response) + publishResponse <- response + } + close(publishResponse) + + c := config.New() + c.Set("Warehouse.jobs.processingSleepInterval", "1ms") + c.Set("Warehouse.jobs.maxAttemptsPerJob", 7) + + m := New(c, logger.NOP, db, newMockPublisher(publishResponse, nil)) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + g, gCtx := errgroup.WithContext(ctx) + g.Go(func() error { + return m.Run(gCtx) + }) + g.Go(func() error { + runs := []jobRunTaskRun{ + {jobRunID: sourceJobRunID + "-1", taskRunID: sourceTaskRunID + "-1"}, + {jobRunID: sourceJobRunID + "-2", taskRunID: sourceTaskRunID + "-2"}, + } + require.Eventually(t, func() bool { + defer cancel() + + jobs := getAll(t, context.Background(), m.sourceRepo, runs...) + filteredJobs := lo.Filter(jobs, func(j *model.SourceJob, index int) bool { + return j.Status == model.SourceJobStatusAborted || j.Status == model.SourceJobStatusSucceeded + }) + return len(filteredJobs) == len(runs) + }, + 60*time.Second, + 100*time.Millisecond, + ) + return nil + }) + require.NoError(t, g.Wait()) + + job1, err := m.sourceRepo.GetByJobRunTaskRun(context.Background(), sourceJobRunID+"-1", sourceTaskRunID+"-1") + require.NoError(t, err) + job2, err := m.sourceRepo.GetByJobRunTaskRun(context.Background(), sourceJobRunID+"-2", sourceTaskRunID+"-2") + require.NoError(t, err) + + require.Equal(t, sourceJobs1, []int64{job1.ID}) + require.Equal(t, model.SourceJobStatusSucceeded, job1.Status) + require.Equal(t, int64(5), job1.Attempts) + require.Error(t, job1.Error) + require.Equal(t, sourceJobs2, []int64{job2.ID}) + require.Equal(t, model.SourceJobStatusSucceeded, job2.Status) // Failed + require.Equal(t, int64(5), job2.Attempts) + require.Error(t, job2.Error) + }) +} + +func setupDB(t *testing.T, pool *dockertest.Pool) *sqlmiddleware.DB { + t.Helper() + + pgResource, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + t.Log("db:", pgResource.DBDsn) + + err = (&migrator.Migrator{ + Handle: pgResource.DB, + MigrationsTable: "wh_schema_migrations", + }).Migrate("warehouse") + require.NoError(t, err) + + return sqlmiddleware.New(pgResource.DB) +} + +func getAll(t testing.TB, ctx context.Context, sourceRepo sourceRepo, runs ...jobRunTaskRun) (jobs []*model.SourceJob) { + t.Helper() + + for _, run := range runs { + job, err := sourceRepo.GetByJobRunTaskRun(ctx, run.jobRunID, run.taskRunID) + require.NoError(t, err) + + jobs = append(jobs, job) + } + return +} diff --git a/warehouse/source/types.go b/warehouse/source/types.go new file mode 100644 index 00000000000..9a6e88916e7 --- /dev/null +++ b/warehouse/source/types.go @@ -0,0 +1,67 @@ +package source + +import ( + "context" + "encoding/json" + "errors" + + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" +) + +var ( + ErrReceivingChannelClosed = errors.New("receiving channel closed") + ErrProcessingTimedOut = errors.New("processing timed out") +) + +type insertJobRequest struct { + SourceID string `json:"source_id"` + DestinationID string `json:"destination_id"` + StartTime string `json:"start_time"` + JobRunID string `json:"job_run_id"` + TaskRunID string `json:"task_run_id"` + JobType string `json:"async_job_type"` + WorkspaceID string `json:"workspace_id"` +} + +type insertJobResponse struct { + JobIds []int64 `json:"jobids"` + Err error `json:"error"` +} + +type jobStatusResponse struct { + Status string + Err string +} + +type NotifierRequest struct { + ID int64 `json:"id"` + SourceID string `json:"source_id"` + DestinationID string `json:"destination_id"` + WorkspaceID string `json:"workspace_id"` + TableName string `json:"tablename"` + JobType string `json:"async_job_type"` + MetaData json.RawMessage `json:"metadata"` +} + +type NotifierResponse struct { + ID int64 `json:"id"` +} + +type publisher interface { + Publish(context.Context, *notifier.PublishRequest) (<-chan *notifier.PublishResponse, error) +} + +type sourceRepo interface { + Insert(context.Context, []model.SourceJob) ([]int64, error) + Reset(context.Context) error + GetToProcess(context.Context, int64) ([]model.SourceJob, error) + GetByJobRunTaskRun(context.Context, string, string) (*model.SourceJob, error) + OnUpdateSuccess(context.Context, int64) error + OnUpdateFailure(context.Context, int64, error, int) error + MarkExecuting(context.Context, []int64) error +} + +type tableUploadsRepo interface { + GetByJobRunTaskRun(ctx context.Context, sourceID, destinationID, jobRunID, taskRunID string) ([]model.TableUpload, error) +} From bc3aa37c8dd94f65cfe5b25a501c27462810b53e Mon Sep 17 00:00:00 2001 From: Leonidas Vrachnis Date: Thu, 2 Nov 2023 06:23:19 +0100 Subject: [PATCH 4/5] chore: upgrade net library (#4065) --- go.mod | 4 ++-- go.sum | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index be1182facb9..d9898ec921f 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ replace ( go.mongodb.org/mongo-driver => go.mongodb.org/mongo-driver v1.12.1 golang.org/x/crypto => golang.org/x/crypto v0.13.0 golang.org/x/image => golang.org/x/image v0.12.0 - golang.org/x/net => golang.org/x/net v0.15.0 + golang.org/x/net => golang.org/x/net v0.17.0 golang.org/x/text => golang.org/x/text v0.13.0 gopkg.in/yaml.v2 => gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1 @@ -289,7 +289,7 @@ require ( golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/term v0.12.0 // indirect + golang.org/x/term v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.13.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/go.sum b/go.sum index f1fcffcc648..e9034d83ad5 100644 --- a/go.sum +++ b/go.sum @@ -1200,8 +1200,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1331,8 +1331,9 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20220526004731-065cf7ba2467/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From 5cd8f7bb201e1e80e3ce1fdd54f6285700205497 Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Thu, 2 Nov 2023 14:00:36 +0530 Subject: [PATCH 5/5] fix: minio heathcheck (#4068) --- warehouse/integrations/testdata/docker-compose.minio.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/warehouse/integrations/testdata/docker-compose.minio.yml b/warehouse/integrations/testdata/docker-compose.minio.yml index 2910716fa91..2db892db548 100644 --- a/warehouse/integrations/testdata/docker-compose.minio.yml +++ b/warehouse/integrations/testdata/docker-compose.minio.yml @@ -11,6 +11,6 @@ services: - MINIO_SITE_REGION=us-east-1 command: server /data healthcheck: - test: curl --fail http://localhost:9000/minio/health/live || exit 1 + test: timeout 5s bash -c ':> /dev/tcp/127.0.0.1/9000' || exit 1 interval: 1s retries: 25