Skip to content

Commit

Permalink
Use pgx instead of sql
Browse files Browse the repository at this point in the history
  • Loading branch information
arielshaqed committed Oct 27, 2020
1 parent 725e181 commit a7affc3
Show file tree
Hide file tree
Showing 33 changed files with 185 additions and 182 deletions.
5 changes: 2 additions & 3 deletions auth/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ func insertOrGetInstallationID(tx db.Tx) (string, error) {
if err != nil {
return "", err
}
if affected, err := res.RowsAffected(); err != nil {
return "", err
} else if affected == 1 {
affected := res.RowsAffected()
if affected == 1 {
return newInstallationID, nil
}
return getInstallationID(tx)
Expand Down
10 changes: 3 additions & 7 deletions auth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"time"

sq "github.com/Masterminds/squirrel"
"github.com/georgysavva/scany/sqlscan"
"github.com/georgysavva/scany/pgxscan"

"github.com/treeverse/lakefs/auth/crypt"
"github.com/treeverse/lakefs/auth/model"
Expand Down Expand Up @@ -98,13 +98,12 @@ func ListPaged(db db.Database, retType reflect.Type, params *model.PaginationPar
if err != nil {
return nil, nil, fmt.Errorf("query DB: %w", err)
}
rowScanner := sqlscan.NewRowScanner(rows)
rowScanner := pgxscan.NewRowScanner(rows)
for rows.Next() {
value := reflect.New(retType)
if err = rowScanner.Scan(value.Interface()); err != nil {
return nil, nil, fmt.Errorf("scan value from DB: %w", err)
}
fmt.Printf("[DEBUG] row %+v value %+v scanner %+v\n", rows, value, rowScanner)
slice = reflect.Append(slice, value)
}
p := &model.Paginator{}
Expand Down Expand Up @@ -151,10 +150,7 @@ func deleteOrNotFound(tx db.Tx, stmt string, args ...interface{}) error {
if err != nil {
return err
}
numRows, err := res.RowsAffected()
if err != nil {
return err
}
numRows := res.RowsAffected()
if numRows == 0 {
return db.ErrNotFound
}
Expand Down
1 change: 0 additions & 1 deletion auth/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func userWithPolicies(t testing.TB, s auth.Service, policies []*model.Policy) st
func TestDBAuthService_ListPaged(t *testing.T) {
const chars = "abcdefghijklmnopqrstuvwxyz"
adb, _ := testutil.GetDB(t, databaseURI)
defer adb.Close()
type row struct{ A string }
if _, err := adb.Exec(`CREATE TABLE test_pages (a text PRIMARY KEY)`); err != nil {
t.Fatalf("CREATE TABLE test_pages: %s", err)
Expand Down
7 changes: 3 additions & 4 deletions catalog/cataloger.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ var ErrExpired = errors.New("expired from storage")

// ExpiryRows is a database iterator over ExpiryResults. Use Next to advance from row to row.
type ExpiryRows interface {
io.Closer
Close()
Next() bool
Err() error
// Read returns the current from ExpiryRows, or an error on failure. Call it only after
Expand Down Expand Up @@ -400,9 +400,8 @@ func (c *cataloger) dedupBatch(batch []*dedupRequest) {
if err != nil {
return nil, err
}
if rowsAffected, err := res.RowsAffected(); err != nil {
return nil, err
} else if rowsAffected == 1 {
rowsAffected := res.RowsAffected()
if rowsAffected == 1 {
// new address was added - continue
continue
}
Expand Down
6 changes: 3 additions & 3 deletions catalog/cataloger_commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func commitUpdateCommittedEntriesWithMaxCommit(tx db.Tx, branchID int64, commitI
if err != nil {
return 0, err
}
return res.RowsAffected()
return res.RowsAffected(), nil
}

func commitDeleteUncommittedTombstones(tx db.Tx, branchID int64, commitID CommitID) (int64, error) {
Expand All @@ -110,7 +110,7 @@ func commitDeleteUncommittedTombstones(tx db.Tx, branchID int64, commitID Commit
if err != nil {
return 0, err
}
return res.RowsAffected()
return res.RowsAffected(), nil
}

func commitEntries(tx db.Tx, branchID int64, commitID CommitID) (int64, error) {
Expand All @@ -119,5 +119,5 @@ func commitEntries(tx db.Tx, branchID int64, commitID CommitID) (int64, error) {
if err != nil {
return 0, err
}
return res.RowsAffected()
return res.RowsAffected(), nil
}
5 changes: 2 additions & 3 deletions catalog/cataloger_delete_branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ func (c *cataloger) DeleteBranch(ctx context.Context, repository, branch string)
if err != nil {
return nil, fmt.Errorf("delete branch: %w", err)
}
if affected, err := res.RowsAffected(); err != nil {
return nil, err
} else if affected != 1 {
affected := res.RowsAffected()
if affected != 1 {
return nil, ErrBranchNotFound
}
return nil, nil
Expand Down
5 changes: 1 addition & 4 deletions catalog/cataloger_delete_entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ func (c *cataloger) DeleteEntry(ctx context.Context, repository, branch string,
if err != nil {
return nil, fmt.Errorf("uncommitted: %w", err)
}
deletedUncommittedCount, err := res.RowsAffected()
if err != nil {
return nil, fmt.Errorf("rows affected: %w", err)
}
deletedUncommittedCount := res.RowsAffected()

// get uncommitted entry based on path
lineage, err := getLineage(tx, branchID, UncommittedID)
Expand Down
5 changes: 2 additions & 3 deletions catalog/cataloger_delete_multipart_upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ func (c *cataloger) DeleteMultipartUpload(ctx context.Context, repository string
if err != nil {
return nil, err
}
if affected, err := res.RowsAffected(); err != nil {
return nil, err
} else if affected != 1 {
affected := res.RowsAffected()
if affected != 1 {
return nil, ErrMultipartUploadNotFound
}
return nil, nil
Expand Down
5 changes: 2 additions & 3 deletions catalog/cataloger_delete_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ func (c *cataloger) DeleteRepository(ctx context.Context, repository string) err
if err != nil {
return nil, err
}
if affected, err := res.RowsAffected(); err != nil {
return nil, err
} else if affected != 1 {
affected := res.RowsAffected()
if affected != 1 {
return nil, ErrRepositoryNotFound
}
return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion catalog/cataloger_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ func (c *cataloger) diffFromChild(ctx context.Context, tx db.Tx, params *doDiffP
return batch.Flush()
}

func createDiffResultsTable(ctx context.Context, executor sq.Execer) (string, error) {
func createDiffResultsTable(ctx context.Context, executor db.Tx) (string, error) {
diffResultsTableName, err := diffResultsTableNameFromContext(ctx)
if err != nil {
return "", err
Expand Down
13 changes: 4 additions & 9 deletions catalog/cataloger_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"regexp"

"github.com/georgysavva/scany/pgxscan"
"github.com/lib/pq"
"github.com/treeverse/lakefs/db"
)
Expand Down Expand Up @@ -35,7 +36,7 @@ func (c *cataloger) GetExportConfigurationForBranch(repository string, branch st
if err != nil {
return nil, err
}
err = c.db.Get(&ret,
err = c.db.GetStruct(&ret,
`SELECT export_path, export_status_path, last_keys_in_prefix_regexp
FROM catalog_branches_export
WHERE branch_id = $1`, branchID)
Expand All @@ -58,14 +59,8 @@ func (c *cataloger) GetExportConfigurations() ([]ExportConfigurationForBranch, e
if err != nil {
return nil, err
}
for rows.Next() {
var rec ExportConfigurationForBranch
if err = rows.StructScan(&rec); err != nil {
return nil, fmt.Errorf("scan configuration %+v: %w", rows, err)
}
ret = append(ret, rec)
}
return ret, nil
err = pgxscan.ScanAll(&ret, rows)
return ret, err
}

func (c *cataloger) PutExportConfiguration(repository string, branch string, conf *ExportConfiguration) error {
Expand Down
2 changes: 0 additions & 2 deletions catalog/cataloger_list_commits.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ func (c *cataloger) ListCommits(ctx context.Context, repository, branch string,
if err := tx.Select(&rawCommits, query, fromCommitID, limit+1); err != nil {
return nil, err
}
fmt.Printf("[DEBUG] rawcommits %+v\n", rawCommits)
commits := convertRawCommits(rawCommits)
fmt.Printf("[DEBUG] commits %+v\n", commits)
return commits, nil
}, c.txOpts(ctx, db.ReadOnly())...)

Expand Down
4 changes: 2 additions & 2 deletions catalog/cataloger_list_entries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package catalog

import (
"context"
"database/sql"
"fmt"
"strings"

sq "github.com/Masterminds/squirrel"
"github.com/jackc/pgx/v4"
"github.com/treeverse/lakefs/db"
)

Expand Down Expand Up @@ -250,7 +250,7 @@ func getMoreRows(path string, branch int64, branchRanges map[int64][]entryPathPr
}
err = readParams.tx.Select(&readBuf, s, args...)
if len(readBuf) == 0 {
err = sql.ErrNoRows
err = pgx.ErrNoRows
}
if err != nil {
return err
Expand Down
6 changes: 1 addition & 5 deletions catalog/cataloger_reset_branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ func (c *cataloger) ResetBranch(ctx context.Context, repository, branch string)
if err != nil {
return nil, err
}
res, err := tx.Exec(`DELETE FROM catalog_entries WHERE branch_id=$1 AND min_commit=$2`, branchID, MinCommitUncommittedIndicator)
if err != nil {
return nil, err
}
_, err = res.RowsAffected()
_, err = tx.Exec(`DELETE FROM catalog_entries WHERE branch_id=$1 AND min_commit=$2`, branchID, MinCommitUncommittedIndicator)
return nil, err
}, c.txOpts(ctx)...)
return err
Expand Down
5 changes: 2 additions & 3 deletions catalog/cataloger_reset_entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ func (c *cataloger) ResetEntry(ctx context.Context, repository, branch string, p
if err != nil {
return nil, err
}
if affected, err := res.RowsAffected(); err != nil {
return nil, err
} else if affected != 1 {
affected := res.RowsAffected()
if affected != 1 {
return nil, ErrEntryNotFound
}
return nil, nil
Expand Down
26 changes: 11 additions & 15 deletions catalog/cataloger_retention.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
"strings"
"time"

"database/sql"

sq "github.com/Masterminds/squirrel"
"github.com/georgysavva/scany/pgxscan"
"github.com/jackc/pgx/v4"

"github.com/georgysavva/scany/sqlscan"
"github.com/treeverse/lakefs/db"
"github.com/treeverse/lakefs/logging"
)
Expand Down Expand Up @@ -162,7 +161,7 @@ func buildRetentionQuery(repositoryName string, policy *Policy, afterRow sq.RowS

// expiryRows implements ExpiryRows.
type expiryRows struct {
rows *sql.Rows
rows pgx.Rows
RepositoryName string
}

Expand All @@ -174,13 +173,13 @@ func (e *expiryRows) Err() error {
return e.rows.Err()
}

func (e *expiryRows) Close() error {
return e.rows.Close()
func (e *expiryRows) Close() {
e.rows.Close()
}

func (e *expiryRows) Read() (*ExpireResult, error) {
var record retentionQueryRecord
err := sqlscan.ScanRow(&record, e.rows)
err := pgxscan.ScanRow(&record, e.rows)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -284,10 +283,7 @@ func (c *cataloger) MarkEntriesExpired(ctx context.Context, repositoryName strin
if err != nil {
return nil, fmt.Errorf("updating entries to expire: %w", err)
}
count, err := result.RowsAffected()
if err != nil {
return nil, fmt.Errorf("getting number of updated rows: %w", err)
}
count := result.RowsAffected()
return int(count), nil
})
if err != nil {
Expand Down Expand Up @@ -319,11 +315,11 @@ func (c *cataloger) MarkObjectsForDeletion(ctx context.Context, repositoryName s
if err != nil {
return 0, err
}
return result.RowsAffected()
return result.RowsAffected(), nil
}

type StringRows struct {
rows *sql.Rows
rows pgx.Rows
}

func (s *StringRows) Next() bool {
Expand All @@ -334,8 +330,8 @@ func (s *StringRows) Err() error {
return s.rows.Err()
}

func (s *StringRows) Close() error {
return s.rows.Close()
func (s *StringRows) Close() {
s.rows.Close()
}

func (s *StringRows) Read() (string, error) {
Expand Down
16 changes: 6 additions & 10 deletions catalog/cataloger_retention_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package catalog

import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"testing"
"time"

"github.com/jackc/pgx/v4"
"github.com/treeverse/lakefs/db/params"

"github.com/go-test/deep"
Expand All @@ -26,11 +26,7 @@ func readEntriesToExpire(t *testing.T, ctx context.Context, c Cataloger, reposit
if err != nil {
t.Fatalf("scan for expired failed: %s", err)
}
defer func() {
if err := rows.Close(); err != nil {
t.Fatalf("close rows from expire result %s", err)
}
}()
defer rows.Close()
ret := make([]*ExpireResult, 0, 10)
for rows.Next() {
e, err := rows.Read()
Expand Down Expand Up @@ -559,7 +555,7 @@ func TestCataloger_MarkEntriesExpired(t *testing.T) {
}
}

func getDeleting(t *testing.T, rows *sql.Rows) map[string]bool {
func getDeleting(t *testing.T, rows pgx.Rows) map[string]bool {
t.Helper()
deleting := make(map[string]bool, 2)
for rows.Next() {
Expand Down Expand Up @@ -771,9 +767,9 @@ func TestCataloger_DeleteOrUnmarkObjectsForDeletion(t *testing.T) {
if err != nil {
t.Fatalf("[internal] failed to set 2 objects to state deleting, %s", err)
}
numRows, err := res.RowsAffected()
if err != nil || numRows != 2 {
t.Fatalf("[internal] failed to set 2 objects to state deleting: %d objects set, %s", numRows, err)
numRows := res.RowsAffected()
if numRows != 2 {
t.Fatalf("[internal] failed to set 2 objects to state deleting: %d objects set", numRows)
}

deleteRows, err := c.DeleteOrUnmarkObjectsForDeletion(ctx, repository)
Expand Down
4 changes: 2 additions & 2 deletions catalog/db_batch_entry_read.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package catalog

import (
"database/sql"
"fmt"
"sync"
"time"

sq "github.com/Masterminds/squirrel"
"github.com/jackc/pgx/v4"
"github.com/treeverse/lakefs/db"
)

Expand Down Expand Up @@ -148,7 +148,7 @@ func (c *cataloger) dbSelectBatchEntries(repository string, ref Ref, pathReqList
return nil, fmt.Errorf("select entries: %w", err)
}
return entries, nil
}, db.WithLogger(c.log), db.ReadOnly(), db.WithIsolationLevel(sql.LevelReadCommitted))
}, db.WithLogger(c.log), db.ReadOnly(), db.WithIsolationLevel(pgx.ReadCommitted))
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit a7affc3

Please sign in to comment.