Skip to content

Commit

Permalink
Feat add mysql update after undo log builder (#289)
Browse files Browse the repository at this point in the history
feat: add mysql update after undo log builder
  • Loading branch information
luky116 authored Oct 9, 2022
1 parent 29c7f38 commit c272389
Show file tree
Hide file tree
Showing 19 changed files with 287 additions and 339 deletions.
8 changes: 4 additions & 4 deletions pkg/datasource/sql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) {
return nil, err
}

execCtx := &exec.ExecContext{
execCtx := &types.ExecContext{
TxCtx: c.txCtx,
Query: query,
Values: args,
Expand Down Expand Up @@ -149,7 +149,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name
return nil, err
}

execCtx := &exec.ExecContext{
execCtx := &types.ExecContext{
TxCtx: c.txCtx,
Query: query,
NamedValues: args,
Expand Down Expand Up @@ -186,7 +186,7 @@ func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return nil, err
}

execCtx := &exec.ExecContext{
execCtx := &types.ExecContext{
TxCtx: c.txCtx,
Query: query,
Values: args,
Expand Down Expand Up @@ -226,7 +226,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return nil, err
}

execCtx := &exec.ExecContext{
execCtx := &types.ExecContext{
TxCtx: c.txCtx,
Query: query,
NamedValues: args,
Expand Down
4 changes: 2 additions & 2 deletions pkg/datasource/sql/conn_at_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestATConn_ExecContext(t *testing.T) {
tm.SetXID(ctx, uuid.New().String())
t.Logf("set xid=%s", tm.GetXID(ctx))

before := func(_ context.Context, execCtx *exec.ExecContext) {
before := func(_ context.Context, execCtx *types.ExecContext) {
t.Logf("on exec xid=%s", execCtx.TxCtx.XaID)
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType)
Expand All @@ -101,7 +101,7 @@ func TestATConn_ExecContext(t *testing.T) {
})

t.Run("not xid", func(t *testing.T) {
mi.before = func(_ context.Context, execCtx *exec.ExecContext) {
mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XaID)
assert.Equal(t, types.Local, execCtx.TxCtx.TransType)
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/datasource/sql/conn_xa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@ import (
)

type mockSQLInterceptor struct {
before func(ctx context.Context, execCtx *exec.ExecContext)
after func(ctx context.Context, execCtx *exec.ExecContext)
before func(ctx context.Context, execCtx *types.ExecContext)
after func(ctx context.Context, execCtx *types.ExecContext)
}

func (mi *mockSQLInterceptor) Type() types.SQLType {
return types.SQLTypeUnknown
}

// Before
func (mi *mockSQLInterceptor) Before(ctx context.Context, execCtx *exec.ExecContext) error {
func (mi *mockSQLInterceptor) Before(ctx context.Context, execCtx *types.ExecContext) error {
if mi.before != nil {
mi.before(ctx, execCtx)
}
return nil
}

// After
func (mi *mockSQLInterceptor) After(ctx context.Context, execCtx *exec.ExecContext) error {
func (mi *mockSQLInterceptor) After(ctx context.Context, execCtx *types.ExecContext) error {
if mi.after != nil {
mi.after(ctx, execCtx)
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func TestXAConn_ExecContext(t *testing.T) {
tm.SetXID(ctx, uuid.New().String())
t.Logf("set xid=%s", tm.GetXID(ctx))

before := func(_ context.Context, execCtx *exec.ExecContext) {
before := func(_ context.Context, execCtx *types.ExecContext) {
t.Logf("on exec xid=%s", execCtx.TxCtx.XaID)
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID)
assert.Equal(t, types.XAMode, execCtx.TxCtx.TransType)
Expand All @@ -152,7 +152,7 @@ func TestXAConn_ExecContext(t *testing.T) {
})

t.Run("not xid", func(t *testing.T) {
before := func(_ context.Context, execCtx *exec.ExecContext) {
before := func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XaID)
assert.Equal(t, types.Local, execCtx.TxCtx.TransType)
}
Expand Down
113 changes: 105 additions & 8 deletions pkg/datasource/sql/exec/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (

"github.com/seata/seata-go/pkg/datasource/sql/parser"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo"
"github.com/seata/seata-go/pkg/tm"
"github.com/seata/seata-go/pkg/util/log"
)

Expand Down Expand Up @@ -50,9 +52,9 @@ type (
// Interceptors
interceptors(interceptors []SQLHook)
// Exec
ExecWithNamedValue(ctx context.Context, execCtx *ExecContext, f CallbackWithNamedValue) (types.ExecResult, error)
ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error)
// Exec
ExecWithValue(ctx context.Context, execCtx *ExecContext, f CallbackWithValue) (types.ExecResult, error)
ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error)
}
)

Expand Down Expand Up @@ -109,39 +111,134 @@ func (e *BaseExecutor) interceptors(interceptors []SQLHook) {
}

// ExecWithNamedValue
func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
for i := range e.is {
e.is[i].Before(ctx, execCtx)
}

var (
beforeImage *types.RecordImage
afterImage *types.RecordImage
result types.ExecResult
err error
)

beforeImage, err = e.beforeImage(ctx, execCtx)
if err != nil {
return nil, err
}
if beforeImage != nil {
execCtx.TxCtx.RoundImages.AppendBeofreImage(beforeImage)
}

defer func() {
for i := range e.is {
e.is[i].After(ctx, execCtx)
}
}()

if e.ex != nil {
return e.ex.ExecWithNamedValue(ctx, execCtx, f)
result, err = e.ex.ExecWithNamedValue(ctx, execCtx, f)
} else {
result, err = f(ctx, execCtx.Query, execCtx.NamedValues)
}

if err != nil {
return nil, err
}

afterImage, err = e.afterImage(ctx, execCtx, beforeImage)
if err != nil {
return nil, err
}
if afterImage != nil {
execCtx.TxCtx.RoundImages.AppendAfterImage(afterImage)
}

return f(ctx, execCtx.Query, execCtx.NamedValues)
return result, err
}

// ExecWithValue
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *ExecContext, f CallbackWithValue) (types.ExecResult, error) {
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) {
for i := range e.is {
e.is[i].Before(ctx, execCtx)
}

var (
beforeImage *types.RecordImage
afterImage *types.RecordImage
result types.ExecResult
err error
)

beforeImage, err = e.beforeImage(ctx, execCtx)
if err != nil {
return nil, err
}
if beforeImage != nil {
execCtx.TxCtx.RoundImages.AppendBeofreImage(beforeImage)
}

defer func() {
for i := range e.is {
e.is[i].After(ctx, execCtx)
}
}()

if e.ex != nil {
return e.ex.ExecWithValue(ctx, execCtx, f)
result, err = e.ex.ExecWithValue(ctx, execCtx, f)
} else {
result, err = f(ctx, execCtx.Query, execCtx.Values)
}
if err != nil {
return nil, err
}

return f(ctx, execCtx.Query, execCtx.Values)
afterImage, err = e.afterImage(ctx, execCtx, beforeImage)
if err != nil {
return nil, err
}
if afterImage != nil {
execCtx.TxCtx.RoundImages.AppendAfterImage(afterImage)
}

return result, err
}

func (h *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) (*types.RecordImage, error) {
if !tm.IsTransactionOpened(ctx) {
return nil, nil
}

pc, err := parser.DoParser(execCtx.Query)
if err != nil {
return nil, err
}
if !pc.HasValidStmt() {
return nil, nil
}

builder := undo.GetUndologBuilder(pc.SQLType)
if builder == nil {
return nil, nil
}
return builder.BeforeImage(ctx, execCtx)
}

func (h *BaseExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImage *types.RecordImage) (*types.RecordImage, error) {
if !tm.IsTransactionOpened(ctx) {
return nil, nil
}
pc, err := parser.DoParser(execCtx.Query)
if err != nil {
return nil, err
}
if !pc.HasValidStmt() {
return nil, nil
}
builder := undo.GetUndologBuilder(pc.SQLType)
if builder == nil {
return nil, nil
}
return builder.AfterImage(ctx, execCtx, beforeImage)
}
16 changes: 2 additions & 14 deletions pkg/datasource/sql/exec/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package exec

import (
"context"
"database/sql/driver"

"github.com/seata/seata-go/pkg/datasource/sql/types"
)
Expand Down Expand Up @@ -50,17 +49,6 @@ func RegisHook(hook SQLHook) {
hookSolts[hook.Type()] = append(hookSolts[hook.Type()], hook)
}

// ExecContext
type ExecContext struct {
TxCtx *types.TransactionContext
Query string
NamedValues []driver.NamedValue
Values []driver.Value
// metaData
MetaData types.TableMeta
Conn driver.Conn
}

// SQLHook SQL execution front and back interceptor
// case 1. Used to intercept SQL to achieve the generation of front and rear mirrors
// case 2. Burning point to report
Expand All @@ -69,8 +57,8 @@ type SQLHook interface {
Type() types.SQLType

// Before
Before(ctx context.Context, execCtx *ExecContext) error
Before(ctx context.Context, execCtx *types.ExecContext) error

// After
After(ctx context.Context, execCtx *ExecContext) error
After(ctx context.Context, execCtx *types.ExecContext) error
}
Loading

0 comments on commit c272389

Please sign in to comment.