Skip to content

Commit

Permalink
chore: master pull
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr committed Nov 1, 2023
2 parents b810b26 + b80d273 commit 98a2ca2
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 137 deletions.
86 changes: 37 additions & 49 deletions warehouse/internal/repo/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package repo

import (
"context"
"database/sql"
jsonstd "encoding/json"
"fmt"

Expand Down Expand Up @@ -44,14 +43,14 @@ func NewLoadFiles(db *sqlmiddleware.DB, opts ...Opt) *LoadFiles {
}

// DeleteByStagingFiles deletes load files associated with stagingFileIDs.
func (repo *LoadFiles) DeleteByStagingFiles(ctx context.Context, stagingFileIDs []int64) error {
func (lf *LoadFiles) DeleteByStagingFiles(ctx context.Context, stagingFileIDs []int64) error {
sqlStatement := `
DELETE FROM
` + loadTableName + `
WHERE
staging_file_id = ANY($1);`

_, err := repo.db.ExecContext(ctx, sqlStatement, pq.Array(stagingFileIDs))
_, err := lf.db.ExecContext(ctx, sqlStatement, pq.Array(stagingFileIDs))
if err != nil {
return fmt.Errorf(`deleting load files: %w`, err)
}
Expand All @@ -60,58 +59,47 @@ func (repo *LoadFiles) DeleteByStagingFiles(ctx context.Context, stagingFileIDs
}

// Insert loadFiles into the database.
func (repo *LoadFiles) Insert(ctx context.Context, loadFiles []model.LoadFile) (err error) {
// Using transactions for bulk copying
txn, err := repo.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return
}

stmt, err := txn.PrepareContext(
ctx,
pq.CopyIn(
"wh_load_files",
"staging_file_id",
"location",
"source_id",
"destination_id",
"destination_type",
"table_name",
"total_events",
"created_at",
"metadata",
),
)
if err != nil {
return fmt.Errorf(`inserting load files: CopyIn: %w`, err)
}
defer func() { _ = stmt.Close() }()

for _, loadFile := range loadFiles {
metadata := fmt.Sprintf(`{"content_length": %d, "destination_revision_id": %q, "use_rudder_storage": %t}`, loadFile.ContentLength, loadFile.DestinationRevisionID, loadFile.UseRudderStorage)
_, err = stmt.ExecContext(ctx, loadFile.StagingFileID, loadFile.Location, loadFile.SourceID, loadFile.DestinationID, loadFile.DestinationType, loadFile.TableName, loadFile.TotalRows, timeutil.Now(), metadata)
func (lf *LoadFiles) Insert(ctx context.Context, loadFiles []model.LoadFile) error {
return (*repo)(lf).WithTx(ctx, func(tx *sqlmiddleware.Tx) error {
stmt, err := tx.PrepareContext(
ctx,
pq.CopyIn(
"wh_load_files",
"staging_file_id",
"location",
"source_id",
"destination_id",
"destination_type",
"table_name",
"total_events",
"created_at",
"metadata",
),
)
if err != nil {
_ = txn.Rollback()
return fmt.Errorf(`inserting load files: CopyIn exec: %w`, err)
return fmt.Errorf(`inserting load files: CopyIn: %w`, err)

Check warning on line 80 in warehouse/internal/repo/load.go

View check run for this annotation

Codecov / codecov/patch

warehouse/internal/repo/load.go#L80

Added line #L80 was not covered by tests
}
}

_, err = stmt.ExecContext(ctx)
if err != nil {
_ = txn.Rollback()
return fmt.Errorf(`inserting load files: CopyIn final exec: %w`, err)
}
err = txn.Commit()
if err != nil {
return fmt.Errorf(`inserting load files: commit: %w`, err)
}
return
defer func() { _ = stmt.Close() }()

for _, loadFile := range loadFiles {
metadata := fmt.Sprintf(`{"content_length": %d, "destination_revision_id": %q, "use_rudder_storage": %t}`, loadFile.ContentLength, loadFile.DestinationRevisionID, loadFile.UseRudderStorage)
_, err = stmt.ExecContext(ctx, loadFile.StagingFileID, loadFile.Location, loadFile.SourceID, loadFile.DestinationID, loadFile.DestinationType, loadFile.TableName, loadFile.TotalRows, timeutil.Now(), metadata)
if err != nil {
return fmt.Errorf(`inserting load files: CopyIn exec: %w`, err)
}

Check warning on line 89 in warehouse/internal/repo/load.go

View check run for this annotation

Codecov / codecov/patch

warehouse/internal/repo/load.go#L88-L89

Added lines #L88 - L89 were not covered by tests
}
_, err = stmt.ExecContext(ctx)
if err != nil {
return fmt.Errorf(`inserting load files: CopyIn final exec: %w`, err)
}

Check warning on line 94 in warehouse/internal/repo/load.go

View check run for this annotation

Codecov / codecov/patch

warehouse/internal/repo/load.go#L93-L94

Added lines #L93 - L94 were not covered by tests
return nil
})
}

// GetByStagingFiles returns all load files matching the staging file ids.
//
// Ordered by id ascending.
func (repo *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []int64) ([]model.LoadFile, error) {
func (lf *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []int64) ([]model.LoadFile, error) {
sqlStatement := `
WITH row_numbered_load_files as (
SELECT
Expand All @@ -136,7 +124,7 @@ func (repo *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []i
ORDER BY id ASC
`

rows, err := repo.db.QueryContext(ctx, sqlStatement, pq.Array(stagingFileIDs))
rows, err := lf.db.QueryContext(ctx, sqlStatement, pq.Array(stagingFileIDs))
if err != nil {
return nil, fmt.Errorf("query staging ids: %w", err)
}
Expand Down
21 changes: 21 additions & 0 deletions warehouse/internal/repo/repo.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package repo

import (
"context"
"database/sql"
"fmt"
"time"

sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
Expand All @@ -18,3 +21,21 @@ func WithNow(now func() time.Time) Opt {
r.now = now
}
}

func (r *repo) WithTx(ctx context.Context, f func(tx *sqlmiddleware.Tx) error) error {
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}

if err := f(tx); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback transaction for %w: %w", err, rollbackErr)
}
return err

Check warning on line 35 in warehouse/internal/repo/repo.go

View check run for this annotation

Codecov / codecov/patch

warehouse/internal/repo/repo.go#L32-L35

Added lines #L32 - L35 were not covered by tests
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}

Check warning on line 39 in warehouse/internal/repo/repo.go

View check run for this annotation

Codecov / codecov/patch

warehouse/internal/repo/repo.go#L38-L39

Added lines #L38 - L39 were not covered by tests
return nil
}
67 changes: 30 additions & 37 deletions warehouse/internal/repo/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,11 @@ func NewSource(db *sqlmw.DB, opts ...Opt) *Source {
}

func (s *Source) Insert(ctx context.Context, sourceJobs []model.SourceJob) ([]int64, error) {
txn, err := s.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return nil, fmt.Errorf(`begin transaction: %w`, err)
}
defer func() {
if err != nil {
_ = txn.Rollback()
}
}()
var ids []int64

stmt, err := txn.PrepareContext(
ctx, `
err := (*repo)(s).WithTx(ctx, func(tx *sqlmw.Tx) error {
stmt, err := tx.PrepareContext(
ctx, `
INSERT INTO `+sourceJobTableName+` (
source_id, destination_id, tablename,
status, created_at, updated_at, async_job_type,
Expand All @@ -66,36 +59,36 @@ func (s *Source) Insert(ctx context.Context, sourceJobs []model.SourceJob) ([]in
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id;
`,
)
if err != nil {
return nil, fmt.Errorf(`preparing statement: %w`, err)
}
defer func() { _ = stmt.Close() }()

var ids []int64
for _, sourceJob := range sourceJobs {
var id int64
err = stmt.QueryRowContext(
ctx,
sourceJob.SourceID,
sourceJob.DestinationID,
sourceJob.TableName,
model.SourceJobStatusWaiting.String(),
s.now(),
s.now(),
sourceJob.JobType.String(),
sourceJob.WorkspaceID,
sourceJob.Metadata,
).Scan(&id)
)
if err != nil {
return nil, fmt.Errorf(`executing: %w`, err)
return fmt.Errorf(`preparing statement: %w`, err)
}

Check warning on line 65 in warehouse/internal/repo/source.go

View check run for this annotation

Codecov / codecov/patch

warehouse/internal/repo/source.go#L64-L65

Added lines #L64 - L65 were not covered by tests
defer func() { _ = stmt.Close() }()

ids = append(ids, id)
}
for _, sourceJob := range sourceJobs {
var id int64
err = stmt.QueryRowContext(
ctx,
sourceJob.SourceID,
sourceJob.DestinationID,
sourceJob.TableName,
model.SourceJobStatusWaiting.String(),
s.now(),
s.now(),
sourceJob.JobType.String(),
sourceJob.WorkspaceID,
sourceJob.Metadata,
).Scan(&id)
if err != nil {
return fmt.Errorf(`executing: %w`, err)
}

Check warning on line 84 in warehouse/internal/repo/source.go

View check run for this annotation

Codecov / codecov/patch

warehouse/internal/repo/source.go#L83-L84

Added lines #L83 - L84 were not covered by tests

if err = txn.Commit(); err != nil {
return nil, fmt.Errorf(`committing: %w`, err)
ids = append(ids, id)
}
return nil
})
if err != nil {
return nil, err
}
return ids, nil
}
Expand Down
Loading

0 comments on commit 98a2ca2

Please sign in to comment.