diff --git a/lightning/backend/importer.go b/lightning/backend/importer.go index ca7624377..0b9ec80d0 100644 --- a/lightning/backend/importer.go +++ b/lightning/backend/importer.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/errors" kv "github.com/pingcap/kvproto/pkg/import_kvpb" "github.com/pingcap/parser/model" + "github.com/pingcap/tidb-lightning/lightning/glue" "github.com/pingcap/tidb/table" "go.uber.org/zap" "google.golang.org/grpc" @@ -267,6 +268,22 @@ func checkTiDBVersion(tls *common.TLS, requiredVersion semver.Version) error { return checkVersion("TiDB", requiredVersion, *version) } +func checkTiDBVersionBySQL(g glue.Glue, requiredVersion semver.Version) error { + versionStr, err := g.GetSQLExecutor().ObtainStringWithLog( + context.Background(), + "SELECT version();", + "check TiDB version", + log.L()) + if err != nil { + return errors.Trace(err) + } + version, err := common.ExtractTiDBVersion(versionStr) + if err != nil { + return errors.Trace(err) + } + return checkVersion("TiDB", requiredVersion, *version) +} + func checkPDVersion(tls *common.TLS, pdAddr string, requiredVersion semver.Version) error { version, err := common.FetchPDVersion(tls, pdAddr) if err != nil { diff --git a/lightning/backend/local.go b/lightning/backend/local.go index 0636660dd..c162c37e4 100644 --- a/lightning/backend/local.go +++ b/lightning/backend/local.go @@ -41,6 +41,7 @@ import ( "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/parser/model" + "github.com/pingcap/tidb-lightning/lightning/glue" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/hack" @@ -167,6 +168,7 @@ type local struct { splitCli split.SplitClient tls *common.TLS pdAddr string + g glue.Glue localStoreDir string regionSplitSize int64 @@ -248,6 +250,7 @@ func NewLocalBackend( rangeConcurrency int, sendKVPairs int, enableCheckpoint bool, + g glue.Glue, ) (Backend, error) { pdCli, err := pd.NewClient([]string{pdAddr}, tls.ToPDSecurityOption()) if err != nil { @@ -278,6 +281,7 @@ func NewLocalBackend( splitCli: splitCli, tls: tls, pdAddr: pdAddr, + g: g, localStoreDir: localFile, regionSplitSize: regionSplitSize, @@ -1170,7 +1174,7 @@ func (local *local) CleanupEngine(ctx context.Context, engineUUID uuid.UUID) err } func (local *local) CheckRequirements() error { - if err := checkTiDBVersion(local.tls, localMinTiDBVersion); err != nil { + if err := checkTiDBVersionBySQL(local.g, localMinTiDBVersion); err != nil { return err } if err := checkPDVersion(local.tls, local.pdAddr, localMinPDVersion); err != nil { diff --git a/lightning/checkpoints/checkpoints.go b/lightning/checkpoints/checkpoints.go index 03b4d7985..4b8f1d896 100644 --- a/lightning/checkpoints/checkpoints.go +++ b/lightning/checkpoints/checkpoints.go @@ -67,6 +67,113 @@ const ( CheckpointTableNameChunk = "chunk_v5" ) +const ( + // shared by MySQLCheckpointsDB and GlueCheckpointsDB + CreateDBTemplate = "CREATE DATABASE IF NOT EXISTS %s;" + CreateTaskTableTemplate = ` + CREATE TABLE IF NOT EXISTS %s.%s ( + id tinyint(1) PRIMARY KEY, + task_id bigint NOT NULL, + source_dir varchar(256) NOT NULL, + backend varchar(16) NOT NULL, + importer_addr varchar(256), + tidb_host varchar(128) NOT NULL, + tidb_port int NOT NULL, + pd_addr varchar(128) NOT NULL, + sorted_kv_dir varchar(256) NOT NULL, + lightning_ver varchar(48) NOT NULL + );` + CreateTableTableTemplate = ` + CREATE TABLE IF NOT EXISTS %s.%s ( + task_id bigint NOT NULL, + table_name varchar(261) NOT NULL PRIMARY KEY, + hash binary(32) NOT NULL, + status tinyint unsigned DEFAULT 30, + alloc_base bigint NOT NULL DEFAULT 0, + table_id bigint NOT NULL DEFAULT 0, + create_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + INDEX(task_id) + );` + CreateEngineTableTemplate = ` + CREATE TABLE IF NOT EXISTS %s.%s ( + table_name varchar(261) NOT NULL, + engine_id int NOT NULL, + status tinyint unsigned DEFAULT 30, + create_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY(table_name, engine_id DESC) + );` + CreateChunkTableTemplate = ` + CREATE TABLE IF NOT EXISTS %s.%s ( + table_name varchar(261) NOT NULL, + engine_id int unsigned NOT NULL, + path varchar(2048) NOT NULL, + offset bigint NOT NULL, + type int NOT NULL, + compression int NOT NULL, + sort_key varchar(256) NOT NULL, + columns text NULL, + should_include_row_id BOOL NOT NULL, + end_offset bigint NOT NULL, + pos bigint NOT NULL, + prev_rowid_max bigint NOT NULL, + rowid_max bigint NOT NULL, + kvc_bytes bigint unsigned NOT NULL DEFAULT 0, + kvc_kvs bigint unsigned NOT NULL DEFAULT 0, + kvc_checksum bigint unsigned NOT NULL DEFAULT 0, + create_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY(table_name, engine_id, path(500), offset) + );` + InitTaskTemplate = ` + REPLACE INTO %s.%s (id, task_id, source_dir, backend, importer_addr, tidb_host, tidb_port, pd_addr, sorted_kv_dir, lightning_ver) + VALUES (1, ?, ?, ?, ?, ?, ?, ?, ?, ?);` + InitTableTemplate = ` + INSERT INTO %s.%s (task_id, table_name, hash, table_id) VALUES (?, ?, ?, ?) + ON DUPLICATE KEY UPDATE task_id = CASE + WHEN hash = VALUES(hash) + THEN VALUES(task_id) + END;` + ReadTaskTemplate = ` + SELECT task_id, source_dir, backend, importer_addr, tidb_host, tidb_port, pd_addr, sorted_kv_dir, lightning_ver FROM %s.%s WHERE id = 1;` + ReadEngineTemplate = ` + SELECT engine_id, status FROM %s.%s WHERE table_name = ? ORDER BY engine_id DESC;` + ReadChunkTemplate = ` + SELECT + engine_id, path, offset, type, compression, sort_key, columns, + pos, end_offset, prev_rowid_max, rowid_max, + kvc_bytes, kvc_kvs, kvc_checksum, unix_timestamp(create_time) + FROM %s.%s WHERE table_name = ? + ORDER BY engine_id, path, offset;` + ReadTableRemainTemplate = ` + SELECT status, alloc_base, table_id FROM %s.%s WHERE table_name = ?;` + ReplaceEngineTemplate = ` + REPLACE INTO %s.%s (table_name, engine_id, status) VALUES (?, ?, ?);` + ReplaceChunkTemplate = ` + REPLACE INTO %s.%s ( + table_name, engine_id, + path, offset, type, compression, sort_key, columns, should_include_row_id, + pos, end_offset, prev_rowid_max, rowid_max, + kvc_bytes, kvc_kvs, kvc_checksum, create_time + ) VALUES ( + ?, ?, + ?, ?, ?, ?, ?, ?, FALSE, + ?, ?, ?, ?, + 0, 0, 0, from_unixtime(?) + );` + UpdateChunkTemplate = ` + UPDATE %s.%s SET pos = ?, prev_rowid_max = ?, kvc_bytes = ?, kvc_kvs = ?, kvc_checksum = ?, columns = ? + WHERE (table_name, engine_id, path, offset) = (?, ?, ?, ?);` + UpdateTableRebaseTemplate = ` + UPDATE %s.%s SET alloc_base = GREATEST(?, alloc_base) WHERE table_name = ?;` + UpdateTableStatusTemplate = ` + UPDATE %s.%s SET status = ? WHERE table_name = ?;` + UpdateEngineTemplate = ` + UPDATE %s.%s SET status = ? WHERE (table_name, engine_id) = (?, ?);` + DeleteCheckpointRecordTemplate = "DELETE FROM %s.%s WHERE table_name = ?;" +) + func IsCheckpointTable(name string) bool { switch name { case CheckpointTableNameTask, CheckpointTableNameTable, CheckpointTableNameEngine, CheckpointTableNameChunk: @@ -382,7 +489,7 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (CheckpointsDB, if err != nil { return nil, errors.Trace(err) } - cpdb, err := NewMySQLCheckpointsDB(ctx, db, cfg.Checkpoint.Schema, cfg.TaskID) + cpdb, err := NewMySQLCheckpointsDB(ctx, db, cfg.Checkpoint.Schema) if err != nil { db.Close() return nil, errors.Trace(err) @@ -432,10 +539,9 @@ func (*NullCheckpointsDB) Update(map[string]*TableCheckpointDiff) {} type MySQLCheckpointsDB struct { db *sql.DB schema string - taskID int64 } -func NewMySQLCheckpointsDB(ctx context.Context, db *sql.DB, schemaName string, taskID int64) (*MySQLCheckpointsDB, error) { +func NewMySQLCheckpointsDB(ctx context.Context, db *sql.DB, schemaName string) (*MySQLCheckpointsDB, error) { var escapedSchemaName strings.Builder common.WriteMySQLIdentifier(&escapedSchemaName, schemaName) schema := escapedSchemaName.String() @@ -445,85 +551,27 @@ func NewMySQLCheckpointsDB(ctx context.Context, db *sql.DB, schemaName string, t Logger: log.With(zap.String("schema", schemaName)), HideQueryLog: true, } - err := sql.Exec(ctx, "create checkpoints database", fmt.Sprintf(` - CREATE DATABASE IF NOT EXISTS %s; - `, schema)) + err := sql.Exec(ctx, "create checkpoints database", fmt.Sprintf(CreateDBTemplate, schema)) if err != nil { return nil, errors.Trace(err) } - err = sql.Exec(ctx, "create task checkpoints table", fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s.%s ( - id tinyint(1) PRIMARY KEY, - task_id bigint NOT NULL, - source_dir varchar(256) NOT NULL, - backend varchar(16) NOT NULL, - importer_addr varchar(256), - tidb_host varchar(128) NOT NULL, - tidb_port int NOT NULL, - pd_addr varchar(128) NOT NULL, - sorted_kv_dir varchar(256) NOT NULL, - lightning_ver varchar(48) NOT NULL - ); - `, schema, CheckpointTableNameTask)) + err = sql.Exec(ctx, "create task checkpoints table", fmt.Sprintf(CreateTaskTableTemplate, schema, CheckpointTableNameTask)) if err != nil { return nil, errors.Trace(err) } - err = sql.Exec(ctx, "create table checkpoints table", fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s.%s ( - task_id bigint NOT NULL, - table_name varchar(261) NOT NULL PRIMARY KEY, - hash binary(32) NOT NULL, - status tinyint unsigned DEFAULT 30, - alloc_base bigint NOT NULL DEFAULT 0, - table_id bigint NOT NULL DEFAULT 0, - create_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - INDEX(task_id) - ); - `, schema, CheckpointTableNameTable)) + err = sql.Exec(ctx, "create table checkpoints table", fmt.Sprintf(CreateTableTableTemplate, schema, CheckpointTableNameTable)) if err != nil { return nil, errors.Trace(err) } - err = sql.Exec(ctx, "create engine checkpoints table", fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s.%s ( - table_name varchar(261) NOT NULL, - engine_id int NOT NULL, - status tinyint unsigned DEFAULT 30, - create_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - PRIMARY KEY(table_name, engine_id DESC) - ); - `, schema, CheckpointTableNameEngine)) + err = sql.Exec(ctx, "create engine checkpoints table", fmt.Sprintf(CreateEngineTableTemplate, schema, CheckpointTableNameEngine)) if err != nil { return nil, errors.Trace(err) } - err = sql.Exec(ctx, "create chunks checkpoints table", fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s.%s ( - table_name varchar(261) NOT NULL, - engine_id int unsigned NOT NULL, - path varchar(2048) NOT NULL, - offset bigint NOT NULL, - type int NOT NULL, - compression int NOT NULL, - sort_key varchar(256) NOT NULL, - columns text NULL, - should_include_row_id BOOL NOT NULL, - end_offset bigint NOT NULL, - pos bigint NOT NULL, - prev_rowid_max bigint NOT NULL, - rowid_max bigint NOT NULL, - kvc_bytes bigint unsigned NOT NULL DEFAULT 0, - kvc_kvs bigint unsigned NOT NULL DEFAULT 0, - kvc_checksum bigint unsigned NOT NULL DEFAULT 0, - create_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - PRIMARY KEY(table_name, engine_id, path(500), offset) - ); - `, schema, CheckpointTableNameChunk)) + err = sql.Exec(ctx, "create chunks checkpoints table", fmt.Sprintf(CreateChunkTableTemplate, schema, CheckpointTableNameChunk)) if err != nil { return nil, errors.Trace(err) } @@ -531,7 +579,6 @@ func NewMySQLCheckpointsDB(ctx context.Context, db *sql.DB, schemaName string, t return &MySQLCheckpointsDB{ db: db, schema: schema, - taskID: taskID, }, nil } @@ -540,10 +587,7 @@ func (cpdb *MySQLCheckpointsDB) Initialize(ctx context.Context, cfg *config.Conf // Since this step is not performance critical, we just insert the rows one-by-one. s := common.SQLWithRetry{DB: cpdb.db, Logger: log.L()} err := s.Transact(ctx, "insert checkpoints", func(c context.Context, tx *sql.Tx) error { - taskStmt, err := tx.PrepareContext(c, fmt.Sprintf(` - REPLACE INTO %s.%s (id, task_id, source_dir, backend, importer_addr, tidb_host, tidb_port, pd_addr, sorted_kv_dir, lightning_ver) - VALUES (1, ?, ?, ?, ?, ?, ?, ?, ?, ?); - `, cpdb.schema, CheckpointTableNameTask)) + taskStmt, err := tx.PrepareContext(c, fmt.Sprintf(InitTaskTemplate, cpdb.schema, CheckpointTableNameTask)) if err != nil { return errors.Trace(err) } @@ -561,13 +605,7 @@ func (cpdb *MySQLCheckpointsDB) Initialize(ctx context.Context, cfg *config.Conf // statement to fail with an irrecoverable error. // We do need to capture the error is display a user friendly message // (multiple nodes cannot import the same table) though. - stmt, err := tx.PrepareContext(c, fmt.Sprintf(` - INSERT INTO %s.%s (task_id, table_name, hash, table_id) VALUES (?, ?, ?, ?) - ON DUPLICATE KEY UPDATE task_id = CASE - WHEN hash = VALUES(hash) - THEN VALUES(task_id) - END; - `, cpdb.schema, CheckpointTableNameTable)) + stmt, err := tx.PrepareContext(c, fmt.Sprintf(InitTableTemplate, cpdb.schema, CheckpointTableNameTable)) if err != nil { return errors.Trace(err) } @@ -576,7 +614,7 @@ func (cpdb *MySQLCheckpointsDB) Initialize(ctx context.Context, cfg *config.Conf for _, db := range dbInfo { for _, table := range db.Tables { tableName := common.UniqueTable(db.Name, table.Name) - _, err = stmt.ExecContext(c, cpdb.taskID, tableName, 0, table.ID) + _, err = stmt.ExecContext(c, cfg.TaskID, tableName, 0, table.ID) if err != nil { return errors.Trace(err) } @@ -598,10 +636,7 @@ func (cpdb *MySQLCheckpointsDB) TaskCheckpoint(ctx context.Context) (*TaskCheckp Logger: log.L(), } - taskQuery := fmt.Sprintf( - "SELECT task_id, source_dir, backend, importer_addr, tidb_host, tidb_port, pd_addr, sorted_kv_dir, lightning_ver FROM %s.%s WHERE id = 1", - cpdb.schema, CheckpointTableNameTask, - ) + taskQuery := fmt.Sprintf(ReadTaskTemplate, cpdb.schema, CheckpointTableNameTask) taskCp := &TaskCheckpoint{} err := s.QueryRow(ctx, "fetch task checkpoint", taskQuery, &taskCp.TaskId, &taskCp.SourceDir, &taskCp.Backend, &taskCp.ImporterAddr, &taskCp.TiDBHost, &taskCp.TiDBPort, &taskCp.PdAddr, &taskCp.SortedKVDir, &taskCp.LightningVer) @@ -633,9 +668,7 @@ func (cpdb *MySQLCheckpointsDB) Get(ctx context.Context, tableName string) (*Tab err := s.Transact(ctx, "read checkpoint", func(c context.Context, tx *sql.Tx) error { // 1. Populate the engines. - engineQuery := fmt.Sprintf(` - SELECT engine_id, status FROM %s.%s WHERE table_name = ? ORDER BY engine_id DESC; - `, cpdb.schema, CheckpointTableNameEngine) + engineQuery := fmt.Sprintf(ReadEngineTemplate, cpdb.schema, CheckpointTableNameEngine) engineRows, err := tx.QueryContext(c, engineQuery, tableName) if err != nil { return errors.Trace(err) @@ -659,14 +692,7 @@ func (cpdb *MySQLCheckpointsDB) Get(ctx context.Context, tableName string) (*Tab // 2. Populate the chunks. - chunkQuery := fmt.Sprintf(` - SELECT - engine_id, path, offset, type, compression, sort_key, columns, - pos, end_offset, prev_rowid_max, rowid_max, - kvc_bytes, kvc_kvs, kvc_checksum, unix_timestamp(create_time) - FROM %s.%s WHERE table_name = ? - ORDER BY engine_id, path, offset; - `, cpdb.schema, CheckpointTableNameChunk) + chunkQuery := fmt.Sprintf(ReadChunkTemplate, cpdb.schema, CheckpointTableNameChunk) chunkRows, err := tx.QueryContext(c, chunkQuery, tableName) if err != nil { return errors.Trace(err) @@ -702,9 +728,7 @@ func (cpdb *MySQLCheckpointsDB) Get(ctx context.Context, tableName string) (*Tab // 3. Fill in the remaining table info - tableQuery := fmt.Sprintf(` - SELECT status, alloc_base, table_id FROM %s.%s WHERE table_name = ? - `, cpdb.schema, CheckpointTableNameTable) + tableQuery := fmt.Sprintf(ReadTableRemainTemplate, cpdb.schema, CheckpointTableNameTable) tableRow := tx.QueryRowContext(c, tableQuery, tableName) var status uint8 @@ -727,27 +751,13 @@ func (cpdb *MySQLCheckpointsDB) InsertEngineCheckpoints(ctx context.Context, tab Logger: log.With(zap.String("table", tableName)), } err := s.Transact(ctx, "update engine checkpoints", func(c context.Context, tx *sql.Tx) error { - engineStmt, err := tx.PrepareContext(c, fmt.Sprintf(` - REPLACE INTO %s.%s (table_name, engine_id, status) VALUES (?, ?, ?); - `, cpdb.schema, CheckpointTableNameEngine)) + engineStmt, err := tx.PrepareContext(c, fmt.Sprintf(ReplaceEngineTemplate, cpdb.schema, CheckpointTableNameEngine)) if err != nil { return errors.Trace(err) } defer engineStmt.Close() - chunkStmt, err := tx.PrepareContext(c, fmt.Sprintf(` - REPLACE INTO %s.%s ( - table_name, engine_id, - path, offset, type, compression, sort_key, columns, should_include_row_id, - pos, end_offset, prev_rowid_max, rowid_max, - kvc_bytes, kvc_kvs, kvc_checksum, create_time - ) VALUES ( - ?, ?, - ?, ?, ?, ?, ?, ?, FALSE, - ?, ?, ?, ?, - 0, 0, 0, from_unixtime(?) - ); - `, cpdb.schema, CheckpointTableNameChunk)) + chunkStmt, err := tx.PrepareContext(c, fmt.Sprintf(ReplaceChunkTemplate, cpdb.schema, CheckpointTableNameChunk)) if err != nil { return errors.Trace(err) } @@ -785,19 +795,10 @@ func (cpdb *MySQLCheckpointsDB) InsertEngineCheckpoints(ctx context.Context, tab } func (cpdb *MySQLCheckpointsDB) Update(checkpointDiffs map[string]*TableCheckpointDiff) { - chunkQuery := fmt.Sprintf(` - UPDATE %s.%s SET pos = ?, prev_rowid_max = ?, kvc_bytes = ?, kvc_kvs = ?, kvc_checksum = ?, columns = ? - WHERE (table_name, engine_id, path, offset) = (?, ?, ?, ?); - `, cpdb.schema, CheckpointTableNameChunk) - rebaseQuery := fmt.Sprintf(` - UPDATE %s.%s SET alloc_base = GREATEST(?, alloc_base) WHERE table_name = ?; - `, cpdb.schema, CheckpointTableNameTable) - tableStatusQuery := fmt.Sprintf(` - UPDATE %s.%s SET status = ? WHERE table_name = ?; - `, cpdb.schema, CheckpointTableNameTable) - engineStatusQuery := fmt.Sprintf(` - UPDATE %s.%s SET status = ? WHERE (table_name, engine_id) = (?, ?); - `, cpdb.schema, CheckpointTableNameEngine) + chunkQuery := fmt.Sprintf(UpdateChunkTemplate, cpdb.schema, CheckpointTableNameChunk) + rebaseQuery := fmt.Sprintf(UpdateTableRebaseTemplate, cpdb.schema, CheckpointTableNameTable) + tableStatusQuery := fmt.Sprintf(UpdateTableStatusTemplate, cpdb.schema, CheckpointTableNameTable) + engineStatusQuery := fmt.Sprintf(UpdateEngineTemplate, cpdb.schema, CheckpointTableNameEngine) s := common.SQLWithRetry{DB: cpdb.db, Logger: log.L()} err := s.Transact(context.Background(), "update checkpoints", func(c context.Context, tx *sql.Tx) error { @@ -1152,9 +1153,9 @@ func (cpdb *MySQLCheckpointsDB) RemoveCheckpoint(ctx context.Context, tableName return s.Exec(ctx, "remove all checkpoints", "DROP SCHEMA "+cpdb.schema) } - deleteChunkQuery := fmt.Sprintf("DELETE FROM %s.%s WHERE table_name = ?", cpdb.schema, CheckpointTableNameChunk) - deleteEngineQuery := fmt.Sprintf("DELETE FROM %s.%s WHERE table_name = ?", cpdb.schema, CheckpointTableNameEngine) - deleteTableQuery := fmt.Sprintf("DELETE FROM %s.%s WHERE table_name = ?", cpdb.schema, CheckpointTableNameTable) + deleteChunkQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameChunk) + deleteEngineQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameEngine) + deleteTableQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameTable) return s.Transact(ctx, "remove checkpoints", func(c context.Context, tx *sql.Tx) error { if _, e := tx.ExecContext(c, deleteChunkQuery, tableName); e != nil { diff --git a/lightning/checkpoints/checkpoints_sql_test.go b/lightning/checkpoints/checkpoints_sql_test.go index a7b15cc5d..1595f1feb 100644 --- a/lightning/checkpoints/checkpoints_sql_test.go +++ b/lightning/checkpoints/checkpoints_sql_test.go @@ -46,7 +46,7 @@ func (s *cpSQLSuite) SetUpTest(c *C) { ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.chunk_v\\d+ .+"). WillReturnResult(sqlmock.NewResult(5, 1)) - cpdb, err := checkpoints.NewMySQLCheckpointsDB(context.Background(), s.db, "mock-schema", 1234) + cpdb, err := checkpoints.NewMySQLCheckpointsDB(context.Background(), s.db, "mock-schema") c.Assert(err, IsNil) c.Assert(s.mock.ExpectationsWereMet(), IsNil) s.cpdb = cpdb @@ -73,13 +73,13 @@ func (s *cpSQLSuite) TestNormalOperations(c *C) { initializeStmt = s.mock. ExpectPrepare("INSERT INTO `mock-schema`\\.table_v\\d+") initializeStmt.ExpectExec(). - WithArgs(1234, "`db1`.`t1`", sqlmock.AnyArg(), int64(1)). + WithArgs(123, "`db1`.`t1`", sqlmock.AnyArg(), int64(1)). WillReturnResult(sqlmock.NewResult(7, 1)) initializeStmt.ExpectExec(). - WithArgs(1234, "`db1`.`t2`", sqlmock.AnyArg(), int64(2)). + WithArgs(123, "`db1`.`t2`", sqlmock.AnyArg(), int64(2)). WillReturnResult(sqlmock.NewResult(8, 1)) initializeStmt.ExpectExec(). - WithArgs(1234, "`db2`.`t3`", sqlmock.AnyArg(), int64(3)). + WithArgs(123, "`db2`.`t3`", sqlmock.AnyArg(), int64(3)). WillReturnResult(sqlmock.NewResult(9, 1)) s.mock.ExpectCommit() diff --git a/lightning/checkpoints/glue_checkpoint.go b/lightning/checkpoints/glue_checkpoint.go new file mode 100644 index 000000000..e154ea00d --- /dev/null +++ b/lightning/checkpoints/glue_checkpoint.go @@ -0,0 +1,726 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checkpoints + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" + "github.com/pingcap/tidb-lightning/lightning/common" + "github.com/pingcap/tidb-lightning/lightning/config" + "github.com/pingcap/tidb-lightning/lightning/log" + "github.com/pingcap/tidb-lightning/lightning/mydump" + verify "github.com/pingcap/tidb-lightning/lightning/verification" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/zap" +) + +type Session interface { + Close() + Execute(context.Context, string) ([]sqlexec.RecordSet, error) + CommitTxn(context.Context) error + RollbackTxn(context.Context) + PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error) + ExecutePreparedStmt(ctx context.Context, stmtID uint32, param []types.Datum) (sqlexec.RecordSet, error) + DropPreparedStmt(stmtID uint32) error +} + +// GlueCheckpointsDB is almost same with MySQLCheckpointsDB, but it uses TiDB's internal data structure which requires a +// lot to keep same with database/sql. +// TODO: Encapsulate Begin/Commit/Rollback txn, form SQL with args and query/iter/scan TiDB's RecordSet into a interface +// to reuse MySQLCheckpointsDB. +type GlueCheckpointsDB struct { + // getSessionFunc will get a new session from TiDB + getSessionFunc func() (Session, error) + schema string +} + +func NewGlueCheckpointsDB(ctx context.Context, se Session, f func() (Session, error), schemaName string) (*GlueCheckpointsDB, error) { + var escapedSchemaName strings.Builder + common.WriteMySQLIdentifier(&escapedSchemaName, schemaName) + schema := escapedSchemaName.String() + logger := log.With(zap.String("schema", schemaName)) + + sql := fmt.Sprintf(CreateDBTemplate, schema) + err := common.Retry("create checkpoints database", logger, func() error { + _, err := se.Execute(ctx, sql) + return err + }) + if err != nil { + return nil, errors.Trace(err) + } + + sql = fmt.Sprintf(CreateTaskTableTemplate, schema, CheckpointTableNameTask) + err = common.Retry("create task checkpoints table", logger, func() error { + _, err := se.Execute(ctx, sql) + return err + }) + if err != nil { + return nil, errors.Trace(err) + } + + sql = fmt.Sprintf(CreateTableTableTemplate, schema, CheckpointTableNameTable) + err = common.Retry("create table checkpoints table", logger, func() error { + _, err := se.Execute(ctx, sql) + return err + }) + if err != nil { + return nil, errors.Trace(err) + } + + sql = fmt.Sprintf(CreateEngineTableTemplate, schema, CheckpointTableNameEngine) + err = common.Retry("create engine checkpoints table", logger, func() error { + _, err := se.Execute(ctx, sql) + return err + }) + if err != nil { + return nil, errors.Trace(err) + } + + sql = fmt.Sprintf(CreateChunkTableTemplate, schema, CheckpointTableNameChunk) + err = common.Retry("create chunks checkpoints table", logger, func() error { + _, err := se.Execute(ctx, sql) + return err + }) + if err != nil { + return nil, errors.Trace(err) + } + + return &GlueCheckpointsDB{ + getSessionFunc: f, + schema: schema, + }, nil +} + +func (g GlueCheckpointsDB) Initialize(ctx context.Context, cfg *config.Config, dbInfo map[string]*TidbDBInfo) error { + logger := log.L() + se, err := g.getSessionFunc() + if err != nil { + return errors.Trace(err) + } + defer se.Close() + + err = Transact(ctx, "insert checkpoints", se, logger, func(c context.Context, s Session) error { + stmtID, _, _, err := s.PrepareStmt(fmt.Sprintf(InitTaskTemplate, g.schema, CheckpointTableNameTask)) + if err != nil { + return errors.Trace(err) + } + defer s.DropPreparedStmt(stmtID) + _, err = s.ExecutePreparedStmt(c, stmtID, []types.Datum{ + types.NewIntDatum(cfg.TaskID), + types.NewStringDatum(cfg.Mydumper.SourceDir), + types.NewStringDatum(cfg.TikvImporter.Backend), + types.NewStringDatum(cfg.TikvImporter.Addr), + types.NewStringDatum(cfg.TiDB.Host), + types.NewIntDatum(int64(cfg.TiDB.Port)), + types.NewStringDatum(cfg.TiDB.PdAddr), + types.NewStringDatum(cfg.TikvImporter.SortedKVDir), + types.NewStringDatum(common.ReleaseVersion), + }) + if err != nil { + return errors.Trace(err) + } + + stmtID2, _, _, err := s.PrepareStmt(fmt.Sprintf(InitTableTemplate, g.schema, CheckpointTableNameTable)) + if err != nil { + return errors.Trace(err) + } + defer s.DropPreparedStmt(stmtID2) + + for _, db := range dbInfo { + for _, table := range db.Tables { + tableName := common.UniqueTable(db.Name, table.Name) + _, err = s.ExecutePreparedStmt(c, stmtID2, []types.Datum{ + types.NewIntDatum(cfg.TaskID), + types.NewStringDatum(tableName), + types.NewIntDatum(0), + types.NewIntDatum(table.ID), + }) + if err != nil { + return errors.Trace(err) + } + } + } + return nil + }) + return errors.Trace(err) +} + +func (g GlueCheckpointsDB) TaskCheckpoint(ctx context.Context) (*TaskCheckpoint, error) { + logger := log.L() + sql := fmt.Sprintf(ReadTaskTemplate, g.schema, CheckpointTableNameTask) + se, err := g.getSessionFunc() + if err != nil { + return nil, errors.Trace(err) + } + defer se.Close() + + var taskCp *TaskCheckpoint + err = common.Retry("fetch task checkpoint", logger, func() error { + rs, err := se.Execute(ctx, sql) + if err != nil { + return errors.Trace(err) + } + r := rs[0] + defer r.Close() + req := r.NewChunk() + err = r.Next(ctx, req) + if err != nil { + return err + } + if req.NumRows() == 0 { + return nil + } + + row := req.GetRow(0) + taskCp = &TaskCheckpoint{} + taskCp.TaskId = row.GetInt64(0) + taskCp.SourceDir = row.GetString(1) + taskCp.Backend = row.GetString(2) + taskCp.ImporterAddr = row.GetString(3) + taskCp.TiDBHost = row.GetString(4) + taskCp.TiDBPort = int(row.GetInt64(5)) + taskCp.PdAddr = row.GetString(6) + taskCp.SortedKVDir = row.GetString(7) + taskCp.LightningVer = row.GetString(8) + return nil + }) + if err != nil { + return nil, errors.Trace(err) + } + return taskCp, nil +} + +func (g GlueCheckpointsDB) Get(ctx context.Context, tableName string) (*TableCheckpoint, error) { + cp := &TableCheckpoint{ + Engines: map[int32]*EngineCheckpoint{}, + } + logger := log.With(zap.String("table", tableName)) + se, err := g.getSessionFunc() + if err != nil { + return nil, errors.Trace(err) + } + defer se.Close() + + var tableNameBuilder strings.Builder + common.EscapeMySQLSingleQuote(&tableNameBuilder, tableName) + tableName = tableNameBuilder.String() + err = Transact(ctx, "read checkpoint", se, logger, func(c context.Context, s Session) error { + // 1. Populate the engines. + sql := fmt.Sprintf(ReadEngineTemplate, g.schema, CheckpointTableNameEngine) + sql = strings.ReplaceAll(sql, "?", tableName) + rs, err := s.Execute(ctx, sql) + if err != nil { + return errors.Trace(err) + } + r := rs[0] + req := r.NewChunk() + it := chunk.NewIterator4Chunk(req) + for { + err = r.Next(ctx, req) + if err != nil { + r.Close() + return err + } + if req.NumRows() == 0 { + break + } + + for row := it.Begin(); row != it.End(); row = it.Next() { + engineID := int32(row.GetInt64(0)) + status := uint8(row.GetUint64(1)) + cp.Engines[engineID] = &EngineCheckpoint{ + Status: CheckpointStatus(status), + } + } + } + r.Close() + + // 2. Populate the chunks. + sql = fmt.Sprintf(ReadChunkTemplate, g.schema, CheckpointTableNameChunk) + sql = strings.ReplaceAll(sql, "?", tableName) + rs, err = s.Execute(ctx, sql) + if err != nil { + return errors.Trace(err) + } + r = rs[0] + req = r.NewChunk() + it = chunk.NewIterator4Chunk(req) + for { + err = r.Next(ctx, req) + if err != nil { + r.Close() + return err + } + if req.NumRows() == 0 { + break + } + + for row := it.Begin(); row != it.End(); row = it.Next() { + value := &ChunkCheckpoint{} + engineID := int32(row.GetInt64(0)) + value.Key.Path = row.GetString(1) + value.Key.Offset = row.GetInt64(2) + value.FileMeta.Type = mydump.SourceType(row.GetInt64(3)) + value.FileMeta.Compression = mydump.Compression(row.GetInt64(4)) + value.FileMeta.SortKey = row.GetString(5) + colPerm := row.GetBytes(6) + value.Chunk.Offset = row.GetInt64(7) + value.Chunk.EndOffset = row.GetInt64(8) + value.Chunk.PrevRowIDMax = row.GetInt64(9) + value.Chunk.RowIDMax = row.GetInt64(10) + kvcBytes := row.GetUint64(11) + kvcKVs := row.GetUint64(12) + kvcChecksum := row.GetUint64(13) + value.Timestamp = row.GetInt64(14) + + value.FileMeta.Path = value.Key.Path + value.Checksum = verify.MakeKVChecksum(kvcBytes, kvcKVs, kvcChecksum) + if err := json.Unmarshal(colPerm, &value.ColumnPermutation); err != nil { + r.Close() + return errors.Trace(err) + } + cp.Engines[engineID].Chunks = append(cp.Engines[engineID].Chunks, value) + } + } + r.Close() + + // 3. Fill in the remaining table info + sql = fmt.Sprintf(ReadTableRemainTemplate, g.schema, CheckpointTableNameTable) + sql = strings.ReplaceAll(sql, "?", tableName) + rs, err = s.Execute(ctx, sql) + if err != nil { + return errors.Trace(err) + } + r = rs[0] + defer r.Close() + req = r.NewChunk() + err = r.Next(ctx, req) + if err != nil { + return err + } + if req.NumRows() == 0 { + return nil + } + + row := req.GetRow(0) + cp.Status = CheckpointStatus(row.GetUint64(0)) + cp.AllocBase = row.GetInt64(1) + cp.TableID = row.GetInt64(2) + return nil + }) + + if err != nil { + return nil, errors.Trace(err) + } + + return cp, nil +} + +func (g GlueCheckpointsDB) Close() error { + return nil +} + +func (g GlueCheckpointsDB) InsertEngineCheckpoints(ctx context.Context, tableName string, checkpointMap map[int32]*EngineCheckpoint) error { + logger := log.With(zap.String("table", tableName)) + se, err := g.getSessionFunc() + if err != nil { + return errors.Trace(err) + } + defer se.Close() + + err = Transact(ctx, "update engine checkpoints", se, logger, func(c context.Context, s Session) error { + engineStmt, _, _, err := s.PrepareStmt(fmt.Sprintf(ReplaceEngineTemplate, g.schema, CheckpointTableNameEngine)) + if err != nil { + return errors.Trace(err) + } + defer s.DropPreparedStmt(engineStmt) + + chunkStmt, _, _, err := s.PrepareStmt(fmt.Sprintf(ReplaceChunkTemplate, g.schema, CheckpointTableNameChunk)) + if err != nil { + return errors.Trace(err) + } + defer s.DropPreparedStmt(chunkStmt) + + for engineID, engine := range checkpointMap { + _, err := s.ExecutePreparedStmt(c, engineStmt, []types.Datum{ + types.NewStringDatum(tableName), + types.NewIntDatum(int64(engineID)), + types.NewUintDatum(uint64(engine.Status)), + }) + if err != nil { + return errors.Trace(err) + } + for _, value := range engine.Chunks { + columnPerm, err := json.Marshal(value.ColumnPermutation) + if err != nil { + return errors.Trace(err) + } + _, err = s.ExecutePreparedStmt(c, chunkStmt, []types.Datum{ + types.NewStringDatum(tableName), + types.NewIntDatum(int64(engineID)), + types.NewStringDatum(value.Key.Path), + types.NewIntDatum(value.Key.Offset), + types.NewIntDatum(int64(value.FileMeta.Type)), + types.NewIntDatum(int64(value.FileMeta.Compression)), + types.NewStringDatum(value.FileMeta.SortKey), + types.NewBytesDatum(columnPerm), + types.NewIntDatum(value.Chunk.Offset), + types.NewIntDatum(value.Chunk.EndOffset), + types.NewIntDatum(value.Chunk.PrevRowIDMax), + types.NewIntDatum(value.Chunk.RowIDMax), + types.NewIntDatum(value.Timestamp), + }) + if err != nil { + return errors.Trace(err) + } + } + } + return nil + }) + return errors.Trace(err) +} + +func (g GlueCheckpointsDB) Update(checkpointDiffs map[string]*TableCheckpointDiff) { + logger := log.L() + se, err := g.getSessionFunc() + if err != nil { + log.L().Error("can't get a session to update GlueCheckpointsDB", zap.Error(errors.Trace(err))) + return + } + defer se.Close() + + chunkQuery := fmt.Sprintf(UpdateChunkTemplate, g.schema, CheckpointTableNameChunk) + rebaseQuery := fmt.Sprintf(UpdateTableRebaseTemplate, g.schema, CheckpointTableNameTable) + tableStatusQuery := fmt.Sprintf(UpdateTableStatusTemplate, g.schema, CheckpointTableNameTable) + engineStatusQuery := fmt.Sprintf(UpdateEngineTemplate, g.schema, CheckpointTableNameEngine) + err = Transact(context.Background(), "update checkpoints", se, logger, func(c context.Context, s Session) error { + chunkStmt, _, _, err := s.PrepareStmt(chunkQuery) + if err != nil { + return errors.Trace(err) + } + defer s.DropPreparedStmt(chunkStmt) + rebaseStmt, _, _, err := s.PrepareStmt(rebaseQuery) + if err != nil { + return errors.Trace(err) + } + defer s.DropPreparedStmt(rebaseStmt) + tableStatusStmt, _, _, err := s.PrepareStmt(tableStatusQuery) + if err != nil { + return errors.Trace(err) + } + defer s.DropPreparedStmt(tableStatusStmt) + engineStatusStmt, _, _, err := s.PrepareStmt(engineStatusQuery) + if err != nil { + return errors.Trace(err) + } + defer s.DropPreparedStmt(engineStatusStmt) + + for tableName, cpd := range checkpointDiffs { + if cpd.hasStatus { + _, err := s.ExecutePreparedStmt(c, tableStatusStmt, []types.Datum{ + types.NewUintDatum(uint64(cpd.status)), + types.NewStringDatum(tableName), + }) + if err != nil { + return errors.Trace(err) + } + } + if cpd.hasRebase { + _, err := s.ExecutePreparedStmt(c, rebaseStmt, []types.Datum{ + types.NewIntDatum(cpd.allocBase), + types.NewStringDatum(tableName), + }) + if err != nil { + return errors.Trace(err) + } + } + for engineID, engineDiff := range cpd.engines { + if engineDiff.hasStatus { + _, err := s.ExecutePreparedStmt(c, engineStatusStmt, []types.Datum{ + types.NewUintDatum(uint64(engineDiff.status)), + types.NewStringDatum(tableName), + types.NewIntDatum(int64(engineID)), + }) + if err != nil { + return errors.Trace(err) + } + } + for key, diff := range engineDiff.chunks { + columnPerm, err := json.Marshal(diff.columnPermutation) + if err != nil { + return errors.Trace(err) + } + _, err = s.ExecutePreparedStmt(c, chunkStmt, []types.Datum{ + types.NewIntDatum(diff.pos), + types.NewIntDatum(diff.rowID), + types.NewUintDatum(diff.checksum.SumSize()), + types.NewUintDatum(diff.checksum.SumKVS()), + types.NewUintDatum(diff.checksum.Sum()), + types.NewBytesDatum(columnPerm), + types.NewStringDatum(tableName), + types.NewIntDatum(int64(engineID)), + types.NewStringDatum(key.Path), + types.NewIntDatum(key.Offset), + }) + if err != nil { + return errors.Trace(err) + } + } + } + } + return nil + }) + if err != nil { + log.L().Error("save checkpoint failed", zap.Error(err)) + } +} + +func (g GlueCheckpointsDB) RemoveCheckpoint(ctx context.Context, tableName string) error { + logger := log.With(zap.String("table", tableName)) + se, err := g.getSessionFunc() + if err != nil { + return errors.Trace(err) + } + defer se.Close() + + if tableName == "all" { + return common.Retry("remove all checkpoints", logger, func() error { + _, err := se.Execute(ctx, "DROP SCHEMA "+g.schema) + return err + }) + } + var tableNameBuilder strings.Builder + common.EscapeMySQLSingleQuote(&tableNameBuilder, tableName) + tableName = tableNameBuilder.String() + deleteChunkQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, g.schema, CheckpointTableNameChunk) + deleteChunkQuery = strings.ReplaceAll(deleteChunkQuery, "?", tableName) + deleteEngineQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, g.schema, CheckpointTableNameEngine) + deleteEngineQuery = strings.ReplaceAll(deleteEngineQuery, "?", tableName) + deleteTableQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, g.schema, CheckpointTableNameTable) + deleteTableQuery = strings.ReplaceAll(deleteTableQuery, "?", tableName) + + return errors.Trace(Transact(ctx, "remove checkpoints", se, logger, func(c context.Context, s Session) error { + if _, e := s.Execute(c, deleteChunkQuery); e != nil { + return e + } + if _, e := s.Execute(c, deleteEngineQuery); e != nil { + return e + } + if _, e := s.Execute(c, deleteTableQuery); e != nil { + return e + } + return nil + })) +} + +func (g GlueCheckpointsDB) MoveCheckpoints(ctx context.Context, taskID int64) error { + newSchema := fmt.Sprintf("`%s.%d.bak`", g.schema[1:len(g.schema)-1], taskID) + logger := log.With(zap.Int64("taskID", taskID)) + se, err := g.getSessionFunc() + if err != nil { + return errors.Trace(err) + } + defer se.Close() + + err = common.Retry("create backup checkpoints schema", logger, func() error { + _, err := se.Execute(ctx, "CREATE SCHEMA IF NOT EXISTS "+newSchema) + return err + }) + if err != nil { + return errors.Trace(err) + } + for _, tbl := range []string{CheckpointTableNameChunk, CheckpointTableNameEngine, + CheckpointTableNameTable, CheckpointTableNameTask} { + query := fmt.Sprintf("RENAME TABLE %[1]s.%[3]s TO %[2]s.%[3]s", g.schema, newSchema, tbl) + err := common.Retry(fmt.Sprintf("move %s checkpoints table", tbl), logger, func() error { + _, err := se.Execute(ctx, query) + return err + }) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +func (g GlueCheckpointsDB) IgnoreErrorCheckpoint(ctx context.Context, tableName string) error { + logger := log.With(zap.String("table", tableName)) + se, err := g.getSessionFunc() + if err != nil { + return errors.Trace(err) + } + defer se.Close() + + var colName string + if tableName == "all" { + // This will expand to `WHERE 'all' = 'all'` and effectively allowing + // all tables to be included. + colName = "'all'" + } else { + colName = "table_name" + } + + var tableNameBuilder strings.Builder + common.EscapeMySQLSingleQuote(&tableNameBuilder, tableName) + tableName = tableNameBuilder.String() + + engineQuery := fmt.Sprintf(` + UPDATE %s.%s SET status = %d WHERE %s = %s AND status <= %d; + `, g.schema, CheckpointTableNameEngine, CheckpointStatusLoaded, colName, tableName, CheckpointStatusMaxInvalid) + tableQuery := fmt.Sprintf(` + UPDATE %s.%s SET status = %d WHERE %s = %s AND status <= %d; + `, g.schema, CheckpointTableNameTable, CheckpointStatusLoaded, colName, tableName, CheckpointStatusMaxInvalid) + return errors.Trace(Transact(ctx, "ignore error checkpoints", se, logger, func(c context.Context, s Session) error { + if _, e := s.Execute(c, engineQuery); e != nil { + return e + } + if _, e := s.Execute(c, tableQuery); e != nil { + return e + } + return nil + })) +} + +func (g GlueCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tableName string) ([]DestroyedTableCheckpoint, error) { + logger := log.With(zap.String("table", tableName)) + se, err := g.getSessionFunc() + if err != nil { + return nil, errors.Trace(err) + } + defer se.Close() + + var colName, aliasedColName string + + if tableName == "all" { + // These will expand to `WHERE 'all' = 'all'` and effectively allowing + // all tables to be included. + colName = "'all'" + aliasedColName = "'all'" + } else { + colName = "table_name" + aliasedColName = "t.table_name" + } + + var tableNameBuilder strings.Builder + common.EscapeMySQLSingleQuote(&tableNameBuilder, tableName) + tableName = tableNameBuilder.String() + + selectQuery := fmt.Sprintf(` + SELECT + t.table_name, + COALESCE(MIN(e.engine_id), 0), + COALESCE(MAX(e.engine_id), -1) + FROM %[1]s.%[4]s t + LEFT JOIN %[1]s.%[5]s e ON t.table_name = e.table_name + WHERE %[2]s = %[6]s AND t.status <= %[3]d + GROUP BY t.table_name; + `, g.schema, aliasedColName, CheckpointStatusMaxInvalid, CheckpointTableNameTable, CheckpointTableNameEngine, tableName) + deleteChunkQuery := fmt.Sprintf(` + DELETE FROM %[1]s.%[4]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[5]s WHERE %[2]s = %[6]s AND status <= %[3]d) + `, g.schema, colName, CheckpointStatusMaxInvalid, CheckpointTableNameChunk, CheckpointTableNameTable, tableName) + deleteEngineQuery := fmt.Sprintf(` + DELETE FROM %[1]s.%[4]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[5]s WHERE %[2]s = %[6]s AND status <= %[3]d) + `, g.schema, colName, CheckpointStatusMaxInvalid, CheckpointTableNameEngine, CheckpointTableNameTable, tableName) + deleteTableQuery := fmt.Sprintf(` + DELETE FROM %s.%s WHERE %s = %s AND status <= %d + `, g.schema, CheckpointTableNameTable, colName, tableName, CheckpointStatusMaxInvalid) + + var targetTables []DestroyedTableCheckpoint + err = Transact(ctx, "destroy error checkpoints", se, logger, func(c context.Context, s Session) error { + // clean because it's in a retry + targetTables = nil + rs, err := s.Execute(c, selectQuery) + if err != nil { + return errors.Trace(err) + } + r := rs[0] + req := r.NewChunk() + it := chunk.NewIterator4Chunk(req) + for { + err = r.Next(ctx, req) + if err != nil { + r.Close() + return err + } + if req.NumRows() == 0 { + break + } + + for row := it.Begin(); row != it.End(); row = it.Next() { + var dtc DestroyedTableCheckpoint + dtc.TableName = row.GetString(0) + dtc.MinEngineID = int32(row.GetInt64(1)) + dtc.MaxEngineID = int32(row.GetInt64(2)) + targetTables = append(targetTables, dtc) + } + } + r.Close() + + if _, e := s.Execute(c, deleteChunkQuery); e != nil { + return errors.Trace(e) + } + if _, e := s.Execute(c, deleteEngineQuery); e != nil { + return errors.Trace(e) + } + if _, e := s.Execute(c, deleteTableQuery); e != nil { + return errors.Trace(e) + } + return nil + }) + + if err != nil { + return nil, errors.Trace(err) + } + + return targetTables, nil +} + +func (g GlueCheckpointsDB) DumpTables(ctx context.Context, csv io.Writer) error { + return errors.Errorf("dumping glue checkpoint into CSV not unsupported") +} + +func (g GlueCheckpointsDB) DumpEngines(ctx context.Context, csv io.Writer) error { + return errors.Errorf("dumping glue checkpoint into CSV not unsupported") +} + +func (g GlueCheckpointsDB) DumpChunks(ctx context.Context, csv io.Writer) error { + return errors.Errorf("dumping glue checkpoint into CSV not unsupported") +} + +func Transact(ctx context.Context, purpose string, s Session, logger log.Logger, action func(context.Context, Session) error) error { + return common.Retry(purpose, logger, func() error { + _, err := s.Execute(ctx, "BEGIN") + if err != nil { + return errors.Annotate(err, "begin transaction failed") + } + err = action(ctx, s) + if err != nil { + s.RollbackTxn(ctx) + return err + } + err = s.CommitTxn(ctx) + if err != nil { + return errors.Annotate(err, "commit transaction failed") + } + return nil + }) +} diff --git a/lightning/common/util.go b/lightning/common/util.go index 96f532d0b..909c26410 100644 --- a/lightning/common/util.go +++ b/lightning/common/util.go @@ -107,6 +107,11 @@ type SQLWithRetry struct { } func (t SQLWithRetry) perform(ctx context.Context, parentLogger log.Logger, purpose string, action func() error) error { + return Retry(purpose, parentLogger, action) +} + +// Retry is shared by SQLWithRetry.perform, implementation of GlueCheckpointsDB and TiDB's glue implementation +func Retry(purpose string, parentLogger log.Logger, action func() error) error { var err error outside: for i := 0; i < defaultMaxRetry; i++ { @@ -264,6 +269,20 @@ func WriteMySQLIdentifier(builder *strings.Builder, identifier string) { builder.WriteByte('`') } +func EscapeMySQLSingleQuote(builder *strings.Builder, s string) { + builder.Grow(len(s) + 2) + builder.WriteByte('\'') + for i := 0; i < len(s); i++ { + b := s[i] + if b == '\'' { + builder.WriteString("''") + } else { + builder.WriteByte(b) + } + } + builder.WriteByte('\'') +} + // GetJSON fetches a page and parses it as JSON. The parsed result will be // stored into the `v`. The variable `v` must be a pointer to a type that can be // unmarshalled from JSON. diff --git a/lightning/glue/glue.go b/lightning/glue/glue.go index 56f95b084..e4e045844 100644 --- a/lightning/glue/glue.go +++ b/lightning/glue/glue.go @@ -16,6 +16,7 @@ package glue import ( "context" "database/sql" + "errors" "github.com/pingcap/parser" "github.com/pingcap/parser/model" @@ -24,7 +25,6 @@ import ( "github.com/pingcap/tidb-lightning/lightning/common" "github.com/pingcap/tidb-lightning/lightning/config" "github.com/pingcap/tidb-lightning/lightning/log" - "go.uber.org/zap" ) type Glue interface { @@ -33,14 +33,17 @@ type Glue interface { GetDB() (*sql.DB, error) GetParser() *parser.Parser GetTables(context.Context, string) ([]*model.TableInfo, error) + GetSession() (checkpoints.Session, error) OpenCheckpointsDB(context.Context, *config.Config) (checkpoints.CheckpointsDB, error) // Record is used to report some information (key, value) to host TiDB, including progress, stage currently Record(string, uint64) } type SQLExecutor interface { - ExecuteWithLog(ctx context.Context, query string, purpose string, logger *zap.Logger) error - ObtainStringWithLog(ctx context.Context, query string, purpose string, logger *zap.Logger) (string, error) + // ExecuteWithLog and ObtainStringWithLog should support concurrently call and can't assure different calls goes to + // same underlying connection + ExecuteWithLog(ctx context.Context, query string, purpose string, logger log.Logger) error + ObtainStringWithLog(ctx context.Context, query string, purpose string, logger log.Logger) (string, error) Close() } @@ -60,19 +63,19 @@ func (e *ExternalTiDBGlue) GetSQLExecutor() SQLExecutor { return e } -func (e *ExternalTiDBGlue) ExecuteWithLog(ctx context.Context, query string, purpose string, logger *zap.Logger) error { +func (e *ExternalTiDBGlue) ExecuteWithLog(ctx context.Context, query string, purpose string, logger log.Logger) error { sql := common.SQLWithRetry{ DB: e.db, - Logger: log.Logger{Logger: logger}, + Logger: logger, } return sql.Exec(ctx, purpose, query) } -func (e *ExternalTiDBGlue) ObtainStringWithLog(ctx context.Context, query string, purpose string, logger *zap.Logger) (string, error) { +func (e *ExternalTiDBGlue) ObtainStringWithLog(ctx context.Context, query string, purpose string, logger log.Logger) (string, error) { var s string err := common.SQLWithRetry{ DB: e.db, - Logger: log.Logger{Logger: logger}, + Logger: logger, }.QueryRow(ctx, purpose, query, &s) return s, err } @@ -85,8 +88,12 @@ func (e *ExternalTiDBGlue) GetParser() *parser.Parser { return e.parser } -func (e *ExternalTiDBGlue) GetTables(context.Context, string) ([]*model.TableInfo, error) { - return nil, nil +func (e ExternalTiDBGlue) GetTables(context.Context, string) ([]*model.TableInfo, error) { + return nil, errors.New("ExternalTiDBGlue doesn't have a valid GetTables function") +} + +func (e ExternalTiDBGlue) GetSession() (checkpoints.Session, error) { + return nil, errors.New("ExternalTiDBGlue doesn't have a valid GetSession function") } func (e *ExternalTiDBGlue) OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (checkpoints.CheckpointsDB, error) { diff --git a/lightning/restore/restore.go b/lightning/restore/restore.go index 22c629839..011d7c4b7 100644 --- a/lightning/restore/restore.go +++ b/lightning/restore/restore.go @@ -209,7 +209,7 @@ func NewRestoreControllerWithPauser( case config.BackendLocal: backend, err = kv.NewLocalBackend(ctx, tls, cfg.TiDB.PdAddr, int64(cfg.TikvImporter.RegionSplitSize), cfg.TikvImporter.SortedKVDir, cfg.TikvImporter.RangeConcurrency, cfg.TikvImporter.SendKVPairs, - cfg.Checkpoint.Enable) + cfg.Checkpoint.Enable, g) if err != nil { return nil, err } @@ -294,15 +294,14 @@ outside: } func (rc *RestoreController) restoreSchema(ctx context.Context) error { - tidbMgr, err := NewTiDBManager(rc.cfg.TiDB, rc.tls) - if err != nil { - return errors.Trace(err) - } - defer tidbMgr.Close() - if !rc.cfg.Mydumper.NoSchema { if rc.tidbGlue.OwnsSQLExecutor() { - tidbMgr.db.ExecContext(ctx, "SET SQL_MODE = ?", rc.cfg.TiDB.StrSQLMode) + db, err := DBFromConfig(rc.cfg.TiDB) + if err != nil { + return errors.Trace(err) + } + defer db.Close() + db.ExecContext(ctx, "SET SQL_MODE = ?", rc.cfg.TiDB.StrSQLMode) } for _, dbMeta := range rc.dbMetas { @@ -312,7 +311,7 @@ func (rc *RestoreController) restoreSchema(ctx context.Context) error { for _, tblMeta := range dbMeta.Tables { tablesSchema[tblMeta.Name] = tblMeta.GetSchema(ctx, rc.store) } - err = tidbMgr.InitSchema(ctx, rc.tidbGlue, dbMeta.Name, tablesSchema) + err := InitSchema(ctx, rc.tidbGlue, dbMeta.Name, tablesSchema) task.End(zap.ErrorLevel, err) if err != nil { @@ -324,7 +323,7 @@ func (rc *RestoreController) restoreSchema(ctx context.Context) error { if !rc.tidbGlue.OwnsSQLExecutor() { getTableFunc = rc.tidbGlue.GetTables } - dbInfos, err := tidbMgr.LoadSchemaInfo(ctx, rc.dbMetas, getTableFunc) + dbInfos, err := LoadSchemaInfo(ctx, rc.dbMetas, getTableFunc) if err != nil { return errors.Trace(err) } @@ -1632,7 +1631,7 @@ func (tr *TableRestore) compareChecksum(ctx context.Context, localChecksum verif func (tr *TableRestore) analyzeTable(ctx context.Context, g glue.SQLExecutor) error { task := tr.logger.Begin(zap.InfoLevel, "analyze") - err := g.ExecuteWithLog(ctx, "ANALYZE TABLE "+tr.tableName, "analyze table", tr.logger.Logger) + err := g.ExecuteWithLog(ctx, "ANALYZE TABLE "+tr.tableName, "analyze table", tr.logger) task.End(zap.ErrorLevel, err) return err } diff --git a/lightning/restore/tidb.go b/lightning/restore/tidb.go index cdc390af8..4d7517863 100644 --- a/lightning/restore/tidb.go +++ b/lightning/restore/tidb.go @@ -124,21 +124,14 @@ func (timgr *TiDBManager) Close() { timgr.db.Close() } -func (timgr *TiDBManager) InitSchema(ctx context.Context, g glue.Glue, database string, tablesSchema map[string]string) error { +func InitSchema(ctx context.Context, g glue.Glue, database string, tablesSchema map[string]string) error { logger := log.With(zap.String("db", database)) sqlExecutor := g.GetSQLExecutor() var createDatabase strings.Builder createDatabase.WriteString("CREATE DATABASE IF NOT EXISTS ") common.WriteMySQLIdentifier(&createDatabase, database) - err := sqlExecutor.ExecuteWithLog(ctx, createDatabase.String(), "create database", logger.Logger) - if err != nil { - return errors.Trace(err) - } - var useDB strings.Builder - useDB.WriteString("USE ") - common.WriteMySQLIdentifier(&useDB, database) - err = sqlExecutor.ExecuteWithLog(ctx, useDB.String(), "use database", logger.Logger) + err := sqlExecutor.ExecuteWithLog(ctx, createDatabase.String(), "create database", logger) if err != nil { return errors.Trace(err) } @@ -147,7 +140,7 @@ func (timgr *TiDBManager) InitSchema(ctx context.Context, g glue.Glue, database for tbl, sqlCreateTable := range tablesSchema { task.Debug("create table", zap.String("schema", sqlCreateTable)) - sqlCreateTable, err = timgr.createTableIfNotExistsStmt(g.GetParser(), sqlCreateTable, tbl) + sqlCreateTable, err = createTableIfNotExistsStmt(g.GetParser(), sqlCreateTable, database, tbl) if err != nil { break } @@ -156,7 +149,7 @@ func (timgr *TiDBManager) InitSchema(ctx context.Context, g glue.Glue, database ctx, sqlCreateTable, "create table", - logger.Logger.With(zap.String("table", common.UniqueTable(database, tbl))), + logger.With(zap.String("table", common.UniqueTable(database, tbl))), ) if err != nil { break @@ -167,7 +160,7 @@ func (timgr *TiDBManager) InitSchema(ctx context.Context, g glue.Glue, database return errors.Trace(err) } -func (timgr *TiDBManager) createTableIfNotExistsStmt(p *parser.Parser, createTable, tblName string) (string, error) { +func createTableIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) (string, error) { stmts, _, err := p.Parse(createTable, "", "") if err != nil { return "", err @@ -179,7 +172,7 @@ func (timgr *TiDBManager) createTableIfNotExistsStmt(p *parser.Parser, createTab for _, stmt := range stmts { if createTableNode, ok := stmt.(*ast.CreateTableStmt); ok { - createTableNode.Table.Schema = model.NewCIStr("") + createTableNode.Table.Schema = model.NewCIStr(dbName) createTableNode.Table.Name = model.NewCIStr(tblName) createTableNode.IfNotExists = true } @@ -200,7 +193,7 @@ func (timgr *TiDBManager) DropTable(ctx context.Context, tableName string) error return sql.Exec(ctx, "drop table", "DROP TABLE "+tableName) } -func (timgr *TiDBManager) LoadSchemaInfo( +func LoadSchemaInfo( ctx context.Context, schemas []*mydump.MDDatabaseMeta, getTables func(context.Context, string) ([]*model.TableInfo, error), @@ -269,7 +262,7 @@ func ObtainRowFormatVersion(ctx context.Context, g glue.SQLExecutor) string { ctx, "SELECT @@tidb_row_format_version", "obtain row format version", - log.L().Logger, + log.L(), ) if err != nil { rowFormatVersion = "1" @@ -283,7 +276,7 @@ func ObtainNewCollationEnabled(ctx context.Context, g glue.SQLExecutor) bool { ctx, "SELECT variable_value FROM mysql.tidb WHERE variable_name = 'new_collation_enabled'", "obtain new collation enabled", - log.L().Logger, + log.L(), ) if err == nil && newCollationVal == "True" { newCollationEnabled = true @@ -301,7 +294,7 @@ func AlterAutoIncrement(ctx context.Context, g glue.SQLExecutor, tableName strin logger := log.With(zap.String("table", tableName), zap.Int64("auto_increment", incr)) query := fmt.Sprintf("ALTER TABLE %s AUTO_INCREMENT=%d", tableName, incr) task := logger.Begin(zap.InfoLevel, "alter table auto_increment") - err := g.ExecuteWithLog(ctx, query, "alter table auto_increment", logger.Logger) + err := g.ExecuteWithLog(ctx, query, "alter table auto_increment", logger) task.End(zap.ErrorLevel, err) if err != nil { task.Error( @@ -316,7 +309,7 @@ func AlterAutoRandom(ctx context.Context, g glue.SQLExecutor, tableName string, logger := log.With(zap.String("table", tableName), zap.Int64("auto_random", randomBase)) query := fmt.Sprintf("ALTER TABLE %s AUTO_RANDOM_BASE=%d", tableName, randomBase) task := logger.Begin(zap.InfoLevel, "alter table auto_random") - err := g.ExecuteWithLog(ctx, query, "alter table auto_random_base", logger.Logger) + err := g.ExecuteWithLog(ctx, query, "alter table auto_random_base", logger) task.End(zap.ErrorLevel, err) if err != nil { task.Error( diff --git a/lightning/restore/tidb_test.go b/lightning/restore/tidb_test.go index a7bf2cab3..bfe3b74b9 100644 --- a/lightning/restore/tidb_test.go +++ b/lightning/restore/tidb_test.go @@ -64,8 +64,9 @@ func (s *tidbSuite) TearDownTest(c *C) { } func (s *tidbSuite) TestCreateTableIfNotExistsStmt(c *C) { + dbName := "testdb" createTableIfNotExistsStmt := func(createTable, tableName string) string { - res, err := s.timgr.createTableIfNotExistsStmt(s.tiGlue.GetParser(), createTable, tableName) + res, err := createTableIfNotExistsStmt(s.tiGlue.GetParser(), createTable, dbName, tableName) c.Assert(err, IsNil) return res } @@ -73,61 +74,61 @@ func (s *tidbSuite) TestCreateTableIfNotExistsStmt(c *C) { c.Assert( createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo"), Equals, - "CREATE TABLE IF NOT EXISTS `foo` (`bar` TINYINT(1));", + "CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));", ) c.Assert( createTableIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo"), Equals, - "CREATE TABLE IF NOT EXISTS `foo` (`bar` TINYINT(1));", + "CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));", ) // case insensitive c.Assert( createTableIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo"), Equals, - "CREATE TABLE IF NOT EXISTS `fOo` (`bar` TINYINT(1));", + "CREATE TABLE IF NOT EXISTS `testdb`.`fOo` (`bar` TINYINT(1));", ) c.Assert( createTableIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO"), Equals, - "CREATE TABLE IF NOT EXISTS `FoO` (`bAR` TINYINT(1));", + "CREATE TABLE IF NOT EXISTS `testdb`.`FoO` (`bAR` TINYINT(1));", ) // only one "CREATE TABLE" is replaced c.Assert( createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo"), Equals, - "CREATE TABLE IF NOT EXISTS `foo` (`bar` INT(1) COMMENT 'CREATE TABLE');", + "CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE');", ) // upper case becomes shorter c.Assert( createTableIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ"), Equals, - "CREATE TABLE IF NOT EXISTS `ſ` (`ı` TINYINT(1));", + "CREATE TABLE IF NOT EXISTS `testdb`.`ſ` (`ı` TINYINT(1));", ) // upper case becomes longer c.Assert( createTableIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ"), Equals, - "CREATE TABLE IF NOT EXISTS `ɑ` (`ȿ` TINYINT(1));", + "CREATE TABLE IF NOT EXISTS `testdb`.`ɑ` (`ȿ` TINYINT(1));", ) // non-utf-8 c.Assert( createTableIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc"), Equals, - "CREATE TABLE IF NOT EXISTS `\xcc\xcc\xcc` (`ÝÝÝ` TINYINT(1));", + "CREATE TABLE IF NOT EXISTS `testdb`.`\xcc\xcc\xcc` (`ÝÝÝ` TINYINT(1));", ) // renaming a table c.Assert( createTableIfNotExistsStmt("create table foo(x int);", "ba`r"), Equals, - "CREATE TABLE IF NOT EXISTS `ba``r` (`x` INT);", + "CREATE TABLE IF NOT EXISTS `testdb`.`ba``r` (`x` INT);", ) // conditional comments @@ -138,7 +139,7 @@ func (s *tidbSuite) TestCreateTableIfNotExistsStmt(c *C) { CREATE TABLE x.y (z double) ENGINE=InnoDB AUTO_INCREMENT=8343230 DEFAULT CHARSET=utf8; `, "m"), Equals, - "SET NAMES 'binary';SET @@SESSION.`FOREIGN_KEY_CHECKS`=0;CREATE TABLE IF NOT EXISTS `m` (`z` DOUBLE) ENGINE = InnoDB AUTO_INCREMENT = 8343230 DEFAULT CHARACTER SET = UTF8;", + "SET NAMES 'binary';SET @@SESSION.`FOREIGN_KEY_CHECKS`=0;CREATE TABLE IF NOT EXISTS `testdb`.`m` (`z` DOUBLE) ENGINE = InnoDB AUTO_INCREMENT = 8343230 DEFAULT CHARACTER SET = UTF8;", ) } @@ -149,19 +150,16 @@ func (s *tidbSuite) TestInitSchema(c *C) { ExpectExec("CREATE DATABASE IF NOT EXISTS `db`"). WillReturnResult(sqlmock.NewResult(1, 1)) s.mockDB. - ExpectExec("USE `db`"). - WillReturnResult(sqlmock.NewResult(0, 0)) - s.mockDB. - ExpectExec("\\QCREATE TABLE IF NOT EXISTS `t1` (`a` INT PRIMARY KEY,`b` VARCHAR(200));\\E"). + ExpectExec("\\QCREATE TABLE IF NOT EXISTS `db`.`t1` (`a` INT PRIMARY KEY,`b` VARCHAR(200));\\E"). WillReturnResult(sqlmock.NewResult(2, 1)) s.mockDB. - ExpectExec("\\QSET @@SESSION.`FOREIGN_KEY_CHECKS`=0;CREATE TABLE IF NOT EXISTS `t2` (`xx` TEXT) AUTO_INCREMENT = 11203;\\E"). + ExpectExec("\\QSET @@SESSION.`FOREIGN_KEY_CHECKS`=0;CREATE TABLE IF NOT EXISTS `db`.`t2` (`xx` TEXT) AUTO_INCREMENT = 11203;\\E"). WillReturnResult(sqlmock.NewResult(2, 1)) s.mockDB. ExpectClose() s.mockDB.MatchExpectationsInOrder(false) // maps are unordered. - err := s.timgr.InitSchema(ctx, s.tiGlue, "db", map[string]string{ + err := InitSchema(ctx, s.tiGlue, "db", map[string]string{ "t1": "create table t1 (a int primary key, b varchar(200));", "t2": "/*!40014 SET FOREIGN_KEY_CHECKS=0*/;CREATE TABLE `db`.`t2` (xx TEXT) AUTO_INCREMENT=11203;", }) @@ -175,13 +173,10 @@ func (s *tidbSuite) TestInitSchemaSyntaxError(c *C) { s.mockDB. ExpectExec("CREATE DATABASE IF NOT EXISTS `db`"). WillReturnResult(sqlmock.NewResult(1, 1)) - s.mockDB. - ExpectExec("USE `db`"). - WillReturnResult(sqlmock.NewResult(0, 0)) s.mockDB. ExpectClose() - err := s.timgr.InitSchema(ctx, s.tiGlue, "db", map[string]string{ + err := InitSchema(ctx, s.tiGlue, "db", map[string]string{ "t1": "create table `t1` with invalid syntax;", }) c.Assert(err, NotNil) @@ -194,10 +189,7 @@ func (s *tidbSuite) TestInitSchemaUnsupportedSchemaError(c *C) { ExpectExec("CREATE DATABASE IF NOT EXISTS `db`"). WillReturnResult(sqlmock.NewResult(1, 1)) s.mockDB. - ExpectExec("USE `db`"). - WillReturnResult(sqlmock.NewResult(0, 0)) - s.mockDB. - ExpectExec("CREATE TABLE IF NOT EXISTS `t1`.*"). + ExpectExec("CREATE TABLE IF NOT EXISTS `db`.`t1`.*"). WillReturnError(&mysql.MySQLError{ Number: tmysql.ErrTooBigFieldlength, Message: "Column length too big", @@ -205,7 +197,7 @@ func (s *tidbSuite) TestInitSchemaUnsupportedSchemaError(c *C) { s.mockDB. ExpectClose() - err := s.timgr.InitSchema(ctx, s.tiGlue, "db", map[string]string{ + err := InitSchema(ctx, s.tiGlue, "db", map[string]string{ "t1": "create table `t1` (a VARCHAR(999999999));", }) c.Assert(err, ErrorMatches, ".*Column length too big.*") @@ -243,7 +235,7 @@ func (s *tidbSuite) TestLoadSchemaInfo(c *C) { tableInfos = append(tableInfos, info) } - loaded, err := s.timgr.LoadSchemaInfo(ctx, []*mydump.MDDatabaseMeta{{Name: "db"}}, func(ctx context.Context, schema string) ([]*model.TableInfo, error) { + loaded, err := LoadSchemaInfo(ctx, []*mydump.MDDatabaseMeta{{Name: "db"}}, func(ctx context.Context, schema string) ([]*model.TableInfo, error) { c.Assert(schema, Equals, "db") return tableInfos, nil }) @@ -272,7 +264,7 @@ func (s *tidbSuite) TestLoadSchemaInfo(c *C) { func (s *tidbSuite) TestLoadSchemaInfoMissing(c *C) { ctx := context.Background() - _, err := s.timgr.LoadSchemaInfo(ctx, []*mydump.MDDatabaseMeta{{Name: "asdjalsjdlas"}}, func(ctx context.Context, schema string) ([]*model.TableInfo, error) { + _, err := LoadSchemaInfo(ctx, []*mydump.MDDatabaseMeta{{Name: "asdjalsjdlas"}}, func(ctx context.Context, schema string) ([]*model.TableInfo, error) { return nil, errors.Errorf("[schema:1049]Unknown database '%s'", schema) }) c.Assert(err, ErrorMatches, ".*Unknown database.*")