diff --git a/dm/pkg/schema/tracker.go b/dm/pkg/schema/tracker.go index 343fffc0806..6ac63128732 100644 --- a/dm/pkg/schema/tracker.go +++ b/dm/pkg/schema/tracker.go @@ -35,7 +35,6 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" - "github.com/pingcap/tidb/types" "go.uber.org/zap" tcontext "github.com/pingcap/tiflow/dm/pkg/context" @@ -44,6 +43,7 @@ import ( dmterror "github.com/pingcap/tiflow/dm/pkg/terror" "github.com/pingcap/tiflow/dm/pkg/utils" "github.com/pingcap/tiflow/dm/syncer/dbconn" + "github.com/pingcap/tiflow/pkg/sqlmodel" ) const ( @@ -78,9 +78,8 @@ type downstreamTracker struct { // DownstreamTableInfo contains tableinfo and index cache. type DownstreamTableInfo struct { - TableInfo *model.TableInfo // tableInfo which comes from parse create statement syntaxtree - AbsoluteUKIndexInfo *model.IndexInfo // absolute uk index is a pk/uk(not null) - AvailableUKIndexList []*model.IndexInfo // index list which is all uks + TableInfo *model.TableInfo // tableInfo which comes from parse create statement syntaxtree + WhereHandle *sqlmodel.WhereHandle } // NewTracker creates a new tracker. `sessionCfg` will be set as tracker's session variables if specified, or retrieve @@ -430,7 +429,7 @@ func (tr *Tracker) GetSystemVar(name string) (string, bool) { // GetDownStreamTableInfo gets downstream table info. // note. this function will init downstreamTrack's table info. -func (tr *Tracker) GetDownStreamTableInfo(tctx *tcontext.Context, tableID string, originTi *model.TableInfo) (*DownstreamTableInfo, error) { +func (tr *Tracker) GetDownStreamTableInfo(tctx *tcontext.Context, tableID string, originTI *model.TableInfo) (*DownstreamTableInfo, error) { dti, ok := tr.dsTracker.tableInfos[tableID] if !ok { tctx.Logger.Info("Downstream schema tracker init. ", zap.String("tableID", tableID)) @@ -440,39 +439,15 @@ func (tr *Tracker) GetDownStreamTableInfo(tctx *tcontext.Context, tableID string return nil, err } - dti = GetDownStreamTI(downstreamTI, originTi) + dti = &DownstreamTableInfo{ + TableInfo: downstreamTI, + WhereHandle: sqlmodel.GetWhereHandle(originTI, downstreamTI), + } tr.dsTracker.tableInfos[tableID] = dti } return dti, nil } -// GetAvailableDownStreamUKIndexInfo gets available downstream UK whose data is not null. -// note. this function will not init downstreamTrack. -func (tr *Tracker) GetAvailableDownStreamUKIndexInfo(tableID string, data []interface{}) *model.IndexInfo { - dti := tr.dsTracker.tableInfos[tableID] - - return GetIdentityUKByData(dti, data) -} - -// GetIdentityUKByData gets available downstream UK whose data is not null. -func GetIdentityUKByData(downstreamTI *DownstreamTableInfo, data []interface{}) *model.IndexInfo { - if downstreamTI == nil || len(downstreamTI.AvailableUKIndexList) == 0 { - return nil - } - // func for check data is not null - fn := func(i int) bool { - return data[i] != nil - } - - for _, uk := range downstreamTI.AvailableUKIndexList { - // check uk's column data is not null - if isSpecifiedIndexColumn(uk, fn) { - return uk - } - } - return nil -} - // RemoveDownstreamSchema just remove schema or table in downstreamTrack. func (tr *Tracker) RemoveDownstreamSchema(tctx *tcontext.Context, targetTables []*filter.Table) { if len(targetTables) == 0 { @@ -541,119 +516,3 @@ func (tr *Tracker) initDownStreamSQLModeAndParser(tctx *tcontext.Context) error tr.dsTracker.stmtParser = stmtParser return nil } - -// GetDownStreamTI constructs downstreamTable index cache by tableinfo. -func GetDownStreamTI(downstreamTI *model.TableInfo, originTi *model.TableInfo) *DownstreamTableInfo { - var ( - absoluteUKIndexInfo *model.IndexInfo - availableUKIndexList = []*model.IndexInfo{} - hasPk = false - absoluteUKPosition = -1 - ) - - // func for check not null constraint - fn := func(i int) bool { - return mysql.HasNotNullFlag(downstreamTI.Columns[i].Flag) - } - - for i, idx := range downstreamTI.Indices { - if !idx.Primary && !idx.Unique { - continue - } - indexRedirect := redirectIndexKeys(idx, originTi) - if indexRedirect == nil { - continue - } - availableUKIndexList = append(availableUKIndexList, indexRedirect) - if idx.Primary { - absoluteUKIndexInfo = indexRedirect - absoluteUKPosition = i - hasPk = true - } else if absoluteUKIndexInfo == nil && isSpecifiedIndexColumn(idx, fn) { - // second check not null unique key - absoluteUKIndexInfo = indexRedirect - absoluteUKPosition = i - } - } - - // handle pk exceptional case. - // e.g. "create table t(a int primary key, b int)". - if !hasPk { - exPk := redirectIndexKeys(handlePkExCase(downstreamTI), originTi) - if exPk != nil { - absoluteUKIndexInfo = exPk - absoluteUKPosition = len(availableUKIndexList) - availableUKIndexList = append(availableUKIndexList, absoluteUKIndexInfo) - } - } - - // move absoluteUKIndexInfo to the first in availableUKIndexList - if absoluteUKPosition != -1 && len(availableUKIndexList) > 1 { - availableUKIndexList[0], availableUKIndexList[absoluteUKPosition] = availableUKIndexList[absoluteUKPosition], availableUKIndexList[0] - } - - return &DownstreamTableInfo{ - TableInfo: downstreamTI, - AbsoluteUKIndexInfo: absoluteUKIndexInfo, - AvailableUKIndexList: availableUKIndexList, - } -} - -// redirectIndexKeys redirect index's columns offset in origin tableinfo. -func redirectIndexKeys(index *model.IndexInfo, originTi *model.TableInfo) *model.IndexInfo { - if index == nil || originTi == nil { - return nil - } - - columns := make([]*model.IndexColumn, 0, len(index.Columns)) - for _, key := range index.Columns { - originColumn := model.FindColumnInfo(originTi.Columns, key.Name.L) - if originColumn == nil { - return nil - } - column := &model.IndexColumn{ - Name: key.Name, - Offset: originColumn.Offset, - Length: key.Length, - } - columns = append(columns, column) - } - return &model.IndexInfo{ - Table: index.Table, - Unique: index.Unique, - Primary: index.Primary, - State: index.State, - Tp: index.Tp, - Columns: columns, - } -} - -// handlePkExCase is handle pk exceptional case. -// e.g. "create table t(a int primary key, b int)". -func handlePkExCase(ti *model.TableInfo) *model.IndexInfo { - if pk := ti.GetPkColInfo(); pk != nil { - return &model.IndexInfo{ - Table: ti.Name, - Unique: true, - Primary: true, - State: model.StatePublic, - Tp: model.IndexTypeBtree, - Columns: []*model.IndexColumn{{ - Name: pk.Name, - Offset: pk.Offset, - Length: types.UnspecifiedLength, - }}, - } - } - return nil -} - -// isSpecifiedIndexColumn checks all of index's columns are matching 'fn'. -func isSpecifiedIndexColumn(index *model.IndexInfo, fn func(i int) bool) bool { - for _, col := range index.Columns { - if !fn(col.Offset) { - return false - } - } - return true -} diff --git a/dm/pkg/schema/tracker_test.go b/dm/pkg/schema/tracker_test.go index dc2d4ad43b9..5df9e70971d 100644 --- a/dm/pkg/schema/tracker_test.go +++ b/dm/pkg/schema/tracker_test.go @@ -768,248 +768,13 @@ func (s *trackerSuite) TestGetDownStreamIndexInfo(c *C) { tableID := "`test`.`test`" - // downstream has no pk/uk - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int, b int, c varchar(10))")) - dti, err := tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has pk(not constraints like "create table t(a int primary key,b int not null)" - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int, b int, c varchar(10), PRIMARY KEY (c))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo, NotNil) - delete(tracker.dsTracker.tableInfos, tableID) - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("test", "create table t(a int primary key, b int, c varchar(10))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo, NotNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has composite pks - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int, b int, c varchar(10), PRIMARY KEY (a,b))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo, NotNil) - c.Assert(len(dti.AbsoluteUKIndexInfo.Columns) == 2, IsTrue) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has uk(not null) - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int unique not null, b int, c varchar(10))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo.Columns, NotNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has uk(without not null) - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int unique, b int, c varchar(10))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - c.Assert(dti.AvailableUKIndexList, NotNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has uks - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int unique, b int unique, c varchar(10) unique not null)")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo, NotNil) - c.Assert(len(dti.AvailableUKIndexList) == 3, IsTrue) - c.Assert(dti.AvailableUKIndexList[0] == dti.AbsoluteUKIndexInfo, IsTrue) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has pk and uk, pk has priority - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int unique not null , b int, c varchar(10), PRIMARY KEY (c))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo.Primary, IsTrue) - c.Assert(len(dti.AvailableUKIndexList) == 2, IsTrue) - c.Assert(dti.AvailableUKIndexList[0] == dti.AbsoluteUKIndexInfo, IsTrue) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has more columns than upstream, and that column in used in PK - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int , d int PRIMARY KEY, c varchar(10), b int unique not null)")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo.Primary, IsFalse) - c.Assert(len(dti.AvailableUKIndexList) == 1, IsTrue) - delete(tracker.dsTracker.tableInfos, tableID) - - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int , d int PRIMARY KEY, c varchar(10), b int unique)")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - c.Assert(len(dti.AvailableUKIndexList) == 1, IsTrue) - delete(tracker.dsTracker.tableInfos, tableID) - - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int , d int PRIMARY KEY, c varchar(10), b int)")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has more columns than upstream, and that column in used in UK(not null) - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int , d int unique not null, c varchar(10), b int unique not null)")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, NotNil) - c.Assert(dti.AbsoluteUKIndexInfo.Columns[0].Name.L == "b", IsTrue) - delete(tracker.dsTracker.tableInfos, tableID) - - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int , d int unique not null, c varchar(10), b int unique)")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - c.Assert(len(dti.AvailableUKIndexList) == 1, IsTrue) - delete(tracker.dsTracker.tableInfos, tableID) - - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int , d int unique not null, c varchar(10), b int)")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - delete(tracker.dsTracker.tableInfos, tableID) -} - -func (s *trackerSuite) TestGetAvailableDownStreamUKIIndexInfo(c *C) { - log.SetLevel(zapcore.ErrorLevel) - - // origin table info - p := parser.New() - se := timock.NewContext() - node, err := p.ParseOneStmt("create table t(a int, b int, c varchar(10))", "utf8mb4", "utf8mb4_bin") - c.Assert(err, IsNil) - oriTi, err := ddl.MockTableInfo(se, node.(*ast.CreateTableStmt), 1) - c.Assert(err, IsNil) - - // tracker and sqlmock - db, mock, err := sqlmock.New() - c.Assert(err, IsNil) - defer db.Close() - con, err := db.Conn(context.Background()) - c.Assert(err, IsNil) - baseConn := conn.NewBaseConn(con, nil) - dbConn := &dbconn.DBConn{Cfg: s.cfg, BaseConn: baseConn} - tracker, err := NewTracker(context.Background(), "test-tracker", defaultTestSessionCfg, dbConn) - c.Assert(err, IsNil) - defer func() { - err = tracker.Close() - c.Assert(err, IsNil) - }() - - mock.ExpectBegin() - mock.ExpectExec(fmt.Sprintf("SET SESSION SQL_MODE = '%s'", mysql.DefaultSQLMode)).WillReturnResult(sqlmock.NewResult(0, 0)) - mock.ExpectCommit() - - tableID := "`test`.`test`" - - // downstream has no uk - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int, b int, c varchar(10))")) dti, err := tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - data := []interface{}{1, 2, 3} - indexinfo := tracker.GetAvailableDownStreamUKIndexInfo(tableID, data) - c.Assert(indexinfo, IsNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has uk but data is null - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int unique, b int, c varchar(10))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - data = []interface{}{nil, 2, 3} - indexinfo = tracker.GetAvailableDownStreamUKIndexInfo(tableID, data) - c.Assert(indexinfo, IsNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has uk and data is not null - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int unique, b int, c varchar(10))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - data = []interface{}{1, 2, 3} - indexinfo = tracker.GetAvailableDownStreamUKIndexInfo(tableID, data) - c.Assert(indexinfo, NotNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has union uk but data has null - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int, b int, c varchar(10), unique key(a, b))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - data = []interface{}{1, nil, 3} - indexinfo = tracker.GetAvailableDownStreamUKIndexInfo(tableID, data) - c.Assert(indexinfo, IsNil) - delete(tracker.dsTracker.tableInfos, tableID) - - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int, b int, c varchar(10), unique key(a, b))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - data = []interface{}{1, nil, nil} - indexinfo = tracker.GetAvailableDownStreamUKIndexInfo(tableID, data) - c.Assert(indexinfo, IsNil) - delete(tracker.dsTracker.tableInfos, tableID) - - // downstream has union uk but data has null - mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows( - sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("test", "create table t(a int, b int, c varchar(10), unique key(a, b))")) - dti, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi) - c.Assert(err, IsNil) - c.Assert(dti.AbsoluteUKIndexInfo, IsNil) - data = []interface{}{1, 2, 3} - indexinfo = tracker.GetAvailableDownStreamUKIndexInfo(tableID, data) - c.Assert(indexinfo, NotNil) + c.Assert(dti, NotNil) + c.Assert(dti.WhereHandle.UniqueNotNullIdx, NotNil) delete(tracker.dsTracker.tableInfos, tableID) } diff --git a/dm/syncer/dml.go b/dm/syncer/dml.go index 7bf8c16a3c1..9eb94b0bbe5 100644 --- a/dm/syncer/dml.go +++ b/dm/syncer/dml.go @@ -129,7 +129,7 @@ RowLoop: downstreamTableInfo.TableInfo, s.sessCtx, ) - rowChange.SetIdentifyInfo(downstreamTableInfo) + rowChange.SetWhereHandle(downstreamTableInfo.WhereHandle) dmls = append(dmls, rowChange) } @@ -203,7 +203,7 @@ RowLoop: downstreamTableInfo.TableInfo, s.sessCtx, ) - rowChange.SetIdentifyInfo(downstreamTableInfo) + rowChange.SetWhereHandle(downstreamTableInfo.WhereHandle) dmls = append(dmls, rowChange) } @@ -257,7 +257,7 @@ RowLoop: downstreamTableInfo.TableInfo, s.sessCtx, ) - rowChange.SetIdentifyInfo(downstreamTableInfo) + rowChange.SetWhereHandle(downstreamTableInfo.WhereHandle) dmls = append(dmls, rowChange) } diff --git a/pkg/sqlmodel/causality.go b/pkg/sqlmodel/causality.go index cc8d8f4572b..6d79db670a5 100644 --- a/pkg/sqlmodel/causality.go +++ b/pkg/sqlmodel/causality.go @@ -30,7 +30,7 @@ import ( // CausalityKeys returns all string representation of causality keys. If two row // changes has the same causality keys, they must be replicated sequentially. func (r *RowChange) CausalityKeys() []string { - r.lazyInitIdentityInfo() + r.lazyInitWhereHandle() ret := make([]string, 0, 1) if r.preValues != nil { @@ -136,7 +136,7 @@ func truncateIndexValues( } func (r *RowChange) getCausalityString(values []interface{}) []string { - pkAndUks := r.identityInfo.AvailableUKIndexList + pkAndUks := r.whereHandle.UniqueIdxs if len(pkAndUks) == 0 { // the table has no PK/UK, all values of the row consists the causality key return []string{genKeyString(r.sourceTable.String(), r.sourceTableInfo.Columns, values)} diff --git a/pkg/sqlmodel/causality_test.go b/pkg/sqlmodel/causality_test.go index 998239168df..3c983bef6e8 100644 --- a/pkg/sqlmodel/causality_test.go +++ b/pkg/sqlmodel/causality_test.go @@ -38,13 +38,13 @@ func TestCausalityKeys(t *testing.T) { "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT, c3 VARCHAR(10) UNIQUE)", []interface{}{1, 2, "abc"}, []interface{}{3, 4, "abc"}, - []string{"1.c.db.tb1", "abc.c3.db.tb1", "3.c.db.tb1", "abc.c3.db.tb1"}, + []string{"abc.c3.db.tb1", "1.c.db.tb1", "abc.c3.db.tb1", "3.c.db.tb1"}, }, { "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT, c3 VARCHAR(10), UNIQUE INDEX(c3(1)))", []interface{}{1, 2, "abc"}, []interface{}{3, 4, "adef"}, - []string{"1.c.db.tb1", "a.c3.db.tb1", "3.c.db.tb1", "a.c3.db.tb1"}, + []string{"a.c3.db.tb1", "1.c.db.tb1", "a.c3.db.tb1", "3.c.db.tb1"}, }, // test not string key @@ -148,13 +148,13 @@ func TestGetCausalityString(t *testing.T) { // multiple keys with primary key schema: `create table t6(a int primary key, b varchar(16) unique)`, values: []interface{}{16, "xyz"}, - keys: []string{"16.a.db.tbl", "xyz.b.db.tbl"}, + keys: []string{"xyz.b.db.tbl", "16.a.db.tbl"}, }, { // non-integer primary key schema: `create table t65(a int unique, b varchar(16) primary key)`, values: []interface{}{16, "xyz"}, - keys: []string{"xyz.b.db.tbl", "16.a.db.tbl"}, + keys: []string{"16.a.db.tbl", "xyz.b.db.tbl"}, }, { // primary key of multiple columns @@ -199,7 +199,7 @@ func TestGetCausalityString(t *testing.T) { for _, ca := range testCases { ti := mockTableInfo(t, ca.schema) change := NewRowChange(source, nil, nil, ca.values, ti, nil, nil) - change.lazyInitIdentityInfo() + change.lazyInitWhereHandle() require.Equal(t, ca.keys, change.getCausalityString(ca.values)) } } diff --git a/pkg/sqlmodel/reduce.go b/pkg/sqlmodel/reduce.go index 7a0aeafd617..e32f2471c78 100644 --- a/pkg/sqlmodel/reduce.go +++ b/pkg/sqlmodel/reduce.go @@ -25,9 +25,9 @@ import ( // HasNotNullUniqueIdx returns true when the target table structure has PK or UK // whose columns are all NOT NULL. func (r *RowChange) HasNotNullUniqueIdx() bool { - r.lazyInitIdentityInfo() + r.lazyInitWhereHandle() - return r.identityInfo.AbsoluteUKIndexInfo != nil + return r.whereHandle.UniqueNotNullIdx != nil } // IdentityValues returns the two group of values that can be used to identify @@ -37,9 +37,9 @@ func (r *RowChange) HasNotNullUniqueIdx() bool { // We always use same index for same table structure to get IdentityValues. // two groups returned are from preValues and postValues. func (r *RowChange) IdentityValues() ([]interface{}, []interface{}) { - r.lazyInitIdentityInfo() + r.lazyInitWhereHandle() - indexInfo := r.identityInfo.AbsoluteUKIndexInfo + indexInfo := r.whereHandle.UniqueNotNullIdx if indexInfo == nil { return r.preValues, r.postValues } @@ -64,7 +64,7 @@ func (r *RowChange) IsIdentityUpdated() bool { return false } - r.lazyInitIdentityInfo() + r.lazyInitWhereHandle() pre, post := r.IdentityValues() if len(pre) != len(post) { // should not happen @@ -138,7 +138,7 @@ func (r *RowChange) SplitUpdate() (*RowChange, *RowChange) { targetTableInfo: r.targetTableInfo, tiSessionCtx: r.tiSessionCtx, tp: RowChangeDelete, - identityInfo: r.identityInfo, + whereHandle: r.whereHandle, } post := &RowChange{ sourceTable: r.sourceTable, @@ -148,7 +148,7 @@ func (r *RowChange) SplitUpdate() (*RowChange, *RowChange) { targetTableInfo: r.targetTableInfo, tiSessionCtx: r.tiSessionCtx, tp: RowChangeInsert, - identityInfo: r.identityInfo, + whereHandle: r.whereHandle, } return pre, post diff --git a/pkg/sqlmodel/reduce_test.go b/pkg/sqlmodel/reduce_test.go index 8592b91b48f..fff3206dede 100644 --- a/pkg/sqlmodel/reduce_test.go +++ b/pkg/sqlmodel/reduce_test.go @@ -115,7 +115,7 @@ func (s *dpanicSuite) TestReduce() { change1 := NewRowChange(source, nil, c.pre1, c.post1, sourceTI, nil, nil) change2 := NewRowChange(source, nil, c.pre2, c.post2, sourceTI, nil, nil) changeAfter := NewRowChange(source, nil, c.preAfter, c.postAfter, sourceTI, nil, nil) - changeAfter.lazyInitIdentityInfo() + changeAfter.lazyInitWhereHandle() change2.Reduce(change1) s.Equal(changeAfter, change2) diff --git a/pkg/sqlmodel/row_change.go b/pkg/sqlmodel/row_change.go index b13d6317736..424a6de086a 100644 --- a/pkg/sqlmodel/row_change.go +++ b/pkg/sqlmodel/row_change.go @@ -24,7 +24,6 @@ import ( cdcmodel "github.com/pingcap/tiflow/cdc/model" "github.com/pingcap/tiflow/dm/pkg/log" - "github.com/pingcap/tiflow/dm/pkg/schema" "github.com/pingcap/tiflow/dm/pkg/utils" "github.com/pingcap/tiflow/pkg/quotes" ) @@ -69,8 +68,8 @@ type RowChange struct { tiSessionCtx sessionctx.Context - tp RowChangeType - identityInfo *schema.DownstreamTableInfo + tp RowChangeType + whereHandle *WhereHandle } // NewRowChange creates a new RowChange. @@ -175,32 +174,28 @@ func (r *RowChange) TargetTableID() string { return r.targetTable.QuoteString() } -// SetIdentifyInfo can be used when caller has calculated and cached -// identityInfo, to avoid every RowChange lazily initialize it. -func (r *RowChange) SetIdentifyInfo(info *schema.DownstreamTableInfo) { - r.identityInfo = info +// SetWhereHandle can be used when caller has cached whereHandle, to avoid every +// RowChange lazily initialize it. +func (r *RowChange) SetWhereHandle(whereHandle *WhereHandle) { + r.whereHandle = whereHandle } -func (r *RowChange) lazyInitIdentityInfo() { - if r.identityInfo != nil { +func (r *RowChange) lazyInitWhereHandle() { + if r.whereHandle != nil { return } - // TODO: move below function into this package - r.identityInfo = schema.GetDownStreamTI(r.targetTableInfo, r.sourceTableInfo) + r.whereHandle = GetWhereHandle(r.sourceTableInfo, r.targetTableInfo) } // whereColumnsAndValues returns columns and values to identify the row, to form // the WHERE clause. func (r *RowChange) whereColumnsAndValues() ([]string, []interface{}) { - r.lazyInitIdentityInfo() - - uniqueIndex := r.identityInfo.AbsoluteUKIndexInfo - if uniqueIndex == nil { - uniqueIndex = schema.GetIdentityUKByData(r.identityInfo, r.preValues) - } + r.lazyInitWhereHandle() columns, values := r.sourceTableInfo.Columns, r.preValues + + uniqueIndex := r.whereHandle.getWhereIdxByData(r.preValues) if uniqueIndex != nil { columns, values = getColsAndValuesOfIdx(r.sourceTableInfo.Columns, uniqueIndex, values) } diff --git a/pkg/sqlmodel/row_change_test.go b/pkg/sqlmodel/row_change_test.go index 733c540cecd..eda1f42470a 100644 --- a/pkg/sqlmodel/row_change_test.go +++ b/pkg/sqlmodel/row_change_test.go @@ -72,21 +72,21 @@ func TestNewRowChange(t *testing.T) { targetTableInfo: targetTI, tiSessionCtx: tiSession, tp: RowChangeUpdate, - identityInfo: nil, + whereHandle: nil, } actual := NewRowChange(source, target, []interface{}{1, 2}, []interface{}{1, 3}, sourceTI, targetTI, tiSession) require.Equal(t, expected, actual) - actual.lazyInitIdentityInfo() - require.NotNil(t, actual.identityInfo) + actual.lazyInitWhereHandle() + require.NotNil(t, actual.whereHandle) // test some arguments of NewRowChange can be nil expected.targetTable = expected.sourceTable expected.targetTableInfo = expected.sourceTableInfo expected.tiSessionCtx = utils.ZeroSessionCtx - expected.identityInfo = nil + expected.whereHandle = nil actual = NewRowChange(source, nil, []interface{}{1, 2}, []interface{}{1, 3}, sourceTI, nil, nil) require.Equal(t, expected, actual) } diff --git a/pkg/sqlmodel/where_handle.go b/pkg/sqlmodel/where_handle.go new file mode 100644 index 00000000000..a4f36989304 --- /dev/null +++ b/pkg/sqlmodel/where_handle.go @@ -0,0 +1,154 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "github.com/pingcap/log" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/types" +) + +// WhereHandle is used to generate a WHERE clause in SQL. +type WhereHandle struct { + UniqueNotNullIdx *model.IndexInfo + // If the index and columns have no NOT NULL constraint, but all data is NOT + // NULL, we can still use it. + // every index that is UNIQUE should be added to UniqueIdxs, even for + // PK and NOT NULL. + UniqueIdxs []*model.IndexInfo +} + +// GetWhereHandle calculates a WhereHandle by source/target TableInfo's indices, +// columns and state. Other component can cache the result. +func GetWhereHandle(source, target *model.TableInfo) *WhereHandle { + ret := WhereHandle{} + indices := make([]*model.IndexInfo, 0, len(target.Indices)+1) + indices = append(indices, target.Indices...) + if idx := getPKIsHandleIdx(target); target.PKIsHandle && idx != nil { + indices = append(indices, idx) + } + + for _, idx := range indices { + if !idx.Unique { + continue + } + // when the tableInfo is from CDC, it may contain some index that is + // creating. + if idx.State != model.StatePublic { + continue + } + + rewritten := rewriteColsOffset(idx, source) + if rewritten == nil { + continue + } + ret.UniqueIdxs = append(ret.UniqueIdxs, rewritten) + + if rewritten.Primary { + // PK is prior to UNIQUE NOT NULL for better performance + ret.UniqueNotNullIdx = rewritten + continue + } + // use downstream columns to check NOT NULL constraint + if ret.UniqueNotNullIdx == nil && allColsNotNull(idx, target.Columns) { + ret.UniqueNotNullIdx = rewritten + continue + } + } + return &ret +} + +// rewriteColsOffset rewrites index columns offset to those from source table. +// Returns nil when any column does not represent in source. +func rewriteColsOffset(index *model.IndexInfo, source *model.TableInfo) *model.IndexInfo { + if index == nil || source == nil { + return nil + } + + columns := make([]*model.IndexColumn, 0, len(index.Columns)) + for _, key := range index.Columns { + sourceColumn := model.FindColumnInfo(source.Columns, key.Name.L) + if sourceColumn == nil { + return nil + } + column := &model.IndexColumn{ + Name: key.Name, + Offset: sourceColumn.Offset, + Length: key.Length, + } + columns = append(columns, column) + } + clone := *index + clone.Columns = columns + return &clone +} + +func getPKIsHandleIdx(ti *model.TableInfo) *model.IndexInfo { + if pk := ti.GetPkColInfo(); pk != nil { + return &model.IndexInfo{ + Table: ti.Name, + Unique: true, + Primary: true, + State: model.StatePublic, + Tp: model.IndexTypeBtree, + Columns: []*model.IndexColumn{{ + Name: pk.Name, + Offset: pk.Offset, + Length: types.UnspecifiedLength, + }}, + } + } + return nil +} + +func allColsNotNull(idx *model.IndexInfo, cols []*model.ColumnInfo) bool { + for _, idxCol := range idx.Columns { + col := cols[idxCol.Offset] + if !mysql.HasNotNullFlag(col.Flag) { + return false + } + } + return true +} + +// getWhereIdxByData returns the index that is identical to a row change, it +// may be +// - a PK, or +// - an UNIQUE index whose columns are all NOT NULL, or +// - an UNIQUE index and the data are all NOT NULL. +// For the last case, last used index is swapped to front. +func (h *WhereHandle) getWhereIdxByData(data []interface{}) *model.IndexInfo { + if h == nil { + log.L().DPanic("WhereHandle is nil") + return nil + } + if h.UniqueNotNullIdx != nil { + return h.UniqueNotNullIdx + } + for i, idx := range h.UniqueIdxs { + ok := true + for _, idxCol := range idx.Columns { + if data[idxCol.Offset] == nil { + ok = false + break + } + } + if ok { + h.UniqueIdxs[0], h.UniqueIdxs[i] = h.UniqueIdxs[i], h.UniqueIdxs[0] + return idx + } + } + return nil +} diff --git a/pkg/sqlmodel/where_handle_test.go b/pkg/sqlmodel/where_handle_test.go new file mode 100644 index 00000000000..ace6668337c --- /dev/null +++ b/pkg/sqlmodel/where_handle_test.go @@ -0,0 +1,216 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "testing" + + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" + "github.com/stretchr/testify/require" +) + +func TestGenWhereHandle(t *testing.T) { + t.Parallel() + + // 1. target is same as source + + createSQL := ` +CREATE TABLE t ( + c INT, c2 INT NOT NULL, c3 VARCHAR(20) NOT NULL, + UNIQUE INDEX idx3 (c2, c3) +)` + p := parser.New() + node, err := p.ParseOneStmt(createSQL, "", "") + require.NoError(t, err) + ti, err := ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt)) + require.NoError(t, err) + require.Len(t, ti.Indices, 1) + idx := ti.Indices[0] + rewritten := rewriteColsOffset(idx, ti) + require.Equal(t, idx, rewritten) + + // check GetWhereHandle when target is same as source + handle := GetWhereHandle(ti, ti) + require.Len(t, handle.UniqueNotNullIdx.Columns, 2) + require.Equal(t, handle.UniqueNotNullIdx.Columns[0].Offset, 1) + require.Equal(t, handle.UniqueNotNullIdx.Columns[1].Offset, 2) + require.Len(t, handle.UniqueIdxs, 1) + + // 2. target has more columns, some index doesn't use it + + targetCreateSQL := ` +CREATE TABLE t ( + pk INT PRIMARY KEY, c INT, c2 INT NOT NULL, c3 VARCHAR(20) NOT NULL, extra INT, + UNIQUE INDEX idx2 (c2, c3), + UNIQUE INDEX idx3 (extra) +)` + node, err = p.ParseOneStmt(targetCreateSQL, "", "") + require.NoError(t, err) + targetTI, err := ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt)) + require.NoError(t, err) + require.Len(t, targetTI.Indices, 2) + targetIdx := targetTI.Indices[0] + require.Len(t, targetIdx.Columns, 2) + require.Equal(t, targetIdx.Columns[0].Offset, 2) + require.Equal(t, targetIdx.Columns[1].Offset, 3) + + rewritten = rewriteColsOffset(targetIdx, ti) + require.Len(t, rewritten.Columns, 2) + require.Equal(t, rewritten.Columns[0].Offset, 1) + require.Equal(t, rewritten.Columns[1].Offset, 2) + + // target has more columns, some index uses it + targetIdx = targetTI.Indices[1] + require.Len(t, targetIdx.Columns, 1) + require.Equal(t, targetIdx.Columns[0].Offset, 4) + + rewritten = rewriteColsOffset(targetIdx, ti) + require.Nil(t, rewritten) + + // check GetWhereHandle when target has more columns + handle = GetWhereHandle(ti, targetTI) + require.Len(t, handle.UniqueNotNullIdx.Columns, 2) + require.Equal(t, handle.UniqueNotNullIdx.Columns[0].Offset, 1) + require.Equal(t, handle.UniqueNotNullIdx.Columns[1].Offset, 2) + // PRIMARY and idx3 is not usable + require.Len(t, handle.UniqueIdxs, 1) + + // 3. PKIsHandle case + + targetCreateSQL = ` +CREATE TABLE t ( + extra INT, c INT PRIMARY KEY +)` + node, err = p.ParseOneStmt(targetCreateSQL, "", "") + require.NoError(t, err) + targetTI, err = ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt)) + require.NoError(t, err) + // PKIsHandle has no entry in Indices + require.Len(t, targetTI.Indices, 0) + + handle = GetWhereHandle(ti, targetTI) + require.Len(t, handle.UniqueNotNullIdx.Columns, 1) + require.Equal(t, handle.UniqueNotNullIdx.Columns[0].Offset, 0) + require.Len(t, handle.UniqueIdxs, 1) + + // 4. target has no available index + + targetCreateSQL = ` +CREATE TABLE t ( + extra INT PRIMARY KEY +)` + node, err = p.ParseOneStmt(targetCreateSQL, "", "") + require.NoError(t, err) + targetTI, err = ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt)) + require.NoError(t, err) + // PKIsHandle has no entry in Indices + require.Len(t, targetTI.Indices, 0) + + handle = GetWhereHandle(ti, targetTI) + require.Nil(t, handle.UniqueNotNullIdx) + require.Len(t, handle.UniqueIdxs, 0) + + // 5. composite PK, and PK has higher priority + + targetCreateSQL = ` +CREATE TABLE t ( + extra INT, c INT NOT NULL, c2 INT NOT NULL, c3 VARCHAR(20) NOT NULL, + UNIQUE INDEX idx (c, c3), + PRIMARY KEY (c, c2), + UNIQUE INDEX idx3 (c2, c3) +)` + node, err = p.ParseOneStmt(targetCreateSQL, "", "") + require.NoError(t, err) + targetTI, err = ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt)) + require.NoError(t, err) + + handle = GetWhereHandle(ti, targetTI) + require.Len(t, handle.UniqueNotNullIdx.Columns, 2) + require.Equal(t, handle.UniqueNotNullIdx.Columns[0].Offset, 0) + require.Equal(t, handle.UniqueNotNullIdx.Columns[1].Offset, 1) + require.Len(t, handle.UniqueIdxs, 3) +} + +func TestAllColsNotNull(t *testing.T) { + t.Parallel() + + createSQL := ` +CREATE TABLE t ( + pk VARCHAR(20) PRIMARY KEY, + c1 INT, + c2 INT, + c3 INT NOT NULL, + c4 INT NOT NULL, + INDEX idx1 (c1, c2), + INDEX idx2 (c2, c3), + INDEX idx3 (c3, c4) +)` + p := parser.New() + node, err := p.ParseOneStmt(createSQL, "", "") + require.NoError(t, err) + ti, err := ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt)) + require.NoError(t, err) + require.Len(t, ti.Indices, 4) + + pk := ti.Indices[3] + require.Equal(t, "PRIMARY", pk.Name.O) + require.True(t, allColsNotNull(pk, ti.Columns)) + + idx1 := ti.Indices[0] + require.Equal(t, "idx1", idx1.Name.O) + require.False(t, allColsNotNull(idx1, ti.Columns)) + + idx2 := ti.Indices[1] + require.Equal(t, "idx2", idx2.Name.O) + require.False(t, allColsNotNull(idx2, ti.Columns)) + + idx3 := ti.Indices[2] + require.Equal(t, "idx3", idx3.Name.O) + require.True(t, allColsNotNull(idx3, ti.Columns)) +} + +func TestGetWhereIdxByData(t *testing.T) { + t.Parallel() + + createSQL := ` +CREATE TABLE t ( + c1 INT, + c2 INT, + c3 INT, + c4 INT, + UNIQUE INDEX idx1 (c1, c2), + UNIQUE INDEX idx2 (c3, c4) +)` + p := parser.New() + node, err := p.ParseOneStmt(createSQL, "", "") + require.NoError(t, err) + ti, err := ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt)) + require.NoError(t, err) + + handle := GetWhereHandle(ti, ti) + idx := handle.getWhereIdxByData([]interface{}{nil, 2, 3, 4}) + require.Equal(t, "idx2", idx.Name.L) + require.Equal(t, idx, handle.UniqueIdxs[0]) + + // last used index is moved to front + idx = handle.getWhereIdxByData([]interface{}{1, 2, 3, nil}) + require.Equal(t, "idx1", idx.Name.L) + require.Equal(t, idx, handle.UniqueIdxs[0]) + + // no index available + idx = handle.getWhereIdxByData([]interface{}{1, nil, 3, nil}) + require.Nil(t, idx) +}