Skip to content
This repository has been archived by the owner on May 21, 2024. It is now read-only.

Commit

Permalink
support multiple error for database transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
imantung committed Apr 25, 2021
1 parent 40a4650 commit f851fc6
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 67 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (r *RepoImpl) Delete(ctx context.Context) (int64, error) {
db := txn // transaction object or database connection
// result, err := ...
if err != nil {
txn.SetError(err) // set the error and plan for rollback
txn.AppendError(err) // append error to plan for rollback
return -1, err
}
// ...
Expand Down
18 changes: 9 additions & 9 deletions internal/generated/dbrepo/book_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (r *BookRepoImpl) Insert(ctx context.Context, ent *entity.Book) (int64, err

var id int64
if err := scanner.Scan(&id); err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
return id, nil
Expand Down Expand Up @@ -195,11 +195,11 @@ func (r *BookRepoImpl) BulkInsert(ctx context.Context, ents ...*entity.Book) (in

res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}

Expand All @@ -224,11 +224,11 @@ func (r *BookRepoImpl) Update(ctx context.Context, ent *entity.Book, opt sqkit.U

res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}

Expand Down Expand Up @@ -258,12 +258,12 @@ func (r *BookRepoImpl) Patch(ctx context.Context, ent *entity.Book, opt sqkit.Up

res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}

affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}

Expand All @@ -285,11 +285,11 @@ func (r *BookRepoImpl) Delete(ctx context.Context, opt sqkit.DeleteOption) (int6

res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}

affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}
20 changes: 10 additions & 10 deletions internal/generated/dbrepo/song_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ func (r *SongRepoImpl) BulkInsert(ctx context.Context, ents ...*entity.Song) (in

res, err := builder.RunWith(txn).ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}

affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}

Expand Down Expand Up @@ -189,12 +189,12 @@ func (r *SongRepoImpl) Insert(ctx context.Context, ent *entity.Song) (int64, err

res, err := builder.RunWith(txn).ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}

lastInsertID, err := res.LastInsertId()
txn.SetError(err)
txn.AppendError(err)
return lastInsertID, err
}

Expand All @@ -218,11 +218,11 @@ func (r *SongRepoImpl) Update(ctx context.Context, ent *entity.Song, opt sqkit.U

res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}

Expand All @@ -249,12 +249,12 @@ func (r *SongRepoImpl) Patch(ctx context.Context, ent *entity.Song, opt sqkit.Up

res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}

affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}

Expand All @@ -272,11 +272,11 @@ func (r *SongRepoImpl) Delete(ctx context.Context, opt sqkit.DeleteOption) (int6

res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}

affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}
21 changes: 10 additions & 11 deletions pkg/dbtxn/dbtxn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"

sq "github.com/Masterminds/squirrel"
"github.com/typical-go/typical-go/pkg/errkit"
Expand All @@ -18,7 +17,7 @@ type (
// Context of transaction
Context struct {
TxMap map[*sql.DB]Tx
Err error
Errs errkit.Errors
}
// CommitFn is commit function to close the transaction
CommitFn func() error
Expand Down Expand Up @@ -78,7 +77,7 @@ func Find(ctx context.Context) *Context {
// Error of transaction
func Error(ctx context.Context) error {
if c := Find(ctx); c != nil {
return c.Err
return c.Errs.Unwrap()
}
return nil
}
Expand All @@ -96,8 +95,8 @@ func (c *Context) Begin(ctx context.Context, db *sql.DB) (sq.StdSqlCtx, error) {

tx, err := db.BeginTx(ctx, nil)
if err != nil {
c.Err = fmt.Errorf("dbtxn: %w", err)
return nil, c.Err
c.AppendError(err)
return nil, err
}
c.TxMap[db] = tx
return tx, nil
Expand All @@ -106,23 +105,23 @@ func (c *Context) Begin(ctx context.Context, db *sql.DB) (sq.StdSqlCtx, error) {
// Commit if no error
func (c *Context) Commit() error {
var errs errkit.Errors
if c.Err != nil {
if len(c.Errs) > 0 {
for _, tx := range c.TxMap {
errs.Append(tx.Rollback())
errs = append(errs, tx.Rollback())
}
} else {
for _, tx := range c.TxMap {
errs.Append(tx.Commit())
errs = append(errs, tx.Commit())
}
}

return errs.Unwrap()
}

// SetError to set error to txn context
func (c *Context) SetError(err error) bool {
// AppendError to append error to txn context
func (c *Context) AppendError(err error) bool {
if c != nil && err != nil {
c.Err = err
c.Errs = append(c.Errs, err)
return true
}
return false
Expand Down
58 changes: 41 additions & 17 deletions pkg/dbtxn/dbtxn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestUse(t *testing.T) {
return db
}(),
Ctx: context.WithValue(context.Background(), dbtxn.ContextKey, &dbtxn.Context{}),
ExpectedErr: "dbtxn: begin-error",
ExpectedErr: "begin-error",
},
}
for _, tt := range testcases {
Expand All @@ -84,15 +84,31 @@ func TestUse(t *testing.T) {
}

func TestUse_success(t *testing.T) {
db, mock, _ := sqlmock.New()
mock.ExpectBegin()

ctx := context.WithValue(context.Background(), dbtxn.ContextKey, &dbtxn.Context{TxMap: make(map[*sql.DB]dbtxn.Tx)})
handler, err := dbtxn.Use(ctx, db)
db, mock, _ := sqlmock.New()

var handler *dbtxn.UseHandler
var err error
t.Run("trigger begin transaction when no transaction object", func(t *testing.T) {

mock.ExpectBegin()

handler, err = dbtxn.Use(ctx, db)

require.NoError(t, err)
require.Equal(t, map[*sql.DB]dbtxn.Tx{
db: handler.StdSqlCtx.(dbtxn.Tx),
}, handler.Context.TxMap)
})

t.Run("using available transaction", func(t *testing.T) {
handler2, err := dbtxn.Use(ctx, db)

require.NoError(t, err)
require.Equal(t, handler, handler2)
})

require.NoError(t, err)
require.Equal(t, map[*sql.DB]dbtxn.Tx{
db: handler.StdSqlCtx.(dbtxn.Tx),
}, handler.Context.TxMap)
}

func TestContext_Commit(t *testing.T) {
Expand All @@ -103,7 +119,7 @@ func TestContext_Commit(t *testing.T) {

c := dbtxn.NewContext()
c.Begin(context.Background(), db)
c.SetError(errors.New("some-error"))
c.AppendError(errors.New("some-error"))

require.NoError(t, c.Commit())
})
Expand All @@ -118,15 +134,23 @@ func TestContext_Commit(t *testing.T) {
})
}

func TestSetError(t *testing.T) {
func TestAppendError(t *testing.T) {
ctx := context.Background()
dbtxn.Begin(&ctx)
t.Run("no txn error before begin", func(t *testing.T) {
require.Nil(t, dbtxn.Error(ctx))
})

db, mock, _ := sqlmock.New()
mock.ExpectBegin()
handler, err := dbtxn.Use(ctx, db)
require.NoError(t, err)
t.Run("append multiple error", func(t *testing.T) {
dbtxn.Begin(&ctx)

handler.SetError(errors.New("some-error"))
require.EqualError(t, dbtxn.Error(ctx), "some-error")
db, mock, _ := sqlmock.New()
mock.ExpectBegin()
handler, err := dbtxn.Use(ctx, db)
require.NoError(t, err)

require.True(t, handler.AppendError(errors.New("some-error-1")))
require.False(t, handler.AppendError(nil))
require.True(t, handler.AppendError(errors.New("some-error-2")))
require.EqualError(t, dbtxn.Error(ctx), "some-error-1; some-error-2")
})
}
20 changes: 10 additions & 10 deletions pkg/typdb/mysql_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ func (r *{{.Name}}RepoImpl) BulkInsert(ctx context.Context, ents ...*{{.SourcePk
res, err := builder.RunWith(txn).ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}
Expand All @@ -153,12 +153,12 @@ func (r *{{.Name}}RepoImpl) Insert(ctx context.Context, ent *{{.SourcePkg}}.{{.N
res, err := builder.RunWith(txn).ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
lastInsertID, err := res.LastInsertId()
txn.SetError(err)
txn.AppendError(err)
return lastInsertID, err
}
Expand All @@ -181,11 +181,11 @@ func (r *{{.Name}}RepoImpl) Update(ctx context.Context, ent *{{.SourcePkg}}.{{.N
res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}
Expand All @@ -209,12 +209,12 @@ func (r *{{.Name}}RepoImpl) Patch(ctx context.Context, ent *{{.SourcePkg}}.{{.Na
res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}
Expand All @@ -233,12 +233,12 @@ func (r *{{.Name}}RepoImpl) Delete(ctx context.Context, opt sqkit.DeleteOption)
res, err := builder.ExecContext(ctx)
if err != nil {
txn.SetError(err)
txn.AppendError(err)
return -1, err
}
affectedRow, err := res.RowsAffected()
txn.SetError(err)
txn.AppendError(err)
return affectedRow, err
}
`
Loading

0 comments on commit f851fc6

Please sign in to comment.