Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

init xa trx log #835

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions pkg/runtime/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,29 @@ const (
_flagWrite
)

// TxState Transaction status
type TxState int64

const (
_ TxState = iota
TrxStarted // CompositeTx Default state
TrxPreparing // All SQL statements are executed, and before the Commit statement executes
TrxPrepared // All SQL statements are executed, and before the Commit statement executes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个comment note是不是要改一下

TrxCommitting // After preparing is completed, ready to start execution
TrxCommitted // Officially complete the Commit action
TrxRolledBacking
TrxRolledBacked
TrxAborted
TrxUnknown // Unknown transaction
)

type (
keyFlag struct{}
keyNodeLabel struct{}
keyDefaultDBGroup struct{}
keyHints struct{}
keyTransactionID struct{}
keyFlag struct{}
keyNodeLabel struct{}
keyDefaultDBGroup struct{}
keyHints struct{}
keyTransactionID struct{}
keyTransactionStatus struct{}
)

type cFlag uint8
Expand Down Expand Up @@ -75,7 +92,7 @@ func WithHints(ctx context.Context, hints []*hint.Hint) context.Context {

// Tenant extracts the tenant.
func Tenant(ctx context.Context) string {
return isString(ctx, proto.ContextKeyTenant{})
return getString(ctx, proto.ContextKeyTenant{})
}

// IsRead returns true if this is a read operation
Expand All @@ -95,25 +112,29 @@ func IsDirect(ctx context.Context) bool {

// SQL returns the original sql string.
func SQL(ctx context.Context) string {
return isString(ctx, proto.ContextKeySQL{})
return getString(ctx, proto.ContextKeySQL{})
}

func Schema(ctx context.Context) string {
return isString(ctx, proto.ContextKeySchema{})
return getString(ctx, proto.ContextKeySchema{})
}

func Version(ctx context.Context) string {
return isString(ctx, proto.ContextKeyServerVersion{})
return getString(ctx, proto.ContextKeyServerVersion{})
}

// NodeLabel returns the label of node.
func NodeLabel(ctx context.Context) string {
return isString(ctx, keyNodeLabel{})
return getString(ctx, keyNodeLabel{})
}

// TransactionID returns the transactions id
func TransactionID(ctx context.Context) string {
return isString(ctx, keyTransactionID{})
return getString(ctx, keyTransactionID{})
}

func TransactionStatus(ctx context.Context) TxState {
return getTxStatus(ctx, keyTransactionStatus{})
}

// Hints extracts the hints.
Expand Down Expand Up @@ -144,9 +165,18 @@ func getFlag(ctx context.Context) cFlag {
return f
}

func isString(ctx context.Context, v any) string {
func getString(ctx context.Context, v any) string {
if data, ok := ctx.Value(v).(string); ok {
return data
}
return ""
}

func getTxStatus(ctx context.Context, v any) TxState {
if data, ok := ctx.Value(v).(int32); ok {
if data >= int32(TrxStarted) && data <= int32(TrxAborted) {
return TxState(data)
}
}
return TrxUnknown
}
8 changes: 4 additions & 4 deletions pkg/runtime/transaction/fault_decision.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ type TxFaultDecisionExecutor struct {
func (bm *TxFaultDecisionExecutor) Run() {
}

func (bm *TxFaultDecisionExecutor) scanUnFinishTxLog() ([]TrxLog, error) {
func (bm *TxFaultDecisionExecutor) scanUnFinishTxLog() ([]GlobalTrxLog, error) {
return nil, nil
}

func (bm *TxFaultDecisionExecutor) handlePreparing(tx TrxLog) {
func (bm *TxFaultDecisionExecutor) handlePreparing(tx GlobalTrxLog) {
}

func (bm *TxFaultDecisionExecutor) handleCommitting(tx TrxLog) {
func (bm *TxFaultDecisionExecutor) handleCommitting(tx GlobalTrxLog) {
}

func (bm *TxFaultDecisionExecutor) handleAborting(tx TrxLog) {
func (bm *TxFaultDecisionExecutor) handleAborting(tx GlobalTrxLog) {
}
63 changes: 34 additions & 29 deletions pkg/runtime/transaction/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package transaction

import (
"context"
rcontext "github.com/arana-db/arana/pkg/runtime/context"
)

import (
Expand All @@ -40,19 +41,19 @@ func NewXAHook(tenant string, enable bool) (*xaHook, error) {
enable: enable,
}

trxStateChangeFunc := map[runtime.TxState]handleFunc{
runtime.TrxActive: xh.onActive,
runtime.TrxPreparing: xh.onPreparing,
runtime.TrxPrepared: xh.onPrepared,
runtime.TrxCommitting: xh.onCommitting,
runtime.TrxCommitted: xh.onCommitted,
runtime.TrxAborting: xh.onAborting,
runtime.TrxRollback: xh.onRollbackOnly,
runtime.TrxRolledBack: xh.onRolledBack,
trxStateChangeFunc := map[rcontext.TxState]handleFunc{
rcontext.TrxStarted: xh.onStarted,
rcontext.TrxPreparing: xh.onPreparing,
rcontext.TrxPrepared: xh.onPrepared,
rcontext.TrxCommitting: xh.onCommitting,
rcontext.TrxCommitted: xh.onCommitted,
rcontext.TrxAborted: xh.onAborting,
rcontext.TrxRolledBacking: xh.onRollbackOnly,
rcontext.TrxRolledBacked: xh.onRolledBack,
}

xh.trxMgr = trxMgr
xh.trxLog = &TrxLog{}
xh.trxLog = &GlobalTrxLog{}
xh.trxStateChangeFunc = trxStateChangeFunc

return xh, nil
Expand All @@ -63,15 +64,15 @@ func NewXAHook(tenant string, enable bool) (*xaHook, error) {
type xaHook struct {
enable bool
trxMgr *TrxManager
trxLog *TrxLog
trxStateChangeFunc map[runtime.TxState]handleFunc
trxLog *GlobalTrxLog
trxStateChangeFunc map[rcontext.TxState]handleFunc
}

func (xh *xaHook) OnTxStateChange(ctx context.Context, state runtime.TxState, tx runtime.CompositeTx) error {
func (xh *xaHook) OnTxStateChange(ctx context.Context, state rcontext.TxState, tx runtime.CompositeTx) error {
if !xh.enable {
return nil
}
xh.trxLog.State = state
xh.trxLog.Status = state
handle, ok := xh.trxStateChangeFunc[state]
if ok {
return handle(ctx, tx)
Expand All @@ -84,33 +85,37 @@ func (xh *xaHook) OnCreateBranchTx(ctx context.Context, tx runtime.BranchTx) {
if !xh.enable {
return
}
xh.trxLog.Participants = append(xh.trxLog.Participants, TrxParticipant{
NodeID: "",
RemoteAddr: tx.GetConn().GetDatabaseConn().GetNetConn().RemoteAddr().String(),
Schema: tx.GetConn().DBName(),
})
// TODO: add branch trx log
//xh.trxLog.BranchTrxLogs = append(xh.trxLog.BranchTrxLogs, BranchTrxLog{
// NodeID: "",
// RemoteAddr: tx.GetConn().GetDatabaseConn().GetNetConn().RemoteAddr().String(),
// Schema: tx.GetConn().DBName(),
//})
}

func (xh *xaHook) onActive(ctx context.Context, tx runtime.CompositeTx) error {
func (xh *xaHook) onStarted(ctx context.Context, tx runtime.CompositeTx) error {
tx.SetBeginFunc(StartXA)
xh.trxLog.TrxID = tx.GetTrxID()
xh.trxLog.State = tx.GetTxState()
xh.trxLog.Status = tx.GetTxState()
xh.trxLog.Tenant = tx.GetTenant()
xh.trxLog.StartTime = tx.GetStartTime()
xh.trxLog.ExpectedEndTime = tx.GetExpectedEndTime()

return nil
}

func (xh *xaHook) onPreparing(ctx context.Context, tx runtime.CompositeTx) error {
tx.Range(func(tx runtime.BranchTx) {
tx.SetPrepareFunc(PrepareXA)
})
if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil {
if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil {
return err
}
return nil
}

func (xh *xaHook) onPrepared(ctx context.Context, tx runtime.CompositeTx) error {
if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil {
if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil {
return err
}
return nil
Expand All @@ -120,14 +125,14 @@ func (xh *xaHook) onCommitting(ctx context.Context, tx runtime.CompositeTx) erro
tx.Range(func(tx runtime.BranchTx) {
tx.SetCommitFunc(CommitXA)
})
if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil {
if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil {
return err
}
return nil
}

func (xh *xaHook) onCommitted(ctx context.Context, tx runtime.CompositeTx) error {
if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil {
if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil {
return err
}
return nil
Expand All @@ -137,7 +142,7 @@ func (xh *xaHook) onAborting(ctx context.Context, tx runtime.CompositeTx) error
tx.Range(func(bTx runtime.BranchTx) {
bTx.SetCommitFunc(RollbackXA)
})
if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil {
if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil {
return err
}
// auto execute XA rollback action
Expand All @@ -151,15 +156,15 @@ func (xh *xaHook) onRollbackOnly(ctx context.Context, tx runtime.CompositeTx) er
tx.Range(func(tx runtime.BranchTx) {
tx.SetCommitFunc(RollbackXA)
})
if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil {
if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil {
return err
}
return nil
}

func (xh *xaHook) onRolledBack(ctx context.Context, tx runtime.CompositeTx) error {
xh.trxLog.State = runtime.TrxRolledBack
if err := xh.trxMgr.trxLog.AddOrUpdateTxLog(*xh.trxLog); err != nil {
xh.trxLog.Status = rcontext.TrxRolledBacking
if err := xh.trxMgr.trxLog.AddOrUpdateGlobalTxLog(*xh.trxLog); err != nil {
return err
}
return nil
Expand Down
Loading
Loading