diff --git a/pkg/frontend/mysql_protocol_test.go b/pkg/frontend/mysql_protocol_test.go index 8f8cbb385979..2851fe4e4b2d 100644 --- a/pkg/frontend/mysql_protocol_test.go +++ b/pkg/frontend/mysql_protocol_test.go @@ -186,19 +186,17 @@ func TestKill(t *testing.T) { eng.EXPECT().Hints().Return(engine.Hints{CommitOrRollbackTimeout: time.Second * 10}).AnyTimes() txnClient := mock_frontend.NewMockTxnClient(ctrl) - txnClient.EXPECT().New(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, commitTS any, options ...any) (client.TxnOperator, error) { - wp := newTestWorkspace() - txnOp := mock_frontend.NewMockTxnOperator(ctrl) - txnOp.EXPECT().Txn().Return(txn.TxnMeta{}).AnyTimes() - txnOp.EXPECT().GetWorkspace().Return(wp).AnyTimes() - txnOp.EXPECT().Commit(gomock.Any()).Return(nil).AnyTimes() - txnOp.EXPECT().Rollback(gomock.Any()).Return(nil).AnyTimes() - txnOp.EXPECT().SetFootPrints(gomock.Any(), gomock.Any()).Return().AnyTimes() - txnOp.EXPECT().Status().Return(txn.TxnStatus_Active).AnyTimes() - txnOp.EXPECT().EnterRunSql().Return().AnyTimes() - txnOp.EXPECT().ExitRunSql().Return().AnyTimes() - return txnOp, nil - }).AnyTimes() + txnClient.EXPECT().New(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, commitTS any, options ...any) (client.TxnOperator, error) { + txnOp := newTestTxnOp() + return txnOp, nil + }).AnyTimes() + txnClient.EXPECT().RestartTxn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, txnOp TxnOperator, commitTS any, options ...any) (client.TxnOperator, error) { + tTxnOp := txnOp.(*testTxnOp) + tTxnOp.meta.Status = txn.TxnStatus_Active + return txnOp, nil + }).AnyTimes() pu, err := getParameterUnit("test/system_vars_config.toml", eng, txnClient) require.NoError(t, err) pu.SV.SkipCheckUser = true diff --git a/pkg/frontend/test/Makefile b/pkg/frontend/test/Makefile index 577090357fe6..507fac07cae2 100644 --- a/pkg/frontend/test/Makefile +++ b/pkg/frontend/test/Makefile @@ -1,10 +1,14 @@ CURRENT_DIR = $(shell pwd) .PHONY: generate_mock -generate_mock: +generate_mock: ../../../pkg/txn/client/types.go ../../../pkg/vm/engine/types.go @go install github.com/golang/mock/mockgen@v1.6.0 @echo "Current Directory " $(CURRENT_DIR) @mockgen -source ../../../pkg/txn/client/types.go -package mock_frontend > txn_mock.go @mockgen -source ../../../pkg/vm/engine/types.go -package mock_frontend > engine_mock.go + +#tricky +generate_fe_mock: ../types.go + @go install github.com/golang/mock/mockgen@v1.6.0 @mockgen -source ../types.go -package mock_frontend > types_mock.go diff --git a/pkg/frontend/txn.go b/pkg/frontend/txn.go index 515f06d79b88..f4af86920a66 100644 --- a/pkg/frontend/txn.go +++ b/pkg/frontend/txn.go @@ -28,6 +28,7 @@ import ( moruntime "github.com/matrixorigin/matrixone/pkg/common/runtime" "github.com/matrixorigin/matrixone/pkg/defines" "github.com/matrixorigin/matrixone/pkg/pb/metadata" + "github.com/matrixorigin/matrixone/pkg/pb/txn" "github.com/matrixorigin/matrixone/pkg/txn/client" "github.com/matrixorigin/matrixone/pkg/txn/clock" "github.com/matrixorigin/matrixone/pkg/txn/storage/memorystorage" @@ -235,7 +236,6 @@ func (th *TxnHandler) GetTxnCtx() context.Context { // invalidateTxnUnsafe releases the txnOp and clears the server status bit SERVER_STATUS_IN_TRANS func (th *TxnHandler) invalidateTxnUnsafe() { - th.txnOp = nil resetBits(&th.serverStatus, defaultServerStatus) resetBits(&th.optionBits, defaultOptionBits) } @@ -252,7 +252,7 @@ func (th *TxnHandler) inActiveTxnUnsafe() bool { if th.txnOp != nil && th.txnCtx == nil { panic("txnOp != nil and txnCtx == nil") } - return th.txnOp != nil && th.txnCtx != nil + return th.txnOp != nil && th.txnOp.Txn().Status == txn.TxnStatus_Active && th.txnCtx != nil } // Create starts a new txn. @@ -411,13 +411,27 @@ func (th *TxnHandler) createTxnOpUnsafe(execCtx *ExecCtx) error { } } - err, hasRecovered = ExecuteFuncWithRecover(func() error { - th.txnOp, err2 = getPu(execCtx.ses.GetService()).TxnClient.New( - th.txnCtx, - execCtx.ses.getLastCommitTS(), - opts...) - return err2 - }) + txnClient := getPu(execCtx.ses.GetService()).TxnClient + if th.txnOp == nil { + err, hasRecovered = ExecuteFuncWithRecover(func() error { + th.txnOp, err2 = txnClient.New( + th.txnCtx, + execCtx.ses.getLastCommitTS(), + opts...) + return err2 + }) + } else if th.txnOp.Txn().Status != txn.TxnStatus_Active { + err, hasRecovered = ExecuteFuncWithRecover(func() error { + _, err2 = txnClient.RestartTxn( + th.txnCtx, + th.txnOp, + execCtx.ses.getLastCommitTS(), + opts...) + return err2 + }) + } else { + return moerr.NewInternalError(execCtx.reqCtx, "NewTxnOperator: txn is already active") + } if err != nil || hasRecovered { return err } diff --git a/pkg/frontend/txn_test.go b/pkg/frontend/txn_test.go index 69516700c11c..7a3e10d7838f 100644 --- a/pkg/frontend/txn_test.go +++ b/pkg/frontend/txn_test.go @@ -28,11 +28,13 @@ import ( "github.com/matrixorigin/matrixone/pkg/common/moerr" "github.com/matrixorigin/matrixone/pkg/defines" mock_frontend "github.com/matrixorigin/matrixone/pkg/frontend/test" + "github.com/matrixorigin/matrixone/pkg/pb/lock" "github.com/matrixorigin/matrixone/pkg/pb/timestamp" "github.com/matrixorigin/matrixone/pkg/pb/txn" "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" "github.com/matrixorigin/matrixone/pkg/txn/client" "github.com/matrixorigin/matrixone/pkg/txn/clock" + "github.com/matrixorigin/matrixone/pkg/txn/rpc" "github.com/matrixorigin/matrixone/pkg/vm/engine" ) @@ -355,6 +357,28 @@ func newMockErrSession3(t *testing.T, ctx context.Context, ctrl *gomock.Controll return ses } +func newMockErrSession4(t *testing.T, ctx context.Context, ctrl *gomock.Controller, + newFunc func(ctx context.Context, commitTS timestamp.Timestamp, options ...TxnOption) (client.TxnOperator, error), + restartTxnFunc func(ctx context.Context, txnOp TxnOperator, commitTS any, options ...any) (client.TxnOperator, error), +) *Session { + txnClient := mock_frontend.NewMockTxnClient(ctrl) + txnClient.EXPECT().New(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(newFunc).AnyTimes() + txnClient.EXPECT().RestartTxn(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(restartTxnFunc).AnyTimes() + eng := mock_frontend.NewMockEngine(ctrl) + eng.EXPECT().New(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + eng.EXPECT().Hints().Return(engine.Hints{ + CommitOrRollbackTimeout: time.Second, + }).AnyTimes() + + ses := newTestSession(t, ctrl) + getPu("").TxnClient = txnClient + getPu("").StorageEngine = eng + ses.txnHandler.storage = eng + var c clock.Clock + _ = ses.GetTxnHandler().CreateTempStorage(c) + return ses +} + func Test_rollbackStatement(t *testing.T) { convey.Convey("normal rollback", t, func() { ctrl := gomock.NewController(t) @@ -364,17 +388,13 @@ func Test_rollbackStatement(t *testing.T) { txnClient := mock_frontend.NewMockTxnClient(ctrl) txnClient.EXPECT().New(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, commitTS timestamp.Timestamp, options ...TxnOption) (client.TxnOperator, error) { - txnOperator := mock_frontend.NewMockTxnOperator(ctrl) - txnOperator.EXPECT().Txn().Return(txn.TxnMeta{}).AnyTimes() - txnOperator.EXPECT().Rollback(gomock.Any()).Return(nil).AnyTimes() - txnOperator.EXPECT().Commit(gomock.Any()).Return(nil).AnyTimes() - wsp := newTestWorkspace() - txnOperator.EXPECT().GetWorkspace().Return(wsp).AnyTimes() - txnOperator.EXPECT().SetFootPrints(gomock.Any(), gomock.Any()).Return().AnyTimes() - txnOperator.EXPECT().Status().Return(txn.TxnStatus_Active).AnyTimes() - txnOperator.EXPECT().EnterRunSql().Return().AnyTimes() - txnOperator.EXPECT().ExitRunSql().Return().AnyTimes() - return txnOperator, nil + return newTestTxnOp(), nil + }).AnyTimes() + txnClient.EXPECT().RestartTxn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, txnOp TxnOperator, commitTS timestamp.Timestamp, options ...TxnOption) (client.TxnOperator, error) { + tTxnOp := txnOp.(*testTxnOp) + tTxnOp.meta.Status = txn.TxnStatus_Active + return txnOp, nil }).AnyTimes() eng := mock_frontend.NewMockEngine(ctrl) eng.EXPECT().New(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -408,7 +428,8 @@ func Test_rollbackStatement(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldBeNil) t2 := ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Txn().Status, convey.ShouldEqual, txn.TxnStatus_Aborted) //case2.1 autocommit && begin && CreateSequence (need to be committed in the active txn) ec.txnOpt = FeTxnOption{ @@ -426,7 +447,8 @@ func Test_rollbackStatement(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldBeNil) t2 = ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Status(), convey.ShouldEqual, txn.TxnStatus_Aborted) //case2.2 not_autocommit && not_begin && CreateSequence (need to be committed in the active txn) ec.txnOpt = FeTxnOption{ @@ -447,7 +469,8 @@ func Test_rollbackStatement(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldBeNil) t2 = ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Status(), convey.ShouldEqual, txn.TxnStatus_Aborted) //case3.1 not_autocommit && not_begin && Insert Stmt (need not to be committed in the active txn) ec.txnOpt = FeTxnOption{ @@ -475,6 +498,7 @@ func Test_rollbackStatement(t *testing.T) { convey.So(err, convey.ShouldBeNil) t2 = ses.txnHandler.GetTxn() convey.So(t2, convey.ShouldNotBeNil) + txnOp.GetWorkspace().EndStatement() //case3.2 not_autocommit && begin && Insert Stmt (need not to be committed in the active txn) ec.txnOpt = FeTxnOption{ @@ -504,6 +528,7 @@ func Test_rollbackStatement(t *testing.T) { convey.So(err, convey.ShouldBeNil) t2 = ses.txnHandler.GetTxn() convey.So(t2, convey.ShouldNotBeNil) + txnOp.GetWorkspace().EndStatement() }) @@ -512,7 +537,17 @@ func Test_rollbackStatement(t *testing.T) { defer ctrl.Finish() ctx := defines.AttachAccountId(context.TODO(), sysAccountID) - ses := newMockErrSession(t, ctx, ctrl) + newFunc := func(ctx context.Context, commitTS timestamp.Timestamp, options ...TxnOption) (client.TxnOperator, error) { + txnOp := newTestTxnOp() + txnOp.mod = modRollbackError + return txnOp, nil + } + restartTxnFunc := func(ctx context.Context, txnOp TxnOperator, commitTS any, options ...any) (client.TxnOperator, error) { + tTxnOp := txnOp.(*testTxnOp) + tTxnOp.meta.Status = txn.TxnStatus_Active + return txnOp, nil + } + ses := newMockErrSession4(t, ctx, ctrl, newFunc, restartTxnFunc) ec := newTestExecCtx(ctx, ctrl) ec.ses = ses //case1. autocommit && not_begin. Insert Stmt (need not to be committed in the active txn) @@ -528,7 +563,8 @@ func Test_rollbackStatement(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldNotBeNil) t2 := ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Txn().Status, convey.ShouldEqual, txn.TxnStatus_Active) }) } @@ -555,7 +591,8 @@ func Test_rollbackStatement2(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldNotBeNil) t2 := ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Status(), convey.ShouldEqual, txn.TxnStatus_Active) }) } @@ -585,7 +622,8 @@ func Test_rollbackStatement3(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldNotBeNil) t2 := ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Txn().Status, convey.ShouldEqual, txn.TxnStatus_Active) }) } @@ -614,7 +652,8 @@ func Test_rollbackStatement4(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldNotBeNil) t2 := ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Status(), convey.ShouldEqual, txn.TxnStatus_Active) }) } @@ -649,7 +688,9 @@ func Test_rollbackStatement5(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldNotBeNil) t2 := ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Status(), convey.ShouldEqual, txn.TxnStatus_Active) + txnOp.GetWorkspace().EndStatement() }) } @@ -687,7 +728,9 @@ func Test_rollbackStatement6(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldNotBeNil) t2 := ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Status(), convey.ShouldEqual, txn.TxnStatus_Active) + txnOp.GetWorkspace().EndStatement() }) convey.Convey("abnormal rollback -- rollback whole txn", t, func() { ctrl := gomock.NewController(t) @@ -722,7 +765,9 @@ func Test_rollbackStatement6(t *testing.T) { err = ses.GetTxnHandler().Rollback(ec) convey.So(err, convey.ShouldNotBeNil) t2 := ses.txnHandler.GetTxn() - convey.So(t2, convey.ShouldBeNil) + convey.So(t2, convey.ShouldNotBeNil) + convey.So(t2.Status(), convey.ShouldEqual, txn.TxnStatus_Active) + txnOp.GetWorkspace().EndStatement() }) } @@ -753,3 +798,196 @@ func Test_commit(t *testing.T) { convey.So(err, convey.ShouldNotBeNil) }) } + +var _ TxnOperator = new(testTxnOp) + +const ( + modRollbackError = 1 +) + +type testTxnOp struct { + meta txn.TxnMeta + wp *testWorkspace + mod int +} + +func newTestTxnOp() *testTxnOp { + return &testTxnOp{ + wp: newTestWorkspace(), + } +} + +func (txnop *testTxnOp) GetOverview() client.TxnOverview { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) CloneSnapshotOp(snapshot timestamp.Timestamp) client.TxnOperator { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) IsSnapOp() bool { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) Txn() txn.TxnMeta { + return txnop.meta +} + +func (txnop *testTxnOp) TxnOptions() txn.TxnOptions { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) TxnRef() *txn.TxnMeta { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) Snapshot() (txn.CNTxnSnapshot, error) { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) UpdateSnapshot(ctx context.Context, ts timestamp.Timestamp) error { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) SnapshotTS() timestamp.Timestamp { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) CreateTS() timestamp.Timestamp { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) Status() txn.TxnStatus { + return txnop.meta.Status +} + +func (txnop *testTxnOp) ApplySnapshot(data []byte) error { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) Read(ctx context.Context, ops []txn.TxnRequest) (*rpc.SendResult, error) { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) Write(ctx context.Context, ops []txn.TxnRequest) (*rpc.SendResult, error) { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) WriteAndCommit(ctx context.Context, ops []txn.TxnRequest) (*rpc.SendResult, error) { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) Commit(ctx context.Context) error { + txnop.meta.Status = txn.TxnStatus_Committed + return nil +} + +func (txnop *testTxnOp) Rollback(ctx context.Context) error { + if txnop.mod == modRollbackError { + return moerr.NewInternalErrorNoCtx("throw error") + } + txnop.meta.Status = txn.TxnStatus_Aborted + return nil +} + +func (txnop *testTxnOp) AddLockTable(locktable lock.LockTable) error { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) HasLockTable(table uint64) bool { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) AddWaitLock(tableID uint64, rows [][]byte, opt lock.LockOptions) uint64 { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) RemoveWaitLock(key uint64) { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) LockTableCount() int32 { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) LockSkipped(tableID uint64, mode lock.LockMode) bool { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) GetWaitActiveCost() time.Duration { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) AddWorkspace(workspace client.Workspace) { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) GetWorkspace() client.Workspace { + return txnop.wp +} + +func (txnop *testTxnOp) AppendEventCallback(event client.EventType, callbacks ...func(client.TxnEvent)) { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) Debug(ctx context.Context, ops []txn.TxnRequest) (*rpc.SendResult, error) { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) NextSequence() uint64 { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) EnterRunSql() { +} + +func (txnop *testTxnOp) ExitRunSql() { +} + +func (txnop *testTxnOp) EnterIncrStmt() { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) ExitIncrStmt() { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) EnterRollbackStmt() { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) ExitRollbackStmt() { + //TODO implement me + panic("implement me") +} + +func (txnop *testTxnOp) SetFootPrints(id int, enter bool) { + +}