From e3d7b148c7edfce2db35748bbfadea9422adec77 Mon Sep 17 00:00:00 2001 From: Yuecai Liu <38887641+luky116@users.noreply.github.com> Date: Sat, 19 Nov 2022 21:09:40 +0800 Subject: [PATCH] optimize meta data (#352) optimize meta data --- go.mod | 2 - go.sum | 1 - pkg/datasource/sql/async_worker.go | 9 +- pkg/datasource/sql/at.go | 73 ++++----- pkg/datasource/sql/conn_at.go | 3 +- .../sql/datasource/base/meta_cache.go | 6 +- .../sql/datasource/datasource_manager.go | 55 ++----- .../sql/datasource/mysql/default.go | 30 ---- .../sql/datasource/mysql/meta_cache.go | 31 ++-- .../sql/datasource/mysql/meta_cache_test.go | 55 ------- .../sql/datasource/mysql/trigger.go | 122 +++++---------- pkg/datasource/sql/db.go | 58 ++++--- pkg/datasource/sql/driver.go | 11 +- pkg/datasource/sql/driver_test.go | 6 +- .../sql/exec/select_for_update_executor.go | 26 ++-- .../sql/mock/mock_datasource_manager.go | 96 ++++++------ pkg/datasource/sql/tx.go | 10 +- pkg/datasource/sql/types/sql.go | 66 ++++++++ pkg/datasource/sql/undo/base/undo.go | 147 +++++++----------- .../builder/mysql_update_undo_log_builder.go | 11 +- .../mysql_update_undo_log_builder_test.go | 9 +- pkg/datasource/sql/undo/executor/executor.go | 3 +- .../executor/mysql_undo_insert_executor.go | 8 +- .../executor/mysql_undo_update_executor.go | 9 +- .../sql/undo/factor/undo_executor_factory.go | 2 +- pkg/datasource/sql/undo/mysql/undo.go | 4 +- pkg/datasource/sql/undo/undo.go | 10 +- pkg/datasource/sql/undo/undo_executor.go | 4 +- pkg/datasource/sql/undo_test.go | 2 +- sample/at/basic/main.go | 1 - sample/at/basic/service.go | 2 +- 31 files changed, 372 insertions(+), 500 deletions(-) delete mode 100644 pkg/datasource/sql/datasource/mysql/default.go delete mode 100644 pkg/datasource/sql/datasource/mysql/meta_cache_test.go diff --git a/go.mod b/go.mod index 306306683..4a2473304 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect - github.com/google/go-cmp v0.5.9 // indirect github.com/gorilla/websocket v1.4.2 // indirect github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect @@ -146,7 +145,6 @@ require ( golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41 // indirect google.golang.org/genproto v0.0.0-20220630174209-ad1d48641aa7 // indirect gopkg.in/ini.v1 v1.62.0 // indirect - gotest.tools v2.2.0+incompatible moul.io/http2curl v1.0.0 // indirect vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d // indirect ) diff --git a/go.sum b/go.sum index 08702f505..85e778226 100644 --- a/go.sum +++ b/go.sum @@ -1425,7 +1425,6 @@ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/datasource/sql/async_worker.go b/pkg/datasource/sql/async_worker.go index 6d1d4a186..056b6a096 100644 --- a/pkg/datasource/sql/async_worker.go +++ b/pkg/datasource/sql/async_worker.go @@ -22,13 +22,14 @@ import ( "flag" "time" + "github.com/seata/seata-go/pkg/rm" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/protocol/branch" - "github.com/seata/seata-go/pkg/protocol/message" "github.com/seata/seata-go/pkg/util/fanout" "github.com/seata/seata-go/pkg/util/log" ) @@ -102,7 +103,7 @@ func NewAsyncWorker(prom prometheus.Registerer, conf AsyncWorkerConfig, sourceMa } // BranchCommit commit branch transaction -func (aw *AsyncWorker) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { +func (aw *AsyncWorker) BranchCommit(ctx context.Context, req rm.BranchResource) (branch.BranchStatus, error) { phaseCtx := phaseTwoContext{ Xid: req.Xid, BranchID: req.BranchId, @@ -170,7 +171,7 @@ func (aw *AsyncWorker) doBranchCommit(phaseCtxs *[]phaseTwoContext) { } func (aw *AsyncWorker) dealWithGroupedContexts(resID string, phaseCtxs []phaseTwoContext) { - val, ok := aw.resourceMgr.GetManagedResources()[resID] + val, ok := aw.resourceMgr.GetCachedResources().Load(resID) if !ok { for i := range phaseCtxs { aw.rePutBackToQueue.Add(1) @@ -180,7 +181,7 @@ func (aw *AsyncWorker) dealWithGroupedContexts(resID string, phaseCtxs []phaseTw } res := val.(*DBResource) - conn, err := res.target.Conn(context.Background()) + conn, err := res.db.Conn(context.Background()) if err != nil { for i := range phaseCtxs { aw.commitQueue <- phaseCtxs[i] diff --git a/pkg/datasource/sql/at.go b/pkg/datasource/sql/at.go index 7578fc7f3..0a38e065f 100644 --- a/pkg/datasource/sql/at.go +++ b/pkg/datasource/sql/at.go @@ -30,7 +30,6 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/protocol/branch" - "github.com/seata/seata-go/pkg/protocol/message" "github.com/seata/seata-go/pkg/rm" ) @@ -52,7 +51,7 @@ func init() { _ = fs.Parse([]string{}) atSourceManager.worker = NewAsyncWorker(prometheus.DefaultRegisterer, asyncWorkerConf, atSourceManager) - datasource.RegisterResourceManager(branch.BranchTypeAT, atSourceManager) + rm.GetRmCacheInstance().RegisterResourceManager(atSourceManager) } type ATSourceManager struct { @@ -62,51 +61,43 @@ type ATSourceManager struct { rmRemoting *rm.RMRemoting } -// Register a Resource to be managed by Resource Manager -func (mgr *ATSourceManager) RegisterResource(res rm.Resource) error { - mgr.resourceCache.Store(res.GetResourceId(), res) - - return mgr.basic.RegisterResource(res) -} - -// Unregister a Resource from the Resource Manager -func (mgr *ATSourceManager) UnregisterResource(res rm.Resource) error { - return mgr.basic.UnregisterResource(res) +func (a *ATSourceManager) GetBranchType() branch.BranchType { + return branch.BranchTypeAT } // Get all resources managed by this manager -func (mgr *ATSourceManager) GetManagedResources() map[string]rm.Resource { - ret := make(map[string]rm.Resource) - - mgr.resourceCache.Range(func(key, value interface{}) bool { - ret[key.(string)] = value.(rm.Resource) - return true - }) - - return ret +func (a *ATSourceManager) GetCachedResources() *sync.Map { + return &a.resourceCache } -// BranchRollback Rollback the corresponding transactions according to the request -func (mgr *ATSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) { - val, ok := mgr.resourceCache.Load(req.ResourceId) +// Register a Resource to be managed by Resource Manager +func (a *ATSourceManager) RegisterResource(res rm.Resource) error { + a.resourceCache.Store(res.GetResourceId(), res) - if !ok { - return branch.BranchStatusPhaseoneFailed, fmt.Errorf("resource %s not found", req.ResourceId) - } + return a.basic.RegisterResource(res) +} - res := val.(*DBResource) +// Unregister a Resource from the Resource Manager +func (a *ATSourceManager) UnregisterResource(res rm.Resource) error { + return a.basic.UnregisterResource(res) +} - undoMgr, err := undo.GetUndoLogManager(res.dbType) - if err != nil { +// Rollback a branch transaction +func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { + var dbResource *DBResource + if resource, ok := a.resourceCache.Load(branchResource.ResourceId); !ok { + err := fmt.Errorf("DB resource is not exist, resourceId: %s", branchResource.ResourceId) return branch.BranchStatusUnknown, err + } else { + dbResource, _ = resource.(*DBResource) } - /*conn, err := res.target.Conn(ctx) + undoMgr, err := undo.GetUndoLogManager(dbResource.dbType) if err != nil { return branch.BranchStatusUnknown, err - }*/ + } - if err := undoMgr.RunUndo(ctx, req.Xid, req.BranchId, res.conn); err != nil { + if err := undoMgr.RunUndo(ctx, branchResource.Xid, branchResource.BranchId, dbResource.db, dbResource.dbName); err != nil { transErr, ok := err.(*types.TransactionError) if !ok { return branch.BranchStatusPhaseoneFailed, err @@ -123,28 +114,28 @@ func (mgr *ATSourceManager) BranchRollback(ctx context.Context, req message.Bran } // BranchCommit -func (mgr *ATSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { - mgr.worker.BranchCommit(ctx, req) +func (a *ATSourceManager) BranchCommit(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) { + a.worker.BranchCommit(ctx, resource) return branch.BranchStatusPhaseoneDone, nil } // LockQuery -func (mgr *ATSourceManager) LockQuery(ctx context.Context, req message.GlobalLockQueryRequest) (bool, error) { +func (a *ATSourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) { return false, nil } // BranchRegister -func (mgr *ATSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { - return mgr.rmRemoting.BranchRegister(req) +func (a *ATSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { + return a.rmRemoting.BranchRegister(req) } // BranchReport -func (mgr *ATSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { +func (a *ATSourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error { return nil } // CreateTableMetaCache -func (mgr *ATSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, +func (a *ATSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { - return mgr.basic.CreateTableMetaCache(ctx, resID, dbType, db) + return a.basic.CreateTableMetaCache(ctx, resID, dbType, db) } diff --git a/pkg/datasource/sql/conn_at.go b/pkg/datasource/sql/conn_at.go index e8b0b4582..4f4c11e81 100644 --- a/pkg/datasource/sql/conn_at.go +++ b/pkg/datasource/sql/conn_at.go @@ -41,7 +41,6 @@ func (c *ATConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, c.txCtx = types.NewTxCtx() }() } - return c.Conn.PrepareContext(ctx, query) } @@ -127,6 +126,7 @@ func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.TxOpt = opts + c.txCtx.ResourceID = c.res.resourceID if tm.IsGlobalTx(ctx) { c.txCtx.XID = tm.GetXID(ctx) @@ -147,6 +147,7 @@ func (c *ATConn) createOnceTxContext(ctx context.Context) bool { if onceTx { c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType + c.txCtx.ResourceID = c.res.resourceID c.txCtx.XID = tm.GetXID(ctx) c.txCtx.TransType = types.ATMode c.txCtx.GlobalLockRequire = true diff --git a/pkg/datasource/sql/datasource/base/meta_cache.go b/pkg/datasource/sql/datasource/base/meta_cache.go index 594d45bb0..e2088a64d 100644 --- a/pkg/datasource/sql/datasource/base/meta_cache.go +++ b/pkg/datasource/sql/datasource/base/meta_cache.go @@ -19,7 +19,7 @@ package base import ( "context" - "database/sql/driver" + "database/sql" "errors" "sync" "time" @@ -30,7 +30,7 @@ import ( type ( // trigger trigger interface { - LoadOne(ctx context.Context, dbName string, table string, conn driver.Conn) (*types.TableMeta, error) + LoadOne(ctx context.Context, dbName string, table string, conn *sql.Conn) (*types.TableMeta, error) LoadAll() ([]types.TableMeta, error) } @@ -134,7 +134,7 @@ func (c *BaseTableMetaCache) scanExpire(ctx context.Context) { } // GetTableMeta -func (c *BaseTableMetaCache) GetTableMeta(ctx context.Context, dbName, tableName string, conn driver.Conn) (types.TableMeta, error) { +func (c *BaseTableMetaCache) GetTableMeta(ctx context.Context, dbName, tableName string, conn *sql.Conn) (types.TableMeta, error) { c.lock.Lock() defer c.lock.Unlock() diff --git a/pkg/datasource/sql/datasource/datasource_manager.go b/pkg/datasource/sql/datasource/datasource_manager.go index c3f173d4c..0d1f4fdcf 100644 --- a/pkg/datasource/sql/datasource/datasource_manager.go +++ b/pkg/datasource/sql/datasource/datasource_manager.go @@ -20,7 +20,6 @@ package datasource import ( "context" "database/sql" - "database/sql/driver" "errors" "sync" @@ -31,33 +30,26 @@ import ( ) var ( - atOnce sync.Once - atMgr DataSourceManager - xaMgr DataSourceManager - solts = map[types.DBType]func() TableMetaCache{} + atOnce sync.Once + tableMetaCacheMap = map[types.DBType]TableMetaCache{} ) // RegisterTableCache -func RegisterTableCache(dbType types.DBType, builder func() TableMetaCache) { - solts[dbType] = builder +func RegisterTableCache(dbType types.DBType, tableMetaCache TableMetaCache) { + tableMetaCacheMap[dbType] = tableMetaCache } -func RegisterResourceManager(b branch.BranchType, d DataSourceManager) { - if b == branch.BranchTypeAT { - atMgr = d - } - - if b == branch.BranchTypeXA { - xaMgr = d - } +func GetTableCache(dbType types.DBType) TableMetaCache { + return tableMetaCacheMap[dbType] } -func GetDataSourceManager(b branch.BranchType) DataSourceManager { - if b == branch.BranchTypeAT { - return atMgr +func GetDataSourceManager(branchType branch.BranchType) DataSourceManager { + resourceManager := rm.GetRmCacheInstance().GetResourceManager(branchType) + if resourceManager == nil { + return nil } - if b == branch.BranchTypeXA { - return xaMgr + if d, ok := resourceManager.(DataSourceManager); ok { + return d } return nil } @@ -65,22 +57,7 @@ func GetDataSourceManager(b branch.BranchType) DataSourceManager { // todo implements ResourceManagerOutbound interface // DataSourceManager type DataSourceManager interface { - // Register a Resource to be managed by Resource Manager - RegisterResource(resource rm.Resource) error - // Unregister a Resource from the Resource Manager - UnregisterResource(resource rm.Resource) error - // GetManagedResources Get all resources managed by this manager - GetManagedResources() map[string]rm.Resource - // BranchRollback - BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) - // BranchCommit - BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) - // LockQuery - LockQuery(ctx context.Context, req message.GlobalLockQueryRequest) (bool, error) - // BranchRegister - BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) - // BranchReport - BranchReport(ctx context.Context, req message.BranchReportRequest) error + rm.ResourceManager // CreateTableMetaCache CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error) } @@ -95,6 +72,7 @@ type BasicSourceManager struct { // lock lock sync.RWMutex // tableMetaCache + // todo do not put meta cache here tableMetaCache map[string]*entry } @@ -178,14 +156,15 @@ type TableMetaCache interface { // Init Init(ctx context.Context, conn *sql.DB) error // GetTableMeta - GetTableMeta(ctx context.Context, dbName, table string, conn driver.Conn) (*types.TableMeta, error) + GetTableMeta(ctx context.Context, dbName, table string) (*types.TableMeta, error) // Destroy Destroy() error } // buildResource +// todo not here func buildResource(ctx context.Context, dbType types.DBType, db *sql.DB) (*entry, error) { - cache := solts[dbType]() + cache := tableMetaCacheMap[dbType] if err := cache.Init(ctx, db); err != nil { return nil, err diff --git a/pkg/datasource/sql/datasource/mysql/default.go b/pkg/datasource/sql/datasource/mysql/default.go deleted file mode 100644 index 28b573f8f..000000000 --- a/pkg/datasource/sql/datasource/mysql/default.go +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mysql - -import ( - "github.com/seata/seata-go/pkg/datasource/sql/datasource" - "github.com/seata/seata-go/pkg/datasource/sql/types" -) - -// todo -func init() { - datasource.RegisterTableCache(types.DBTypeMySQL, func() datasource.TableMetaCache { - return &TableMetaCache{} - }) -} diff --git a/pkg/datasource/sql/datasource/mysql/meta_cache.go b/pkg/datasource/sql/datasource/mysql/meta_cache.go index 1ccae84b0..14ca88a5d 100644 --- a/pkg/datasource/sql/datasource/mysql/meta_cache.go +++ b/pkg/datasource/sql/datasource/mysql/meta_cache.go @@ -20,7 +20,6 @@ package mysql import ( "context" "database/sql" - "database/sql/driver" "sync" "time" @@ -31,24 +30,21 @@ import ( ) var ( - capacity int32 = 1024 - EexpireTime = 15 * time.Minute - tableMetaInstance *TableMetaCache - tableMetaOnce sync.Once + capacity int32 = 1024 + EexpireTime = 15 * time.Minute + tableMetaOnce sync.Once ) type TableMetaCache struct { tableMetaCache *base.BaseTableMetaCache + db *sql.DB } -func GetTableMetaInstance() *TableMetaCache { - // Todo constant.DBName get from config - tableMetaOnce.Do(func() { - tableMetaInstance = &TableMetaCache{ - tableMetaCache: base.NewBaseCache(capacity, EexpireTime, NewMysqlTrigger()), - } - }) - +func NewTableMetaInstance(db *sql.DB) *TableMetaCache { + tableMetaInstance := &TableMetaCache{ + tableMetaCache: base.NewBaseCache(capacity, EexpireTime, NewMysqlTrigger()), + db: db, + } return tableMetaInstance } @@ -58,9 +54,14 @@ func (c *TableMetaCache) Init(ctx context.Context, conn *sql.DB) error { } // GetTableMeta get table info from cache or information schema -func (c *TableMetaCache) GetTableMeta(ctx context.Context, dbName, tableName string, conn driver.Conn) (*types.TableMeta, error) { +func (c *TableMetaCache) GetTableMeta(ctx context.Context, dbName, tableName string) (*types.TableMeta, error) { if tableName == "" { - return nil, errors.New("TableMeta cannot be fetched without tableName") + return nil, errors.New("table name is empty") + } + + conn, err := c.db.Conn(ctx) + if err != nil { + return nil, err } tableMeta, err := c.tableMetaCache.GetTableMeta(ctx, dbName, tableName, conn) diff --git a/pkg/datasource/sql/datasource/mysql/meta_cache_test.go b/pkg/datasource/sql/datasource/mysql/meta_cache_test.go deleted file mode 100644 index 63e89b3a0..000000000 --- a/pkg/datasource/sql/datasource/mysql/meta_cache_test.go +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mysql - -import ( - "context" - "database/sql" - "testing" - - _ "github.com/go-sql-driver/mysql" - "gotest.tools/assert" -) - -// TestGetTableMeta -func TestGetTableMeta(t *testing.T) { - // local test can annotation t.SkipNow() - t.SkipNow() - - testTableMeta := func() { - metaInstance := GetTableMetaInstance() - - db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/seata?multiStatements=true") - if err != nil { - t.Fatal(err) - } - - defer db.Close() - - ctx := context.Background() - - tableMeta, err := metaInstance.GetTableMeta(ctx, "seata_client", "undo_log", nil) - assert.NilError(t, err) - - t.Logf("%+v", tableMeta) - } - - t.Run("testTableMeta", func(t *testing.T) { - testTableMeta() - }) -} diff --git a/pkg/datasource/sql/datasource/mysql/trigger.go b/pkg/datasource/sql/datasource/mysql/trigger.go index af2e2bc65..b9b076bce 100644 --- a/pkg/datasource/sql/datasource/mysql/trigger.go +++ b/pkg/datasource/sql/datasource/mysql/trigger.go @@ -19,9 +19,7 @@ package mysql import ( "context" - "database/sql/driver" - "io" - "reflect" + "database/sql" "strings" "github.com/pkg/errors" @@ -31,7 +29,7 @@ import ( const ( columnMetaSql = "SELECT `TABLE_NAME`, `TABLE_SCHEMA`, `COLUMN_NAME`, `DATA_TYPE`, `COLUMN_TYPE`, `COLUMN_KEY`, `IS_NULLABLE`, `EXTRA` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - indexMetaSql = "SELECT `INDEX_NAME`, `COLUMN_NAME`, `NON_UNIQUE`, `INDEX_TYPE`, `COLLATION`, `CARDINALITY` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + indexMetaSql = "SELECT `INDEX_NAME`, `COLUMN_NAME`, `NON_UNIQUE` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" ) type mysqlTrigger struct { @@ -42,7 +40,7 @@ func NewMysqlTrigger() *mysqlTrigger { } // LoadOne get table meta column and index -func (m *mysqlTrigger) LoadOne(ctx context.Context, dbName string, tableName string, conn driver.Conn) (*types.TableMeta, error) { +func (m *mysqlTrigger) LoadOne(ctx context.Context, dbName string, tableName string, conn *sql.Conn) (*types.TableMeta, error) { tableMeta := types.TableMeta{ TableName: tableName, Columns: make(map[string]types.ColumnMeta), @@ -89,52 +87,53 @@ func (m *mysqlTrigger) LoadAll() ([]types.TableMeta, error) { } // getColumns get tableMeta column -func (m *mysqlTrigger) getColumns(ctx context.Context, dbName string, table string, conn driver.Conn) ([]types.ColumnMeta, error) { +func (m *mysqlTrigger) getColumns(ctx context.Context, dbName string, table string, conn *sql.Conn) ([]types.ColumnMeta, error) { table = executor.DelEscape(table, types.DBTypeMySQL) var columnMetas []types.ColumnMeta - stmt, err := conn.Prepare(columnMetaSql) + stmt, err := conn.PrepareContext(ctx, columnMetaSql) if err != nil { return nil, err } - rowsi, err := stmt.Query([]driver.Value{dbName, table}) + rows, err := stmt.Query(dbName, table) if err != nil { return nil, err } + defer rows.Close() - columnTypes := buildColumnType(rowsi) - i := 0 - - for { - vals := make([]driver.Value, 8) - err = rowsi.Next(vals) - if err == io.EOF { - break - } - if err != nil { - return nil, err - } - + for rows.Next() { var ( - tableName = string(vals[0].([]uint8)) - tableSchema = string(vals[1].([]uint8)) - columnName = string(vals[2].([]uint8)) - dataType = string(vals[3].([]uint8)) - columnType = string(vals[4].([]uint8)) - columnKey = string(vals[5].([]uint8)) - isNullable = string(vals[6].([]uint8)) - extra = string(vals[7].([]uint8)) + tableName string + tableSchema string + columnName string + dataType string + columnType string + columnKey string + isNullable string + extra string ) columnMeta := types.ColumnMeta{} + + if err = rows.Scan( + &tableName, + &tableSchema, + &columnName, + &dataType, + &columnType, + &columnKey, + &isNullable, + &extra); err != nil { + return nil, err + } + columnMeta.Schema = tableSchema columnMeta.Table = tableName columnMeta.ColumnName = strings.Trim(columnName, "` ") columnMeta.DataType = types.GetSqlDataType(dataType) columnMeta.ColumnType = columnType columnMeta.ColumnKey = columnKey - columnMeta.ColumnTypeInfo = *columnTypes[i] if strings.ToLower(isNullable) == "yes" { columnMeta.IsNullable = 1 } else { @@ -144,7 +143,6 @@ func (m *mysqlTrigger) getColumns(ctx context.Context, dbName string, table stri columnMeta.Autoincrement = strings.Contains(strings.ToLower(extra), "auto_increment") columnMetas = append(columnMetas, columnMeta) - i++ } if len(columnMetas) == 0 { @@ -154,73 +152,34 @@ func (m *mysqlTrigger) getColumns(ctx context.Context, dbName string, table stri return columnMetas, nil } -func buildColumnType(rowsi driver.Rows) []*types.ColumnType { - names := rowsi.Columns() - - list := make([]*types.ColumnType, len(names)) - for i := range list { - ci := &types.ColumnType{ - Name: names[i], - } - list[i] = ci - - if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok { - ci.ScanType = prop.ColumnTypeScanType(i) - } else { - ci.ScanType = reflect.TypeOf(new(any)).Elem() - } - if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok { - ci.DatabaseType = prop.ColumnTypeDatabaseTypeName(i) - } - if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok { - ci.Length, ci.HasLength = prop.ColumnTypeLength(i) - } - if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok { - ci.Nullable, ci.HasNullable = prop.ColumnTypeNullable(i) - } - if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok { - ci.Precision, ci.Scale, ci.HasPrecisionScale = prop.ColumnTypePrecisionScale(i) - } - } - return list -} - // getIndex get tableMetaIndex -func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName string, conn driver.Conn) ([]types.IndexMeta, error) { +func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName string, conn *sql.Conn) ([]types.IndexMeta, error) { tableName = executor.DelEscape(tableName, types.DBTypeMySQL) result := make([]types.IndexMeta, 0) - stmt, err := conn.Prepare(indexMetaSql) + stmt, err := conn.PrepareContext(ctx, indexMetaSql) if err != nil { return nil, err } - rowsi, err := stmt.Query([]driver.Value{dbName, tableName}) + rows, err := stmt.Query(dbName, tableName) if err != nil { return nil, err } + defer rows.Close() - defer rowsi.Close() + for rows.Next() { + var ( + indexName string + columnName string + nonUnique int64 + ) - for { - vals := make([]driver.Value, 6) - err = rowsi.Next(vals) - if err == io.EOF { - break - } + err = rows.Scan(&indexName, &columnName, &nonUnique) if err != nil { return nil, err } - var ( - indexName = string(vals[0].([]uint8)) - columnName = string(vals[1].([]uint8)) - nonUnique = vals[2].(int64) - //indexType = string(vals[3].([]uint8)) - //collation = string(vals[4].([]uint8)) - //cardinality = int(vals[6].([]uint8)) - ) - index := types.IndexMeta{ Schema: dbName, Table: tableName, @@ -242,6 +201,7 @@ func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName } result = append(result, index) + } return result, nil diff --git a/pkg/datasource/sql/db.go b/pkg/datasource/sql/db.go index 8ce74e80e..6d9c4bf6a 100644 --- a/pkg/datasource/sql/db.go +++ b/pkg/datasource/sql/db.go @@ -18,14 +18,11 @@ package sql import ( - "context" - gosql "database/sql" - "database/sql/driver" - - "github.com/seata/seata-go/pkg/datasource/sql/undo" + "database/sql" "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/protocol/branch" ) @@ -55,9 +52,15 @@ func withDBType(dt types.DBType) dbOption { } } -func withTarget(source *gosql.DB) dbOption { +func withTarget(source *sql.DB) dbOption { return func(db *DBResource) { - db.target = source + db.db = source + } +} + +func withDBName(dbName string) dbOption { + return func(db *DBResource) { + db.dbName = dbName } } @@ -85,10 +88,9 @@ type DBResource struct { resourceID string // conf conf seataServerConfig - // target - target *gosql.DB - // conn - conn driver.Conn + // db + db *sql.DB + dbName string // dbType dbType types.DBType // undoLogMgr @@ -98,17 +100,22 @@ type DBResource struct { } func (db *DBResource) init() error { - mgr := datasource.GetDataSourceManager(db.GetBranchType()) - metaCache, err := mgr.CreateTableMetaCache(context.Background(), db.resourceID, db.dbType, db.target) - if err != nil { - return err - } - - db.metaCache = metaCache - return nil } +// todo do not put meta data to rm +//func (db *DBResource) init() error { +// mgr := datasource.GetDataSourceManager(db.GetBranchType()) +// metaCache, err := mgr.CreateTableMetaCache(context.Background(), db.resourceID, db.dbType, db.db) +// if err != nil { +// return err +// } +// +// db.metaCache = metaCache +// +// return nil +//} + func (db *DBResource) GetResourceGroupId() string { return db.groupID } @@ -120,3 +127,16 @@ func (db *DBResource) GetResourceId() string { func (db *DBResource) GetBranchType() branch.BranchType { return db.conf.BranchType } + +type SqlDBProxy struct { + db *sql.DB + dbName string +} + +func (s *SqlDBProxy) GetDB() *sql.DB { + return s.db +} + +func (s *SqlDBProxy) GetDBName() string { + return s.dbName +} diff --git a/pkg/datasource/sql/driver.go b/pkg/datasource/sql/driver.go index d2df87ada..6ce0b703b 100644 --- a/pkg/datasource/sql/driver.go +++ b/pkg/datasource/sql/driver.go @@ -24,6 +24,8 @@ import ( "fmt" "strings" + mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" + "github.com/go-sql-driver/mysql" "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/types" @@ -101,7 +103,7 @@ type seataDriver struct { func (d *seataDriver) Open(name string) (driver.Conn, error) { conn, err := d.target.Open(name) if err != nil { - log.Errorf("open target connection: %w", err) + log.Errorf("open db connection: %w", err) return nil, err } @@ -161,12 +163,14 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db * return connector, err } + cfg, _ := mysql.ParseDSN(dataSourceName) options := []dbOption{ withGroupID(conf.GroupID), withResourceID(parseResourceID(dataSourceName)), withConf(conf), withTarget(db), withDBType(dbType), + withDBName(cfg.DBName), } res, err := newResource(options...) @@ -175,11 +179,11 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db * return nil, err } + datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db)) if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil { log.Errorf("regisiter resource: %w", err) return nil, err } - cfg, _ := mysql.ParseDSN(dataSourceName) return &seataConnector{ res: res, @@ -222,12 +226,9 @@ func loadConfig() *seataServerConfig { func parseResourceID(dsn string) string { i := strings.Index(dsn, "?") - res := dsn - if i > 0 { res = dsn[:i] } - return strings.ReplaceAll(res, ",", "|") } diff --git a/pkg/datasource/sql/driver_test.go b/pkg/datasource/sql/driver_test.go index 766a3f5e9..8e6715b49 100644 --- a/pkg/datasource/sql/driver_test.go +++ b/pkg/datasource/sql/driver_test.go @@ -21,21 +21,19 @@ import ( "context" "database/sql" "database/sql/driver" + "github.com/seata/seata-go/pkg/rm" "reflect" "testing" "github.com/golang/mock/gomock" - "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/mock" - "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/util/reflectx" "github.com/stretchr/testify/assert" ) func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDataSourceManager { mockResourceMgr := mock.NewMockDataSourceManager(ctrl) - datasource.RegisterResourceManager(branch.BranchTypeAT, mockResourceMgr) - + rm.GetRmCacheInstance().RegisterResourceManager(mockResourceMgr) mockResourceMgr.EXPECT().RegisterResource(gomock.Any()).AnyTimes().Return(nil) mockResourceMgr.EXPECT().CreateTableMetaCache(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) diff --git a/pkg/datasource/sql/exec/select_for_update_executor.go b/pkg/datasource/sql/exec/select_for_update_executor.go index 5bd8508ce..961d0733f 100644 --- a/pkg/datasource/sql/exec/select_for_update_executor.go +++ b/pkg/datasource/sql/exec/select_for_update_executor.go @@ -32,7 +32,7 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo/builder" "github.com/seata/seata-go/pkg/protocol/branch" - "github.com/seata/seata-go/pkg/protocol/message" + "github.com/seata/seata-go/pkg/rm" seatabytes "github.com/seata/seata-go/pkg/util/bytes" "github.com/seata/seata-go/pkg/util/log" ) @@ -124,13 +124,11 @@ func (s SelectForUpdateExecutor) ExecWithNamedValue(ctx context.Context, execCtx break } // check global lock - lockable, err := datasource.GetDataSourceManager(branch.BranchTypeAT).LockQuery(ctx, message.GlobalLockQueryRequest{ - BranchRegisterRequest: message.BranchRegisterRequest{ - Xid: execCtx.TxCtx.XID, - BranchType: branch.BranchTypeAT, - ResourceId: execCtx.TxCtx.ResourceID, - LockKey: lockKey, - }, + lockable, err := datasource.GetDataSourceManager(branch.BranchTypeAT).LockQuery(ctx, rm.LockQueryParam{ + Xid: execCtx.TxCtx.XID, + BranchType: branch.BranchTypeAT, + ResourceId: execCtx.TxCtx.ResourceID, + LockKeys: lockKey, }) // if obtained global lock @@ -235,13 +233,11 @@ func (s SelectForUpdateExecutor) ExecWithValue(ctx context.Context, execCtx *typ break } // check global lock - lockable, err := datasource.GetDataSourceManager(branch.BranchTypeAT).LockQuery(ctx, message.GlobalLockQueryRequest{ - BranchRegisterRequest: message.BranchRegisterRequest{ - Xid: execCtx.TxCtx.XID, - BranchType: branch.BranchTypeAT, - ResourceId: execCtx.TxCtx.ResourceID, - LockKey: lockKey, - }, + lockable, err := datasource.GetDataSourceManager(branch.BranchTypeAT).LockQuery(ctx, rm.LockQueryParam{ + Xid: execCtx.TxCtx.XID, + BranchType: branch.BranchTypeAT, + ResourceId: execCtx.TxCtx.ResourceID, + LockKeys: lockKey, }) // has obtained global lock diff --git a/pkg/datasource/sql/mock/mock_datasource_manager.go b/pkg/datasource/sql/mock/mock_datasource_manager.go index 54d1e9211..81fc3c02e 100644 --- a/pkg/datasource/sql/mock/mock_datasource_manager.go +++ b/pkg/datasource/sql/mock/mock_datasource_manager.go @@ -1,22 +1,5 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - // Code generated by MockGen. DO NOT EDIT. -// Source: ../datasource/datasource_manager.go +// Source: datasource_manager.go // Package mock is a generated GoMock package. package mock @@ -25,12 +8,12 @@ import ( context "context" sql "database/sql" reflect "reflect" + sync "sync" gomock "github.com/golang/mock/gomock" datasource "github.com/seata/seata-go/pkg/datasource/sql/datasource" types "github.com/seata/seata-go/pkg/datasource/sql/types" branch "github.com/seata/seata-go/pkg/protocol/branch" - message "github.com/seata/seata-go/pkg/protocol/message" rm "github.com/seata/seata-go/pkg/rm" ) @@ -58,62 +41,62 @@ func (m *MockDataSourceManager) EXPECT() *MockDataSourceManagerMockRecorder { } // BranchCommit mocks base method. -func (m *MockDataSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { +func (m *MockDataSourceManager) BranchCommit(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BranchCommit", ctx, req) + ret := m.ctrl.Call(m, "BranchCommit", ctx, resource) ret0, _ := ret[0].(branch.BranchStatus) ret1, _ := ret[1].(error) return ret0, ret1 } // BranchCommit indicates an expected call of BranchCommit. -func (mr *MockDataSourceManagerMockRecorder) BranchCommit(ctx, req interface{}) *gomock.Call { +func (mr *MockDataSourceManagerMockRecorder) BranchCommit(ctx, resource interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockDataSourceManager)(nil).BranchCommit), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockDataSourceManager)(nil).BranchCommit), ctx, resource) } // BranchRegister mocks base method. -func (m *MockDataSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { +func (m *MockDataSourceManager) BranchRegister(ctx context.Context, param rm.BranchRegisterParam) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BranchRegister", ctx, req) + ret := m.ctrl.Call(m, "BranchRegister", ctx, param) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // BranchRegister indicates an expected call of BranchRegister. -func (mr *MockDataSourceManagerMockRecorder) BranchRegister(ctx, clientId, req interface{}) *gomock.Call { +func (mr *MockDataSourceManagerMockRecorder) BranchRegister(ctx, param interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRegister), ctx, clientId, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRegister), ctx, param) } // BranchReport mocks base method. -func (m *MockDataSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { +func (m *MockDataSourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BranchReport", ctx, req) + ret := m.ctrl.Call(m, "BranchReport", ctx, param) ret0, _ := ret[0].(error) return ret0 } // BranchReport indicates an expected call of BranchReport. -func (mr *MockDataSourceManagerMockRecorder) BranchReport(ctx, req interface{}) *gomock.Call { +func (mr *MockDataSourceManagerMockRecorder) BranchReport(ctx, param interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockDataSourceManager)(nil).BranchReport), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockDataSourceManager)(nil).BranchReport), ctx, param) } // BranchRollback mocks base method. -func (m *MockDataSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) { +func (m *MockDataSourceManager) BranchRollback(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BranchRollback", ctx, req) + ret := m.ctrl.Call(m, "BranchRollback", ctx, resource) ret0, _ := ret[0].(branch.BranchStatus) ret1, _ := ret[1].(error) return ret0, ret1 } // BranchRollback indicates an expected call of BranchRollback. -func (mr *MockDataSourceManagerMockRecorder) BranchRollback(ctx, req interface{}) *gomock.Call { +func (mr *MockDataSourceManagerMockRecorder) BranchRollback(ctx, resource interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRollback", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRollback), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRollback", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRollback), ctx, resource) } // CreateTableMetaCache mocks base method. @@ -131,33 +114,44 @@ func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, db return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db) } -// GetManagedResources mocks base method. -func (m *MockDataSourceManager) GetManagedResources() map[string]rm.Resource { +// GetBranchType mocks base method. +func (m *MockDataSourceManager) GetBranchType() branch.BranchType { + return branch.BranchTypeAT +} + +// GetBranchType indicates an expected call of GetBranchType. +func (mr *MockDataSourceManagerMockRecorder) GetBranchType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBranchType", reflect.TypeOf((*MockDataSourceManager)(nil).GetBranchType)) +} + +// GetCachedResources mocks base method. +func (m *MockDataSourceManager) GetCachedResources() *sync.Map { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetManagedResources") - ret0, _ := ret[0].(map[string]rm.Resource) + ret := m.ctrl.Call(m, "GetCachedResources") + ret0, _ := ret[0].(*sync.Map) return ret0 } -// GetManagedResources indicates an expected call of GetManagedResources. -func (mr *MockDataSourceManagerMockRecorder) GetManagedResources() *gomock.Call { +// GetCachedResources indicates an expected call of GetCachedResources. +func (mr *MockDataSourceManagerMockRecorder) GetCachedResources() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedResources", reflect.TypeOf((*MockDataSourceManager)(nil).GetManagedResources)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCachedResources", reflect.TypeOf((*MockDataSourceManager)(nil).GetCachedResources)) } // LockQuery mocks base method. -func (m *MockDataSourceManager) LockQuery(ctx context.Context, req message.GlobalLockQueryRequest) (bool, error) { +func (m *MockDataSourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LockQuery", ctx, req) + ret := m.ctrl.Call(m, "LockQuery", ctx, param) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // LockQuery indicates an expected call of LockQuery. -func (mr *MockDataSourceManagerMockRecorder) LockQuery(ctx, req interface{}) *gomock.Call { +func (mr *MockDataSourceManagerMockRecorder) LockQuery(ctx, param interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockDataSourceManager)(nil).LockQuery), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockDataSourceManager)(nil).LockQuery), ctx, param) } // RegisterResource mocks base method. @@ -226,18 +220,18 @@ func (mr *MockTableMetaCacheMockRecorder) Destroy() *gomock.Call { } // GetTableMeta mocks base method. -func (m *MockTableMetaCache) GetTableMeta(table string) (types.TableMeta, error) { +func (m *MockTableMetaCache) GetTableMeta(ctx context.Context, dbName, table string) (*types.TableMeta, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTableMeta", table) - ret0, _ := ret[0].(types.TableMeta) + ret := m.ctrl.Call(m, "GetTableMeta", ctx, dbName, table) + ret0, _ := ret[0].(*types.TableMeta) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTableMeta indicates an expected call of GetTableMeta. -func (mr *MockTableMetaCacheMockRecorder) GetTableMeta(table interface{}) *gomock.Call { +func (mr *MockTableMetaCacheMockRecorder) GetTableMeta(ctx, dbName, table interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMeta", reflect.TypeOf((*MockTableMetaCache)(nil).GetTableMeta), table) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMeta", reflect.TypeOf((*MockTableMetaCache)(nil).GetTableMeta), ctx, dbName, table) } // Init mocks base method. diff --git a/pkg/datasource/sql/tx.go b/pkg/datasource/sql/tx.go index 0b01ff066..a77e34b3d 100644 --- a/pkg/datasource/sql/tx.go +++ b/pkg/datasource/sql/tx.go @@ -24,7 +24,6 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/protocol/branch" - "github.com/seata/seata-go/pkg/protocol/message" "github.com/seata/seata-go/pkg/rm" "github.com/seata/seata-go/pkg/util/log" @@ -175,11 +174,10 @@ func (tx *Tx) report(success bool) error { return nil } status := getStatus(success) - request := message.BranchReportRequest{ - Xid: tx.tranCtx.XID, - BranchId: int64(tx.tranCtx.BranchID), - ResourceId: tx.tranCtx.ResourceID, - Status: status, + request := rm.BranchReportParam{ + Xid: tx.tranCtx.XID, + BranchId: int64(tx.tranCtx.BranchID), + Status: status, } dataSourceManager := datasource.GetDataSourceManager(branch.BranchType(tx.tranCtx.TransType)) retry := REPORT_RETRY_COUNT diff --git a/pkg/datasource/sql/types/sql.go b/pkg/datasource/sql/types/sql.go index 40f60d996..3189ce646 100644 --- a/pkg/datasource/sql/types/sql.go +++ b/pkg/datasource/sql/types/sql.go @@ -120,3 +120,69 @@ func (s SQLType) MarshalText() (text []byte, err error) { } return []byte("INVALID_SQLTYPE"), nil } + +func (s *SQLType) UnmarshalText(b []byte) error { + switch string(b) { + case "SELECT": + *s = SQLTypeSelect + case "INSERT": + *s = SQLTypeInsert + case "UPDATE": + *s = SQLTypeUpdate + case "DELETE": + *s = SQLTypeDelete + case "SELECT_FOR_UPDATE": + *s = SQLTypeSelectForUpdate + case "REPLACE": + *s = SQLTypeReplace + case "TRUNCATE": + *s = SQLTypeTruncate + case "CREATE": + *s = SQLTypeCreate + case "DROP": + *s = SQLTypeDrop + case "LOAD": + *s = SQLTypeLoad + case "MERGE": + *s = SQLTypeMerge + case "SHOW": + *s = SQLTypeShow + case "ALTER": + *s = SQLTypeAlter + case "RENAME": + *s = SQLTypeRename + case "DUMP": + *s = SQLTypeDump + case "DEBUG": + *s = SQLTypeDebug + case "EXPLAIN": + *s = SQLTypeExplain + case "DESC": + *s = SQLTypeDesc + case "SET": + *s = SQLTypeSet + case "RELOAD": + *s = SQLTypeReload + case "SELECT_UNION": + *s = SQLTypeSelectUnion + case "CREATE_TABLE": + *s = SQLTypeCreateTable + case "DROP_TABLE": + *s = SQLTypeDropTable + case "ALTER_TABLE": + *s = SQLTypeAlterTable + case "SELECT_FROM_UPDATE": + *s = SQLTypeSelectFromUpdate + case "MULTI_DELETE": + *s = SQLTypeMultiDelete + case "MULTI_UPDATE": + *s = SQLTypeMultiUpdate + case "CREATE_INDEX": + *s = SQLTypeCreateIndex + case "DROP_INDEX": + *s = SQLTypeDropIndex + case "MULTI": + *s = SQLTypeMulti + } + return nil +} diff --git a/pkg/datasource/sql/undo/base/undo.go b/pkg/datasource/sql/undo/base/undo.go index 7b9b6b7c9..2b0e929b2 100644 --- a/pkg/datasource/sql/undo/base/undo.go +++ b/pkg/datasource/sql/undo/base/undo.go @@ -22,17 +22,13 @@ import ( "database/sql" "database/sql/driver" "encoding/json" - "fmt" "strconv" "strings" - "github.com/seata/seata-go/pkg/util/convert" - "github.com/arana-db/parser/mysql" "github.com/pkg/errors" - "github.com/seata/seata-go/pkg/constant" - dataSourceMysql "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/datasource/sql/undo/factor" @@ -47,6 +43,7 @@ var ( var ( checkUndoLogTableExistSql = "SELECT 1 FROM " + constant.UndoLogTableName + " LIMIT 1" insertUndoLogSql = "INSERT INTO " + constant.UndoLogTableName + "(branch_id,xid,context,rollback_info,log_status,log_created,log_modified) VALUES (?, ?, ?, ?, ?, now(6), now(6))" + selectUndoLogSql = "SELECT `branch_id`,`xid`,`context`,`rollback_info`,`log_status` FROM " + constant.UndoLogTableName + " WHERE " + constant.UndoLogBranchXid + " = ? AND " + constant.UndoLogXid + " = ? FOR UPDATE" ) const ( @@ -60,9 +57,6 @@ const ( CheckUndoLogTableExistSql = "SELECT 1 FROM " + constant.UndoLogTableName + " LIMIT 1" // DeleteUndoLogSql delete undo log DeleteUndoLogSql = constant.DeleteFrom + constant.UndoLogTableName + " WHERE " + constant.UndoLogBranchXid + " = ? AND " + constant.UndoLogXid + " = ?" - - // UndoLog Todo get from config - Seata = "seata" ) // undo log status @@ -211,17 +205,13 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con } // RunUndo undo sql -func (m *BaseUndoLogManager) RunUndo(ctx context.Context, xid string, branchID int64, conn driver.Conn) error { +func (m *BaseUndoLogManager) RunUndo(ctx context.Context, xid string, branchID int64, conn *sql.DB, dbName string) error { return nil } // Undo undo sql -func (m *BaseUndoLogManager) Undo(ctx context.Context, dbType types.DBType, - xid string, branchID int64, conn driver.Conn) error { - - var branchUndoLogs []undo.BranchUndoLog - - tx, err := conn.Begin() +func (m *BaseUndoLogManager) Undo(ctx context.Context, dbType types.DBType, xid string, branchID int64, db *sql.DB, dbName string) error { + tx, err := db.Begin() if err != nil { return err } @@ -229,128 +219,96 @@ func (m *BaseUndoLogManager) Undo(ctx context.Context, dbType types.DBType, defer func() { if err != nil { if err = tx.Rollback(); err != nil { - log.Errorf("[RunUndo] rollback fail, xid: %s, branchID:%s err:%v", xid, branchID, err) + log.Errorf("rollback fail, xid: %s, branchID:%s err:%v", xid, branchID, err) return } } }() - selectUndoLogSql := "SELECT `log_status`,`context`,`rollback_info` FROM " + constant.UndoLogTableName + " WHERE " + constant.UndoLogBranchXid + " = ? AND " + constant.UndoLogXid + " = ? FOR UPDATE" - stmt, err := conn.Prepare(selectUndoLogSql) + conn, err := db.Conn(ctx) if err != nil { - log.Errorf("[Undo] prepare sql fail, err: %v", err) return err } - + stmt, err := conn.PrepareContext(ctx, selectUndoLogSql) + if err != nil { + log.Errorf("prepare sql fail, err: %v", err) + return err + } defer func() { if err = stmt.Close(); err != nil { - log.Errorf("[RunUndo] stmt close fail, xid: %s, branchID:%s err:%v", xid, branchID, err) + log.Errorf("stmt close fail, xid: %s, branchID:%s err:%v", xid, branchID, err) return } }() - rows, err := stmt.Query([]driver.Value{branchID, xid}) + rows, err := stmt.Query(branchID, xid) if err != nil { - log.Errorf("[Undo] query sql fail, err: %v", err) + log.Errorf("query sql fail, err: %v", err) return err } - - var ( - //logStatus string - //contextx string - //rollbackInfo []byte - logStatus sql.NullInt32 - contextx sql.NullString - rollbackInfo sql.RawBytes - ) - vals := make([]driver.Value, 5) - dest := []interface{}{&logStatus, &contextx, &rollbackInfo} - - exist := false - for { - if err = rows.Next(vals); err != nil { - break - } - - exist = true - - for i, sv := range vals { - err := convert.ConvertAssignRows(dest[i], sv) - if err != nil { - return fmt.Errorf(`sql: Scan error on column index %d, name %q: %v`, i, rows.Columns()[i], err) - } + defer func() { + if err = rows.Close(); err != nil { + log.Errorf("rows close fail, xid: %s, branchID:%s err:%v", xid, branchID, err) + return } + }() - /*if err = rows.Scan(&logStatus, &contextx, &rollbackInfo); err != nil { - log.Errorf("[Undo] get log status fail, err: %v", err) + var undoLogRecords []undo.UndologRecord + for rows.Next() { + var record undo.UndologRecord + err = rows.Scan(&record.BranchID, &record.XID, &record.Context, &record.RollbackInfo, &record.LogStatus) + if err != nil { return err } + undoLogRecords = append(undoLogRecords, record) + } - state, _ := strconv.Atoi(logStatus)*/ - - // check if it can undo - if !m.canUndo(logStatus.Int32) { - return nil + for _, record := range undoLogRecords { + if !record.CanUndo() { + log.Infof("xid %v branch %v, ignore %v undo_log", record.XID, record.BranchID, record.LogStatus) + continue } - // Todo pr 242 调用对应的 parser 方法 - /*contextMap := m.parseContext(context) - rollbackInfo := m.getRollbackInfo(rollbackInfo, contextMap) - serializer := m.getSerializer(contextMap) - branchUndoLog = parser.decode(rollbackInfo); - */ - - // Todo 替换成 parser 解析器解析 + // todo use serializer and decode var branchUndoLog undo.BranchUndoLog - if cErr := json.Unmarshal(rollbackInfo, &branchUndoLog); cErr != nil { - return cErr + if err = json.Unmarshal(record.RollbackInfo, &branchUndoLog); err != nil { + return err } - branchUndoLogs = append(branchUndoLogs, branchUndoLog) - } - - /*if err = rows.Err(); err != nil { - return err - }*/ - - if err = rows.Close(); err != nil { - return err - } - - for _, branchUndoLog := range branchUndoLogs { sqlUndoLogs := branchUndoLog.Logs - if len(sqlUndoLogs) > 1 { - branchUndoLog.Reverse() + if len(sqlUndoLogs) == 0 { + return nil } + branchUndoLog.Reverse() for _, undoLog := range sqlUndoLogs { - tableMeta, cErr := dataSourceMysql.GetTableMetaInstance().GetTableMeta(ctx, Seata, undoLog.TableName, conn) - if cErr != nil { - log.Errorf("[Undo] get table meta fail, err: %v", cErr) - return cErr + tableMeta, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, dbName, undoLog.TableName) + if err != nil { + log.Errorf("get table meta fail, err: %v", err) + return err } undoLog.SetTableMeta(*tableMeta) - undoExecutor, cErr := factor.GetUndoExecutor(dbType, undoLog) - if cErr != nil { - log.Errorf("[Undo] get undo executor, err: %v", cErr) - return cErr + undoExecutor, err := factor.GetUndoExecutor(dbType, undoLog) + if err != nil { + log.Errorf("get undo executor, err: %v", err) + return err } if err = undoExecutor.ExecuteOn(ctx, dbType, undoLog, conn); err != nil { - log.Errorf("[Undo] execute on fail, err: %v", err) + log.Errorf("execute on fail, err: %v", err) return err } } } - if exist { - if err = m.DeleteUndoLog(ctx, xid, branchID, conn); err != nil { - log.Errorf("[Undo] delete undo log fail, err: %v", err) - return err - } - } + //if exist { + // if err = m.DeleteUndoLog(ctx, xid, branchID, conn); err != nil { + // log.Errorf("[Undo] delete undo log fail, err: %v", err) + // return err + // } + //} // Todo 等 insertLog 合并后加上 insertUndoLogWithGlobalFinished 功能 /*else { @@ -490,7 +448,6 @@ func (m *BaseUndoLogManager) getSerializer(undoLogContext map[string]string) (se if undoLogContext == nil { return } - serializer, _ = undoLogContext[SerializerKey] return } diff --git a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go index ba0312991..87ca69a96 100644 --- a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go @@ -23,11 +23,10 @@ import ( "fmt" "strings" - "github.com/arana-db/parser/model" - "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" - "github.com/arana-db/parser/ast" "github.com/arana-db/parser/format" + "github.com/arana-db/parser/model" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" @@ -72,7 +71,7 @@ func (u *MySQLUpdateUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *ty } tableName, _ := execCtx.ParseContext.GteTableName() - metaData, err := mysql.GetTableMetaInstance().GetTableMeta(ctx, execCtx.DBName, tableName, execCtx.Conn) + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, execCtx.DBName, tableName) if err != nil { return nil, err } @@ -112,7 +111,7 @@ func (u *MySQLUpdateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *typ } tableName, _ := execCtx.ParseContext.GteTableName() - metaData, err := mysql.GetTableMetaInstance().GetTableMeta(ctx, execCtx.DBName, tableName, execCtx.Conn) + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, execCtx.DBName, tableName) if err != nil { return nil, err } @@ -172,7 +171,7 @@ func (u *MySQLUpdateUndoLogBuilder) buildBeforeImageSQL(ctx context.Context, exe // select indexes columns tableName, _ := execCtx.ParseContext.GteTableName() - metaData, err := mysql.GetTableMetaInstance().GetTableMeta(ctx, execCtx.DBName, tableName, execCtx.Conn) + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, execCtx.DBName, tableName) if err != nil { return "", nil, err } diff --git a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder_test.go b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder_test.go index d0e742c64..11bb44035 100644 --- a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder_test.go +++ b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder_test.go @@ -20,6 +20,7 @@ package builder import ( "context" "database/sql/driver" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "reflect" "testing" @@ -39,12 +40,8 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { var ( builder = MySQLUpdateUndoLogBuilder{} ) - //stub := gomonkey.ApplyMethod(reflect.TypeOf(mysql.GetTableMetaInstance()), "GetTableMeta", func(_ *datasource.TableMetaCache, ctx context.Context, dbName, tableName string, conn driver.Conn) (*types.TableMeta, error) { - // return &types.TableMeta{ - // - // }, nil - //}) - stub := gomonkey.ApplyMethod(reflect.TypeOf(mysql.GetTableMetaInstance()), "GetTableMeta", func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string, conn driver.Conn) (*types.TableMeta, error) { + + stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta", func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string, conn driver.Conn) (*types.TableMeta, error) { return &types.TableMeta{ Indexs: map[string]types.IndexMeta{ "id": { diff --git a/pkg/datasource/sql/undo/executor/executor.go b/pkg/datasource/sql/undo/executor/executor.go index 945aad542..79e383226 100644 --- a/pkg/datasource/sql/undo/executor/executor.go +++ b/pkg/datasource/sql/undo/executor/executor.go @@ -20,7 +20,6 @@ package executor import ( "context" "database/sql" - "database/sql/driver" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" @@ -32,7 +31,7 @@ type BaseExecutor struct { } // ExecuteOn -func (b *BaseExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, sqlUndoLog undo.SQLUndoLog, conn driver.Conn) error { +func (b *BaseExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, sqlUndoLog undo.SQLUndoLog, conn *sql.Conn) error { // check data if valid return nil } diff --git a/pkg/datasource/sql/undo/executor/mysql_undo_insert_executor.go b/pkg/datasource/sql/undo/executor/mysql_undo_insert_executor.go index 73ac43e42..3a9661f74 100644 --- a/pkg/datasource/sql/undo/executor/mysql_undo_insert_executor.go +++ b/pkg/datasource/sql/undo/executor/mysql_undo_insert_executor.go @@ -19,7 +19,7 @@ package executor import ( "context" - "database/sql/driver" + "database/sql" "errors" "fmt" @@ -38,7 +38,7 @@ func NewMySQLUndoInsertExecutor() *MySQLUndoInsertExecutor { // ExecuteOn execute insert undo logic func (m *MySQLUndoInsertExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, - sqlUndoLog undo.SQLUndoLog, conn driver.Conn) error { + sqlUndoLog undo.SQLUndoLog, conn *sql.Conn) error { if err := m.BaseExecutor.ExecuteOn(ctx, dbType, sqlUndoLog, conn); err != nil { return err @@ -47,7 +47,7 @@ func (m *MySQLUndoInsertExecutor) ExecuteOn(ctx context.Context, dbType types.DB // build delete sql undoSql, _ := m.buildUndoSQL(dbType, sqlUndoLog) - stmt, err := conn.Prepare(undoSql) + stmt, err := conn.PrepareContext(ctx, undoSql) if err != nil { return err } @@ -62,7 +62,7 @@ func (m *MySQLUndoInsertExecutor) ExecuteOn(ctx context.Context, dbType types.DB } } - if _, err = stmt.Exec([]driver.Value{pkValueList}); err != nil { + if _, err = stmt.Exec(pkValueList); err != nil { return err } } diff --git a/pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go b/pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go index 8afdd9a7f..b0329b914 100644 --- a/pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go +++ b/pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go @@ -19,7 +19,7 @@ package executor import ( "context" - "database/sql/driver" + "database/sql" "fmt" "strings" @@ -36,13 +36,12 @@ func NewMySQLUndoUpdateExecutor() *MySQLUndoUpdateExecutor { return &MySQLUndoUpdateExecutor{} } -func (m *MySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, - sqlUndoLog undo.SQLUndoLog, conn driver.Conn) error { +func (m *MySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, sqlUndoLog undo.SQLUndoLog, conn *sql.Conn) error { //m.BaseExecutor.ExecuteOn(ctx, dbType, sqlUndoLog, conn) undoSql, _ := m.buildUndoSQL(dbType, sqlUndoLog) - stmt, err := conn.Prepare(undoSql) + stmt, err := conn.PrepareContext(ctx, undoSql) if err != nil { return err } @@ -66,7 +65,7 @@ func (m *MySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DB undoValues = append(undoValues, col.Value) } - if _, err = stmt.Exec([]driver.Value{undoValues}); err != nil { + if _, err = stmt.Exec(undoValues); err != nil { return err } } diff --git a/pkg/datasource/sql/undo/factor/undo_executor_factory.go b/pkg/datasource/sql/undo/factor/undo_executor_factory.go index 95e712a24..2c4111c83 100644 --- a/pkg/datasource/sql/undo/factor/undo_executor_factory.go +++ b/pkg/datasource/sql/undo/factor/undo_executor_factory.go @@ -38,7 +38,7 @@ func GetUndoExecutor(dbType types.DBType, sqlUndoLog undo.SQLUndoLog) (res undo. case types.SQLTypeDelete: res = undoExecutorHolder.GetDeleteExecutor(sqlUndoLog) case types.SQLTypeUpdate: - res = undoExecutorHolder.GetDeleteExecutor(sqlUndoLog) + res = undoExecutorHolder.GetUpdateExecutor(sqlUndoLog) default: return nil, fmt.Errorf("sql type: %d not support", sqlUndoLog.SQLType) } diff --git a/pkg/datasource/sql/undo/mysql/undo.go b/pkg/datasource/sql/undo/mysql/undo.go index fc62b8702..67b8942a6 100644 --- a/pkg/datasource/sql/undo/mysql/undo.go +++ b/pkg/datasource/sql/undo/mysql/undo.go @@ -57,8 +57,8 @@ func (m *undoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, conn dr } // RunUndo undo sql -func (m *undoLogManager) RunUndo(ctx context.Context, xid string, branchID int64, conn driver.Conn) error { - return m.Base.Undo(ctx, m.DBType(), xid, branchID, conn) +func (m *undoLogManager) RunUndo(ctx context.Context, xid string, branchID int64, db *sql.DB, dbName string) error { + return m.Base.Undo(ctx, m.DBType(), xid, branchID, db, dbName) } // DBType diff --git a/pkg/datasource/sql/undo/undo.go b/pkg/datasource/sql/undo/undo.go index 633e19416..e7cc08b86 100644 --- a/pkg/datasource/sql/undo/undo.go +++ b/pkg/datasource/sql/undo/undo.go @@ -72,7 +72,7 @@ type UndoLogManager interface { //FlushUndoLog FlushUndoLog(tranCtx *types.TransactionContext, conn driver.Conn) error // RunUndo - RunUndo(ctx context.Context, xid string, branchID int64, conn driver.Conn) error + RunUndo(ctx context.Context, xid string, branchID int64, conn *sql.DB, dbName string) error // DBType DBType() types.DBType // HasUndoLogTable @@ -107,8 +107,12 @@ type UndologRecord struct { Context []byte `json:"context"` RollbackInfo []byte `json:"rollbackInfo"` LogStatus UndoLogStatue `json:"logStatus"` - LogCreated uint64 `json:"logCreated"` - LogModified uint64 `json:"logModified"` + LogCreated []byte `json:"logCreated"` + LogModified []byte `json:"logModified"` +} + +func (u *UndologRecord) CanUndo() bool { + return u.LogStatus == UndoLogStatueNormnal } // BranchUndoLog diff --git a/pkg/datasource/sql/undo/undo_executor.go b/pkg/datasource/sql/undo/undo_executor.go index ab003f152..1728e6b0b 100644 --- a/pkg/datasource/sql/undo/undo_executor.go +++ b/pkg/datasource/sql/undo/undo_executor.go @@ -19,11 +19,11 @@ package undo import ( "context" - "database/sql/driver" + "database/sql" "github.com/seata/seata-go/pkg/datasource/sql/types" ) type UndoExecutor interface { - ExecuteOn(ctx context.Context, dbType types.DBType, sqlUndoLog SQLUndoLog, conn driver.Conn) error + ExecuteOn(ctx context.Context, dbType types.DBType, sqlUndoLog SQLUndoLog, conn *sql.Conn) error } diff --git a/pkg/datasource/sql/undo_test.go b/pkg/datasource/sql/undo_test.go index 5aef7148d..497add99b 100644 --- a/pkg/datasource/sql/undo_test.go +++ b/pkg/datasource/sql/undo_test.go @@ -118,7 +118,7 @@ func TestUndo(t *testing.T) { _ = sqlConn.Close() }() - if err = manager.RunUndo(ctx, "1", 1, nil); err != nil { + if err = manager.RunUndo(ctx, "1", 1, nil, ""); err != nil { t.Logf("%+v", err) } diff --git a/sample/at/basic/main.go b/sample/at/basic/main.go index d517b9240..2c60f0d8d 100644 --- a/sample/at/basic/main.go +++ b/sample/at/basic/main.go @@ -38,7 +38,6 @@ type OrderTbl struct { func main() { client.Init() initService() - selectData() tm.WithGlobalTx(context.Background(), &tm.TransactionInfo{ Name: "ATSampleLocalGlobalTx", TimeOut: time.Second * 30, diff --git a/sample/at/basic/service.go b/sample/at/basic/service.go index 9ea0da774..6e5ca6542 100644 --- a/sample/at/basic/service.go +++ b/sample/at/basic/service.go @@ -29,7 +29,7 @@ var ( func initService() { var err error - db, err = sql.Open(sql2.SeataATMySQLDriver, "root:123456@tcp(127.0.0.1:3306)/seata_client?multiStatements=true&interpolateParams=true") + db, err = sql.Open(sql2.SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true&interpolateParams=true") if err != nil { panic("init service error") }