diff --git a/.golangcilint.yml b/.golangcilint.yml index 05d03ec78..100bb83af 100644 --- a/.golangcilint.yml +++ b/.golangcilint.yml @@ -138,9 +138,7 @@ issues: ^private/mgmtapi/segments/api/api.go$|\ ^private/path/combinator/combinator.go$|\ ^private/revcache/revcachetest/revcachetest.go$|\ - ^private/segment/seghandler/storage.go$|\ ^private/service/statuspages.go$|\ - ^private/storage/db/sqler.go$|\ ^private/storage/trust/fspersister/db_test.go$|\ ^private/svc/internal/ctxconn/ctxconn.go$|\ ^private/trust/db_inspector.go$|\ diff --git a/pkg/private/serrors/errors.go b/pkg/private/serrors/errors.go index ea4ee0fff..33c1c7aa3 100644 --- a/pkg/private/serrors/errors.go +++ b/pkg/private/serrors/errors.go @@ -263,6 +263,28 @@ func (e List) MarshalLogArray(ae zapcore.ArrayEncoder) error { return nil } +// Join returns an error that wraps the given errors in a List error. +// Any nil error values are discarded. +// Join returns nil if errs contains no non-nil values. +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + l := make(List, 0, n) + for _, err := range errs { + if err != nil { + l = append(l, err) + } + } + return l +} + func errCtxToFields(errCtx []interface{}) map[string]interface{} { if len(errCtx) == 0 { return nil diff --git a/pkg/private/serrors/errors_test.go b/pkg/private/serrors/errors_test.go index 0deb23aae..38a261070 100644 --- a/pkg/private/serrors/errors_test.go +++ b/pkg/private/serrors/errors_test.go @@ -270,6 +270,36 @@ func TestList(t *testing.T) { assert.NotNil(t, combinedErr) } +func TestJoinNil(t *testing.T) { + assert.Nil(t, serrors.Join()) + assert.Nil(t, serrors.Join(nil)) + assert.Nil(t, serrors.Join(nil, nil)) +} + +func TestJoin(t *testing.T) { + err1 := serrors.New("err1") + err2 := serrors.New("err2") + for _, test := range []struct { + errs []error + want serrors.List + }{{ + errs: []error{err1}, + want: serrors.List{err1}, + }, { + errs: []error{err1}, + want: serrors.List{err1}, + }, { + errs: []error{err1, err2}, + want: serrors.List{err1, err2}, + }, { + errs: []error{err1, nil, err2}, + want: serrors.List{err1, err2}, + }} { + got := serrors.Join(test.errs...) + assert.Equal(t, got, test.want) + } +} + func TestAtMostOneStacktrace(t *testing.T) { err := errors.New("core") for i := range [20]int{} { diff --git a/private/segment/seghandler/storage.go b/private/segment/seghandler/storage.go index ddc693f55..33e16ca54 100644 --- a/private/segment/seghandler/storage.go +++ b/private/segment/seghandler/storage.go @@ -20,6 +20,7 @@ import ( "github.com/scionproto/scion/pkg/log" "github.com/scionproto/scion/pkg/private/ctrl/path_mgmt" + "github.com/scionproto/scion/pkg/private/serrors" seg "github.com/scionproto/scion/pkg/segment" "github.com/scionproto/scion/private/pathdb" "github.com/scionproto/scion/private/revcache" @@ -68,7 +69,6 @@ func (s *DefaultStorage) StoreSegs(ctx context.Context, segs []*seg.Meta) (SegSt if err != nil { return SegStats{}, err } - defer tx.Rollback() // Sort to prevent sql deadlock. sort.Slice(segs, func(i, j int) bool { return segs[i].Segment.GetLoggingID() < segs[j].Segment.GetLoggingID() @@ -77,7 +77,7 @@ func (s *DefaultStorage) StoreSegs(ctx context.Context, segs []*seg.Meta) (SegSt for _, seg := range segs { stats, err := tx.Insert(ctx, seg) if err != nil { - return SegStats{}, err + return SegStats{}, serrors.Join(err, tx.Rollback()) } if stats.Inserted > 0 { segStats.InsertedSegs = append(segStats.InsertedSegs, seg.Segment.GetLoggingID()) @@ -86,7 +86,7 @@ func (s *DefaultStorage) StoreSegs(ctx context.Context, segs []*seg.Meta) (SegSt } } if err := tx.Commit(); err != nil { - return SegStats{}, err + return SegStats{}, serrors.Join(err, tx.Rollback()) } segStats.Log(ctx) return segStats, nil diff --git a/private/segment/seghandler/storage_test.go b/private/segment/seghandler/storage_test.go index 9993ada9b..8c4a27672 100644 --- a/private/segment/seghandler/storage_test.go +++ b/private/segment/seghandler/storage_test.go @@ -60,7 +60,6 @@ func TestDefaultStorageStoreSegs(t *testing.T) { pathDB.EXPECT().BeginTransaction(gomock.Any(), gomock.Any()). Return(tx, nil), tx.EXPECT().Commit(), - tx.EXPECT().Rollback(), ) return pathDB }, @@ -80,6 +79,20 @@ func TestDefaultStorageStoreSegs(t *testing.T) { }, ErrorAssertion: assert.Error, }, + "Rollback error": { + PathDB: func(ctrl *gomock.Controller) pathdb.DB { + pathDB := mock_pathdb.NewMockDB(ctrl) + tx := mock_pathdb.NewMockTransaction(ctrl) + gomock.InOrder( + pathDB.EXPECT().BeginTransaction(gomock.Any(), gomock.Any()). + Return(tx, nil), + tx.EXPECT().Commit().Return(errors.New("test err")), + tx.EXPECT().Rollback().Return(errors.New("test rollback err")), + ) + return pathDB + }, + ErrorAssertion: assert.Error, + }, "Stats correct": { Segs: []*seg.Meta{ {Segment: seg110To130, Type: seg.TypeCore}, @@ -104,7 +117,6 @@ func TestDefaultStorageStoreSegs(t *testing.T) { }, ).Return(pathdb.InsertStats{Inserted: 1}, nil), tx.EXPECT().Commit(), - tx.EXPECT().Rollback(), ) return pathDB }, diff --git a/private/storage/db/sqler.go b/private/storage/db/sqler.go index 1ed1fd549..946474a13 100644 --- a/private/storage/db/sqler.go +++ b/private/storage/db/sqler.go @@ -17,6 +17,8 @@ package db import ( "context" "database/sql" + + "github.com/scionproto/scion/pkg/private/serrors" ) var _ Sqler = (*sql.DB)(nil) @@ -42,12 +44,11 @@ func DoInTx(ctx context.Context, db Sqler, action func(context.Context, *sql.Tx) if tx, err = db.(*sql.DB).BeginTx(ctx, nil); err != nil { return NewTxError("create tx", err) } - defer tx.Rollback() if err := action(ctx, tx); err != nil { - return err + return serrors.Join(err, tx.Rollback()) } if err := tx.Commit(); err != nil { - return NewTxError("commit", err) + return serrors.Join(NewTxError("commit", err), tx.Rollback()) } return nil }