diff --git a/db/errors.go b/db/errors.go index fab714d5de0..6a8881840ff 100644 --- a/db/errors.go +++ b/db/errors.go @@ -2,10 +2,13 @@ package db import ( "errors" + "fmt" + + "github.com/jackc/pgx/v4" ) var ( - ErrNotFound = errors.New("not found") + ErrNotFound = fmt.Errorf("not found: %w", pgx.ErrNoRows) ErrAlreadyExists = errors.New("already exists") ErrSerialization = errors.New("serialization error") ErrNotASlice = errors.New("results must be a pointer to a slice") diff --git a/db/tx.go b/db/tx.go index 472efd2f2b2..4a15ef84bc8 100644 --- a/db/tx.go +++ b/db/tx.go @@ -74,7 +74,9 @@ func (d *dbTx) Get(dest interface{}, query string, args ...interface{}) error { "took": time.Since(start), }) err := pgxscan.Get(context.Background(), d.tx, dest, query, args...) - if errors.Is(err, pgx.ErrNoRows) { + if pgxscan.NotFound(err) { + // Don't wrap this err: it might come from a different version of pgx and then + // !errors.Is(err, pgx.ErrNoRows). log.Trace("SQL query returned no results") return ErrNotFound } diff --git a/db/tx_test.go b/db/tx_test.go new file mode 100644 index 00000000000..48e2484f2be --- /dev/null +++ b/db/tx_test.go @@ -0,0 +1,90 @@ +package db_test + +import ( + "errors" + "testing" + + "github.com/treeverse/lakefs/db" + "github.com/treeverse/lakefs/db/params" +) + +func getDB(t *testing.T) db.Database { + t.Helper() + ret, err := db.ConnectDB(params.Database{Driver: "pgx", ConnectionString: databaseURI}) + if err != nil { + t.Fatal("failed to get DB") + } + return ret +} + +func TestGetPrimitive(t *testing.T) { + d := getDB(t) + + t.Run("success", func(t *testing.T) { + ret, err := d.Transact(func(tx db.Tx) (interface{}, error) { + var i int64 + err := tx.GetPrimitive(&i, "SELECT 17") + return i, err + }) + + if err != nil { + t.Errorf("failed to SELECT 17: %s", err) + } + i := ret.(int64) + if i != 17 { + t.Errorf("got %d not 17 from SELECT 17", i) + } + }) + + t.Run("failure", func(t *testing.T) { + _, err := d.Transact(func(tx db.Tx) (interface{}, error) { + var i int64 + err := tx.GetPrimitive(&i, "SELECT 17 WHERE 2=1") + return i, err + }) + + if !errors.Is(err, db.ErrNotFound) { + t.Errorf("got %s wanted not found", err) + } + }) +} + +func TestGet(t *testing.T) { + type R struct { + A int64 + B string + } + + d := getDB(t) + + t.Run("success", func(t *testing.T) { + ret, err := d.Transact(func(tx db.Tx) (interface{}, error) { + var r R + err := tx.Get(&r, "SELECT 17 A, 'foo' B") + return &r, err + }) + + if err != nil { + t.Errorf("failed to SELECT 17 and 'foo': %s", err) + } + r := ret.(*R) + if r.A != 17 { + t.Errorf("got %+v with A != 17 from SELECT 17 and 'foo'", r) + } + if r.B != "foo" { + t.Errorf("got %+v with B != 'foo' from SELECT 17 and 'foo'", r) + } + }) + + t.Run("failure", func(t *testing.T) { + _, err := d.Transact(func(tx db.Tx) (interface{}, error) { + var r R + err := tx.Get(&r, "SELECT 17 A, 'foo' B WHERE 2=1") + return &r, err + }) + + if !errors.Is(err, db.ErrNotFound) { + t.Errorf("got %s wanted not found", err) + } + }) +}