Skip to content

Commit

Permalink
refactor: uploads default logFields and repo load files queries (#4089)
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr authored Nov 7, 2023
1 parent 094dc2b commit 4610c7e
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 146 deletions.
118 changes: 110 additions & 8 deletions warehouse/internal/repo/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package repo

import (
"context"
"database/sql"
jsonstd "encoding/json"
"errors"
"fmt"

"github.com/lib/pq"
Expand Down Expand Up @@ -101,14 +103,15 @@ func (lf *LoadFiles) Insert(ctx context.Context, loadFiles []model.LoadFile) err
// Ordered by id ascending.
func (lf *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []int64) ([]model.LoadFile, error) {
sqlStatement := `
WITH row_numbered_load_files as (
WITH row_numbered_load_files AS (
SELECT
` + loadTableColumns + `,
row_number() OVER (
PARTITION BY staging_file_id,
table_name
ORDER BY
id DESC
PARTITION BY
staging_file_id,
table_name
ORDER BY
id DESC
) AS row_number
FROM
` + loadTableName + `
Expand All @@ -118,10 +121,11 @@ func (lf *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []int
SELECT
` + loadTableColumns + `
FROM
row_numbered_load_files
row_numbered_load_files
WHERE
row_number = 1
ORDER BY id ASC
row_number = 1
ORDER BY
id ASC;
`

rows, err := lf.db.QueryContext(ctx, sqlStatement, pq.Array(stagingFileIDs))
Expand Down Expand Up @@ -173,3 +177,101 @@ func (lf *LoadFiles) GetByStagingFiles(ctx context.Context, stagingFileIDs []int

return loadFiles, nil
}

// TotalExportedEvents returns the total number of events exported by the corresponding staging files.
// It excludes the tables present in skipTables.
func (lf *LoadFiles) TotalExportedEvents(
ctx context.Context,
stagingFileIDs []int64,
skipTables []string,
) (int64, error) {
var (
count sql.NullInt64
err error
)

if skipTables == nil {
skipTables = []string{}
}

sqlStatement := `
WITH row_numbered_load_files AS (
SELECT
total_events,
table_name,
row_number() OVER (
PARTITION BY
staging_file_id,
table_name
ORDER BY
id DESC
) AS row_number
FROM
` + loadTableName + `
WHERE
staging_file_id = ANY($1)
)
SELECT
COALESCE(sum(total_events), 0) AS total_events
FROM
row_numbered_load_files
WHERE
row_number = 1
AND
table_name != ALL($2);`

err = lf.db.QueryRowContext(ctx, sqlStatement, pq.Array(stagingFileIDs), pq.Array(skipTables)).Scan(&count)
if err != nil {
return 0, fmt.Errorf(`counting total exported events: %w`, err)
}
if !count.Valid {
return 0, errors.New(`count is not valid`)
}
return count.Int64, nil
}

// DistinctTableName returns the distinct table names for the given parameters.
func (lf *LoadFiles) DistinctTableName(
ctx context.Context,
sourceID string,
destinationID string,
startID int64,
endID int64,
) ([]string, error) {
rows, err := lf.db.QueryContext(ctx, `
SELECT
distinct table_name
FROM
`+loadTableName+`
WHERE
source_id = $1
AND destination_id = $2
AND id >= $3
AND id <= $4;`,
sourceID,
destinationID,
startID,
endID,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("querying load files: %w", err)
}
defer func() { _ = rows.Close() }()

var tableNames []string
for rows.Next() {
var tableName string
err := rows.Scan(&tableName)
if err != nil {
return nil, fmt.Errorf(`scanning table names: %w`, err)
}
tableNames = append(tableNames, tableName)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("querying table names: %w", err)
}
return tableNames, nil
}
149 changes: 149 additions & 0 deletions warehouse/internal/repo/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package repo_test
import (
"context"
"fmt"
"strconv"
"testing"
"time"

"github.com/samber/lo"

"github.com/stretchr/testify/require"

"github.com/rudderlabs/rudder-server/warehouse/internal/model"
Expand Down Expand Up @@ -109,3 +112,149 @@ func Test_LoadFiles(t *testing.T) {
require.Equal(t, lastLoadFile, gotLoadFiles[0])
})
}

func TestLoadFiles_TotalExportedEvents(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second).UTC()
db := setupDB(t)

r := repo.NewLoadFiles(db, repo.WithNow(func() time.Time {
return now
}))

stagingFilesCount := 960
loadFilesCount := 25
retriesCount := 3

loadFiles := make([]model.LoadFile, 0, stagingFilesCount*loadFilesCount*retriesCount)
stagingFileIDs := make([]int64, 0, stagingFilesCount)

for i := 0; i < stagingFilesCount; i++ {
for j := 0; j < loadFilesCount; j++ {
for k := 0; k < retriesCount; k++ {
loadFiles = append(loadFiles, model.LoadFile{
TableName: "table_name_" + strconv.Itoa(j+1),
Location: "s3://bucket/path/to/file",
TotalRows: (i + 1) + (j + 1) + (k + 1),
ContentLength: 1000,
StagingFileID: int64(i + 1),
DestinationRevisionID: "revision_id",
UseRudderStorage: true,
SourceID: "source_id",
DestinationID: "destination_id",
DestinationType: "RS",
})
}
}
stagingFileIDs = append(stagingFileIDs, int64(i+1))
}

err := r.Insert(ctx, loadFiles)
require.NoError(t, err)

t.Run("no staging files", func(t *testing.T) {
exportedEvents, err := r.TotalExportedEvents(ctx, []int64{-1}, []string{})
require.NoError(t, err)
require.Zero(t, exportedEvents)
})
t.Run("without skip tables", func(t *testing.T) {
exportedEvents, err := r.TotalExportedEvents(ctx, stagingFileIDs, nil)
require.NoError(t, err)

actualEvents := lo.SumBy(stagingFileIDs, func(item int64) int64 {
sum := 0
for j := 0; j < loadFilesCount; j++ {
sum += int(item) + (j + 1) + retriesCount
}
return int64(sum)
})
require.Equal(t, actualEvents, exportedEvents)
})
t.Run("with skip tables", func(t *testing.T) {
excludeIDS := []int64{1, 3, 5, 7, 9}

skipTable := lo.Map(excludeIDS, func(item int64, index int) string {
return "table_name_" + strconv.Itoa(int(item))
})

exportedEvents, err := r.TotalExportedEvents(ctx, stagingFileIDs, skipTable) // 11916000
require.NoError(t, err)

actualEvents := lo.SumBy(stagingFileIDs, func(item int64) int64 {
sum := 0
for j := 0; j < loadFilesCount; j++ {
if lo.Contains(excludeIDS, int64(j+1)) {
continue
}
sum += int(item) + (j + 1) + retriesCount
}
return int64(sum)
})
require.Equal(t, actualEvents, exportedEvents)
})
t.Run("context cancelled", func(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
cancel()

exportedEvents, err := r.TotalExportedEvents(ctx, stagingFileIDs, nil)
require.ErrorIs(t, err, context.Canceled)
require.Zero(t, exportedEvents)
})
}

func TestLoadFiles_DistinctTableName(t *testing.T) {
sourceID := "source_id"
destinationID := "destination_id"

ctx := context.Background()
now := time.Now().Truncate(time.Second).UTC()
db := setupDB(t)

r := repo.NewLoadFiles(db, repo.WithNow(func() time.Time {
return now
}))

stagingFilesCount := 960
loadFilesCount := 25

loadFiles := make([]model.LoadFile, 0, stagingFilesCount*loadFilesCount)

for i := 0; i < stagingFilesCount; i++ {
for j := 0; j < loadFilesCount; j++ {
loadFiles = append(loadFiles, model.LoadFile{
TableName: "table_name_" + strconv.Itoa(j+1),
Location: "s3://bucket/path/to/file",
TotalRows: (i + 1) + (j + 1),
ContentLength: 1000,
StagingFileID: int64(i + 1),
DestinationRevisionID: "revision_id",
UseRudderStorage: true,
SourceID: sourceID,
DestinationID: destinationID,
DestinationType: "RS",
})
}
}

err := r.Insert(ctx, loadFiles)
require.NoError(t, err)

t.Run("no staging files", func(t *testing.T) {
tables, err := r.DistinctTableName(ctx, sourceID, destinationID, -1, -1)
require.NoError(t, err)
require.Zero(t, tables)
})
t.Run("some staging files", func(t *testing.T) {
tables, err := r.DistinctTableName(ctx, sourceID, destinationID, 1, int64(len(loadFiles)))
require.NoError(t, err)
require.Len(t, tables, loadFilesCount)
})
t.Run("context cancelled", func(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
cancel()

tables, err := r.DistinctTableName(ctx, sourceID, destinationID, -1, -1)
require.ErrorIs(t, err, context.Canceled)
require.Zero(t, tables)
})
}
Loading

0 comments on commit 4610c7e

Please sign in to comment.