diff --git a/pkg/schema/tracker.go b/pkg/schema/tracker.go new file mode 100644 index 0000000000..2a768d99ea --- /dev/null +++ b/pkg/schema/tracker.go @@ -0,0 +1,231 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package schema + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/store/mockstore" +) + +const ( + waitDDLRetryCount = 10 + schemaLeaseTime = 10 * time.Millisecond +) + +// Tracker is used to track schema locally. +type Tracker struct { + store kv.Storage + dom *domain.Domain + se session.Session +} + +// NewTracker creates a new tracker. +func NewTracker() (*Tracker, error) { + store, err := mockstore.NewMockTikvStore() + if err != nil { + return nil, err + } + + // shorten the schema lease, since the time needed to confirm DDL sync is + // proportional to this duration (default = 1 second) + session.SetSchemaLease(schemaLeaseTime) + dom, err := session.BootstrapSession(store) + if err != nil { + return nil, err + } + + se, err := session.CreateSession(store) + if err != nil { + return nil, err + } + + return &Tracker{ + store: store, + dom: dom, + se: se, + }, nil +} + +// Exec runs an SQL (DDL) statement. +func (tr *Tracker) Exec(ctx context.Context, db string, sql string) error { + tr.se.GetSessionVars().CurrentDB = db + _, err := tr.se.Execute(ctx, sql) + return err +} + +// GetTable returns the schema associated with the table. +func (tr *Tracker) GetTable(db, table string) (*model.TableInfo, error) { + t, err := tr.dom.InfoSchema().TableByName(model.NewCIStr(db), model.NewCIStr(table)) + if err != nil { + return nil, err + } + return t.Meta(), nil +} + +// AllSchemas returns all schemas visible to the tracker (excluding system tables). +func (tr *Tracker) AllSchemas() []*model.DBInfo { + allSchemas := tr.dom.InfoSchema().AllSchemas() + filteredSchemas := make([]*model.DBInfo, 0, len(allSchemas)-3) + for _, db := range allSchemas { + switch db.Name.L { + case "mysql", "performance_schema", "information_schema": + default: + filteredSchemas = append(filteredSchemas, db) + } + } + return filteredSchemas +} + +// IsTableNotExists checks if err means the database or table does not exist. +func IsTableNotExists(err error) bool { + return infoschema.ErrTableNotExists.Equal(err) || infoschema.ErrDatabaseDropExists.Equal(err) +} + +// Reset drops all tables inserted into this tracker. +func (tr *Tracker) Reset() error { + allDBs := tr.dom.InfoSchema().AllSchemaNames() + ddl := tr.dom.DDL() + for _, db := range allDBs { + dbName := model.NewCIStr(db) + switch dbName.L { + case "mysql", "performance_schema", "information_schema": + continue + } + if err := ddl.DropSchema(tr.se, dbName); err != nil { + return err + } + } + return nil +} + +// DropTable drops a table from this tracker. +func (tr *Tracker) DropTable(db, table string) error { + return tr.dom.DDL().DropTable(tr.se, ast.Ident{Schema: model.NewCIStr(db), Name: model.NewCIStr(table)}) +} + +// CreateSchemaIfNotExists creates a SCHEMA of the given name if it did not exist. +func (tr *Tracker) CreateSchemaIfNotExists(db string) error { + dbName := model.NewCIStr(db) + if tr.dom.InfoSchema().SchemaExists(dbName) { + return nil + } + return tr.dom.DDL().CreateSchema(tr.se, dbName, nil) +} + +// cloneTableInfo creates a clone of the TableInfo. +func cloneTableInfo(ti *model.TableInfo) *model.TableInfo { + ret := ti.Clone() + ret.Lock = nil + // FIXME pingcap/parser's Clone() doesn't clone Partition yet + if ret.Partition != nil { + pi := *ret.Partition + pi.Definitions = append([]model.PartitionDefinition(nil), ret.Partition.Definitions...) + ret.Partition = &pi + } + return ret +} + +// CreateTableIfNotExists creates a TABLE of the given name if it did not exist. +func (tr *Tracker) CreateTableIfNotExists(db, table string, ti *model.TableInfo) error { + infoSchema := tr.dom.InfoSchema() + dbName := model.NewCIStr(db) + tableName := model.NewCIStr(table) + if infoSchema.TableExists(dbName, tableName) { + return nil + } + + dbInfo, exists := infoSchema.SchemaByName(dbName) + if !exists || dbInfo == nil { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName) + } + + // we need to go through the low-level DDL Job API since we don't have a way + // to recover a CreateTableStmt from a TableInfo yet. + + // First enqueue the DDL job. + var jobID int64 + err := kv.RunInNewTxn(tr.store, true /*retryable*/, func(txn kv.Transaction) error { + // reallocate IDs + idsCount := 2 + if ti.Partition != nil { + idsCount += len(ti.Partition.Definitions) + } + m := meta.NewMeta(txn) + ids, err := m.GenGlobalIDs(idsCount) + if err != nil { + return err + } + + jobID = ids[0] + tableInfo := cloneTableInfo(ti) + tableInfo.ID = ids[1] + tableInfo.Name = tableName + if tableInfo.Partition != nil { + for i := range tableInfo.Partition.Definitions { + tableInfo.Partition.Definitions[i].ID = ids[i+2] + } + } + + return m.EnQueueDDLJob(&model.Job{ + ID: jobID, + Type: model.ActionCreateTable, + SchemaID: dbInfo.ID, + TableID: tableInfo.ID, + SchemaName: dbName.O, + Version: 1, + StartTS: txn.StartTS(), + BinlogInfo: &model.HistoryInfo{}, + Args: []interface{}{tableInfo}, + }) + }) + if err != nil { + return err + } + + // Then wait until the DDL job is synchronized (should take 2 * lease) + lease := tr.dom.DDL().GetLease() * 2 + for i := 0; i < waitDDLRetryCount; i++ { + var job *model.Job + err = kv.RunInNewTxn(tr.store, false /*retryable*/, func(txn kv.Transaction) error { + m := meta.NewMeta(txn) + var e error + job, e = m.GetHistoryDDLJob(jobID) + return e + }) + if err == nil && job != nil { + if job.IsSynced() { + return nil + } + if job.Error != nil { + return job.Error + } + } + time.Sleep(lease) + } + if err == nil { + // reaching here is basically a bug. + return errors.Errorf("Cannot create table %s.%s, the DDL job never returned", db, table) + } + return err +} diff --git a/pkg/schema/tracker_test.go b/pkg/schema/tracker_test.go new file mode 100644 index 0000000000..4256725f11 --- /dev/null +++ b/pkg/schema/tracker_test.go @@ -0,0 +1,190 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package schema_test + +import ( + "context" + "encoding/json" + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/dm/pkg/schema" + "github.com/pingcap/log" + "github.com/pingcap/parser/model" + "go.uber.org/zap/zapcore" +) + +func Test(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&trackerSuite{}) + +type trackerSuite struct{} + +func (s *trackerSuite) TestDDL(c *C) { + log.SetLevel(zapcore.ErrorLevel) + + tracker, err := schema.NewTracker() + c.Assert(err, IsNil) + + // Table shouldn't exist before initialization. + _, err = tracker.GetTable("testdb", "foo") + c.Assert(err, ErrorMatches, `.*Table 'testdb\.foo' doesn't exist`) + c.Assert(schema.IsTableNotExists(err), IsTrue) + + // Now create the table with 3 columns. + ctx := context.Background() + err = tracker.Exec(ctx, "", "create database testdb;") + c.Assert(err, IsNil) + + err = tracker.Exec(ctx, "testdb", "create table foo (a varchar(255) primary key, b varchar(255) as (concat(a, a)), c int)") + c.Assert(err, IsNil) + + // Verify the table has 3 columns. + ti, err := tracker.GetTable("testdb", "foo") + c.Assert(err, IsNil) + c.Assert(ti.Columns, HasLen, 3) + c.Assert(ti.Columns[0].Name.L, Equals, "a") + c.Assert(ti.Columns[0].IsGenerated(), IsFalse) + c.Assert(ti.Columns[1].Name.L, Equals, "b") + c.Assert(ti.Columns[1].IsGenerated(), IsTrue) + c.Assert(ti.Columns[2].Name.L, Equals, "c") + c.Assert(ti.Columns[2].IsGenerated(), IsFalse) + + // Drop one column from the table. + err = tracker.Exec(ctx, "testdb", "alter table foo drop column b") + c.Assert(err, IsNil) + + // Verify that 2 columns remain. + ti2, err := tracker.GetTable("testdb", "foo") + c.Assert(err, IsNil) + c.Assert(ti, Not(Equals), ti2) + c.Assert(ti2.Columns, HasLen, 2) + c.Assert(ti2.Columns[0].Name.L, Equals, "a") + c.Assert(ti2.Columns[0].IsGenerated(), IsFalse) + c.Assert(ti2.Columns[1].Name.L, Equals, "c") + c.Assert(ti2.Columns[1].IsGenerated(), IsFalse) +} + +func (s *trackerSuite) TestCreateSchemaIfNotExists(c *C) { + log.SetLevel(zapcore.ErrorLevel) + + tracker, err := schema.NewTracker() + c.Assert(err, IsNil) + + // We cannot create a table without a database. + ctx := context.Background() + err = tracker.Exec(ctx, "testdb", "create table foo(a int)") + c.Assert(err, ErrorMatches, `.*Unknown database 'testdb'`) + + // We can create the database directly. + err = tracker.CreateSchemaIfNotExists("testdb") + c.Assert(err, IsNil) + + // Creating the same database twice is no-op. + err = tracker.CreateSchemaIfNotExists("testdb") + c.Assert(err, IsNil) + + // Now creating a table should be successful + err = tracker.Exec(ctx, "testdb", "create table foo(a int)") + c.Assert(err, IsNil) + + ti, err := tracker.GetTable("testdb", "foo") + c.Assert(err, IsNil) + c.Assert(ti.Name.L, Equals, "foo") +} + +// clearVolatileInfo removes generated information like TS and ID so DeepEquals +// of two compatible schemas can pass. +func clearVolatileInfo(ti *model.TableInfo) { + ti.ID = 0 + ti.UpdateTS = 0 + if ti.Partition != nil { + for i := range ti.Partition.Definitions { + ti.Partition.Definitions[i].ID = 0 + } + } +} + +// asJson is a convenient wrapper to print a TableInfo in its JSON representation. +type asJson struct{ *model.TableInfo } + +func (aj asJson) String() string { + b, _ := json.Marshal(aj.TableInfo) + return string(b) +} + +func (s *trackerSuite) TestCreateTableIfNotExists(c *C) { + log.SetLevel(zapcore.ErrorLevel) + + tracker, err := schema.NewTracker() + c.Assert(err, IsNil) + + // Create some sort of complicated table. + err = tracker.CreateSchemaIfNotExists("testdb") + c.Assert(err, IsNil) + + ctx := context.Background() + err = tracker.Exec(ctx, "testdb", ` + create table foo( + a int primary key auto_increment, + b int as (c+1) not null, + c int comment 'some cmt', + d text, + key dk(d(255)) + ) comment 'more cmt' partition by range columns (a) ( + partition x41 values less than (41), + partition x82 values less than (82), + partition rest values less than maxvalue comment 'part cmt' + ); + `) + c.Assert(err, IsNil) + + // Save the table info + ti1, err := tracker.GetTable("testdb", "foo") + c.Assert(err, IsNil) + c.Assert(ti1, NotNil) + c.Assert(ti1.Name.O, Equals, "foo") + ti1 = ti1.Clone() + clearVolatileInfo(ti1) + + // Remove the table. Should not be found anymore. + err = tracker.Exec(ctx, "testdb", "drop table foo") + c.Assert(err, IsNil) + + _, err = tracker.GetTable("testdb", "foo") + c.Assert(err, ErrorMatches, `.*Table 'testdb\.foo' doesn't exist`) + + // Recover the table using the table info. + err = tracker.CreateTableIfNotExists("testdb", "foo", ti1) + c.Assert(err, IsNil) + + // The new table info should be equivalent to the old one except the TS and generated IDs. + ti2, err := tracker.GetTable("testdb", "foo") + c.Assert(err, IsNil) + clearVolatileInfo(ti2) + c.Assert(ti2, DeepEquals, ti1, Commentf("ti2 = %s\nti1 = %s", asJson{ti2}, asJson{ti1})) + + // Can use the table info to recover a table using a different name. + err = tracker.CreateTableIfNotExists("testdb", "bar", ti1) + c.Assert(err, IsNil) + + ti3, err := tracker.GetTable("testdb", "bar") + c.Assert(err, IsNil) + c.Assert(ti3.Name.O, Equals, "bar") + clearVolatileInfo(ti3) + ti3.Name = ti1.Name + c.Assert(ti3, DeepEquals, ti1, Commentf("ti3 = %s\nti1 = %s", asJson{ti3}, asJson{ti1})) +} diff --git a/syncer/checkpoint.go b/syncer/checkpoint.go index c4a1d7c0f0..569bee7846 100644 --- a/syncer/checkpoint.go +++ b/syncer/checkpoint.go @@ -14,6 +14,7 @@ package syncer import ( + "encoding/json" "fmt" "path" "sync" @@ -23,10 +24,14 @@ import ( "github.com/pingcap/dm/pkg/conn" tcontext "github.com/pingcap/dm/pkg/context" "github.com/pingcap/dm/pkg/log" + "github.com/pingcap/dm/pkg/schema" "github.com/pingcap/dm/pkg/terror" "github.com/pingcap/dm/pkg/utils" + "github.com/pingcap/tidb-tools/pkg/dbutil" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/parser/model" tmysql "github.com/pingcap/parser/mysql" "github.com/siddontang/go-mysql/mysql" "go.uber.org/zap" @@ -54,18 +59,22 @@ var ( type binlogPoint struct { sync.RWMutex mysql.Position + ti *model.TableInfo flushedPos mysql.Position // pos which flushed permanently + flushedTI *model.TableInfo } -func newBinlogPoint(pos mysql.Position, flushedPos mysql.Position) *binlogPoint { +func newBinlogPoint(pos mysql.Position, ti *model.TableInfo, flushedPos mysql.Position, flushedTI *model.TableInfo) *binlogPoint { return &binlogPoint{ Position: pos, + ti: ti, flushedPos: flushedPos, + flushedTI: flushedTI, } } -func (b *binlogPoint) save(pos mysql.Position) error { +func (b *binlogPoint) save(pos mysql.Position, ti *model.TableInfo) error { b.Lock() defer b.Unlock() if pos.Compare(b.Position) < 0 { @@ -73,6 +82,7 @@ func (b *binlogPoint) save(pos mysql.Position) error { return terror.ErrCheckpointSaveInvalidPos.Generate(pos, b.Position) } b.Position = pos + b.ti = ti return nil } @@ -80,12 +90,17 @@ func (b *binlogPoint) flush() { b.Lock() defer b.Unlock() b.flushedPos = b.Position + b.flushedTI = b.ti } -func (b *binlogPoint) rollback() { +func (b *binlogPoint) rollback() (isSchemaChanged bool) { b.Lock() defer b.Unlock() b.Position = b.flushedPos + if isSchemaChanged = b.ti != b.flushedTI; isSchemaChanged { + b.ti = b.flushedTI + } + return } func (b *binlogPoint) outOfDate() bool { @@ -108,6 +123,13 @@ func (b *binlogPoint) FlushedMySQLPos() mysql.Position { return b.flushedPos } +// TableInfo returns the table schema associated at the current binlog position. +func (b *binlogPoint) TableInfo() *model.TableInfo { + b.RLock() + defer b.RUnlock() + return b.ti +} + func (b *binlogPoint) String() string { b.RLock() defer b.RUnlock() @@ -135,17 +157,20 @@ type CheckPoint interface { Clear() error // Load loads all checkpoints saved by CheckPoint - Load() error + Load(schemaTracker *schema.Tracker) error // LoadMeta loads checkpoints from meta config item or file LoadMeta() error // SaveTablePoint saves checkpoint for specified table in memory - SaveTablePoint(sourceSchema, sourceTable string, pos mysql.Position) + SaveTablePoint(sourceSchema, sourceTable string, pos mysql.Position, ti *model.TableInfo) // DeleteTablePoint deletes checkpoint for specified table in memory and storage DeleteTablePoint(sourceSchema, sourceTable string) error + // DeleteSchemaPoint deletes checkpoint for specified schema + DeleteSchemaPoint(sourceSchema string) error + // IsNewerTablePoint checks whether job's checkpoint is newer than previous saved checkpoint IsNewerTablePoint(sourceSchema, sourceTable string, pos mysql.Position) bool @@ -173,7 +198,7 @@ type CheckPoint interface { CheckGlobalPoint() bool // Rollback rolls global checkpoint and all table checkpoints back to flushed checkpoints - Rollback() + Rollback(schemaTracker *schema.Tracker) // String return text of global position String() string @@ -188,11 +213,10 @@ type RemoteCheckPoint struct { cfg *config.SubTaskConfig - db *conn.BaseDB - dbConn *DBConn - schema string // schema name, set through task config - table string // table name, now it's task name - id string // checkpoint ID, now it is `source-id` + db *conn.BaseDB + dbConn *DBConn + tableName string // qualified table name: schema is set through task config, table is task name + id string // checkpoint ID, now it is `source-id` // source-schema -> source-table -> checkpoint // used to filter the synced binlog when re-syncing for sharding group @@ -212,16 +236,14 @@ type RemoteCheckPoint struct { // NewRemoteCheckPoint creates a new RemoteCheckPoint func NewRemoteCheckPoint(tctx *tcontext.Context, cfg *config.SubTaskConfig, id string) CheckPoint { - newtctx := tctx.WithLogger(tctx.L().WithFields(zap.String("component", "remote checkpoint"))) cp := &RemoteCheckPoint{ cfg: cfg, - schema: cfg.MetaSchema, - table: fmt.Sprintf("%s_syncer_checkpoint", cfg.Name), + tableName: dbutil.TableName(cfg.MetaSchema, cfg.Name+"_syncer_checkpoint"), id: id, points: make(map[string]map[string]*binlogPoint), - globalPoint: newBinlogPoint(minCheckpoint, minCheckpoint), + globalPoint: newBinlogPoint(minCheckpoint, nil, minCheckpoint, nil), tctx: newtctx, } @@ -258,14 +280,16 @@ func (cp *RemoteCheckPoint) Clear() error { defer cp.Unlock() // delete all checkpoints - sql2 := fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE `id` = '%s'", cp.schema, cp.table, cp.id) - args := make([]interface{}, 0) - _, err := cp.dbConn.executeSQL(cp.tctx, []string{sql2}, [][]interface{}{args}...) + _, err := cp.dbConn.executeSQL( + cp.tctx, + []string{`DELETE FROM ` + cp.tableName + ` WHERE id = ?`}, + []interface{}{cp.id}, + ) if err != nil { return err } - cp.globalPoint = newBinlogPoint(minCheckpoint, minCheckpoint) + cp.globalPoint = newBinlogPoint(minCheckpoint, nil, minCheckpoint, nil) cp.points = make(map[string]map[string]*binlogPoint) @@ -273,14 +297,14 @@ func (cp *RemoteCheckPoint) Clear() error { } // SaveTablePoint implements CheckPoint.SaveTablePoint -func (cp *RemoteCheckPoint) SaveTablePoint(sourceSchema, sourceTable string, pos mysql.Position) { +func (cp *RemoteCheckPoint) SaveTablePoint(sourceSchema, sourceTable string, pos mysql.Position, ti *model.TableInfo) { cp.Lock() defer cp.Unlock() - cp.saveTablePoint(sourceSchema, sourceTable, pos) + cp.saveTablePoint(sourceSchema, sourceTable, pos, ti) } // saveTablePoint saves single table's checkpoint without mutex.Lock -func (cp *RemoteCheckPoint) saveTablePoint(sourceSchema, sourceTable string, pos mysql.Position) { +func (cp *RemoteCheckPoint) saveTablePoint(sourceSchema, sourceTable string, pos mysql.Position, ti *model.TableInfo) { if cp.globalPoint.Compare(pos) > 0 { panic(fmt.Sprintf("table checkpoint %+v less than global checkpoint %+v", pos, cp.globalPoint)) } @@ -294,11 +318,9 @@ func (cp *RemoteCheckPoint) saveTablePoint(sourceSchema, sourceTable string, pos } point, ok := mSchema[sourceTable] if !ok { - mSchema[sourceTable] = newBinlogPoint(pos, minCheckpoint) - } else { - if err := point.save(pos); err != nil { - cp.tctx.L().Error("fail to save table point", zap.String("schema", sourceSchema), zap.String("table", sourceTable), log.ShortError(err)) - } + mSchema[sourceTable] = newBinlogPoint(pos, ti, minCheckpoint, nil) + } else if err := point.save(pos, ti); err != nil { + cp.tctx.L().Error("fail to save table point", zap.String("schema", sourceSchema), zap.String("table", sourceTable), log.ShortError(err)) } } @@ -316,10 +338,11 @@ func (cp *RemoteCheckPoint) DeleteTablePoint(sourceSchema, sourceTable string) e } cp.tctx.L().Info("delete table checkpoint", zap.String("schema", sourceSchema), zap.String("table", sourceTable)) - // delete checkpoint - sql2 := fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE `id` = '%s' AND `cp_schema` = '%s' AND `cp_table` = '%s'", cp.schema, cp.table, cp.id, sourceSchema, sourceTable) - args := make([]interface{}, 0) - _, err := cp.dbConn.executeSQL(cp.tctx, []string{sql2}, [][]interface{}{args}...) + _, err := cp.dbConn.executeSQL( + cp.tctx, + []string{`DELETE FROM ` + cp.tableName + ` WHERE id = ? AND cp_schema = ? AND cp_table = ?`}, + []interface{}{cp.id, sourceSchema, sourceTable}, + ) if err != nil { return err } @@ -327,6 +350,28 @@ func (cp *RemoteCheckPoint) DeleteTablePoint(sourceSchema, sourceTable string) e return nil } +// DeleteSchemaPoint implements CheckPoint.DeleteSchemaPoint +func (cp *RemoteCheckPoint) DeleteSchemaPoint(sourceSchema string) error { + cp.Lock() + defer cp.Unlock() + _, ok := cp.points[sourceSchema] + if !ok { + return nil + } + + cp.tctx.L().Info("delete schema checkpoint", zap.String("schema", sourceSchema)) + _, err := cp.dbConn.executeSQL( + cp.tctx, + []string{`DELETE FROM ` + cp.tableName + ` WHERE id = ? AND cp_schema = ?`}, + []interface{}{cp.id, sourceSchema}, + ) + if err != nil { + return err + } + delete(cp.points, sourceSchema) + return nil +} + // IsNewerTablePoint implements CheckPoint.IsNewerTablePoint func (cp *RemoteCheckPoint) IsNewerTablePoint(sourceSchema, sourceTable string, pos mysql.Position) bool { cp.RLock() @@ -349,7 +394,7 @@ func (cp *RemoteCheckPoint) SaveGlobalPoint(pos mysql.Position) { defer cp.Unlock() cp.tctx.L().Debug("save global checkpoint", zap.Stringer("position", pos)) - if err := cp.globalPoint.save(pos); err != nil { + if err := cp.globalPoint.save(pos, nil); err != nil { cp.tctx.L().Error("fail to save global checkpoint", log.ShortError(err)) } } @@ -376,7 +421,7 @@ func (cp *RemoteCheckPoint) FlushPointsExcept(exceptTables [][]string, extraSQLs if cp.globalPoint.outOfDate() { posG := cp.GlobalPoint() - sqlG, argG := cp.genUpdateSQL(globalCpSchema, globalCpTable, posG.Name, posG.Pos, true) + sqlG, argG := cp.genUpdateSQL(globalCpSchema, globalCpTable, posG.Name, posG.Pos, nil, true) sqls = append(sqls, sqlG) args = append(args, argG) } @@ -391,8 +436,13 @@ func (cp *RemoteCheckPoint) FlushPointsExcept(exceptTables [][]string, extraSQLs } } if point.outOfDate() { + tiBytes, err := json.Marshal(point.ti) + if err != nil { + return errors.Annotatef(err, "failed to serialize table info for %s.%s", schema, table) + } + pos := point.MySQLPos() - sql2, arg := cp.genUpdateSQL(schema, table, pos.Name, pos.Pos, false) + sql2, arg := cp.genUpdateSQL(schema, table, pos.Name, pos.Pos, tiBytes, false) sqls = append(sqls, sql2) args = append(args, arg) @@ -442,14 +492,22 @@ func (cp *RemoteCheckPoint) CheckGlobalPoint() bool { } // Rollback implements CheckPoint.Rollback -func (cp *RemoteCheckPoint) Rollback() { +func (cp *RemoteCheckPoint) Rollback(schemaTracker *schema.Tracker) { cp.RLock() defer cp.RUnlock() cp.globalPoint.rollback() for schema, mSchema := range cp.points { for table, point := range mSchema { cp.tctx.L().Info("rollback checkpoint", log.WrapStringerField("checkpoint", point), zap.String("schema", schema), zap.String("table", table)) - point.rollback() + if point.rollback() { + // schema changed + _ = schemaTracker.DropTable(schema, table) + if point.ti != nil { + if err := schemaTracker.CreateTableIfNotExists(schema, table, point.ti); err != nil { + cp.tctx.L().Warn("failed to rollback schema on schema tracker", zap.String("schema", schema), zap.String("table", table), log.ShortError(err)) + } + } + } } } } @@ -466,7 +524,7 @@ func (cp *RemoteCheckPoint) prepare() error { } func (cp *RemoteCheckPoint) createSchema() error { - sql2 := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s`", cp.schema) + sql2 := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s`", cp.cfg.MetaSchema) args := make([]interface{}, 0) _, err := cp.dbConn.executeSQL(cp.tctx, []string{sql2}, [][]interface{}{args}...) cp.tctx.L().Info("create checkpoint schema", zap.String("statement", sql2)) @@ -474,28 +532,29 @@ func (cp *RemoteCheckPoint) createSchema() error { } func (cp *RemoteCheckPoint) createTable() error { - tableName := fmt.Sprintf("`%s`.`%s`", cp.schema, cp.table) - sql2 := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + sqls := []string{ + `CREATE TABLE IF NOT EXISTS ` + cp.tableName + ` ( id VARCHAR(32) NOT NULL, cp_schema VARCHAR(128) NOT NULL, cp_table VARCHAR(128) NOT NULL, binlog_name VARCHAR(128), binlog_pos INT UNSIGNED, + table_info JSON NOT NULL, is_global BOOLEAN, create_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, UNIQUE KEY uk_id_schema_table (id, cp_schema, cp_table) - )`, tableName) - args := make([]interface{}, 0) - _, err := cp.dbConn.executeSQL(cp.tctx, []string{sql2}, [][]interface{}{args}...) - cp.tctx.L().Info("create checkpoint table", zap.String("statement", sql2)) + )`, + } + _, err := cp.dbConn.executeSQL(cp.tctx, sqls) + cp.tctx.L().Info("create checkpoint table", zap.Strings("statements", sqls)) return err } // Load implements CheckPoint.Load -func (cp *RemoteCheckPoint) Load() error { - query := fmt.Sprintf("SELECT `cp_schema`, `cp_table`, `binlog_name`, `binlog_pos`, `is_global` FROM `%s`.`%s` WHERE `id`='%s'", cp.schema, cp.table, cp.id) - rows, err := cp.dbConn.querySQL(cp.tctx, query) +func (cp *RemoteCheckPoint) Load(schemaTracker *schema.Tracker) error { + query := `SELECT cp_schema, cp_table, binlog_name, binlog_pos, table_info, is_global FROM ` + cp.tableName + ` WHERE id = ?` + rows, err := cp.dbConn.querySQL(cp.tctx, query, cp.id) defer func() { if rows != nil { rows.Close() @@ -518,10 +577,11 @@ func (cp *RemoteCheckPoint) Load() error { cpTable string binlogName string binlogPos uint32 + tiBytes []byte isGlobal bool ) for rows.Next() { - err := rows.Scan(&cpSchema, &cpTable, &binlogName, &binlogPos, &isGlobal) + err := rows.Scan(&cpSchema, &cpTable, &binlogName, &binlogPos, &tiBytes, &isGlobal) if err != nil { return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeDownstream) } @@ -531,18 +591,31 @@ func (cp *RemoteCheckPoint) Load() error { } if isGlobal { if pos.Compare(minCheckpoint) > 0 { - cp.globalPoint = newBinlogPoint(pos, pos) + cp.globalPoint = newBinlogPoint(pos, nil, pos, nil) cp.tctx.L().Info("fetch global checkpoint from DB", log.WrapStringerField("global checkpoint", cp.globalPoint)) } continue // skip global checkpoint } + + var ti model.TableInfo + if err = json.Unmarshal(tiBytes, &ti); err != nil { + return errors.Annotatef(err, "saved schema of %s.%s is not proper JSON", cpSchema, cpTable) + } + if err = schemaTracker.CreateSchemaIfNotExists(cpSchema); err != nil { + return errors.Annotatef(err, "failed to create database for `%s` in schema tracker", cpSchema) + } + if err = schemaTracker.CreateTableIfNotExists(cpSchema, cpTable, &ti); err != nil { + return errors.Annotatef(err, "failed to create table for `%s`.`%s` in schema tracker", cpSchema, cpTable) + } + mSchema, ok := cp.points[cpSchema] if !ok { mSchema = make(map[string]*binlogPoint) cp.points[cpSchema] = mSchema } - mSchema[cpTable] = newBinlogPoint(pos, pos) + mSchema[cpTable] = newBinlogPoint(pos, &ti, pos, &ti) } + return terror.WithScope(terror.DBErrorAdapt(rows.Err(), terror.ErrDBDriverError), terror.ScopeDownstream) } @@ -577,7 +650,7 @@ func (cp *RemoteCheckPoint) LoadMeta() error { // if meta loaded, we will start syncing from meta's pos if pos != nil { - cp.globalPoint = newBinlogPoint(*pos, *pos) + cp.globalPoint = newBinlogPoint(*pos, nil, *pos, nil) cp.tctx.L().Info("loaded checkpoints from meta", log.WrapStringerField("global checkpoint", cp.globalPoint)) } @@ -585,16 +658,28 @@ func (cp *RemoteCheckPoint) LoadMeta() error { } // genUpdateSQL generates SQL and arguments for update checkpoint -func (cp *RemoteCheckPoint) genUpdateSQL(cpSchema, cpTable string, binlogName string, binlogPos uint32, isGlobal bool) (string, []interface{}) { +func (cp *RemoteCheckPoint) genUpdateSQL(cpSchema, cpTable string, binlogName string, binlogPos uint32, tiBytes []byte, isGlobal bool) (string, []interface{}) { // use `INSERT INTO ... ON DUPLICATE KEY UPDATE` rather than `REPLACE INTO` // to keep `create_time`, `update_time` correctly - sql2 := fmt.Sprintf("INSERT INTO `%s`.`%s` (`id`, `cp_schema`, `cp_table`, `binlog_name`, `binlog_pos`, `is_global`) VALUES(?,?,?,?,?,?) ON DUPLICATE KEY UPDATE `binlog_name`=?, `binlog_pos`=?", - cp.schema, cp.table) + sql2 := `INSERT INTO ` + cp.tableName + ` + (id, cp_schema, cp_table, binlog_name, binlog_pos, table_info, is_global) VALUES + (?, ?, ?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + binlog_name = VALUES(binlog_name), + binlog_pos = VALUES(binlog_pos), + table_info = VALUES(table_info), + is_global = VALUES(is_global); + ` + if isGlobal { cpSchema = globalCpSchema cpTable = globalCpTable } - args := []interface{}{cp.id, cpSchema, cpTable, binlogName, binlogPos, isGlobal, binlogName, binlogPos} + + if len(tiBytes) == 0 { + tiBytes = []byte("null") + } + args := []interface{}{cp.id, cpSchema, cpTable, binlogName, binlogPos, tiBytes, isGlobal} return sql2, args } diff --git a/syncer/checkpoint_test.go b/syncer/checkpoint_test.go index 58228dae20..a500bc1da0 100644 --- a/syncer/checkpoint_test.go +++ b/syncer/checkpoint_test.go @@ -24,10 +24,13 @@ import ( "github.com/pingcap/dm/pkg/conn" tcontext "github.com/pingcap/dm/pkg/context" "github.com/pingcap/dm/pkg/retry" + "github.com/pingcap/dm/pkg/schema" "github.com/DATA-DOG/go-sqlmock" . "github.com/pingcap/check" + "github.com/pingcap/log" "github.com/siddontang/go-mysql/mysql" + "go.uber.org/zap/zapcore" ) var ( @@ -42,8 +45,9 @@ var ( var _ = Suite(&testCheckpointSuite{}) type testCheckpointSuite struct { - cfg *config.SubTaskConfig - mock sqlmock.Sqlmock + cfg *config.SubTaskConfig + mock sqlmock.Sqlmock + tracker *schema.Tracker } func (s *testCheckpointSuite) SetUpSuite(c *C) { @@ -52,17 +56,24 @@ func (s *testCheckpointSuite) SetUpSuite(c *C) { MetaSchema: "test", Name: "syncer_checkpoint_ut", } + + log.SetLevel(zapcore.ErrorLevel) + var err error + s.tracker, err = schema.NewTracker() + c.Assert(err, IsNil) } -func (s *testCheckpointSuite) TearDownSuite(c *C) { +func (s *testCheckpointSuite) TestUpTest(c *C) { + err := s.tracker.Reset() + c.Assert(err, IsNil) } func (s *testCheckpointSuite) prepareCheckPointSQL() { schemaCreateSQL = fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s`", s.cfg.MetaSchema) tableCreateSQL = fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s`.`%s_syncer_checkpoint` .*", s.cfg.MetaSchema, s.cfg.Name) flushCheckPointSQL = fmt.Sprintf("INSERT INTO `%s`.`%s_syncer_checkpoint` .* VALUES.* ON DUPLICATE KEY UPDATE .*", s.cfg.MetaSchema, s.cfg.Name) - clearCheckPointSQL = fmt.Sprintf("DELETE FROM `%s`.`%s_syncer_checkpoint` WHERE `id` = '%s'", s.cfg.MetaSchema, s.cfg.Name, cpid) - loadCheckPointSQL = fmt.Sprintf("SELECT .* FROM `%s`.`%s_syncer_checkpoint` WHERE `id`='%s'", s.cfg.MetaSchema, s.cfg.Name, cpid) + clearCheckPointSQL = fmt.Sprintf("DELETE FROM `%s`.`%s_syncer_checkpoint` WHERE id = \\?", s.cfg.MetaSchema, s.cfg.Name) + loadCheckPointSQL = fmt.Sprintf("SELECT .* FROM `%s`.`%s_syncer_checkpoint` WHERE id = \\?", s.cfg.MetaSchema, s.cfg.Name) } // this test case uses sqlmock to simulate all SQL operations in tests @@ -87,7 +98,7 @@ func (s *testCheckpointSuite) TestCheckPoint(c *C) { mock.ExpectExec(tableCreateSQL).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() mock.ExpectBegin() - mock.ExpectExec(clearCheckPointSQL).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(clearCheckPointSQL).WithArgs(cpid).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() dbConn, err := db.Conn(tcontext.Background().Context()) @@ -97,7 +108,7 @@ func (s *testCheckpointSuite) TestCheckPoint(c *C) { cp.(*RemoteCheckPoint).dbConn = conn err = cp.(*RemoteCheckPoint).prepare() c.Assert(err, IsNil) - cp.Clear() + c.Assert(cp.Clear(), IsNil) // test operation for global checkpoint s.testGlobalCheckPoint(c, cp) @@ -113,7 +124,7 @@ func (s *testCheckpointSuite) testGlobalCheckPoint(c *C, cp CheckPoint) { // try load, but should load nothing s.mock.ExpectQuery(loadCheckPointSQL).WillReturnRows(sqlmock.NewRows(nil)) - err := cp.Load() + err := cp.Load(s.tracker) c.Assert(err, IsNil) c.Assert(cp.GlobalPoint(), Equals, minCheckpoint) c.Assert(cp.FlushedGlobalPoint(), Equals, minCheckpoint) @@ -142,13 +153,13 @@ func (s *testCheckpointSuite) testGlobalCheckPoint(c *C, cp CheckPoint) { s.cfg.Mode = config.ModeAll s.cfg.Dir = dir - s.mock.ExpectQuery(loadCheckPointSQL).WillReturnRows(sqlmock.NewRows(nil)) - err = cp.Load() + s.mock.ExpectQuery(loadCheckPointSQL).WithArgs(cpid).WillReturnRows(sqlmock.NewRows(nil)) + err = cp.Load(s.tracker) c.Assert(err, IsNil) cp.SaveGlobalPoint(pos1) s.mock.ExpectBegin() - s.mock.ExpectExec(flushCheckPointSQL).WithArgs(cpid, "", "", pos1.Name, pos1.Pos, true, pos1.Name, pos1.Pos).WillReturnResult(sqlmock.NewResult(0, 1)) + s.mock.ExpectExec("(162)?"+flushCheckPointSQL).WithArgs(cpid, "", "", pos1.Name, pos1.Pos, []byte("null"), true).WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() err = cp.FlushPointsExcept(nil, nil, nil) c.Assert(err, IsNil) @@ -177,7 +188,7 @@ func (s *testCheckpointSuite) testGlobalCheckPoint(c *C, cp CheckPoint) { c.Assert(cp.FlushedGlobalPoint(), Equals, pos1) // test rollback - cp.Rollback() + cp.Rollback(s.tracker) c.Assert(cp.GlobalPoint(), Equals, pos1) c.Assert(cp.FlushedGlobalPoint(), Equals, pos1) @@ -188,11 +199,11 @@ func (s *testCheckpointSuite) testGlobalCheckPoint(c *C, cp CheckPoint) { // flush + rollback s.mock.ExpectBegin() - s.mock.ExpectExec(flushCheckPointSQL).WithArgs(cpid, "", "", pos2.Name, pos2.Pos, true, pos2.Name, pos2.Pos).WillReturnResult(sqlmock.NewResult(0, 1)) + s.mock.ExpectExec("(202)?"+flushCheckPointSQL).WithArgs(cpid, "", "", pos2.Name, pos2.Pos, []byte("null"), true).WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() err = cp.FlushPointsExcept(nil, nil, nil) c.Assert(err, IsNil) - cp.Rollback() + cp.Rollback(s.tracker) c.Assert(cp.GlobalPoint(), Equals, pos2) c.Assert(cp.FlushedGlobalPoint(), Equals, pos2) @@ -200,9 +211,9 @@ func (s *testCheckpointSuite) testGlobalCheckPoint(c *C, cp CheckPoint) { pos3 := pos2 pos3.Pos = pos2.Pos + 1000 // > pos2 to enable save cp.SaveGlobalPoint(pos3) - columns := []string{"cp_schema", "cp_table", "binlog_name", "binlog_pos", "is_global"} - s.mock.ExpectQuery(loadCheckPointSQL).WillReturnRows(sqlmock.NewRows(columns).AddRow("", "", pos2.Name, pos2.Pos, true)) - err = cp.Load() + columns := []string{"cp_schema", "cp_table", "binlog_name", "binlog_pos", "table_info", "is_global"} + s.mock.ExpectQuery(loadCheckPointSQL).WithArgs(cpid).WillReturnRows(sqlmock.NewRows(columns).AddRow("", "", pos2.Name, pos2.Pos, []byte("null"), true)) + err = cp.Load(s.tracker) c.Assert(err, IsNil) c.Assert(cp.GlobalPoint(), Equals, pos2) c.Assert(cp.FlushedGlobalPoint(), Equals, pos2) @@ -220,7 +231,7 @@ func (s *testCheckpointSuite) testGlobalCheckPoint(c *C, cp CheckPoint) { // test clear s.mock.ExpectBegin() - s.mock.ExpectExec(clearCheckPointSQL).WillReturnResult(sqlmock.NewResult(0, 1)) + s.mock.ExpectExec(clearCheckPointSQL).WithArgs(cpid).WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() err = cp.Clear() c.Assert(err, IsNil) @@ -228,7 +239,7 @@ func (s *testCheckpointSuite) testGlobalCheckPoint(c *C, cp CheckPoint) { c.Assert(cp.FlushedGlobalPoint(), Equals, minCheckpoint) s.mock.ExpectQuery(loadCheckPointSQL).WillReturnRows(sqlmock.NewRows(nil)) - err = cp.Load() + err = cp.Load(s.tracker) c.Assert(err, IsNil) c.Assert(cp.GlobalPoint(), Equals, minCheckpoint) c.Assert(cp.FlushedGlobalPoint(), Equals, minCheckpoint) @@ -254,33 +265,33 @@ func (s *testCheckpointSuite) testTableCheckPoint(c *C, cp CheckPoint) { c.Assert(newer, IsTrue) // save - cp.SaveTablePoint(schema, table, pos2) + cp.SaveTablePoint(schema, table, pos2, nil) newer = cp.IsNewerTablePoint(schema, table, pos1) c.Assert(newer, IsFalse) // rollback, to min - cp.Rollback() + cp.Rollback(s.tracker) newer = cp.IsNewerTablePoint(schema, table, pos1) c.Assert(newer, IsTrue) // save again - cp.SaveTablePoint(schema, table, pos2) + cp.SaveTablePoint(schema, table, pos2, nil) newer = cp.IsNewerTablePoint(schema, table, pos1) c.Assert(newer, IsFalse) // flush + rollback s.mock.ExpectBegin() - s.mock.ExpectExec(flushCheckPointSQL).WithArgs(cpid, schema, table, pos2.Name, pos2.Pos, false, pos2.Name, pos2.Pos).WillReturnResult(sqlmock.NewResult(0, 1)) + s.mock.ExpectExec("(284)?"+flushCheckPointSQL).WithArgs(cpid, schema, table, pos2.Name, pos2.Pos, sqlmock.AnyArg(), false).WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() err = cp.FlushPointsExcept(nil, nil, nil) c.Assert(err, IsNil) - cp.Rollback() + cp.Rollback(s.tracker) newer = cp.IsNewerTablePoint(schema, table, pos1) c.Assert(newer, IsFalse) // clear, to min s.mock.ExpectBegin() - s.mock.ExpectExec(clearCheckPointSQL).WillReturnResult(sqlmock.NewResult(0, 1)) + s.mock.ExpectExec(clearCheckPointSQL).WithArgs(cpid).WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() err = cp.Clear() c.Assert(err, IsNil) @@ -288,7 +299,7 @@ func (s *testCheckpointSuite) testTableCheckPoint(c *C, cp CheckPoint) { c.Assert(newer, IsTrue) // save - cp.SaveTablePoint(schema, table, pos2) + cp.SaveTablePoint(schema, table, pos2, nil) newer = cp.IsNewerTablePoint(schema, table, pos1) c.Assert(newer, IsFalse) @@ -301,16 +312,16 @@ func (s *testCheckpointSuite) testTableCheckPoint(c *C, cp CheckPoint) { c.Assert(r, Matches, matchStr) }() cp.SaveGlobalPoint(pos2) - cp.SaveTablePoint(schema, table, pos1) + cp.SaveTablePoint(schema, table, pos1, nil) }() // flush but except + rollback s.mock.ExpectBegin() - s.mock.ExpectExec(flushCheckPointSQL).WithArgs(cpid, "", "", pos2.Name, pos2.Pos, true, pos2.Name, pos2.Pos).WillReturnResult(sqlmock.NewResult(0, 1)) + s.mock.ExpectExec("(320)?"+flushCheckPointSQL).WithArgs(cpid, "", "", pos2.Name, pos2.Pos, []byte("null"), true).WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() err = cp.FlushPointsExcept([][]string{{schema, table}}, nil, nil) c.Assert(err, IsNil) - cp.Rollback() + cp.Rollback(s.tracker) newer = cp.IsNewerTablePoint(schema, table, pos1) c.Assert(newer, IsTrue) } diff --git a/syncer/db.go b/syncer/db.go index 12b7d22a0c..2d716fb121 100644 --- a/syncer/db.go +++ b/syncer/db.go @@ -15,8 +15,6 @@ package syncer import ( "database/sql" - "fmt" - "strings" "time" "github.com/pingcap/dm/dm/config" @@ -37,27 +35,6 @@ import ( "go.uber.org/zap" ) -type column struct { - idx int - name string - NotNull bool - unsigned bool - tp string - extra string -} - -func (c *column) isGeneratedColumn() bool { - return strings.Contains(c.extra, "VIRTUAL GENERATED") || strings.Contains(c.extra, "STORED GENERATED") -} - -type table struct { - schema string - name string - - columns []*column - indexColumns map[string][]*column -} - // in MySQL, we can set `max_binlog_size` to control the max size of a binlog file. // but this is not absolute: // > A transaction is written in one chunk to the binary log, so it is never split between several binary logs. @@ -304,134 +281,6 @@ func createConns(tctx *tcontext.Context, cfg *config.SubTaskConfig, dbCfg config return baseDB, conns, nil } -func getTableIndex(tctx *tcontext.Context, conn *DBConn, table *table) error { - if table.schema == "" || table.name == "" { - return terror.ErrDBUnExpect.Generate("schema/table is empty") - } - - query := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", table.schema, table.name) - rows, err := conn.querySQL(tctx, query) - if err != nil { - return terror.DBErrorAdapt(err, terror.ErrDBDriverError) - } - defer rows.Close() - - rowColumns, err := rows.Columns() - if err != nil { - return terror.DBErrorAdapt(err, terror.ErrDBDriverError) - } - - // Show an example. - /* - mysql> show index from test.t; - +-------+------------+----------+--------------+-------------+-----------+-------------+----------+--------+------+------------+---------+---------------+ - | Table | Non_unique | Key_name | Seq_in_index | Column_name | Collation | Cardinality | Sub_part | Packed | Null | Index_type | Comment | Index_comment | - +-------+------------+----------+--------------+-------------+-----------+-------------+----------+--------+------+------------+---------+---------------+ - | t | 0 | PRIMARY | 1 | a | A | 0 | NULL | NULL | | BTREE | | | - | t | 0 | PRIMARY | 2 | b | A | 0 | NULL | NULL | | BTREE | | | - | t | 0 | ucd | 1 | c | A | 0 | NULL | NULL | YES | BTREE | | | - | t | 0 | ucd | 2 | d | A | 0 | NULL | NULL | YES | BTREE | | | - +-------+------------+----------+--------------+-------------+-----------+-------------+----------+--------+------+------------+---------+---------------+ - */ - var columns = make(map[string][]string) - for rows.Next() { - data := make([]sql.RawBytes, len(rowColumns)) - values := make([]interface{}, len(rowColumns)) - - for i := range values { - values[i] = &data[i] - } - - err = rows.Scan(values...) - if err != nil { - return terror.DBErrorAdapt(err, terror.ErrDBDriverError) - } - - nonUnique := string(data[1]) - if nonUnique == "0" { - keyName := strings.ToLower(string(data[2])) - columns[keyName] = append(columns[keyName], string(data[4])) - } - } - if rows.Err() != nil { - return terror.DBErrorAdapt(rows.Err(), terror.ErrDBDriverError) - } - - table.indexColumns = findColumns(table.columns, columns) - return nil -} - -func getTableColumns(tctx *tcontext.Context, conn *DBConn, table *table) error { - if table.schema == "" || table.name == "" { - return terror.ErrDBUnExpect.Generate("schema/table is empty") - } - - query := fmt.Sprintf("SHOW COLUMNS FROM `%s`.`%s`", table.schema, table.name) - rows, err := conn.querySQL(tctx, query) - if err != nil { - return terror.DBErrorAdapt(err, terror.ErrDBDriverError) - } - defer rows.Close() - - rowColumns, err := rows.Columns() - if err != nil { - return terror.DBErrorAdapt(err, terror.ErrDBDriverError) - } - - // Show an example. - /* - mysql> show columns from test.t; - +-------+---------+------+-----+---------+-------------------+ - | Field | Type | Null | Key | Default | Extra | - +-------+---------+------+-----+---------+-------------------+ - | a | int(11) | NO | PRI | NULL | | - | b | int(11) | NO | PRI | NULL | | - | c | int(11) | YES | MUL | NULL | | - | d | int(11) | YES | | NULL | | - | d | json | YES | | NULL | VIRTUAL GENERATED | - +-------+---------+------+-----+---------+-------------------+ - */ - - idx := 0 - for rows.Next() { - data := make([]sql.RawBytes, len(rowColumns)) - values := make([]interface{}, len(rowColumns)) - - for i := range values { - values[i] = &data[i] - } - - err = rows.Scan(values...) - if err != nil { - return terror.DBErrorAdapt(err, terror.ErrDBDriverError) - } - - column := &column{} - column.idx = idx - column.name = string(data[0]) - column.tp = string(data[1]) - column.extra = string(data[5]) - - if strings.ToLower(string(data[2])) == "no" { - column.NotNull = true - } - - // Check whether column has unsigned flag. - if strings.Contains(strings.ToLower(string(data[1])), "unsigned") { - column.unsigned = true - } - - table.columns = append(table.columns, column) - idx++ - } - - if rows.Err() != nil { - return terror.DBErrorAdapt(rows.Err(), terror.ErrDBDriverError) - } - - return nil -} - func countBinaryLogsSize(fromFile mysql.Position, db *sql.DB) (int64, error) { files, err := getBinaryLogs(db) if err != nil { diff --git a/syncer/dml.go b/syncer/dml.go index 4043bbca2b..18782b2a08 100644 --- a/syncer/dml.go +++ b/syncer/dml.go @@ -14,101 +14,44 @@ package syncer import ( - "bytes" "encoding/binary" "fmt" "strconv" "strings" - "github.com/pingcap/dm/pkg/log" "github.com/pingcap/dm/pkg/terror" - + "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/types" "github.com/pingcap/tidb-tools/pkg/dbutil" ) -type genColumnCacheStatus uint8 - -const ( - genColumnNoCache genColumnCacheStatus = iota - hasGenColumn - noGenColumn -) - -// GenColCache stores generated column information for all tables -type GenColCache struct { - // `schema`.`table` -> whether this table has generated column - hasGenColumn map[string]bool - - // `schema`.`table` -> column list - columns map[string][]*column - - // `schema`.`table` -> a bool slice representing whether it is generated for each column - isGenColumn map[string][]bool -} - // genDMLParam stores pruned columns, data as well as the original columns, data, index type genDMLParam struct { - schema string - table string - safeMode bool // only used in update - data [][]interface{} // pruned data - originalData [][]interface{} // all data - columns []*column // pruned columns - originalColumns []*column // all columns - originalIndexColumns map[string][]*column // all index information -} - -// NewGenColCache creates a GenColCache. -func NewGenColCache() *GenColCache { - c := &GenColCache{} - c.reset() - return c -} - -// status returns `NotFound` if a `schema`.`table` has no generated column -// information cached, otherwise returns `hasGenColumn` if cache found and -// it has generated column and returns `noGenColumn` if it has no generated column. -func (c *GenColCache) status(key string) genColumnCacheStatus { - val, ok := c.hasGenColumn[key] - if !ok { - return genColumnNoCache - } - if val { - return hasGenColumn - } - return noGenColumn -} - -func (c *GenColCache) clearTable(schema, table string) { - key := dbutil.TableName(schema, table) - delete(c.hasGenColumn, key) - delete(c.columns, key) - delete(c.isGenColumn, key) -} - -func (c *GenColCache) reset() { - c.hasGenColumn = make(map[string]bool) - c.columns = make(map[string][]*column) - c.isGenColumn = make(map[string][]bool) + schema string + table string + safeMode bool // only used in update + data [][]interface{} // pruned data + originalData [][]interface{} // all data + columns []*model.ColumnInfo // pruned columns + originalTableInfo *model.TableInfo // all table info } -func extractValueFromData(data []interface{}, columns []*column) []interface{} { +func extractValueFromData(data []interface{}, columns []*model.ColumnInfo) []interface{} { value := make([]interface{}, 0, len(data)) for i := range data { - value = append(value, castUnsigned(data[i], columns[i].unsigned, columns[i].tp)) + value = append(value, castUnsigned(data[i], &columns[i].FieldType)) } return value } func genInsertSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, error) { var ( - schema = param.schema - table = param.table - dataSeq = param.data - originalDataSeq = param.originalData - columns = param.columns - originalColumns = param.originalColumns - originalIndexColumns = param.originalIndexColumns + qualifiedName = dbutil.TableName(param.schema, param.table) + dataSeq = param.data + originalDataSeq = param.originalData + columns = param.columns + ti = param.originalTableInfo ) sqls := make([]string, 0, len(dataSeq)) keys := make([][]string, 0, len(dataSeq)) @@ -123,10 +66,10 @@ func genInsertSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, e value := extractValueFromData(data, columns) originalData := originalDataSeq[dataIdx] var originalValue []interface{} - if len(columns) == len(originalColumns) { + if len(columns) == len(ti.Columns) { originalValue = value } else { - originalValue = extractValueFromData(originalData, originalColumns) + originalValue = extractValueFromData(originalData, ti.Columns) } var insertOrReplace string @@ -136,8 +79,8 @@ func genInsertSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, e insertOrReplace = "INSERT" } - sql := fmt.Sprintf("%s INTO `%s`.`%s` (%s) VALUES (%s);", insertOrReplace, schema, table, columnList, columnPlaceholders) - ks := genMultipleKeys(originalColumns, originalValue, originalIndexColumns) + sql := fmt.Sprintf("%s INTO %s (%s) VALUES (%s);", insertOrReplace, qualifiedName, columnList, columnPlaceholders) + ks := genMultipleKeys(ti, originalValue) sqls = append(sqls, sql) values = append(values, value) keys = append(keys, ks) @@ -148,21 +91,19 @@ func genInsertSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, e func genUpdateSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, error) { var ( - schema = param.schema - table = param.table - safeMode = param.safeMode - data = param.data - originalData = param.originalData - columns = param.columns - originalColumns = param.originalColumns - originalIndexColumns = param.originalIndexColumns + qualifiedName = dbutil.TableName(param.schema, param.table) + safeMode = param.safeMode + data = param.data + originalData = param.originalData + columns = param.columns + ti = param.originalTableInfo ) sqls := make([]string, 0, len(data)/2) keys := make([][]string, 0, len(data)/2) values := make([][]interface{}, 0, len(data)/2) columnList := genColumnList(columns) columnPlaceholders := genColumnPlaceholders(len(columns)) - defaultIndexColumns := findFitIndex(originalIndexColumns) + defaultIndexColumns := findFitIndex(ti) for i := 0; i < len(data); i += 2 { oldData := data[i] @@ -182,37 +123,37 @@ func genUpdateSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, e changedValues := extractValueFromData(changedData, columns) var oriOldValues, oriChangedValues []interface{} - if len(columns) == len(originalColumns) { + if len(columns) == len(ti.Columns) { oriOldValues = oldValues oriChangedValues = changedValues } else { - oriOldValues = extractValueFromData(oriOldData, originalColumns) - oriChangedValues = extractValueFromData(oriChangedData, originalColumns) + oriOldValues = extractValueFromData(oriOldData, ti.Columns) + oriChangedValues = extractValueFromData(oriChangedData, ti.Columns) } - if len(defaultIndexColumns) == 0 { - defaultIndexColumns = getAvailableIndexColumn(originalIndexColumns, oriOldValues) + if defaultIndexColumns == nil { + defaultIndexColumns = getAvailableIndexColumn(ti, oriOldValues) } - ks := genMultipleKeys(originalColumns, oriOldValues, originalIndexColumns) - ks = append(ks, genMultipleKeys(originalColumns, oriChangedValues, originalIndexColumns)...) + ks := genMultipleKeys(ti, oriOldValues) + ks = append(ks, genMultipleKeys(ti, oriChangedValues)...) if safeMode { // generate delete sql from old data - sql, value := genDeleteSQL(schema, table, oriOldValues, originalColumns, defaultIndexColumns) + sql, value := genDeleteSQL(qualifiedName, oriOldValues, ti.Columns, defaultIndexColumns) sqls = append(sqls, sql) values = append(values, value) keys = append(keys, ks) // generate replace sql from new data - sql = fmt.Sprintf("REPLACE INTO `%s`.`%s` (%s) VALUES (%s);", schema, table, columnList, columnPlaceholders) + sql = fmt.Sprintf("REPLACE INTO %s (%s) VALUES (%s);", qualifiedName, columnList, columnPlaceholders) sqls = append(sqls, sql) values = append(values, changedValues) keys = append(keys, ks) continue } - updateColumns := make([]*column, 0, len(defaultIndexColumns)) - updateValues := make([]interface{}, 0, len(defaultIndexColumns)) + updateColumns := make([]*model.ColumnInfo, 0, indexColumnsCount(defaultIndexColumns)) + updateValues := make([]interface{}, 0, indexColumnsCount(defaultIndexColumns)) for j := range oldValues { updateColumns = append(updateColumns, columns[j]) updateValues = append(updateValues, changedValues[j]) @@ -227,15 +168,15 @@ func genUpdateSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, e kvs := genKVs(updateColumns) value = append(value, updateValues...) - whereColumns, whereValues := originalColumns, oriOldValues - if len(defaultIndexColumns) > 0 { - whereColumns, whereValues = getColumnData(originalColumns, defaultIndexColumns, oriOldValues) + whereColumns, whereValues := ti.Columns, oriOldValues + if defaultIndexColumns != nil { + whereColumns, whereValues = getColumnData(ti.Columns, defaultIndexColumns, oriOldValues) } where := genWhere(whereColumns, whereValues) value = append(value, whereValues...) - sql := fmt.Sprintf("UPDATE `%s`.`%s` SET %s WHERE %s LIMIT 1;", schema, table, kvs, where) + sql := fmt.Sprintf("UPDATE %s SET %s WHERE %s LIMIT 1;", qualifiedName, kvs, where) sqls = append(sqls, sql) values = append(values, value) keys = append(keys, ks) @@ -246,30 +187,28 @@ func genUpdateSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, e func genDeleteSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, error) { var ( - schema = param.schema - table = param.table - dataSeq = param.originalData - columns = param.originalColumns - indexColumns = param.originalIndexColumns + qualifiedName = dbutil.TableName(param.schema, param.table) + dataSeq = param.originalData + ti = param.originalTableInfo ) sqls := make([]string, 0, len(dataSeq)) keys := make([][]string, 0, len(dataSeq)) values := make([][]interface{}, 0, len(dataSeq)) - defaultIndexColumns := findFitIndex(indexColumns) + defaultIndexColumns := findFitIndex(ti) for _, data := range dataSeq { - if len(data) != len(columns) { - return nil, nil, nil, terror.ErrSyncerUnitDMLColumnNotMatch.Generate(len(columns), len(data)) + if len(data) != len(ti.Columns) { + return nil, nil, nil, terror.ErrSyncerUnitDMLColumnNotMatch.Generate(len(ti.Columns), len(data)) } - value := extractValueFromData(data, columns) + value := extractValueFromData(data, ti.Columns) - if len(defaultIndexColumns) == 0 { - defaultIndexColumns = getAvailableIndexColumn(indexColumns, value) + if defaultIndexColumns == nil { + defaultIndexColumns = getAvailableIndexColumn(ti, value) } - ks := genMultipleKeys(columns, value, indexColumns) + ks := genMultipleKeys(ti, value) - sql, value := genDeleteSQL(schema, table, value, columns, defaultIndexColumns) + sql, value := genDeleteSQL(qualifiedName, value, ti.Columns, defaultIndexColumns) sqls = append(sqls, sql) values = append(values, value) keys = append(keys, ks) @@ -278,28 +217,35 @@ func genDeleteSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, e return sqls, keys, values, nil } -func genDeleteSQL(schema string, table string, value []interface{}, columns []*column, indexColumns []*column) (string, []interface{}) { +func genDeleteSQL(qualifiedName string, value []interface{}, columns []*model.ColumnInfo, indexColumns *model.IndexInfo) (string, []interface{}) { whereColumns, whereValues := columns, value - if len(indexColumns) > 0 { + if indexColumns != nil { whereColumns, whereValues = getColumnData(columns, indexColumns, value) } where := genWhere(whereColumns, whereValues) - sql := fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE %s LIMIT 1;", schema, table, where) + sql := fmt.Sprintf("DELETE FROM %s WHERE %s LIMIT 1;", qualifiedName, where) return sql, whereValues } -func genColumnList(columns []*column) string { +func indexColumnsCount(index *model.IndexInfo) int { + if index == nil { + return 0 + } + return len(index.Columns) +} + +func genColumnList(columns []*model.ColumnInfo) string { var buf strings.Builder for i, column := range columns { - if i != len(columns)-1 { - buf.WriteString("`" + column.name + "`,") - } else { - buf.WriteString("`" + column.name + "`") + if i != 0 { + buf.WriteByte(',') } + buf.WriteByte('`') + buf.WriteString(strings.ReplaceAll(column.Name.O, "`", "``")) + buf.WriteByte('`') } - return buf.String() } @@ -311,8 +257,8 @@ func genColumnPlaceholders(length int) string { return strings.Join(values, ",") } -func castUnsigned(data interface{}, unsigned bool, tp string) interface{} { - if !unsigned { +func castUnsigned(data interface{}, ft *types.FieldType) interface{} { + if !mysql.HasUnsignedFlag(ft.Flag) { return data } @@ -324,7 +270,7 @@ func castUnsigned(data interface{}, unsigned bool, tp string) interface{} { case int16: return uint16(v) case int32: - if strings.Contains(strings.ToLower(tp), "mediumint") { + if ft.Tp == mysql.TypeInt24 { // we use int32 to store MEDIUMINT, if the value is signed, it's fine // but if the value is un-signed, simply convert it use `uint32` may out of the range // like -4692783 converted to 4290274513 (2^32 - 4692783), but we expect 12084433 (2^24 - 4692783) @@ -340,8 +286,8 @@ func castUnsigned(data interface{}, unsigned bool, tp string) interface{} { return data } -func columnValue(value interface{}, unsigned bool, tp string) string { - castValue := castUnsigned(value, unsigned, tp) +func columnValue(value interface{}, ft *types.FieldType) string { + castValue := castUnsigned(value, ft) var data string switch v := castValue.(type) { @@ -386,86 +332,76 @@ func columnValue(value interface{}, unsigned bool, tp string) string { return data } -func findColumn(columns []*column, indexColumn string) *column { - for _, column := range columns { - if column.name == indexColumn { - return column - } - } - - return nil -} - -func findColumns(columns []*column, indexColumns map[string][]string) map[string][]*column { - result := make(map[string][]*column) - - for keyName, indexCols := range indexColumns { - cols := make([]*column, 0, len(indexCols)) - for _, name := range indexCols { - column := findColumn(columns, name) - if column != nil { - cols = append(cols, column) - } - } - result[keyName] = cols - } - - return result -} - -func genKeyList(columns []*column, dataSeq []interface{}) string { +func genKeyList(columns []*model.ColumnInfo, dataSeq []interface{}) string { values := make([]string, 0, len(dataSeq)) for i, data := range dataSeq { - values = append(values, columnValue(data, columns[i].unsigned, columns[i].tp)) + values = append(values, columnValue(data, &columns[i].FieldType)) } return strings.Join(values, ",") } -func genMultipleKeys(columns []*column, value []interface{}, indexColumns map[string][]*column) []string { - multipleKeys := make([]string, 0, len(indexColumns)) - for _, indexCols := range indexColumns { - cols, vals := getColumnData(columns, indexCols, value) +func genMultipleKeys(ti *model.TableInfo, value []interface{}) []string { + multipleKeys := make([]string, 0, len(ti.Indices)+1) + if pk := ti.GetPkColInfo(); pk != nil { + cols := []*model.ColumnInfo{pk} + vals := []interface{}{value[pk.Offset]} + multipleKeys = append(multipleKeys, genKeyList(cols, vals)) + } + for _, indexCols := range ti.Indices { + cols, vals := getColumnData(ti.Columns, indexCols, value) multipleKeys = append(multipleKeys, genKeyList(cols, vals)) } return multipleKeys } -func findFitIndex(indexColumns map[string][]*column) []*column { - cols, ok := indexColumns["primary"] - if ok { - if len(cols) == 0 { - log.L().Error("cols is empty") - } else { - return cols +func findFitIndex(ti *model.TableInfo) *model.IndexInfo { + for _, idx := range ti.Indices { + if idx.Primary { + return idx + } + } + + if pk := ti.GetPkColInfo(); pk != nil { + return &model.IndexInfo{ + Table: ti.Name, + Unique: true, + Primary: true, + State: model.StatePublic, + Tp: model.IndexTypeBtree, + Columns: []*model.IndexColumn{{ + Name: pk.Name, + Offset: pk.Offset, + Length: types.UnspecifiedLength, + }}, } } // second find not null unique key - fn := func(c *column) bool { - return !c.NotNull + fn := func(i int) bool { + return !mysql.HasNotNullFlag(ti.Columns[i].Flag) } - return getSpecifiedIndexColumn(indexColumns, fn) + return getSpecifiedIndexColumn(ti, fn) } -func getAvailableIndexColumn(indexColumns map[string][]*column, data []interface{}) []*column { - fn := func(c *column) bool { - return data[c.idx] == nil +func getAvailableIndexColumn(ti *model.TableInfo, data []interface{}) *model.IndexInfo { + fn := func(i int) bool { + return data[i] == nil } - return getSpecifiedIndexColumn(indexColumns, fn) + return getSpecifiedIndexColumn(ti, fn) } -func getSpecifiedIndexColumn(indexColumns map[string][]*column, fn func(col *column) bool) []*column { - for _, indexCols := range indexColumns { - if len(indexCols) == 0 { +func getSpecifiedIndexColumn(ti *model.TableInfo, fn func(i int) bool) *model.IndexInfo { + for _, indexCols := range ti.Indices { + if !indexCols.Unique { continue } findFitIndex := true - for _, col := range indexCols { - if fn(col) { + for _, col := range indexCols.Columns { + if fn(col.Offset) { findFitIndex = false break } @@ -479,52 +415,59 @@ func getSpecifiedIndexColumn(indexColumns map[string][]*column, fn func(col *col return nil } -func getColumnData(columns []*column, indexColumns []*column, data []interface{}) ([]*column, []interface{}) { - cols := make([]*column, 0, len(columns)) +func getColumnData(columns []*model.ColumnInfo, indexColumns *model.IndexInfo, data []interface{}) ([]*model.ColumnInfo, []interface{}) { + cols := make([]*model.ColumnInfo, 0, len(columns)) values := make([]interface{}, 0, len(columns)) - for _, column := range indexColumns { - cols = append(cols, column) - values = append(values, data[column.idx]) + for _, column := range indexColumns.Columns { + cols = append(cols, columns[column.Offset]) + values = append(values, data[column.Offset]) } return cols, values } -func genWhere(columns []*column, data []interface{}) string { - var kvs bytes.Buffer - for i := range columns { - kvSplit := "=" - if data[i] == nil { - kvSplit = "IS" +func genWhere(columns []*model.ColumnInfo, data []interface{}) string { + var kvs strings.Builder + for i, col := range columns { + if i != 0 { + kvs.WriteString(" AND ") } - - if i == len(columns)-1 { - fmt.Fprintf(&kvs, "`%s` %s ?", columns[i].name, kvSplit) + kvs.WriteByte('`') + kvs.WriteString(strings.ReplaceAll(col.Name.O, "`", "``")) + if data[i] == nil { + kvs.WriteString("` IS ?") } else { - fmt.Fprintf(&kvs, "`%s` %s ? AND ", columns[i].name, kvSplit) + kvs.WriteString("` = ?") } } return kvs.String() } -func genKVs(columns []*column) string { - var kvs bytes.Buffer - for i := range columns { - if i == len(columns)-1 { - fmt.Fprintf(&kvs, "`%s` = ?", columns[i].name) - } else { - fmt.Fprintf(&kvs, "`%s` = ?, ", columns[i].name) +func genKVs(columns []*model.ColumnInfo) string { + var kvs strings.Builder + for i, col := range columns { + if i != 0 { + kvs.WriteString(", ") } + kvs.WriteByte('`') + kvs.WriteString(strings.ReplaceAll(col.Name.O, "`", "``")) + kvs.WriteString("` = ?") } return kvs.String() } -func (s *Syncer) mappingDML(schema, table string, columns []string, data [][]interface{}) ([][]interface{}, error) { +func (s *Syncer) mappingDML(schema, table string, ti *model.TableInfo, data [][]interface{}) ([][]interface{}, error) { if s.columnMapping == nil { return data, nil } + + columns := make([]string, 0, len(ti.Columns)) + for _, col := range ti.Columns { + columns = append(columns, col.Name.O) + } + var ( err error rows = make([][]interface{}, len(data)) @@ -542,82 +485,41 @@ func (s *Syncer) mappingDML(schema, table string, columns []string, data [][]int // generated column. because generated column is not support setting value // directly in DML, we must remove generated column from DML, including column // list and data list including generated columns. -func pruneGeneratedColumnDML(columns []*column, data [][]interface{}, schema, table string, cache *GenColCache) ([]*column, [][]interface{}, error) { - var ( - cacheKey = dbutil.TableName(schema, table) - cacheStatus = cache.status(cacheKey) - ) - - if cacheStatus == noGenColumn { - return columns, data, nil - } - if cacheStatus == hasGenColumn { - rows := make([][]interface{}, 0, len(data)) - filters, ok1 := cache.isGenColumn[cacheKey] - if !ok1 { - return nil, nil, terror.ErrSyncerUnitCacheKeyNotFound.Generate(cacheKey, "isGenColumn") - } - cols, ok2 := cache.columns[cacheKey] - if !ok2 { - return nil, nil, terror.ErrSyncerUnitCacheKeyNotFound.Generate(cacheKey, "columns") - } - for _, row := range data { - value := make([]interface{}, 0, len(row)) - for i := range row { - if !filters[i] { - value = append(value, row[i]) - } - } - rows = append(rows, value) +func pruneGeneratedColumnDML(ti *model.TableInfo, data [][]interface{}) ([]*model.ColumnInfo, [][]interface{}, error) { + // search for generated columns. if none found, return everything as-is. + firstGeneratedColumnIndex := -1 + for i, c := range ti.Columns { + if c.IsGenerated() { + firstGeneratedColumnIndex = i + break } - return cols, rows, nil } - - var ( - needPrune bool - colIndexfilters = make([]bool, 0, len(columns)) - genColumnNames = make(map[string]bool) - ) - - for _, c := range columns { - isGenColumn := c.isGeneratedColumn() - colIndexfilters = append(colIndexfilters, isGenColumn) - if isGenColumn { - needPrune = true - genColumnNames[c.name] = true - } - } - - if !needPrune { - cache.hasGenColumn[cacheKey] = false - return columns, data, nil + if firstGeneratedColumnIndex < 0 { + return ti.Columns, data, nil } - var ( - cols = make([]*column, 0, len(columns)) - rows = make([][]interface{}, 0, len(data)) - ) - - for i := range columns { - if !colIndexfilters[i] { - cols = append(cols, columns[i]) + // remove generated columns from the list of columns + cols := make([]*model.ColumnInfo, 0, len(ti.Columns)) + cols = append(cols, ti.Columns[:firstGeneratedColumnIndex]...) + for _, c := range ti.Columns[(firstGeneratedColumnIndex + 1):] { + if !c.IsGenerated() { + cols = append(cols, c) } } + + // remove generated columns from the list of data. + rows := make([][]interface{}, 0, len(data)) for _, row := range data { - if len(row) != len(columns) { - return nil, nil, terror.ErrSyncerUnitDMLPruneColumnMismatch.Generate(len(columns), len(data)) + if len(row) != len(ti.Columns) { + return nil, nil, terror.ErrSyncerUnitDMLPruneColumnMismatch.Generate(len(ti.Columns), len(data)) } - value := make([]interface{}, 0, len(row)) + value := make([]interface{}, 0, len(cols)) for i := range row { - if !colIndexfilters[i] { + if !ti.Columns[i].IsGenerated() { value = append(value, row[i]) } } rows = append(rows, value) } - cache.hasGenColumn[cacheKey] = true - cache.columns[cacheKey] = cols - cache.isGenColumn[cacheKey] = colIndexfilters - return cols, rows, nil } diff --git a/syncer/dml_test.go b/syncer/dml_test.go index ba365a9a92..9d498c84ca 100644 --- a/syncer/dml_test.go +++ b/syncer/dml_test.go @@ -18,6 +18,14 @@ import ( "strconv" . "github.com/pingcap/check" + "github.com/pingcap/parser" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/types" + tiddl "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/mock" ) func (s *testSyncerSuite) TestCastUnsigned(c *C) { @@ -25,22 +33,26 @@ func (s *testSyncerSuite) TestCastUnsigned(c *C) { cases := []struct { data interface{} unsigned bool - Type string + Type byte expected interface{} }{ - {int8(-math.Exp2(7)), false, "tinyint(4)", int8(-math.Exp2(7))}, // TINYINT - {int8(-math.Exp2(7)), true, "tinyint(3) unsigned", uint8(math.Exp2(7))}, - {int16(-math.Exp2(15)), false, "smallint(6)", int16(-math.Exp2(15))}, //SMALLINT - {int16(-math.Exp2(15)), true, "smallint(5) unsigned", uint16(math.Exp2(15))}, - {int32(-math.Exp2(23)), false, "mediumint(9)", int32(-math.Exp2(23))}, //MEDIUMINT - {int32(-math.Exp2(23)), true, "mediumint(8) unsigned", uint32(math.Exp2(23))}, - {int32(-math.Exp2(31)), false, "int(11)", int32(-math.Exp2(31))}, // INT - {int32(-math.Exp2(31)), true, "int(10) unsigned", uint32(math.Exp2(31))}, - {int64(-math.Exp2(63)), false, "bigint(20)", int64(-math.Exp2(63))}, // BIGINT - {int64(-math.Exp2(63)), true, "bigint(20) unsigned", strconv.FormatUint(uint64(math.Exp2(63)), 10)}, // special case use string to represent uint64 + {int8(-math.Exp2(7)), false, mysql.TypeTiny, int8(-math.Exp2(7))}, // TINYINT + {int8(-math.Exp2(7)), true, mysql.TypeTiny, uint8(math.Exp2(7))}, + {int16(-math.Exp2(15)), false, mysql.TypeShort, int16(-math.Exp2(15))}, //SMALLINT + {int16(-math.Exp2(15)), true, mysql.TypeShort, uint16(math.Exp2(15))}, + {int32(-math.Exp2(23)), false, mysql.TypeInt24, int32(-math.Exp2(23))}, //MEDIUMINT + {int32(-math.Exp2(23)), true, mysql.TypeInt24, uint32(math.Exp2(23))}, + {int32(-math.Exp2(31)), false, mysql.TypeLong, int32(-math.Exp2(31))}, // INT + {int32(-math.Exp2(31)), true, mysql.TypeLong, uint32(math.Exp2(31))}, + {int64(-math.Exp2(63)), false, mysql.TypeLonglong, int64(-math.Exp2(63))}, // BIGINT + {int64(-math.Exp2(63)), true, mysql.TypeLonglong, strconv.FormatUint(uint64(math.Exp2(63)), 10)}, // special case use string to represent uint64 } for _, cs := range cases { - obtained := castUnsigned(cs.data, cs.unsigned, cs.Type) + ft := types.NewFieldType(cs.Type) + if cs.unsigned { + ft.Flag |= mysql.UnsignedFlag + } + obtained := castUnsigned(cs.data, ft) c.Assert(obtained, Equals, cs.expected) } } @@ -53,14 +65,22 @@ func (s *testSyncerSuite) TestGenColumnPlaceholders(c *C) { c.Assert(placeholderStr, Equals, "?,?,?") } +func createTableInfo(p *parser.Parser, se sessionctx.Context, tableID int64, sql string) (*model.TableInfo, error) { + node, err := p.ParseOneStmt(sql, "utf8mb4", "utf8mb4_bin") + if err != nil { + return nil, err + } + return tiddl.MockTableInfo(se, node.(*ast.CreateTableStmt), tableID) +} + func (s *testSyncerSuite) TestGenColumnList(c *C) { - columns := []*column{ + columns := []*model.ColumnInfo{ { - name: "a", + Name: model.NewCIStr("a"), }, { - name: "b", + Name: model.NewCIStr("b"), }, { - name: "c", + Name: model.NewCIStr("c`d"), }, } @@ -68,45 +88,60 @@ func (s *testSyncerSuite) TestGenColumnList(c *C) { c.Assert(columnList, Equals, "`a`") columnList = genColumnList(columns) - c.Assert(columnList, Equals, "`a`,`b`,`c`") + c.Assert(columnList, Equals, "`a`,`b`,`c``d`") } func (s *testSyncerSuite) TestFindFitIndex(c *C) { - pkColumns := []*column{ - { - name: "a", - }, { - name: "b", - }, - } - indexColumns := []*column{ - { - name: "c", - }, - } - indexColumnsNotNull := []*column{ - { - name: "d", - NotNull: true, - }, - } + p := parser.New() + se := mock.NewContext() + + ti, err := createTableInfo(p, se, 1, ` + create table t1( + a int, + b int, + c int, + d int not null, + primary key(a, b), + unique key(c), + unique key(d) + ); + `) + c.Assert(err, IsNil) + + columns := findFitIndex(ti) + c.Assert(columns, NotNil) + c.Assert(columns.Columns, HasLen, 2) + c.Assert(columns.Columns[0].Name.L, Equals, "a") + c.Assert(columns.Columns[1].Name.L, Equals, "b") + + ti, err = createTableInfo(p, se, 2, `create table t2(c int unique);`) + c.Assert(err, IsNil) + columns = findFitIndex(ti) + c.Assert(columns, IsNil) + + ti, err = createTableInfo(p, se, 3, `create table t3(d int not null unique);`) + c.Assert(err, IsNil) + columns = findFitIndex(ti) + c.Assert(columns, NotNil) + c.Assert(columns.Columns, HasLen, 1) + c.Assert(columns.Columns[0].Name.L, Equals, "d") + + ti, err = createTableInfo(p, se, 4, `create table t4(e int not null, key(e));`) + c.Assert(err, IsNil) + columns = findFitIndex(ti) + c.Assert(columns, IsNil) + + ti, err = createTableInfo(p, se, 5, `create table t5(f datetime primary key);`) + c.Assert(err, IsNil) + columns = findFitIndex(ti) + c.Assert(columns, NotNil) + c.Assert(columns.Columns, HasLen, 1) + c.Assert(columns.Columns[0].Name.L, Equals, "f") - columns := findFitIndex(map[string][]*column{ - "primary": pkColumns, - "index": indexColumns, - }) - c.Assert(columns, HasLen, 2) - c.Assert(columns[0].name, Equals, "a") - c.Assert(columns[1].name, Equals, "b") - - columns = findFitIndex(map[string][]*column{ - "index": indexColumns, - }) - c.Assert(columns, HasLen, 0) - - columns = findFitIndex(map[string][]*column{ - "index": indexColumnsNotNull, - }) - c.Assert(columns, HasLen, 1) - c.Assert(columns[0].name, Equals, "d") + ti, err = createTableInfo(p, se, 6, `create table t6(g int primary key);`) + c.Assert(err, IsNil) + columns = findFitIndex(ti) + c.Assert(columns, NotNil) + c.Assert(columns.Columns, HasLen, 1) + c.Assert(columns.Columns[0].Name.L, Equals, "g") } diff --git a/syncer/syncer.go b/syncer/syncer.go index 8c2963eade..d294e3aed7 100644 --- a/syncer/syncer.go +++ b/syncer/syncer.go @@ -23,9 +23,12 @@ import ( "sync" "time" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" + "github.com/pingcap/parser/model" bf "github.com/pingcap/tidb-tools/pkg/binlog-filter" cm "github.com/pingcap/tidb-tools/pkg/column-mapping" "github.com/pingcap/tidb-tools/pkg/dbutil" @@ -45,6 +48,7 @@ import ( fr "github.com/pingcap/dm/pkg/func-rollback" "github.com/pingcap/dm/pkg/gtid" "github.com/pingcap/dm/pkg/log" + "github.com/pingcap/dm/pkg/schema" "github.com/pingcap/dm/pkg/streamer" "github.com/pingcap/dm/pkg/terror" "github.com/pingcap/dm/pkg/tracing" @@ -143,9 +147,7 @@ type Syncer struct { wg sync.WaitGroup jobWg sync.WaitGroup - tables map[string]*table // table cache: `target-schema`.`target-table` -> table - cacheColumns map[string][]string // table columns cache: `target-schema`.`target-table` -> column names list - genColsCache *GenColCache + schemaTracker *schema.Tracker fromDB *UpStreamConn @@ -226,9 +228,6 @@ func NewSyncer(cfg *config.SubTaskConfig) *Syncer { syncer.binlogSizeCount.Set(0) syncer.lastCount.Set(0) syncer.count.Set(0) - syncer.tables = make(map[string]*table) - syncer.cacheColumns = make(map[string][]string) - syncer.genColsCache = NewGenColCache() syncer.c = newCausality() syncer.done = nil syncer.injectEventCh = make(chan *replication.BinlogEvent) @@ -266,6 +265,12 @@ func NewSyncer(cfg *config.SubTaskConfig) *Syncer { syncer.ddlExecInfo = NewDDLExecInfo() } + var err error + syncer.schemaTracker, err = schema.NewTracker() + if err != nil { + syncer.tctx.L().DPanic("cannot create schema tracker", zap.Error(err)) + } + return syncer } @@ -380,7 +385,7 @@ func (s *Syncer) Init() (err error) { s.tctx.L().Info("all previous meta cleared") } - err = s.checkpoint.Load() + err = s.checkpoint.Load(s.schemaTracker) if err != nil { return err } @@ -497,8 +502,6 @@ func (s *Syncer) reset() { s.resetReplicationSyncer() // create new job chans s.newJobChans(s.cfg.WorkerCount + 1) - // clear tables info - s.clearAllTables() s.execErrorDetected.Set(false) s.resetExecErrors() @@ -621,7 +624,7 @@ func (s *Syncer) Process(ctx context.Context, pr chan pb.ProcessResult) { // try to rollback checkpoints, if they already flushed, no effect prePos := s.checkpoint.GlobalPoint() - s.checkpoint.Rollback() + s.checkpoint.Rollback(s.schemaTracker) currPos := s.checkpoint.GlobalPoint() if prePos.Compare(currPos) != 0 { s.tctx.L().Warn("something wrong with rollback global checkpoint", zap.Stringer("previous position", prePos), zap.Stringer("current position", currPos)) @@ -637,66 +640,60 @@ func (s *Syncer) getMasterStatus() (mysql.Position, gtid.Set, error) { return s.fromDB.getMasterStatus(s.cfg.Flavor) } -// clearTables is used for clear table cache of given table. this function must -// be called when DDL is applied to this table. -func (s *Syncer) clearTables(schema, table string) { - key := dbutil.TableName(schema, table) - delete(s.tables, key) - delete(s.cacheColumns, key) - s.genColsCache.clearTable(schema, table) -} - -func (s *Syncer) clearAllTables() { - s.tables = make(map[string]*table) - s.cacheColumns = make(map[string][]string) - s.genColsCache.reset() -} - -func (s *Syncer) getTableFromDB(db *DBConn, schema string, name string) (*table, error) { - table := &table{} - table.schema = schema - table.name = name - table.indexColumns = make(map[string][]*column) +func (s *Syncer) getTable(origSchema, origTable, renamedSchema, renamedTable string, p *parser.Parser) (*model.TableInfo, error) { + ti, err := s.schemaTracker.GetTable(origSchema, origTable) + if err == nil || !schema.IsTableNotExists(err) { + return ti, err + } - err := getTableColumns(s.tctx, db, table) - if err != nil { + ctx := context.Background() + if err := s.schemaTracker.CreateSchemaIfNotExists(origSchema); err != nil { return nil, err } - err = getTableIndex(s.tctx, db, table) + // TODO: Switch to use the HTTP interface to retrieve the TableInfo directly + // (and get rid of ddlDBConn). + rows, err := s.ddlDBConn.querySQL(s.tctx, "SHOW CREATE TABLE "+dbutil.TableName(renamedSchema, renamedTable)) if err != nil { return nil, err } + defer rows.Close() - if len(table.columns) == 0 { - return nil, terror.ErrSyncerUnitGetTableFromDB.Generate(schema, name) - } - - return table, nil -} - -func (s *Syncer) getTable(schema string, table string) (*table, []string, error) { - key := dbutil.TableName(schema, table) - - value, ok := s.tables[key] - if ok { - return value, s.cacheColumns[key], nil - } + for rows.Next() { + var tableName, createSQL string + if err := rows.Scan(&tableName, &createSQL); err != nil { + return nil, errors.Trace(err) + } - t, err := s.getTableFromDB(s.ddlDBConn, schema, table) - if err != nil { - return nil, nil, err - } + // rename the table back to original. + createNode, err := p.ParseOneStmt(createSQL, "", "") + if err != nil { + return nil, err + } + createStmt := createNode.(*ast.CreateTableStmt) + createStmt.IfNotExists = true + createStmt.Table.Schema = model.NewCIStr(origSchema) + createStmt.Table.Name = model.NewCIStr(origTable) - // compute cache column list for column mapping - columns := make([]string, 0, len(t.columns)) - for _, c := range t.columns { - columns = append(columns, c.name) + var newCreateSQLBuilder strings.Builder + restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &newCreateSQLBuilder) + if err := createStmt.Restore(restoreCtx); err != nil { + return nil, err + } + newCreateSQL := newCreateSQLBuilder.String() + s.tctx.L().Debug("reverse-synchronized table schema", + zap.String("origSchema", origSchema), + zap.String("origTable", origTable), + zap.String("renamedSchema", renamedSchema), + zap.String("renamedTable", renamedTable), + zap.String("sql", newCreateSQL), + ) + if err := s.schemaTracker.Exec(ctx, origSchema, newCreateSQL); err != nil { + return nil, err + } } - s.tables[key] = t - s.cacheColumns[key] = columns - return t, columns, nil + return s.schemaTracker.GetTable(origSchema, origTable) } func (s *Syncer) addCount(isFinished bool, queueBucket string, tp opType, n int64) { @@ -738,6 +735,18 @@ func (s *Syncer) checkWait(job *job) bool { return false } +func (s *Syncer) saveTablePoint(db, table string, pos mysql.Position) { + ti, err := s.schemaTracker.GetTable(db, table) + if err != nil { + s.tctx.L().DPanic("table info missing from schema tracker", + zap.String("schema", db), + zap.String("table", table), + zap.Stringer("pos", pos), + zap.Error(err)) + } + s.checkpoint.SaveTablePoint(db, table, pos, ti) +} + func (s *Syncer) addJob(job *job) error { var ( queueBucket int @@ -791,14 +800,14 @@ func (s *Syncer) addJob(job *job) error { // only save checkpoint for DDL and XID (see above) s.saveGlobalPoint(job.pos) if len(job.sourceSchema) > 0 { - s.checkpoint.SaveTablePoint(job.sourceSchema, job.sourceTable, job.pos) + s.saveTablePoint(job.sourceSchema, job.sourceTable, job.pos) } // reset sharding group after checkpoint saved s.resetShardingGroup(job.targetSchema, job.targetTable) case insert, update, del: // save job's current pos for DML events if len(job.sourceSchema) > 0 { - s.checkpoint.SaveTablePoint(job.sourceSchema, job.sourceTable, job.currentPos) + s.saveTablePoint(job.sourceSchema, job.sourceTable, job.currentPos) } } @@ -1443,15 +1452,15 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) err } } - table, columns, err := s.getTable(schemaName, tableName) + ti, err := s.getTable(originSchema, originTable, schemaName, tableName, ec.parser2) if err != nil { return terror.WithScope(err, terror.ScopeDownstream) } - rows, err := s.mappingDML(originSchema, originTable, columns, ev.Rows) + rows, err := s.mappingDML(originSchema, originTable, ti, ev.Rows) if err != nil { return err } - prunedColumns, prunedRows, err := pruneGeneratedColumnDML(table.columns, rows, schemaName, tableName, s.genColsCache) + prunedColumns, prunedRows, err := pruneGeneratedColumnDML(ti, rows) if err != nil { return err } @@ -1471,13 +1480,12 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) err return err } param := &genDMLParam{ - schema: table.schema, - table: table.name, - data: prunedRows, - originalData: rows, - columns: prunedColumns, - originalColumns: table.columns, - originalIndexColumns: table.indexColumns, + schema: schemaName, + table: tableName, + data: prunedRows, + originalData: rows, + columns: prunedColumns, + originalTableInfo: ti, } switch ec.header.EventType { @@ -1486,7 +1494,7 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) err param.safeMode = ec.safeMode.Enable() sqls, keys, args, err = genInsertSQLs(param) if err != nil { - return terror.Annotatef(err, "gen insert sqls failed, schema: %s, table: %s", table.schema, table.name) + return terror.Annotatef(err, "gen insert sqls failed, schema: %s, table: %s", schemaName, tableName) } } binlogEvent.WithLabelValues("write_rows", s.cfg.Name).Observe(time.Since(ec.startTime).Seconds()) @@ -1497,7 +1505,7 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) err param.safeMode = ec.safeMode.Enable() sqls, keys, args, err = genUpdateSQLs(param) if err != nil { - return terror.Annotatef(err, "gen update sqls failed, schema: %s, table: %s", table.schema, table.name) + return terror.Annotatef(err, "gen update sqls failed, schema: %s, table: %s", schemaName, tableName) } } binlogEvent.WithLabelValues("update_rows", s.cfg.Name).Observe(time.Since(ec.startTime).Seconds()) @@ -1507,7 +1515,7 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) err if !applied { sqls, keys, args, err = genDeleteSQLs(param) if err != nil { - return terror.Annotatef(err, "gen delete sqls failed, schema: %s, table: %s", table.schema, table.name) + return terror.Annotatef(err, "gen delete sqls failed, schema: %s, table: %s", schemaName, tableName) } } binlogEvent.WithLabelValues("delete_rows", s.cfg.Name).Observe(time.Since(ec.startTime).Seconds()) @@ -1535,7 +1543,7 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) err if keys != nil { key = keys[i] } - err = s.commitJob(*ec.latestOp, originSchema, originTable, table.schema, table.name, sqls[i], arg, key, true, *ec.lastPos, *ec.currentPos, nil, *ec.traceID) + err = s.commitJob(*ec.latestOp, originSchema, originTable, schemaName, tableName, sqls[i], arg, key, true, *ec.lastPos, *ec.currentPos, nil, *ec.traceID) if err != nil { return err } @@ -1621,9 +1629,15 @@ func (s *Syncer) handleQueryEvent(ev *replication.QueryEvent, ec eventContext) e * online ddl: we would ignore rename ghost table, make no difference * other rename: we don't allow user to execute more than one rename operation in one ddl event, then it would make no difference */ + type trackedDDL struct { + rawSQL string + stmt ast.StmtNode + tableNames [][]*filter.Table + } var ( ddlInfo *shardingDDLInfo needHandleDDLs []string + needTrackDDLs []trackedDDL targetTbls = make(map[string]*filter.Table) ) for _, sql := range sqls { @@ -1683,6 +1697,7 @@ func (s *Syncer) handleQueryEvent(ev *replication.QueryEvent, ec eventContext) e } needHandleDDLs = append(needHandleDDLs, sqlDDL) + needTrackDDLs = append(needTrackDDLs, trackedDDL{rawSQL: sql, stmt: stmt, tableNames: tableNames}) targetTbls[tableNames[1][0].String()] = tableNames[1][0] } @@ -1718,10 +1733,14 @@ func (s *Syncer) handleQueryEvent(ev *replication.QueryEvent, ec eventContext) e } s.tctx.L().Info("finish to handle ddls in normal mode", zap.String("event", "query"), zap.Strings("ddls", needHandleDDLs), zap.ByteString("raw statement", ev.Query), log.WrapStringerField("position", ec.currentPos)) + for _, td := range needTrackDDLs { + if err := s.trackDDL(usedSchema, td.rawSQL, td.tableNames, td.stmt, &ec); err != nil { + return err + } + } for _, tbl := range targetTbls { - s.clearTables(tbl.Schema, tbl.Name) // save checkpoint of each table - s.checkpoint.SaveTablePoint(tbl.Schema, tbl.Name, *ec.currentPos) + s.saveTablePoint(tbl.Schema, tbl.Name, *ec.currentPos) } for _, table := range onlineDDLTableNames { @@ -1787,11 +1806,17 @@ func (s *Syncer) handleQueryEvent(ev *replication.QueryEvent, ec eventContext) e return err } + for _, td := range needTrackDDLs { + if err := s.trackDDL(usedSchema, td.rawSQL, td.tableNames, td.stmt, &ec); err != nil { + return err + } + } + // save checkpoint in memory, don't worry, if error occurred, we can rollback it // for non-last sharding DDL's table, this checkpoint will be used to skip binlog event when re-syncing // NOTE: when last sharding DDL executed, all this checkpoints will be flushed in the same txn s.tctx.L().Info("save table checkpoint for source", zap.String("event", "query"), zap.String("source", source), zap.Stringer("start position", startPos), log.WrapStringerField("end position", ec.currentPos)) - s.checkpoint.SaveTablePoint(ddlInfo.tableNames[0][0].Schema, ddlInfo.tableNames[0][0].Name, *ec.currentPos) + s.saveTablePoint(ddlInfo.tableNames[0][0].Schema, ddlInfo.tableNames[0][0].Name, *ec.currentPos) if !synced { s.tctx.L().Info("source shard group is not synced", zap.String("event", "query"), zap.String("source", source), zap.Stringer("start position", startPos), log.WrapStringerField("end position", ec.currentPos)) return nil @@ -1906,8 +1931,67 @@ func (s *Syncer) handleQueryEvent(ev *replication.QueryEvent, ec eventContext) e } s.tctx.L().Info("finish to handle ddls in shard mode", zap.String("event", "query"), zap.Strings("ddls", needHandleDDLs), zap.ByteString("raw statement", ev.Query), zap.Stringer("start position", startPos), log.WrapStringerField("end position", ec.currentPos)) + return nil +} + +func (s *Syncer) trackDDL(usedSchema string, sql string, tableNames [][]*filter.Table, stmt ast.StmtNode, ec *eventContext) error { + srcTable := tableNames[0][0] + + // Make sure the tables are all loaded into the schema tracker. + var shouldExecDDLOnSchemaTracker, shouldSchemaExist, shouldTableExist bool + switch stmt.(type) { + case *ast.CreateDatabaseStmt: + shouldExecDDLOnSchemaTracker = true + case *ast.AlterDatabaseStmt: + shouldExecDDLOnSchemaTracker = true + shouldSchemaExist = true + case *ast.DropDatabaseStmt: + shouldExecDDLOnSchemaTracker = true + shouldSchemaExist = true + if !s.cfg.IsSharding { + if err := s.checkpoint.DeleteSchemaPoint(srcTable.Schema); err != nil { + return err + } + } + case *ast.CreateTableStmt, *ast.CreateViewStmt, *ast.RecoverTableStmt: + shouldExecDDLOnSchemaTracker = true + shouldSchemaExist = true + case *ast.DropTableStmt: + shouldExecDDLOnSchemaTracker = true + shouldSchemaExist = true + shouldTableExist = true + if err := s.checkpoint.DeleteTablePoint(srcTable.Schema, srcTable.Name); err != nil { + return err + } + case *ast.RenameTableStmt, *ast.CreateIndexStmt, *ast.DropIndexStmt, *ast.RepairTableStmt, *ast.AlterTableStmt: + // TODO: RENAME TABLE / ALTER TABLE RENAME should require special treatment. + shouldExecDDLOnSchemaTracker = true + shouldSchemaExist = true + shouldTableExist = true + case *ast.LockTablesStmt, *ast.UnlockTablesStmt, *ast.CleanupTableLockStmt, *ast.TruncateTableStmt: + break + default: + s.tctx.L().DPanic("unhandled DDL type cannot be tracked", zap.Stringer("type", reflect.TypeOf(stmt))) + } + + if shouldSchemaExist { + if err := s.schemaTracker.CreateSchemaIfNotExists(srcTable.Schema); err != nil { + return err + } + } + if shouldTableExist { + targetTable := tableNames[1][0] + if _, err := s.getTable(srcTable.Schema, srcTable.Name, targetTable.Schema, targetTable.Name, ec.parser2); err != nil { + return err + } + } + if shouldExecDDLOnSchemaTracker { + if err := s.schemaTracker.Exec(s.tctx.Ctx, usedSchema, sql); err != nil { + s.tctx.L().Error("cannot track DDL", zap.String("schema", usedSchema), zap.String("statement", sql), log.WrapStringerField("position", ec.currentPos), log.ShortError(err)) + return errors.Annotatef(err, "cannot track DDL: %s", sql) + } + } - s.clearTables(ddlInfo.tableNames[1][0].Schema, ddlInfo.tableNames[1][0].Name) return nil } diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index ce7c58d779..3dafda9fae 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -217,22 +217,24 @@ func (s *testSyncerSuite) mockParser(db *sql.DB, mock sqlmock.Sqlmock) (*parser. return utils.GetParser(db, false) } -func (s *testSyncerSuite) mockCheckPointCreate(checkPointMock sqlmock.Sqlmock) { +func (s *testSyncerSuite) mockCheckPointCreate(checkPointMock sqlmock.Sqlmock, tag string) { checkPointMock.ExpectBegin() - checkPointMock.ExpectExec(fmt.Sprintf("INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) - checkPointMock.ExpectExec(fmt.Sprintf("INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) + // we encode the line number to make it easier to figure out which expectation has failed. + checkPointMock.ExpectExec(fmt.Sprintf("(223:"+tag+")?INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) + checkPointMock.ExpectExec(fmt.Sprintf("(224:"+tag+")?INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) // TODO because shardGroup DB is same as checkpoint DB, next time split them is better - checkPointMock.ExpectExec(fmt.Sprintf("DELETE FROM `%s`.`%s_syncer_sharding_meta", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) + checkPointMock.ExpectExec(fmt.Sprintf("(226:"+tag+")?DELETE FROM `%s`.`%s_syncer_sharding_meta(228)?", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) checkPointMock.ExpectCommit() } -func (s *testSyncerSuite) mockCheckPointFlush(checkPointMock sqlmock.Sqlmock) { +func (s *testSyncerSuite) mockCheckPointFlush(checkPointMock sqlmock.Sqlmock, tagInt int) { + tag := fmt.Sprintf("%d", tagInt) checkPointMock.ExpectBegin() - checkPointMock.ExpectExec(fmt.Sprintf("INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) - checkPointMock.ExpectExec(fmt.Sprintf("INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) - checkPointMock.ExpectExec(fmt.Sprintf("INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) + checkPointMock.ExpectExec(fmt.Sprintf("(242:"+tag+")?INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) + checkPointMock.ExpectExec(fmt.Sprintf("(243:"+tag+")?INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) + checkPointMock.ExpectExec(fmt.Sprintf("(244:"+tag+")?INSERT INTO `%s`.`%s_syncer_checkpoint`", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) // TODO because shardGroup DB is same as checkpoint DB, next time split them is better - checkPointMock.ExpectExec(fmt.Sprintf("DELETE FROM `%s`.`%s_syncer_sharding_meta", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) + checkPointMock.ExpectExec(fmt.Sprintf("(246:"+tag+")?DELETE FROM `%s`.`%s_syncer_sharding_meta(239)?", s.cfg.MetaSchema, s.cfg.Name)).WillReturnResult(sqlmock.NewResult(1, 1)) checkPointMock.ExpectCommit() } @@ -794,11 +796,13 @@ func (s *testSyncerSuite) TestColumnMapping(c *C) { func (s *testSyncerSuite) TestGeneratedColumn(c *C) { // TODO Currently mock eventGenerator don't support generate json,varchar field event, so use real mysql binlog event here - dbAddr := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8", s.cfg.From.User, s.cfg.From.Password, s.cfg.From.Host, s.cfg.From.Port) + dbAddr := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4", s.cfg.From.User, s.cfg.From.Password, s.cfg.From.Host, s.cfg.From.Port) db, err := sql.Open("mysql", dbAddr) if err != nil { c.Fatal(err) } + p, err := utils.GetParser(db, false) + c.Assert(err, IsNil) _, err = db.Exec("SET GLOBAL binlog_format = 'ROW';") c.Assert(err, IsNil) @@ -967,23 +971,24 @@ func (s *testSyncerSuite) TestGeneratedColumn(c *C) { c.Assert(err, IsNil) switch ev := e.Event.(type) { case *replication.RowsEvent: - table, _, err := syncer.getTable(string(ev.Table.Schema), string(ev.Table.Table)) + schemaName := string(ev.Table.Schema) + tableName := string(ev.Table.Table) + ti, err := syncer.getTable(schemaName, tableName, schemaName, tableName, p) c.Assert(err, IsNil) var ( sqls []string args [][]interface{} ) - prunedColumns, prunedRows, err := pruneGeneratedColumnDML(table.columns, ev.Rows, table.schema, table.name, syncer.genColsCache) + prunedColumns, prunedRows, err := pruneGeneratedColumnDML(ti, ev.Rows) c.Assert(err, IsNil) param := &genDMLParam{ - schema: table.schema, - table: table.name, - data: prunedRows, - originalData: ev.Rows, - columns: prunedColumns, - originalColumns: table.columns, - originalIndexColumns: table.indexColumns, + schema: schemaName, + table: tableName, + data: prunedRows, + originalData: ev.Rows, + columns: prunedColumns, + originalTableInfo: ti, } switch e.Header.EventType { case replication.WRITE_ROWS_EVENTv0, replication.WRITE_ROWS_EVENTv1, replication.WRITE_ROWS_EVENTv2: @@ -1259,9 +1264,9 @@ func (s *testSyncerSuite) TestSharding(c *C) { AddRow("sql_mode", "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION")) // mock checkpoint db after create db table1 table2 - s.mockCheckPointCreate(checkPointMock) - s.mockCheckPointCreate(checkPointMock) - s.mockCheckPointCreate(checkPointMock) + s.mockCheckPointCreate(checkPointMock, "db") + s.mockCheckPointCreate(checkPointMock, "table1") + s.mockCheckPointCreate(checkPointMock, "table2") // mock downstream db result mock.ExpectBegin() @@ -1275,13 +1280,11 @@ func (s *testSyncerSuite) TestSharding(c *C) { mock.ExpectExec("CREATE TABLE").WillReturnError(e) mock.ExpectCommit() - // mock get table in first handle RowEvent - mock.ExpectQuery("SHOW COLUMNS").WillReturnRows( - sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).AddRow("id", "int", "NO", "PRI", null, "").AddRow("age", "int", "NO", "", null, "")) - mock.ExpectQuery("SHOW INDEX").WillReturnRows( - sqlmock.NewRows([]string{"Table", "Non_unique", "Key_name", "Seq_in_index", "Column_name", - "Collation", "Cardinality", "Sub_part", "Packed", "Null", "Index_type", "Comment", "Index_comment"}, - ).AddRow("st", 0, "PRIMARY", 1, "id", "A", 0, null, null, null, "BTREE", "", "")) + // mock fetching table schema from downstream + mock.ExpectQuery("SHOW CREATE TABLE `stest`.`st`"). + WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow("-", "create table st(id int, age int)")) + mock.ExpectQuery("SHOW CREATE TABLE `stest`.`st`"). + WillReturnRows(sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow("-", "create table st(id int, age int)")) // mock expect sql for i, expectSQL := range _case.expectSQLS { @@ -1289,14 +1292,7 @@ func (s *testSyncerSuite) TestSharding(c *C) { if strings.HasPrefix(expectSQL.sql, "ALTER") { mock.ExpectExec(expectSQL.sql).WillReturnResult(sqlmock.NewResult(1, int64(i)+1)) mock.ExpectCommit() - // mock get table after ddl sql exec - mock.ExpectQuery("SHOW COLUMNS").WillReturnRows( - sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).AddRow("id", "int", "NO", "PRI", null, "").AddRow("age", "int", "NO", "", null, "").AddRow("name", "varchar", "NO", "", null, "")) - mock.ExpectQuery("SHOW INDEX").WillReturnRows( - sqlmock.NewRows([]string{"Table", "Non_unique", "Key_name", "Seq_in_index", "Column_name", - "Collation", "Cardinality", "Sub_part", "Packed", "Null", "Index_type", "Comment", "Index_comment"}, - ).AddRow("st", 0, "PRIMARY", 1, "id", "A", 0, null, null, null, "BTREE", "", "")) - s.mockCheckPointFlush(checkPointMock) + s.mockCheckPointFlush(checkPointMock, i) } else { // change insert to replace because of safe mode mock.ExpectExec(expectSQL.sql).WithArgs(expectSQL.args...).WillReturnResult(sqlmock.NewResult(1, 1)) @@ -1306,7 +1302,7 @@ func (s *testSyncerSuite) TestSharding(c *C) { ctx, cancel := context.WithCancel(context.Background()) resultCh := make(chan pb.ProcessResult) - s.mockCheckPointFlush(checkPointMock) + s.mockCheckPointFlush(checkPointMock, -1) go syncer.Process(ctx, resultCh) @@ -1440,22 +1436,6 @@ func (s *testSyncerSuite) TestRun(c *C) { WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}). AddRow("sql_mode", "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION")) - // mock get table schema at handling first row event - mock.ExpectQuery("SHOW COLUMNS").WillReturnRows( - sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).AddRow("id", "int", "NO", "PRI", null, "").AddRow("name", "varchar", "NO", "", null, "")) - mock.ExpectQuery("SHOW INDEX").WillReturnRows( - sqlmock.NewRows([]string{"Table", "Non_unique", "Key_name", "Seq_in_index", "Column_name", - "Collation", "Cardinality", "Sub_part", "Packed", "Null", "Index_type", "Comment", "Index_comment"}, - ).AddRow("t_1", 0, "PRIMARY", 1, "id", "A", 0, null, null, null, "BTREE", "", "")) - - // mock get table schema after handle first query event - mock.ExpectQuery("SHOW COLUMNS").WillReturnRows( - sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).AddRow("id", "int", "NO", "PRI", null, "").AddRow("name", "varchar", "NO", "", null, "")) - mock.ExpectQuery("SHOW INDEX").WillReturnRows( - sqlmock.NewRows([]string{"Table", "Non_unique", "Key_name", "Seq_in_index", "Column_name", - "Collation", "Cardinality", "Sub_part", "Packed", "Null", "Index_type", "Comment", "Index_comment"}, - ).AddRow("t_1", 0, "PRIMARY", 1, "id", "A", 0, null, null, null, "BTREE", "", "").AddRow("t_1", 1, "index1", 1, "name", "A", 0, null, null, "YES", "BTREE", "", "")) - go syncer.Process(ctx, resultCh) expectJobs1 := []*expectJob{ @@ -1542,14 +1522,6 @@ func (s *testSyncerSuite) TestRun(c *C) { WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}). AddRow("sql_mode", "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION")) - mock.ExpectQuery("SHOW COLUMNS").WillReturnRows( - sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}).AddRow("id", "int", "NO", "PRI", null, "").AddRow("name", "varchar", "NO", "", null, "")) - - mock.ExpectQuery("SHOW INDEX").WillReturnRows( - sqlmock.NewRows([]string{"Table", "Non_unique", "Key_name", "Seq_in_index", "Column_name", - "Collation", "Cardinality", "Sub_part", "Packed", "Null", "Index_type", "Comment", "Index_comment"}, - ).AddRow("t_2", 0, "PRIMARY", 1, "id", "A", 0, null, null, null, "BTREE", "", "").AddRow("t_2", 1, "index1", 1, "name", "A", 0, null, null, "YES", "BTREE", "", "")) - ctx, cancel = context.WithCancel(context.Background()) resultCh = make(chan pb.ProcessResult) // simulate `syncer.Resume` here, but doesn't reset database conns @@ -1615,7 +1587,7 @@ type expectJob struct { } func checkJobs(c *C, jobs []*job, expectJobs []*expectJob) { - c.Assert(jobs, HasLen, len(expectJobs)) + c.Assert(len(jobs), Equals, len(expectJobs), Commentf("jobs = %q", jobs)) for i, job := range jobs { c.Log(i, job.tp, job.ddls, job.sql, job.args) diff --git a/tests/_utils/check_log_contains b/tests/_utils/check_log_contains index d221505454..226878c47e 100755 --- a/tests/_utils/check_log_contains +++ b/tests/_utils/check_log_contains @@ -13,7 +13,7 @@ fi got=`grep "$text" $log | wc -l` if [ $num -eq 0 ]; then - if [[ $got -eq 0 ]]; then + if [ $got -eq 0 ]; then cat $log echo "$log dosen't contain $text" exit 1