Skip to content

Commit

Permalink
Refactor seata conn (apache#295)
Browse files Browse the repository at this point in the history
* refactor:split xa and at logic

* refactor:split xa and at logic

* refactor:split xa and at logic
  • Loading branch information
chuntaojun authored Oct 23, 2022
1 parent b20f33b commit 8fbdd3f
Show file tree
Hide file tree
Showing 27 changed files with 1,208 additions and 528 deletions.
51 changes: 29 additions & 22 deletions pkg/datasource/sql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
// Conn is assumed to be stateful.

type Conn struct {
txType types.TransactionType
res *DBResource
txCtx *types.TransactionContext
targetConn driver.Conn
Expand All @@ -47,8 +46,8 @@ func (c *Conn) ResetSession(ctx context.Context) error {
return driver.ErrSkip
}

c.txType = types.Local
c.txCtx = nil
c.autoCommit = true
c.txCtx = types.NewTxCtx()
return conn.ResetSession(ctx)
}

Expand Down Expand Up @@ -221,26 +220,29 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return c.Query(query, values)
}

executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query)
if err != nil {
return nil, err
}
ret, err := c.createNewTxOnExecIfNeed(func() (types.ExecResult, error) {
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query)
if err != nil {
return nil, err
}

execCtx := &types.ExecContext{
TxCtx: c.txCtx,
Query: query,
NamedValues: args,
}
execCtx := &types.ExecContext{
TxCtx: c.txCtx,
Query: query,
NamedValues: args,
}

ret, err := executor.ExecWithNamedValue(ctx, execCtx,
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) {
ret, err := conn.QueryContext(ctx, query, args)
if err != nil {
return nil, err
}
return executor.ExecWithNamedValue(ctx, execCtx,
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) {
ret, err := conn.QueryContext(ctx, query, args)
if err != nil {
return nil, err
}

return types.NewResult(types.WithRows(ret)), nil
})
})

return types.NewResult(types.WithRows(ret)), nil
})
if err != nil {
return nil, err
}
Expand All @@ -252,6 +254,8 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
//
// Deprecated: Drivers should implement ConnBeginTx instead (or additionally).
func (c *Conn) Begin() (driver.Tx, error) {
c.autoCommit = false

tx, err := c.targetConn.Begin()
if err != nil {
return nil, err
Expand All @@ -271,8 +275,11 @@ func (c *Conn) Begin() (driver.Tx, error) {
}

// BeginTx Open a transaction and judge whether the current transaction needs to open a
// global transaction according to ctx. If so, it needs to be included in the transaction management of seata
//
// global transaction according to ctx. If so, it needs to be included in the transaction management of seata
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
c.autoCommit = false

if conn, ok := c.targetConn.(driver.ConnBeginTx); ok {
tx, err := conn.BeginTx(ctx, opts)
if err != nil {
Expand Down Expand Up @@ -309,7 +316,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
// Drivers must ensure all network calls made by Close
// do not block indefinitely (e.g. apply a timeout).
func (c *Conn) Close() error {
c.txCtx = nil
c.txCtx = types.NewTxCtx()
return c.targetConn.Close()
}

Expand Down
45 changes: 29 additions & 16 deletions pkg/datasource/sql/conn_at.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,38 @@ import (
"github.com/seata/seata-go/pkg/tm"
)

// ATConn Database connection proxy object under XA transaction model
// Conn is assumed to be stateful.
type ATConn struct {
*Conn
}

func (c *ATConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if c.createTxCtxIfAbsent(ctx) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = nil
c.txCtx = types.NewTxCtx()
}()
}

return c.Conn.PrepareContext(ctx, query)
}

// QueryContext
func (c *ATConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
}

return c.Conn.QueryContext(ctx, query, args)
}

// ExecContext
func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.createTxCtxIfAbsent(ctx) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = nil
c.txCtx = types.NewTxCtx()
}()
}

Expand All @@ -52,33 +65,33 @@ func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.Na

// BeginTx
func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
c.autoCommit = false

c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.TxOpt = opts

if IsGlobalTx(ctx) {
c.txCtx.XaID = tm.GetXID(ctx)
c.txCtx.TransType = c.txType
c.txCtx.TransType = types.ATMode
}

tx, err := c.Conn.BeginTx(ctx, opts)
if err != nil {
return nil, err
}

return c.Conn.BeginTx(ctx, opts)
return &ATTx{tx: tx.(*Tx)}, nil
}

func (c *ATConn) createTxCtxIfAbsent(ctx context.Context) bool {
var onceTx bool
func (c *ATConn) createOnceTxContext(ctx context.Context) bool {
onceTx := IsGlobalTx(ctx) && c.autoCommit

if IsGlobalTx(ctx) && c.txCtx == nil {
if onceTx {
c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.XaID = tm.GetXID(ctx)
c.txCtx.TransType = types.ATMode
c.autoCommit = true
onceTx = true
}

if c.txCtx == nil {
c.txCtx = types.NewTxCtx()
onceTx = true
}

return onceTx
Expand Down
114 changes: 106 additions & 8 deletions pkg/datasource/sql/conn_at_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package sql
import (
"context"
"database/sql"
"database/sql/driver"
"sync/atomic"
"testing"

Expand All @@ -32,9 +33,8 @@ import (
"github.com/stretchr/testify/assert"
)

func TestATConn_ExecContext(t *testing.T) {
func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
_ = mockMgr
Expand All @@ -44,9 +44,7 @@ func TestATConn_ExecContext(t *testing.T) {
t.Fatal(err)
}

defer db.Close()

_ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector {
_ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) driver.Connector {
mockTx := mock.NewMockTestDriverTx(ctrl)
mockTx.EXPECT().Commit().AnyTimes().Return(nil)
mockTx.EXPECT().Rollback().AnyTimes().Return(nil)
Expand All @@ -67,20 +65,31 @@ func TestATConn_ExecContext(t *testing.T) {

exec.CleanCommonHook()
CleanTxHooks()
exec.RegisCommonHook(mi)
exec.RegisterCommonHook(mi)
RegisterTxHook(ti)

return ctrl, db, mi, ti
}

func TestATConn_ExecContext(t *testing.T) {
ctrl, db, mi, ti := initAtConnTestResource(t)
defer func() {
ctrl.Finish()
db.Close()
CleanTxHooks()
}()

t.Run("have xid", func(t *testing.T) {
ctx := tm.InitSeataContext(context.Background())
tm.SetXID(ctx, uuid.New().String())
t.Logf("set xid=%s", tm.GetXID(ctx))

before := func(_ context.Context, execCtx *types.ExecContext) {
beforeHook := func(_ context.Context, execCtx *types.ExecContext) {
t.Logf("on exec xid=%s", execCtx.TxCtx.XaID)
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType)
}
mi.before = before
mi.before = beforeHook

var comitCnt int32
beforeCommit := func(tx *Tx) {
Expand Down Expand Up @@ -125,3 +134,92 @@ func TestATConn_ExecContext(t *testing.T) {
assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt))
})
}

func TestATConn_BeginTx(t *testing.T) {
ctrl, db, mi, ti := initAtConnTestResource(t)
defer func() {
ctrl.Finish()
db.Close()
CleanTxHooks()
}()

t.Run("tx-local", func(t *testing.T) {
tx, err := db.Begin()
assert.NoError(t, err)

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XaID)
assert.Equal(t, types.Local, execCtx.TxCtx.TransType)
}

var comitCnt int32
ti.beforeCommit = func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
}

_, err = tx.ExecContext(context.Background(), "SELECT * FROM user")
assert.NoError(t, err)

_, err = tx.ExecContext(tm.InitSeataContext(context.Background()), "SELECT * FROM user")
assert.NoError(t, err)

err = tx.Commit()
assert.NoError(t, err)

assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt))
})

t.Run("tx-local-context", func(t *testing.T) {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
assert.NoError(t, err)

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XaID)
assert.Equal(t, types.Local, execCtx.TxCtx.TransType)
}

var comitCnt int32
ti.beforeCommit = func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
}

_, err = tx.ExecContext(context.Background(), "SELECT * FROM user")
assert.NoError(t, err)

_, err = tx.ExecContext(tm.InitSeataContext(context.Background()), "SELECT * FROM user")
assert.NoError(t, err)

err = tx.Commit()
assert.NoError(t, err)

assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt))
})

t.Run("tx-at-context", func(t *testing.T) {
ctx := tm.InitSeataContext(context.Background())
tm.SetXID(ctx, uuid.NewString())
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
assert.NoError(t, err)

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType)
}

var comitCnt int32
ti.beforeCommit = func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
}

_, err = tx.ExecContext(context.Background(), "SELECT * FROM user")
assert.NoError(t, err)

_, err = tx.ExecContext(context.Background(), "SELECT * FROM user")
assert.NoError(t, err)

err = tx.Commit()
assert.NoError(t, err)

assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt))
})
}
Loading

0 comments on commit 8fbdd3f

Please sign in to comment.