Skip to content

Commit

Permalink
db: check tx.Rollback error / serrors: add Join (#4307)
Browse files Browse the repository at this point in the history
Check returned error of tx.Rollback where this had previously just been
a `defer tx.Rollback`. The errcheck linter (rightly) complains about this.
The Rollback was previously also be called after a successful commit, in
which case it would have returned the error `ErrTxDone`. Now, Rollback
is only called in error cases.

Not fixed the `defer tx.Rollback` pattern in colibri code, as this
appears to be pretty much dead code by now.

Add a `serrors.Join` helper function to make it easier to return
multiple potential errors. This API is modeled closely after
`errors.Join`, which will be available in the go standard library in the
next release (go 1.20).
  • Loading branch information
matzf authored Jan 17, 2023
1 parent 62706ec commit 013b5cf
Showing 6 changed files with 73 additions and 10 deletions.
2 changes: 0 additions & 2 deletions .golangcilint.yml
Original file line number Diff line number Diff line change
@@ -138,9 +138,7 @@ issues:
^private/mgmtapi/segments/api/api.go$|\
^private/path/combinator/combinator.go$|\
^private/revcache/revcachetest/revcachetest.go$|\
^private/segment/seghandler/storage.go$|\
^private/service/statuspages.go$|\
^private/storage/db/sqler.go$|\
^private/storage/trust/fspersister/db_test.go$|\
^private/svc/internal/ctxconn/ctxconn.go$|\
^private/trust/db_inspector.go$|\
22 changes: 22 additions & 0 deletions pkg/private/serrors/errors.go
Original file line number Diff line number Diff line change
@@ -263,6 +263,28 @@ func (e List) MarshalLogArray(ae zapcore.ArrayEncoder) error {
return nil
}

// Join returns an error that wraps the given errors in a List error.
// Any nil error values are discarded.
// Join returns nil if errs contains no non-nil values.
func Join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
l := make(List, 0, n)
for _, err := range errs {
if err != nil {
l = append(l, err)
}
}
return l
}

func errCtxToFields(errCtx []interface{}) map[string]interface{} {
if len(errCtx) == 0 {
return nil
30 changes: 30 additions & 0 deletions pkg/private/serrors/errors_test.go
Original file line number Diff line number Diff line change
@@ -270,6 +270,36 @@ func TestList(t *testing.T) {
assert.NotNil(t, combinedErr)
}

func TestJoinNil(t *testing.T) {
assert.Nil(t, serrors.Join())
assert.Nil(t, serrors.Join(nil))
assert.Nil(t, serrors.Join(nil, nil))
}

func TestJoin(t *testing.T) {
err1 := serrors.New("err1")
err2 := serrors.New("err2")
for _, test := range []struct {
errs []error
want serrors.List
}{{
errs: []error{err1},
want: serrors.List{err1},
}, {
errs: []error{err1},
want: serrors.List{err1},
}, {
errs: []error{err1, err2},
want: serrors.List{err1, err2},
}, {
errs: []error{err1, nil, err2},
want: serrors.List{err1, err2},
}} {
got := serrors.Join(test.errs...)
assert.Equal(t, got, test.want)
}
}

func TestAtMostOneStacktrace(t *testing.T) {
err := errors.New("core")
for i := range [20]int{} {
6 changes: 3 additions & 3 deletions private/segment/seghandler/storage.go
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ import (

"github.com/scionproto/scion/pkg/log"
"github.com/scionproto/scion/pkg/private/ctrl/path_mgmt"
"github.com/scionproto/scion/pkg/private/serrors"
seg "github.com/scionproto/scion/pkg/segment"
"github.com/scionproto/scion/private/pathdb"
"github.com/scionproto/scion/private/revcache"
@@ -68,7 +69,6 @@ func (s *DefaultStorage) StoreSegs(ctx context.Context, segs []*seg.Meta) (SegSt
if err != nil {
return SegStats{}, err
}
defer tx.Rollback()
// Sort to prevent sql deadlock.
sort.Slice(segs, func(i, j int) bool {
return segs[i].Segment.GetLoggingID() < segs[j].Segment.GetLoggingID()
@@ -77,7 +77,7 @@ func (s *DefaultStorage) StoreSegs(ctx context.Context, segs []*seg.Meta) (SegSt
for _, seg := range segs {
stats, err := tx.Insert(ctx, seg)
if err != nil {
return SegStats{}, err
return SegStats{}, serrors.Join(err, tx.Rollback())
}
if stats.Inserted > 0 {
segStats.InsertedSegs = append(segStats.InsertedSegs, seg.Segment.GetLoggingID())
@@ -86,7 +86,7 @@ func (s *DefaultStorage) StoreSegs(ctx context.Context, segs []*seg.Meta) (SegSt
}
}
if err := tx.Commit(); err != nil {
return SegStats{}, err
return SegStats{}, serrors.Join(err, tx.Rollback())
}
segStats.Log(ctx)
return segStats, nil
16 changes: 14 additions & 2 deletions private/segment/seghandler/storage_test.go
Original file line number Diff line number Diff line change
@@ -60,7 +60,6 @@ func TestDefaultStorageStoreSegs(t *testing.T) {
pathDB.EXPECT().BeginTransaction(gomock.Any(), gomock.Any()).
Return(tx, nil),
tx.EXPECT().Commit(),
tx.EXPECT().Rollback(),
)
return pathDB
},
@@ -80,6 +79,20 @@ func TestDefaultStorageStoreSegs(t *testing.T) {
},
ErrorAssertion: assert.Error,
},
"Rollback error": {
PathDB: func(ctrl *gomock.Controller) pathdb.DB {
pathDB := mock_pathdb.NewMockDB(ctrl)
tx := mock_pathdb.NewMockTransaction(ctrl)
gomock.InOrder(
pathDB.EXPECT().BeginTransaction(gomock.Any(), gomock.Any()).
Return(tx, nil),
tx.EXPECT().Commit().Return(errors.New("test err")),
tx.EXPECT().Rollback().Return(errors.New("test rollback err")),
)
return pathDB
},
ErrorAssertion: assert.Error,
},
"Stats correct": {
Segs: []*seg.Meta{
{Segment: seg110To130, Type: seg.TypeCore},
@@ -104,7 +117,6 @@ func TestDefaultStorageStoreSegs(t *testing.T) {
},
).Return(pathdb.InsertStats{Inserted: 1}, nil),
tx.EXPECT().Commit(),
tx.EXPECT().Rollback(),
)
return pathDB
},
7 changes: 4 additions & 3 deletions private/storage/db/sqler.go
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@ package db
import (
"context"
"database/sql"

"github.com/scionproto/scion/pkg/private/serrors"
)

var _ Sqler = (*sql.DB)(nil)
@@ -42,12 +44,11 @@ func DoInTx(ctx context.Context, db Sqler, action func(context.Context, *sql.Tx)
if tx, err = db.(*sql.DB).BeginTx(ctx, nil); err != nil {
return NewTxError("create tx", err)
}
defer tx.Rollback()
if err := action(ctx, tx); err != nil {
return err
return serrors.Join(err, tx.Rollback())
}
if err := tx.Commit(); err != nil {
return NewTxError("commit", err)
return serrors.Join(NewTxError("commit", err), tx.Rollback())
}
return nil
}

0 comments on commit 013b5cf

Please sign in to comment.