From dfeba0ff8651f49f6299190ada82bf96ac7b7b42 Mon Sep 17 00:00:00 2001 From: "Ariel Shaqed (Scolnicov)" Date: Tue, 24 Nov 2020 11:34:56 +0200 Subject: [PATCH 1/3] [bugfix] return ErrNotFound correctly Package `scany` depends on a different version of `pgx` than the rest of lakeFS. So `errors.Is(err, pgx.ErrNoRows)` fails. Luckily it (sort-of) knows of this issue and wraps this call inside it as `pgxscan.NotFound`. Also make `ErrNotFound` wrap `pgx.ErrNoRows` rather than a new error. --- db/errors.go | 5 ++++- db/tx.go | 9 ++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/db/errors.go b/db/errors.go index fab714d5de0..6a8881840ff 100644 --- a/db/errors.go +++ b/db/errors.go @@ -2,10 +2,13 @@ package db import ( "errors" + "fmt" + + "github.com/jackc/pgx/v4" ) var ( - ErrNotFound = errors.New("not found") + ErrNotFound = fmt.Errorf("not found: %w", pgx.ErrNoRows) ErrAlreadyExists = errors.New("already exists") ErrSerialization = errors.New("serialization error") ErrNotASlice = errors.New("results must be a pointer to a slice") diff --git a/db/tx.go b/db/tx.go index 472efd2f2b2..fcad080da00 100644 --- a/db/tx.go +++ b/db/tx.go @@ -2,7 +2,6 @@ package db import ( "context" - "errors" "fmt" "strings" "time" @@ -74,7 +73,9 @@ func (d *dbTx) Get(dest interface{}, query string, args ...interface{}) error { "took": time.Since(start), }) err := pgxscan.Get(context.Background(), d.tx, dest, query, args...) - if errors.Is(err, pgx.ErrNoRows) { + if pgxscan.NotFound(err) { + // Don't wrap err: it might come from a different version of pgx and then + // !errors.Is(err, pgx.ErrNoRows). log.Trace("SQL query returned no results") return ErrNotFound } @@ -97,7 +98,9 @@ func (d *dbTx) GetPrimitive(dest interface{}, query string, args ...interface{}) }) row := d.tx.QueryRow(context.Background(), query, args...) err := row.Scan(dest) - if errors.Is(err, pgx.ErrNoRows) { + if pgxscan.NotFound(err) { + // Don't wrap err: it might come from a different version of pgx and then + // !errors.Is(err, pgx.ErrNoRows). log.Trace("SQL query returned no results") return ErrNotFound } From 64c93c5717e36bb22e00f6e17fe30ee9bdf4dad1 Mon Sep 17 00:00:00 2001 From: "Ariel Shaqed (Scolnicov)" Date: Tue, 24 Nov 2020 12:28:23 +0200 Subject: [PATCH 2/3] Test Get, GetPrimitive --- db/tx_test.go | 90 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 db/tx_test.go diff --git a/db/tx_test.go b/db/tx_test.go new file mode 100644 index 00000000000..48e2484f2be --- /dev/null +++ b/db/tx_test.go @@ -0,0 +1,90 @@ +package db_test + +import ( + "errors" + "testing" + + "github.com/treeverse/lakefs/db" + "github.com/treeverse/lakefs/db/params" +) + +func getDB(t *testing.T) db.Database { + t.Helper() + ret, err := db.ConnectDB(params.Database{Driver: "pgx", ConnectionString: databaseURI}) + if err != nil { + t.Fatal("failed to get DB") + } + return ret +} + +func TestGetPrimitive(t *testing.T) { + d := getDB(t) + + t.Run("success", func(t *testing.T) { + ret, err := d.Transact(func(tx db.Tx) (interface{}, error) { + var i int64 + err := tx.GetPrimitive(&i, "SELECT 17") + return i, err + }) + + if err != nil { + t.Errorf("failed to SELECT 17: %s", err) + } + i := ret.(int64) + if i != 17 { + t.Errorf("got %d not 17 from SELECT 17", i) + } + }) + + t.Run("failure", func(t *testing.T) { + _, err := d.Transact(func(tx db.Tx) (interface{}, error) { + var i int64 + err := tx.GetPrimitive(&i, "SELECT 17 WHERE 2=1") + return i, err + }) + + if !errors.Is(err, db.ErrNotFound) { + t.Errorf("got %s wanted not found", err) + } + }) +} + +func TestGet(t *testing.T) { + type R struct { + A int64 + B string + } + + d := getDB(t) + + t.Run("success", func(t *testing.T) { + ret, err := d.Transact(func(tx db.Tx) (interface{}, error) { + var r R + err := tx.Get(&r, "SELECT 17 A, 'foo' B") + return &r, err + }) + + if err != nil { + t.Errorf("failed to SELECT 17 and 'foo': %s", err) + } + r := ret.(*R) + if r.A != 17 { + t.Errorf("got %+v with A != 17 from SELECT 17 and 'foo'", r) + } + if r.B != "foo" { + t.Errorf("got %+v with B != 'foo' from SELECT 17 and 'foo'", r) + } + }) + + t.Run("failure", func(t *testing.T) { + _, err := d.Transact(func(tx db.Tx) (interface{}, error) { + var r R + err := tx.Get(&r, "SELECT 17 A, 'foo' B WHERE 2=1") + return &r, err + }) + + if !errors.Is(err, db.ErrNotFound) { + t.Errorf("got %s wanted not found", err) + } + }) +} From 36c0fbb93b81119de59c2f7830013fb43dbde793 Mon Sep 17 00:00:00 2001 From: "Ariel Shaqed (Scolnicov)" Date: Tue, 24 Nov 2020 13:11:22 +0200 Subject: [PATCH 3/3] [CR] GetPrimitive doesn't call pgxscan, use pgx directly there --- db/tx.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/db/tx.go b/db/tx.go index fcad080da00..4a15ef84bc8 100644 --- a/db/tx.go +++ b/db/tx.go @@ -2,6 +2,7 @@ package db import ( "context" + "errors" "fmt" "strings" "time" @@ -74,7 +75,7 @@ func (d *dbTx) Get(dest interface{}, query string, args ...interface{}) error { }) err := pgxscan.Get(context.Background(), d.tx, dest, query, args...) if pgxscan.NotFound(err) { - // Don't wrap err: it might come from a different version of pgx and then + // Don't wrap this err: it might come from a different version of pgx and then // !errors.Is(err, pgx.ErrNoRows). log.Trace("SQL query returned no results") return ErrNotFound @@ -98,9 +99,7 @@ func (d *dbTx) GetPrimitive(dest interface{}, query string, args ...interface{}) }) row := d.tx.QueryRow(context.Background(), query, args...) err := row.Scan(dest) - if pgxscan.NotFound(err) { - // Don't wrap err: it might come from a different version of pgx and then - // !errors.Is(err, pgx.ErrNoRows). + if errors.Is(err, pgx.ErrNoRows) { log.Trace("SQL query returned no results") return ErrNotFound }