From 5fee75661a4c9396d0f5e16b1376d9213c9d7285 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Wed, 17 Jan 2024 16:34:47 +0800 Subject: [PATCH] lightning: use SprintfWithIdentifier to process strings with identifier (#50434) ref pingcap/tidb#30699 --- br/pkg/lightning/checkpoints/BUILD.bazel | 5 - br/pkg/lightning/checkpoints/checkpoints.go | 181 ++-- .../checkpoints/checkpoints_sql_test.go | 92 +- .../lightning/checkpoints/glue_checkpoint.go | 826 ------------------ br/pkg/lightning/common/util.go | 26 +- br/pkg/lightning/errormanager/errormanager.go | 37 +- .../errormanager/errormanager_test.go | 2 +- br/pkg/lightning/importer/get_pre_info.go | 2 +- br/pkg/lightning/importer/import.go | 4 +- br/pkg/lightning/importer/meta_manager.go | 87 +- .../lightning/importer/meta_manager_test.go | 22 +- br/pkg/lightning/importer/table_import.go | 3 +- br/pkg/lightning/mydump/loader.go | 2 +- pkg/disttask/importinto/scheduler.go | 3 +- pkg/executor/importer/precheck.go | 2 +- 15 files changed, 264 insertions(+), 1030 deletions(-) delete mode 100644 br/pkg/lightning/checkpoints/glue_checkpoint.go diff --git a/br/pkg/lightning/checkpoints/BUILD.bazel b/br/pkg/lightning/checkpoints/BUILD.bazel index ed3c69ff04935..41506474aee82 100644 --- a/br/pkg/lightning/checkpoints/BUILD.bazel +++ b/br/pkg/lightning/checkpoints/BUILD.bazel @@ -4,7 +4,6 @@ go_library( name = "checkpoints", srcs = [ "checkpoints.go", - "glue_checkpoint.go", "tidb.go", ], importpath = "github.com/pingcap/tidb/br/pkg/lightning/checkpoints", @@ -18,11 +17,7 @@ go_library( "//br/pkg/lightning/verification", "//br/pkg/storage", "//br/pkg/version/build", - "//pkg/parser/ast", "//pkg/parser/model", - "//pkg/types", - "//pkg/util/chunk", - "//pkg/util/sqlexec", "@com_github_joho_sqltocsv//:sqltocsv", "@com_github_pingcap_errors//:errors", "@org_uber_go_zap//:zap", diff --git a/br/pkg/lightning/checkpoints/checkpoints.go b/br/pkg/lightning/checkpoints/checkpoints.go index aebbadc275190..65ea720a1f455 100644 --- a/br/pkg/lightning/checkpoints/checkpoints.go +++ b/br/pkg/lightning/checkpoints/checkpoints.go @@ -79,7 +79,6 @@ const ( const ( // Some frequently used table name or constants. allTables = "all" - stringLitAll = "'all'" columnTableName = "table_name" ) @@ -730,40 +729,39 @@ type MySQLCheckpointsDB struct { // NewMySQLCheckpointsDB creates a new MySQLCheckpointsDB. func NewMySQLCheckpointsDB(ctx context.Context, db *sql.DB, schemaName string) (*MySQLCheckpointsDB, error) { - schema := common.EscapeIdentifier(schemaName) sql := common.SQLWithRetry{ DB: db, Logger: log.FromContext(ctx).With(zap.String("schema", schemaName)), HideQueryLog: true, } - err := sql.Exec(ctx, "create checkpoints database", fmt.Sprintf(CreateDBTemplate, schema)) + err := sql.Exec(ctx, "create checkpoints database", common.SprintfWithIdentifiers(CreateDBTemplate, schemaName)) if err != nil { return nil, errors.Trace(err) } - err = sql.Exec(ctx, "create task checkpoints table", fmt.Sprintf(CreateTaskTableTemplate, schema, CheckpointTableNameTask)) + err = sql.Exec(ctx, "create task checkpoints table", common.SprintfWithIdentifiers(CreateTaskTableTemplate, schemaName, CheckpointTableNameTask)) if err != nil { return nil, errors.Trace(err) } - err = sql.Exec(ctx, "create table checkpoints table", fmt.Sprintf(CreateTableTableTemplate, schema, CheckpointTableNameTable)) + err = sql.Exec(ctx, "create table checkpoints table", common.SprintfWithIdentifiers(CreateTableTableTemplate, schemaName, CheckpointTableNameTable)) if err != nil { return nil, errors.Trace(err) } - err = sql.Exec(ctx, "create engine checkpoints table", fmt.Sprintf(CreateEngineTableTemplate, schema, CheckpointTableNameEngine)) + err = sql.Exec(ctx, "create engine checkpoints table", common.SprintfWithIdentifiers(CreateEngineTableTemplate, schemaName, CheckpointTableNameEngine)) if err != nil { return nil, errors.Trace(err) } - err = sql.Exec(ctx, "create chunks checkpoints table", fmt.Sprintf(CreateChunkTableTemplate, schema, CheckpointTableNameChunk)) + err = sql.Exec(ctx, "create chunks checkpoints table", common.SprintfWithIdentifiers(CreateChunkTableTemplate, schemaName, CheckpointTableNameChunk)) if err != nil { return nil, errors.Trace(err) } return &MySQLCheckpointsDB{ db: db, - schema: schema, + schema: schemaName, }, nil } @@ -773,7 +771,7 @@ func (cpdb *MySQLCheckpointsDB) Initialize(ctx context.Context, cfg *config.Conf // Since this step is not performance critical, we just insert the rows one-by-one. s := common.SQLWithRetry{DB: cpdb.db, Logger: log.FromContext(ctx)} err := s.Transact(ctx, "insert checkpoints", func(c context.Context, tx *sql.Tx) error { - taskStmt, err := tx.PrepareContext(c, fmt.Sprintf(InitTaskTemplate, cpdb.schema, CheckpointTableNameTask)) + taskStmt, err := tx.PrepareContext(c, common.SprintfWithIdentifiers(InitTaskTemplate, cpdb.schema, CheckpointTableNameTask)) if err != nil { return errors.Trace(err) } @@ -792,7 +790,7 @@ func (cpdb *MySQLCheckpointsDB) Initialize(ctx context.Context, cfg *config.Conf // statement to fail with an irrecoverable error. // We do need to capture the error is display a user friendly message // (multiple nodes cannot import the same table) though. - stmt, err := tx.PrepareContext(c, fmt.Sprintf(InitTableTemplate, cpdb.schema, CheckpointTableNameTable)) + stmt, err := tx.PrepareContext(c, common.SprintfWithIdentifiers(InitTableTemplate, cpdb.schema, CheckpointTableNameTable)) if err != nil { return errors.Trace(err) } @@ -829,7 +827,7 @@ func (cpdb *MySQLCheckpointsDB) TaskCheckpoint(ctx context.Context) (*TaskCheckp Logger: log.FromContext(ctx), } - taskQuery := fmt.Sprintf(ReadTaskTemplate, cpdb.schema, CheckpointTableNameTask) + taskQuery := common.SprintfWithIdentifiers(ReadTaskTemplate, cpdb.schema, CheckpointTableNameTask) taskCp := &TaskCheckpoint{} err := s.QueryRow(ctx, "fetch task checkpoint", taskQuery, &taskCp.TaskID, &taskCp.SourceDir, &taskCp.Backend, &taskCp.ImporterAddr, &taskCp.TiDBHost, &taskCp.TiDBPort, &taskCp.PdAddr, &taskCp.SortedKVDir, &taskCp.LightningVer) @@ -862,7 +860,7 @@ func (cpdb *MySQLCheckpointsDB) Get(ctx context.Context, tableName string) (*Tab err := s.Transact(ctx, "read checkpoint", func(c context.Context, tx *sql.Tx) error { // 1. Populate the engines. - engineQuery := fmt.Sprintf(ReadEngineTemplate, cpdb.schema, CheckpointTableNameEngine) + engineQuery := common.SprintfWithIdentifiers(ReadEngineTemplate, cpdb.schema, CheckpointTableNameEngine) engineRows, err := tx.QueryContext(c, engineQuery, tableName) if err != nil { return errors.Trace(err) @@ -887,7 +885,7 @@ func (cpdb *MySQLCheckpointsDB) Get(ctx context.Context, tableName string) (*Tab // 2. Populate the chunks. - chunkQuery := fmt.Sprintf(ReadChunkTemplate, cpdb.schema, CheckpointTableNameChunk) + chunkQuery := common.SprintfWithIdentifiers(ReadChunkTemplate, cpdb.schema, CheckpointTableNameChunk) chunkRows, err := tx.QueryContext(c, chunkQuery, tableName) if err != nil { return errors.Trace(err) @@ -924,7 +922,7 @@ func (cpdb *MySQLCheckpointsDB) Get(ctx context.Context, tableName string) (*Tab // 3. Fill in the remaining table info - tableQuery := fmt.Sprintf(ReadTableRemainTemplate, cpdb.schema, CheckpointTableNameTable) + tableQuery := common.SprintfWithIdentifiers(ReadTableRemainTemplate, cpdb.schema, CheckpointTableNameTable) tableRow := tx.QueryRowContext(c, tableQuery, tableName) var status uint8 @@ -956,14 +954,14 @@ func (cpdb *MySQLCheckpointsDB) InsertEngineCheckpoints(ctx context.Context, tab Logger: log.FromContext(ctx).With(zap.String("table", tableName)), } err := s.Transact(ctx, "update engine checkpoints", func(c context.Context, tx *sql.Tx) error { - engineStmt, err := tx.PrepareContext(c, fmt.Sprintf(ReplaceEngineTemplate, cpdb.schema, CheckpointTableNameEngine)) + engineStmt, err := tx.PrepareContext(c, common.SprintfWithIdentifiers(ReplaceEngineTemplate, cpdb.schema, CheckpointTableNameEngine)) if err != nil { return errors.Trace(err) } //nolint: errcheck defer engineStmt.Close() - chunkStmt, err := tx.PrepareContext(c, fmt.Sprintf(ReplaceChunkTemplate, cpdb.schema, CheckpointTableNameChunk)) + chunkStmt, err := tx.PrepareContext(c, common.SprintfWithIdentifiers(ReplaceChunkTemplate, cpdb.schema, CheckpointTableNameChunk)) if err != nil { return errors.Trace(err) } @@ -1003,11 +1001,11 @@ func (cpdb *MySQLCheckpointsDB) InsertEngineCheckpoints(ctx context.Context, tab // Update implements the DB interface. func (cpdb *MySQLCheckpointsDB) Update(taskCtx context.Context, checkpointDiffs map[string]*TableCheckpointDiff) error { - chunkQuery := fmt.Sprintf(UpdateChunkTemplate, cpdb.schema, CheckpointTableNameChunk) - rebaseQuery := fmt.Sprintf(UpdateTableRebaseTemplate, cpdb.schema, CheckpointTableNameTable) - tableStatusQuery := fmt.Sprintf(UpdateTableStatusTemplate, cpdb.schema, CheckpointTableNameTable) - tableChecksumQuery := fmt.Sprintf(UpdateTableChecksumTemplate, cpdb.schema, CheckpointTableNameTable) - engineStatusQuery := fmt.Sprintf(UpdateEngineTemplate, cpdb.schema, CheckpointTableNameEngine) + chunkQuery := common.SprintfWithIdentifiers(UpdateChunkTemplate, cpdb.schema, CheckpointTableNameChunk) + rebaseQuery := common.SprintfWithIdentifiers(UpdateTableRebaseTemplate, cpdb.schema, CheckpointTableNameTable) + tableStatusQuery := common.SprintfWithIdentifiers(UpdateTableStatusTemplate, cpdb.schema, CheckpointTableNameTable) + tableChecksumQuery := common.SprintfWithIdentifiers(UpdateTableChecksumTemplate, cpdb.schema, CheckpointTableNameTable) + engineStatusQuery := common.SprintfWithIdentifiers(UpdateEngineTemplate, cpdb.schema, CheckpointTableNameEngine) s := common.SQLWithRetry{DB: cpdb.db, Logger: log.FromContext(taskCtx)} return s.Transact(taskCtx, "update checkpoints", func(c context.Context, tx *sql.Tx) error { @@ -1499,12 +1497,12 @@ func (cpdb *MySQLCheckpointsDB) RemoveCheckpoint(ctx context.Context, tableName } if tableName == allTables { - return s.Exec(ctx, "remove all checkpoints", "DROP SCHEMA "+cpdb.schema) + return s.Exec(ctx, "remove all checkpoints", common.SprintfWithIdentifiers("DROP SCHEMA %s", cpdb.schema)) } - deleteChunkQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameChunk) - deleteEngineQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameEngine) - deleteTableQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameTable) + deleteChunkQuery := common.SprintfWithIdentifiers(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameChunk) + deleteEngineQuery := common.SprintfWithIdentifiers(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameEngine) + deleteTableQuery := common.SprintfWithIdentifiers(DeleteCheckpointRecordTemplate, cpdb.schema, CheckpointTableNameTable) return s.Transact(ctx, "remove checkpoints", func(c context.Context, tx *sql.Tx) error { if _, e := tx.ExecContext(c, deleteChunkQuery, tableName); e != nil { @@ -1522,16 +1520,13 @@ func (cpdb *MySQLCheckpointsDB) RemoveCheckpoint(ctx context.Context, tableName // MoveCheckpoints implements CheckpointsDB.MoveCheckpoints. func (cpdb *MySQLCheckpointsDB) MoveCheckpoints(ctx context.Context, taskID int64) error { - // The "cpdb.schema" is an escaped schema name of the form "`foo`". - // We use "x[1:len(x)-1]" instead of unescaping it to keep the - // double-backquotes (if any) intact. - newSchema := fmt.Sprintf("`%s.%d.bak`", cpdb.schema[1:len(cpdb.schema)-1], taskID) + newSchema := fmt.Sprintf("%s.%d.bak", cpdb.schema, taskID) s := common.SQLWithRetry{ DB: cpdb.db, Logger: log.FromContext(ctx).With(zap.Int64("taskID", taskID)), } - createSchemaQuery := "CREATE SCHEMA IF NOT EXISTS " + newSchema + createSchemaQuery := common.SprintfWithIdentifiers("CREATE SCHEMA IF NOT EXISTS %s", newSchema) if e := s.Exec(ctx, "create backup checkpoints schema", createSchemaQuery); e != nil { return e } @@ -1539,7 +1534,7 @@ func (cpdb *MySQLCheckpointsDB) MoveCheckpoints(ctx context.Context, taskID int6 CheckpointTableNameChunk, CheckpointTableNameEngine, CheckpointTableNameTable, CheckpointTableNameTask, } { - query := fmt.Sprintf("RENAME TABLE %[1]s.%[3]s TO %[2]s.%[3]s", cpdb.schema, newSchema, tbl) + query := common.SprintfWithIdentifiers("RENAME TABLE %[1]s.%[3]s TO %[2]s.%[3]s", cpdb.schema, newSchema, tbl) if e := s.Exec(ctx, fmt.Sprintf("move %s checkpoints table", tbl), query); e != nil { return e } @@ -1558,20 +1553,20 @@ func (cpdb *MySQLCheckpointsDB) GetLocalStoringTables(ctx context.Context) (map[ // 2. engine status is earlier than CheckpointStatusImported, and // 3. chunk has been read - query := fmt.Sprintf(` + query := common.SprintfWithIdentifiers(` SELECT DISTINCT t.table_name, c.engine_id FROM %s.%s t, %s.%s c, %s.%s e WHERE t.table_name = c.table_name AND t.table_name = e.table_name AND c.engine_id = e.engine_id - AND %d < t.status AND t.status < %d - AND %d < e.status AND e.status < %d + AND ? < t.status AND t.status < ? + AND ? < e.status AND e.status < ? AND c.pos > c.offset;`, - cpdb.schema, CheckpointTableNameTable, cpdb.schema, CheckpointTableNameChunk, cpdb.schema, CheckpointTableNameEngine, - CheckpointStatusMaxInvalid, CheckpointStatusIndexImported, - CheckpointStatusMaxInvalid, CheckpointStatusImported) + cpdb.schema, CheckpointTableNameTable, cpdb.schema, CheckpointTableNameChunk, cpdb.schema, CheckpointTableNameEngine) err := common.Retry("get local storing tables", log.FromContext(ctx), func() error { targetTables = make(map[string][]int32) - rows, err := cpdb.db.QueryContext(ctx, query) // #nosec G201 + rows, err := cpdb.db.QueryContext(ctx, query, + CheckpointStatusMaxInvalid, CheckpointStatusIndexImported, + CheckpointStatusMaxInvalid, CheckpointStatusImported) if err != nil { return errors.Trace(err) } @@ -1601,13 +1596,18 @@ func (cpdb *MySQLCheckpointsDB) GetLocalStoringTables(ctx context.Context) (map[ // IgnoreErrorCheckpoint implements CheckpointsDB.IgnoreErrorCheckpoint. func (cpdb *MySQLCheckpointsDB) IgnoreErrorCheckpoint(ctx context.Context, tableName string) error { - var colName string + var ( + query, query2 string + args []any + ) if tableName == allTables { - // This will expand to `WHERE 'all' = 'all'` and effectively allowing - // all tables to be included. - colName = stringLitAll + query = common.SprintfWithIdentifiers("UPDATE %s.%s SET status = ? WHERE status <= ?", cpdb.schema, CheckpointTableNameEngine) + query2 = common.SprintfWithIdentifiers("UPDATE %s.%s SET status = ? WHERE status <= ?", cpdb.schema, CheckpointTableNameTable) + args = []any{CheckpointStatusLoaded, CheckpointStatusMaxInvalid} } else { - colName = columnTableName + query = common.SprintfWithIdentifiers("UPDATE %s.%s SET status = ? WHERE table_name = ? AND status <= ?", cpdb.schema, CheckpointTableNameEngine) + query2 = common.SprintfWithIdentifiers("UPDATE %s.%s SET status = ? WHERE table_name = ? AND status <= ?", cpdb.schema, CheckpointTableNameTable) + args = []any{CheckpointStatusLoaded, tableName, CheckpointStatusMaxInvalid} } s := common.SQLWithRetry{ @@ -1615,12 +1615,10 @@ func (cpdb *MySQLCheckpointsDB) IgnoreErrorCheckpoint(ctx context.Context, table Logger: log.FromContext(ctx).With(zap.String("table", tableName)), } err := s.Transact(ctx, "ignore error checkpoints", func(c context.Context, tx *sql.Tx) error { - query := fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE %s = ? AND status <= ?", cpdb.schema, CheckpointTableNameEngine, colName) - if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, tableName, CheckpointStatusMaxInvalid); e != nil { + if _, e := tx.ExecContext(c, query, args...); e != nil { return errors.Trace(e) } - query = fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE %s = ? AND status <= ?", cpdb.schema, CheckpointTableNameTable, colName) - if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, tableName, CheckpointStatusMaxInvalid); e != nil { + if _, e := tx.ExecContext(c, query2, args...); e != nil { return errors.Trace(e) } return nil @@ -1630,41 +1628,54 @@ func (cpdb *MySQLCheckpointsDB) IgnoreErrorCheckpoint(ctx context.Context, table // DestroyErrorCheckpoint implements CheckpointsDB.DestroyErrorCheckpoint. func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tableName string) ([]DestroyedTableCheckpoint, error) { - var colName, aliasedColName string - + var ( + selectQuery, deleteChunkQuery, deleteEngineQuery, deleteTableQuery string + args []any + ) if tableName == allTables { - // These will expand to `WHERE 'all' = 'all'` and effectively allowing - // all tables to be included. - colName = stringLitAll - aliasedColName = stringLitAll + selectQuery = common.SprintfWithIdentifiers(` + SELECT + t.table_name, + COALESCE(MIN(e.engine_id), 0), + COALESCE(MAX(e.engine_id), -1) + FROM %[1]s.%[2]s t + LEFT JOIN %[1]s.%[3]s e ON t.table_name = e.table_name + WHERE t.status <= ? + GROUP BY t.table_name; + `, cpdb.schema, CheckpointTableNameTable, CheckpointTableNameEngine) + deleteChunkQuery = common.SprintfWithIdentifiers(` + DELETE FROM %[1]s.%[2]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[3]s WHERE status <= ?) + `, cpdb.schema, CheckpointTableNameChunk, CheckpointTableNameTable) + deleteEngineQuery = common.SprintfWithIdentifiers(` + DELETE FROM %[1]s.%[2]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[3]s WHERE status <= ?) + `, cpdb.schema, CheckpointTableNameEngine, CheckpointTableNameTable) + deleteTableQuery = common.SprintfWithIdentifiers(` + DELETE FROM %s.%s status <= ? + `, cpdb.schema, CheckpointTableNameTable) + args = []any{CheckpointStatusMaxInvalid} } else { - colName = columnTableName - aliasedColName = "t.table_name" + selectQuery = common.SprintfWithIdentifiers(` + SELECT + t.table_name, + COALESCE(MIN(e.engine_id), 0), + COALESCE(MAX(e.engine_id), -1) + FROM %[1]s.%[2]s t + LEFT JOIN %[1]s.%[3]s e ON t.table_name = e.table_name + WHERE t.table_name = ? AND t.status <= ? + GROUP BY t.table_name; + `, cpdb.schema, CheckpointTableNameTable, CheckpointTableNameEngine) + deleteChunkQuery = common.SprintfWithIdentifiers(` + DELETE FROM %[1]s.%[2]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[3]s WHERE table_name = ? AND status <= ?) + `, cpdb.schema, CheckpointTableNameChunk, CheckpointTableNameTable) + deleteEngineQuery = common.SprintfWithIdentifiers(` + DELETE FROM %[1]s.%[2]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[3]s WHERE table_name = ? AND status <= ?) + `, cpdb.schema, CheckpointTableNameEngine, CheckpointTableNameTable) + deleteTableQuery = common.SprintfWithIdentifiers(` + DELETE FROM %s.%s WHERE table_name = ? AND status <= ? + `, cpdb.schema, CheckpointTableNameTable) + args = []any{tableName, CheckpointStatusMaxInvalid} } - selectQuery := fmt.Sprintf(` - SELECT - t.table_name, - COALESCE(MIN(e.engine_id), 0), - COALESCE(MAX(e.engine_id), -1) - FROM %[1]s.%[3]s t - LEFT JOIN %[1]s.%[4]s e ON t.table_name = e.table_name - WHERE %[2]s = ? AND t.status <= ? - GROUP BY t.table_name; - `, cpdb.schema, aliasedColName, CheckpointTableNameTable, CheckpointTableNameEngine) - - deleteChunkQuery := fmt.Sprintf(` - DELETE FROM %[1]s.%[3]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[4]s WHERE %[2]s = ? AND status <= ?) - `, cpdb.schema, colName, CheckpointTableNameChunk, CheckpointTableNameTable) - - deleteEngineQuery := fmt.Sprintf(` - DELETE FROM %[1]s.%[3]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[4]s WHERE %[2]s = ? AND status <= ?) - `, cpdb.schema, colName, CheckpointTableNameEngine, CheckpointTableNameTable) - - deleteTableQuery := fmt.Sprintf(` - DELETE FROM %s.%s WHERE %s = ? AND status <= ? - `, cpdb.schema, CheckpointTableNameTable, colName) - var targetTables []DestroyedTableCheckpoint s := common.SQLWithRetry{ @@ -1674,7 +1685,7 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl err := s.Transact(ctx, "destroy error checkpoints", func(c context.Context, tx *sql.Tx) error { // Obtain the list of tables targetTables = nil - rows, e := tx.QueryContext(c, selectQuery, tableName, CheckpointStatusMaxInvalid) + rows, e := tx.QueryContext(c, selectQuery, args...) if e != nil { return errors.Trace(e) } @@ -1692,13 +1703,13 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl } // Delete the checkpoints - if _, e := tx.ExecContext(c, deleteChunkQuery, tableName, CheckpointStatusMaxInvalid); e != nil { + if _, e := tx.ExecContext(c, deleteChunkQuery, args...); e != nil { return errors.Trace(e) } - if _, e := tx.ExecContext(c, deleteEngineQuery, tableName, CheckpointStatusMaxInvalid); e != nil { + if _, e := tx.ExecContext(c, deleteEngineQuery, args...); e != nil { return errors.Trace(e) } - if _, e := tx.ExecContext(c, deleteTableQuery, tableName, CheckpointStatusMaxInvalid); e != nil { + if _, e := tx.ExecContext(c, deleteTableQuery, args...); e != nil { return errors.Trace(e) } return nil @@ -1715,7 +1726,7 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl //nolint:rowserrcheck // sqltocsv.Write will check this. func (cpdb *MySQLCheckpointsDB) DumpTables(ctx context.Context, writer io.Writer) error { //nolint: rowserrcheck - rows, err := cpdb.db.QueryContext(ctx, fmt.Sprintf(` + rows, err := cpdb.db.QueryContext(ctx, common.SprintfWithIdentifiers(` SELECT task_id, table_name, @@ -1740,7 +1751,7 @@ func (cpdb *MySQLCheckpointsDB) DumpTables(ctx context.Context, writer io.Writer //nolint:rowserrcheck // sqltocsv.Write will check this. func (cpdb *MySQLCheckpointsDB) DumpEngines(ctx context.Context, writer io.Writer) error { //nolint: rowserrcheck - rows, err := cpdb.db.QueryContext(ctx, fmt.Sprintf(` + rows, err := cpdb.db.QueryContext(ctx, common.SprintfWithIdentifiers(` SELECT table_name, engine_id, @@ -1763,7 +1774,7 @@ func (cpdb *MySQLCheckpointsDB) DumpEngines(ctx context.Context, writer io.Write //nolint:rowserrcheck // sqltocsv.Write will check this. func (cpdb *MySQLCheckpointsDB) DumpChunks(ctx context.Context, writer io.Writer) error { //nolint: rowserrcheck - rows, err := cpdb.db.QueryContext(ctx, fmt.Sprintf(` + rows, err := cpdb.db.QueryContext(ctx, common.SprintfWithIdentifiers(` SELECT table_name, path, diff --git a/br/pkg/lightning/checkpoints/checkpoints_sql_test.go b/br/pkg/lightning/checkpoints/checkpoints_sql_test.go index 6c8aa823edcd1..d85bfb69c9156 100644 --- a/br/pkg/lightning/checkpoints/checkpoints_sql_test.go +++ b/br/pkg/lightning/checkpoints/checkpoints_sql_test.go @@ -36,16 +36,16 @@ func newCPSQLSuite(t *testing.T) *cpSQLSuite { ExpectExec("CREATE DATABASE IF NOT EXISTS `mock-schema`"). WillReturnResult(sqlmock.NewResult(1, 1)) s.mock. - ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.task_v\\d+ .+"). + ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.`task_v\\d+` .+"). WillReturnResult(sqlmock.NewResult(2, 1)) s.mock. - ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.table_v\\d+ .+"). + ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.`table_v\\d+` .+"). WillReturnResult(sqlmock.NewResult(3, 1)) s.mock. - ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.engine_v\\d+ .+"). + ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.`engine_v\\d+` .+"). WillReturnResult(sqlmock.NewResult(4, 1)) s.mock. - ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.chunk_v\\d+ .+"). + ExpectExec("CREATE TABLE IF NOT EXISTS `mock-schema`\\.`chunk_v\\d+` .+"). WillReturnResult(sqlmock.NewResult(5, 1)) cpdb, err := checkpoints.NewMySQLCheckpointsDB(context.Background(), s.db, "mock-schema") @@ -82,12 +82,12 @@ func TestNormalOperations(t *testing.T) { s.mock.ExpectBegin() initializeStmt := s.mock.ExpectPrepare( - "REPLACE INTO `mock-schema`\\.task_v\\d+") + "REPLACE INTO `mock-schema`\\.`task_v\\d+`") initializeStmt.ExpectExec(). WithArgs(123, "/data", "local", "127.0.0.1:8287", "127.0.0.1", 4000, "127.0.0.1:2379", "/tmp/sorted-kv", build.ReleaseVersion). WillReturnResult(sqlmock.NewResult(6, 1)) initializeStmt = s.mock. - ExpectPrepare("INSERT INTO `mock-schema`\\.table_v\\d+") + ExpectPrepare("INSERT INTO `mock-schema`\\.`table_v\\d+`") initializeStmt.ExpectExec(). WithArgs(123, "`db1`.`t1`", sqlmock.AnyArg(), int64(1), t1Info). WillReturnResult(sqlmock.NewResult(7, 1)) @@ -142,7 +142,7 @@ func TestNormalOperations(t *testing.T) { s.mock.ExpectBegin() insertEngineStmt := s.mock. - ExpectPrepare("REPLACE INTO `mock-schema`\\.engine_v\\d+ .+") + ExpectPrepare("REPLACE INTO `mock-schema`\\.`engine_v\\d+` .+") insertEngineStmt. ExpectExec(). WithArgs("`db1`.`t2`", 0, 30). @@ -152,7 +152,7 @@ func TestNormalOperations(t *testing.T) { WithArgs("`db1`.`t2`", -1, 30). WillReturnResult(sqlmock.NewResult(9, 1)) insertChunkStmt := s.mock. - ExpectPrepare("REPLACE INTO `mock-schema`\\.chunk_v\\d+ .+") + ExpectPrepare("REPLACE INTO `mock-schema`\\.`chunk_v\\d+` .+") insertChunkStmt. ExpectExec(). WithArgs("`db1`.`t2`", 0, "/tmp/path/1.sql", 0, mydump.SourceTypeSQL, 0, "", 123, []byte("null"), 12, 102400, 1, 5000, 1234567890). @@ -223,7 +223,7 @@ func TestNormalOperations(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectPrepare("UPDATE `mock-schema`\\.chunk_v\\d+ SET pos = .+"). + ExpectPrepare("UPDATE `mock-schema`\\.`chunk_v\\d+` SET pos = .+"). ExpectExec(). WithArgs( 55904, 681, 4491, 586, 486070148917, []byte("null"), @@ -231,22 +231,22 @@ func TestNormalOperations(t *testing.T) { ). WillReturnResult(sqlmock.NewResult(11, 1)) s.mock. - ExpectPrepare("UPDATE `mock-schema`\\.table_v\\d+ SET alloc_base = .+"). + ExpectPrepare("UPDATE `mock-schema`\\.`table_v\\d+` SET alloc_base = .+"). ExpectExec(). WithArgs(132861, "`db1`.`t2`"). WillReturnResult(sqlmock.NewResult(12, 1)) s.mock. - ExpectPrepare("UPDATE `mock-schema`\\.engine_v\\d+ SET status = .+"). + ExpectPrepare("UPDATE `mock-schema`\\.`engine_v\\d+` SET status = .+"). ExpectExec(). WithArgs(120, "`db1`.`t2`", 0). WillReturnResult(sqlmock.NewResult(13, 1)) s.mock. - ExpectPrepare("UPDATE `mock-schema`\\.table_v\\d+ SET status = .+"). + ExpectPrepare("UPDATE `mock-schema`\\.`table_v\\d+` SET status = .+"). ExpectExec(). WithArgs(60, "`db1`.`t2`"). WillReturnResult(sqlmock.NewResult(14, 1)) s.mock. - ExpectPrepare("UPDATE `mock-schema`\\.table_v\\d+ SET kv_bytes = .+"). + ExpectPrepare("UPDATE `mock-schema`\\.`table_v\\d+` SET kv_bytes = .+"). ExpectExec(). WithArgs(4492, 686, 486070148910, "`db1`.`t2`"). WillReturnResult(sqlmock.NewResult(15, 1)) @@ -262,7 +262,7 @@ func TestNormalOperations(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectQuery("SELECT .+ FROM `mock-schema`\\.engine_v\\d+"). + ExpectQuery("SELECT .+ FROM `mock-schema`\\.`engine_v\\d+`"). WithArgs("`db1`.`t2`"). WillReturnRows( sqlmock.NewRows([]string{"engine_id", "status"}). @@ -270,7 +270,7 @@ func TestNormalOperations(t *testing.T) { AddRow(-1, 30), ) s.mock. - ExpectQuery("SELECT (?s:.+) FROM `mock-schema`\\.chunk_v\\d+"). + ExpectQuery("SELECT (?s:.+) FROM `mock-schema`\\.`chunk_v\\d+`"). WithArgs("`db1`.`t2`"). WillReturnRows( sqlmock.NewRows([]string{ @@ -285,7 +285,7 @@ func TestNormalOperations(t *testing.T) { ), ) s.mock. - ExpectQuery("SELECT .+ FROM `mock-schema`\\.table_v\\d+"). + ExpectQuery("SELECT .+ FROM `mock-schema`\\.`table_v\\d+`"). WithArgs("`db1`.`t2`"). WillReturnRows( sqlmock.NewRows([]string{"status", "alloc_base", "table_id", "table_info", "kv_bytes", "kv_kvs", "kv_checksum"}). @@ -345,11 +345,11 @@ func TestRemoveAllCheckpoints_SQL(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectQuery("SELECT .+ FROM `mock-schema`\\.engine_v\\d+"). + ExpectQuery("SELECT .+ FROM `mock-schema`\\.`engine_v\\d+`"). WithArgs("`db1`.`t2`"). WillReturnRows(sqlmock.NewRows([]string{"engine_id", "status"})) s.mock. - ExpectQuery("SELECT (?s:.+) FROM `mock-schema`\\.chunk_v\\d+"). + ExpectQuery("SELECT (?s:.+) FROM `mock-schema`\\.`chunk_v\\d+`"). WithArgs("`db1`.`t2`"). WillReturnRows( sqlmock.NewRows([]string{ @@ -358,7 +358,7 @@ func TestRemoveAllCheckpoints_SQL(t *testing.T) { "kvc_bytes", "kvc_kvs", "kvc_checksum", "unix_timestamp(create_time)", })) s.mock. - ExpectQuery("SELECT .+ FROM `mock-schema`\\.table_v\\d+"). + ExpectQuery("SELECT .+ FROM `mock-schema`\\.`table_v\\d+`"). WithArgs("`db1`.`t2`"). WillReturnRows(sqlmock.NewRows([]string{"status", "alloc_base", "table_id"})) s.mock.ExpectRollback() @@ -373,15 +373,15 @@ func TestRemoveOneCheckpoint_SQL(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.chunk_v\\d+ WHERE table_name = \\?"). + ExpectExec("DELETE FROM `mock-schema`\\.`chunk_v\\d+` WHERE table_name = \\?"). WithArgs("`db1`.`t2`"). WillReturnResult(sqlmock.NewResult(0, 4)) s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.engine_v\\d+ WHERE table_name = \\?"). + ExpectExec("DELETE FROM `mock-schema`\\.`engine_v\\d+` WHERE table_name = \\?"). WithArgs("`db1`.`t2`"). WillReturnResult(sqlmock.NewResult(0, 2)) s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.table_v\\d+ WHERE table_name = \\?"). + ExpectExec("DELETE FROM `mock-schema`\\.`table_v\\d+` WHERE table_name = \\?"). WithArgs("`db1`.`t2`"). WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() @@ -395,12 +395,12 @@ func TestIgnoreAllErrorCheckpoints_SQL(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE 'all' = \\? AND status <= \\?"). - WithArgs(checkpoints.CheckpointStatusLoaded, sqlmock.AnyArg(), 25). + ExpectExec("UPDATE `mock-schema`\\.`engine_v\\d+` SET status = \\? WHERE status <= \\?"). + WithArgs(checkpoints.CheckpointStatusLoaded, 25). WillReturnResult(sqlmock.NewResult(5, 3)) s.mock. - ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE 'all' = \\? AND status <= \\?"). - WithArgs(checkpoints.CheckpointStatusLoaded, sqlmock.AnyArg(), 25). + ExpectExec("UPDATE `mock-schema`\\.`table_v\\d+` SET status = \\? WHERE status <= \\?"). + WithArgs(checkpoints.CheckpointStatusLoaded, 25). WillReturnResult(sqlmock.NewResult(6, 2)) s.mock.ExpectCommit() @@ -413,11 +413,11 @@ func TestIgnoreOneErrorCheckpoint(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE table_name = \\? AND status <= \\?"). + ExpectExec("UPDATE `mock-schema`\\.`engine_v\\d+` SET status = \\? WHERE table_name = \\? AND status <= \\?"). WithArgs(checkpoints.CheckpointStatusLoaded, "`db1`.`t2`", 25). WillReturnResult(sqlmock.NewResult(5, 2)) s.mock. - ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE table_name = \\? AND status <= \\?"). + ExpectExec("UPDATE `mock-schema`\\.`table_v\\d+` SET status = \\? WHERE table_name = \\? AND status <= \\?"). WithArgs(checkpoints.CheckpointStatusLoaded, "`db1`.`t2`", 25). WillReturnResult(sqlmock.NewResult(6, 1)) s.mock.ExpectCommit() @@ -431,23 +431,23 @@ func TestDestroyAllErrorCheckpoints_SQL(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectQuery("SELECT (?s:.+)'all' = \\?"). - WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). + ExpectQuery("SELECT (?s:.+)"). + WithArgs(sqlmock.AnyArg()). WillReturnRows( sqlmock.NewRows([]string{"table_name", "__min__", "__max__"}). AddRow("`db1`.`t2`", -1, 0), ) s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.chunk_v\\d+ WHERE table_name IN .+ 'all' = \\?"). - WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). + ExpectExec("DELETE FROM `mock-schema`\\.`chunk_v\\d+` WHERE table_name IN"). + WithArgs(sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 5)) s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.engine_v\\d+ WHERE table_name IN .+ 'all' = \\?"). - WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). + ExpectExec("DELETE FROM `mock-schema`\\.`engine_v\\d+` WHERE table_name IN"). + WithArgs(sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 3)) s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.table_v\\d+ WHERE 'all' = \\?"). - WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). + ExpectExec("DELETE FROM `mock-schema`\\.`table_v\\d+`"). + WithArgs(sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 2)) s.mock.ExpectCommit() @@ -472,15 +472,15 @@ func TestDestroyOneErrorCheckpoints(t *testing.T) { AddRow("`db1`.`t2`", -1, 0), ) s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.chunk_v\\d+ WHERE .+table_name = \\?"). + ExpectExec("DELETE FROM `mock-schema`\\.`chunk_v\\d+` WHERE .+table_name = \\?"). WithArgs("`db1`.`t2`", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 4)) s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.engine_v\\d+ WHERE .+table_name = \\?"). + ExpectExec("DELETE FROM `mock-schema`\\.`engine_v\\d+` WHERE .+table_name = \\?"). WithArgs("`db1`.`t2`", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 2)) s.mock. - ExpectExec("DELETE FROM `mock-schema`\\.table_v\\d+ WHERE table_name = \\?"). + ExpectExec("DELETE FROM `mock-schema`\\.`table_v\\d+` WHERE table_name = \\?"). WithArgs("`db1`.`t2`", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() @@ -500,7 +500,7 @@ func TestDump(t *testing.T) { tm := time.Unix(1555555555, 0).UTC() s.mock. - ExpectQuery("SELECT (?s:.+) FROM `mock-schema`\\.chunk_v\\d+"). + ExpectQuery("SELECT (?s:.+) FROM `mock-schema`\\.`chunk_v\\d+`"). WillReturnRows( sqlmock.NewRows([]string{ "table_name", "path", "offset", "type", "compression", "sort_key", "file_size", "columns", @@ -525,7 +525,7 @@ func TestDump(t *testing.T) { ) s.mock. - ExpectQuery("SELECT .+ FROM `mock-schema`\\.engine_v\\d+"). + ExpectQuery("SELECT .+ FROM `mock-schema`\\.`engine_v\\d+`"). WillReturnRows( sqlmock.NewRows([]string{"table_name", "engine_id", "status", "create_time", "update_time"}). AddRow("`db1`.`t2`", -1, 30, tm, tm). @@ -541,7 +541,7 @@ func TestDump(t *testing.T) { csvBuilder.String()) s.mock. - ExpectQuery("SELECT .+ FROM `mock-schema`\\.table_v\\d+"). + ExpectQuery("SELECT .+ FROM `mock-schema`\\.`table_v\\d+`"). WillReturnRows( sqlmock.NewRows([]string{"task_id", "table_name", "hash", "status", "alloc_base", "create_time", "update_time"}). AddRow(1555555555, "`db1`.`t2`", 0, 90, 132861, tm, tm), @@ -564,16 +564,16 @@ func TestMoveCheckpoints(t *testing.T) { ExpectExec("CREATE SCHEMA IF NOT EXISTS `mock-schema\\.12345678\\.bak`"). WillReturnResult(sqlmock.NewResult(1, 1)) s.mock. - ExpectExec("RENAME TABLE `mock-schema`\\.chunk_v\\d+ TO `mock-schema\\.12345678\\.bak`\\.chunk_v\\d+"). + ExpectExec("RENAME TABLE `mock-schema`\\.`chunk_v\\d+` TO `mock-schema\\.12345678\\.bak`\\.`chunk_v\\d+`"). WillReturnResult(sqlmock.NewResult(0, 1)) s.mock. - ExpectExec("RENAME TABLE `mock-schema`\\.engine_v\\d+ TO `mock-schema\\.12345678\\.bak`\\.engine_v\\d+"). + ExpectExec("RENAME TABLE `mock-schema`\\.`engine_v\\d+` TO `mock-schema\\.12345678\\.bak`\\.`engine_v\\d+`"). WillReturnResult(sqlmock.NewResult(0, 1)) s.mock. - ExpectExec("RENAME TABLE `mock-schema`\\.table_v\\d+ TO `mock-schema\\.12345678\\.bak`\\.table_v\\d+"). + ExpectExec("RENAME TABLE `mock-schema`\\.`table_v\\d+` TO `mock-schema\\.12345678\\.bak`\\.`table_v\\d+`"). WillReturnResult(sqlmock.NewResult(0, 1)) s.mock. - ExpectExec("RENAME TABLE `mock-schema`\\.task_v\\d+ TO `mock-schema\\.12345678\\.bak`\\.task_v\\d+"). + ExpectExec("RENAME TABLE `mock-schema`\\.`task_v\\d+` TO `mock-schema\\.12345678\\.bak`\\.`task_v\\d+`"). WillReturnResult(sqlmock.NewResult(0, 1)) err := s.cpdb.MoveCheckpoints(ctx, 12345678) diff --git a/br/pkg/lightning/checkpoints/glue_checkpoint.go b/br/pkg/lightning/checkpoints/glue_checkpoint.go deleted file mode 100644 index a1a8cd96fed3b..0000000000000 --- a/br/pkg/lightning/checkpoints/glue_checkpoint.go +++ /dev/null @@ -1,826 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package checkpoints - -import ( - "context" - "encoding/json" - "fmt" - "io" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/tidb/br/pkg/lightning/common" - "github.com/pingcap/tidb/br/pkg/lightning/config" - "github.com/pingcap/tidb/br/pkg/lightning/log" - "github.com/pingcap/tidb/br/pkg/lightning/mydump" - verify "github.com/pingcap/tidb/br/pkg/lightning/verification" - "github.com/pingcap/tidb/br/pkg/version/build" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "go.uber.org/zap" -) - -// Session is a wrapper of TiDB's session. -type Session interface { - Close() - Execute(context.Context, string) ([]sqlexec.RecordSet, error) - CommitTxn(context.Context) error - RollbackTxn(context.Context) - PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error) - ExecutePreparedStmt(ctx context.Context, stmtID uint32, param []types.Datum) (sqlexec.RecordSet, error) - DropPreparedStmt(stmtID uint32) error -} - -// GlueCheckpointsDB is almost same with MySQLCheckpointsDB, but it uses TiDB's internal data structure which requires a -// lot to keep same with database/sql. -// TODO: Encapsulate Begin/Commit/Rollback txn, form SQL with args and query/iter/scan TiDB's RecordSet into a interface -// to reuse MySQLCheckpointsDB. -type GlueCheckpointsDB struct { - // getSessionFunc will get a new session from TiDB - getSessionFunc func() (Session, error) - schema string -} - -var _ DB = (*GlueCheckpointsDB)(nil) - -// dropPreparedStmt drops the statement and when meet an error, -// print an error message. -func dropPreparedStmt(ctx context.Context, session Session, stmtID uint32) { - if err := session.DropPreparedStmt(stmtID); err != nil { - log.FromContext(ctx).Error("failed to drop prepared statement", log.ShortError(err)) - } -} - -// NewGlueCheckpointsDB creates a new GlueCheckpointsDB. -func NewGlueCheckpointsDB(ctx context.Context, se Session, f func() (Session, error), schemaName string) (*GlueCheckpointsDB, error) { - var escapedSchemaName strings.Builder - common.WriteMySQLIdentifier(&escapedSchemaName, schemaName) - schema := escapedSchemaName.String() - logger := log.FromContext(ctx).With(zap.String("schema", schemaName)) - - sql := fmt.Sprintf(CreateDBTemplate, schema) - err := common.Retry("create checkpoints database", logger, func() error { - _, err := se.Execute(ctx, sql) - return err - }) - if err != nil { - return nil, errors.Trace(err) - } - - sql = fmt.Sprintf(CreateTaskTableTemplate, schema, CheckpointTableNameTask) - err = common.Retry("create task checkpoints table", logger, func() error { - _, err := se.Execute(ctx, sql) - return err - }) - if err != nil { - return nil, errors.Trace(err) - } - - sql = fmt.Sprintf(CreateTableTableTemplate, schema, CheckpointTableNameTable) - err = common.Retry("create table checkpoints table", logger, func() error { - _, err := se.Execute(ctx, sql) - return err - }) - if err != nil { - return nil, errors.Trace(err) - } - - sql = fmt.Sprintf(CreateEngineTableTemplate, schema, CheckpointTableNameEngine) - err = common.Retry("create engine checkpoints table", logger, func() error { - _, err := se.Execute(ctx, sql) - return err - }) - if err != nil { - return nil, errors.Trace(err) - } - - sql = fmt.Sprintf(CreateChunkTableTemplate, schema, CheckpointTableNameChunk) - err = common.Retry("create chunks checkpoints table", logger, func() error { - _, err := se.Execute(ctx, sql) - return err - }) - if err != nil { - return nil, errors.Trace(err) - } - - return &GlueCheckpointsDB{ - getSessionFunc: f, - schema: schema, - }, nil -} - -// Initialize implements CheckpointsDB.Initialize. -func (g GlueCheckpointsDB) Initialize(ctx context.Context, cfg *config.Config, dbInfo map[string]*TidbDBInfo) error { - logger := log.FromContext(ctx) - se, err := g.getSessionFunc() - if err != nil { - return errors.Trace(err) - } - defer se.Close() - - err = Transact(ctx, "insert checkpoints", se, logger, func(c context.Context, s Session) error { - stmtID, _, _, err := s.PrepareStmt(fmt.Sprintf(InitTaskTemplate, g.schema, CheckpointTableNameTask)) - if err != nil { - return errors.Trace(err) - } - defer dropPreparedStmt(ctx, s, stmtID) - _, err = s.ExecutePreparedStmt(c, stmtID, []types.Datum{ - types.NewIntDatum(cfg.TaskID), - types.NewStringDatum(cfg.Mydumper.SourceDir), - types.NewStringDatum(cfg.TikvImporter.Backend), - types.NewStringDatum(cfg.TikvImporter.Addr), - types.NewStringDatum(cfg.TiDB.Host), - types.NewIntDatum(int64(cfg.TiDB.Port)), - types.NewStringDatum(cfg.TiDB.PdAddr), - types.NewStringDatum(cfg.TikvImporter.SortedKVDir), - types.NewStringDatum(build.ReleaseVersion), - }) - if err != nil { - return errors.Trace(err) - } - - stmtID2, _, _, err := s.PrepareStmt(fmt.Sprintf(InitTableTemplate, g.schema, CheckpointTableNameTable)) - if err != nil { - return errors.Trace(err) - } - defer dropPreparedStmt(ctx, s, stmtID2) - - for _, db := range dbInfo { - for _, table := range db.Tables { - tableName := common.UniqueTable(db.Name, table.Name) - _, err = s.ExecutePreparedStmt(c, stmtID2, []types.Datum{ - types.NewIntDatum(cfg.TaskID), - types.NewStringDatum(tableName), - types.NewIntDatum(0), - types.NewIntDatum(table.ID), - }) - if err != nil { - return errors.Trace(err) - } - } - } - return nil - }) - return errors.Trace(err) -} - -// TaskCheckpoint implements CheckpointsDB.TaskCheckpoint. -func (g GlueCheckpointsDB) TaskCheckpoint(ctx context.Context) (*TaskCheckpoint, error) { - logger := log.FromContext(ctx) - sql := fmt.Sprintf(ReadTaskTemplate, g.schema, CheckpointTableNameTask) - se, err := g.getSessionFunc() - if err != nil { - return nil, errors.Trace(err) - } - defer se.Close() - - var taskCp *TaskCheckpoint - err = common.Retry("fetch task checkpoint", logger, func() error { - rs, err := se.Execute(ctx, sql) - if err != nil { - return errors.Trace(err) - } - r := rs[0] - //nolint: errcheck - defer r.Close() - req := r.NewChunk(nil) - err = r.Next(ctx, req) - if err != nil { - return err - } - if req.NumRows() == 0 { - return nil - } - - row := req.GetRow(0) - taskCp = &TaskCheckpoint{} - taskCp.TaskID = row.GetInt64(0) - taskCp.SourceDir = row.GetString(1) - taskCp.Backend = row.GetString(2) - taskCp.ImporterAddr = row.GetString(3) - taskCp.TiDBHost = row.GetString(4) - taskCp.TiDBPort = int(row.GetInt64(5)) - taskCp.PdAddr = row.GetString(6) - taskCp.SortedKVDir = row.GetString(7) - taskCp.LightningVer = row.GetString(8) - return nil - }) - if err != nil { - return nil, errors.Trace(err) - } - return taskCp, nil -} - -// Get implements CheckpointsDB.Get. -func (g GlueCheckpointsDB) Get(ctx context.Context, tableName string) (*TableCheckpoint, error) { - cp := &TableCheckpoint{ - Engines: map[int32]*EngineCheckpoint{}, - } - logger := log.FromContext(ctx).With(zap.String("table", tableName)) - se, err := g.getSessionFunc() - if err != nil { - return nil, errors.Trace(err) - } - defer se.Close() - - tableName = common.InterpolateMySQLString(tableName) - err = Transact(ctx, "read checkpoint", se, logger, func(c context.Context, s Session) error { - // 1. Populate the engines. - sql := fmt.Sprintf(ReadEngineTemplate, g.schema, CheckpointTableNameEngine) - sql = strings.ReplaceAll(sql, "?", tableName) - rs, err := s.Execute(ctx, sql) - if err != nil { - return errors.Trace(err) - } - r := rs[0] - req := r.NewChunk(nil) - it := chunk.NewIterator4Chunk(req) - for { - err = r.Next(ctx, req) - if err != nil { - _ = r.Close() - return err - } - if req.NumRows() == 0 { - break - } - - for row := it.Begin(); row != it.End(); row = it.Next() { - engineID := int32(row.GetInt64(0)) - status := uint8(row.GetUint64(1)) - cp.Engines[engineID] = &EngineCheckpoint{ - Status: CheckpointStatus(status), - } - } - } - _ = r.Close() - - // 2. Populate the chunks. - sql = fmt.Sprintf(ReadChunkTemplate, g.schema, CheckpointTableNameChunk) - sql = strings.ReplaceAll(sql, "?", tableName) - rs, err = s.Execute(ctx, sql) - if err != nil { - return errors.Trace(err) - } - r = rs[0] - req = r.NewChunk(nil) - it = chunk.NewIterator4Chunk(req) - for { - err = r.Next(ctx, req) - if err != nil { - _ = r.Close() - return err - } - if req.NumRows() == 0 { - break - } - - for row := it.Begin(); row != it.End(); row = it.Next() { - value := &ChunkCheckpoint{} - engineID := int32(row.GetInt64(0)) - value.Key.Path = row.GetString(1) - value.Key.Offset = row.GetInt64(2) - value.FileMeta.Type = mydump.SourceType(row.GetInt64(3)) - value.FileMeta.Compression = mydump.Compression(row.GetInt64(4)) - value.FileMeta.SortKey = row.GetString(5) - value.FileMeta.FileSize = row.GetInt64(6) - colPerm := row.GetBytes(7) - value.Chunk.Offset = row.GetInt64(8) - value.Chunk.EndOffset = row.GetInt64(9) - value.Chunk.PrevRowIDMax = row.GetInt64(10) - value.Chunk.RowIDMax = row.GetInt64(11) - kvcBytes := row.GetUint64(12) - kvcKVs := row.GetUint64(13) - kvcChecksum := row.GetUint64(14) - value.Timestamp = row.GetInt64(15) - - value.FileMeta.Path = value.Key.Path - value.Checksum = verify.MakeKVChecksum(kvcBytes, kvcKVs, kvcChecksum) - if err := json.Unmarshal(colPerm, &value.ColumnPermutation); err != nil { - _ = r.Close() - return errors.Trace(err) - } - cp.Engines[engineID].Chunks = append(cp.Engines[engineID].Chunks, value) - } - } - _ = r.Close() - - // 3. Fill in the remaining table info - sql = fmt.Sprintf(ReadTableRemainTemplate, g.schema, CheckpointTableNameTable) - sql = strings.ReplaceAll(sql, "?", tableName) - rs, err = s.Execute(ctx, sql) - if err != nil { - return errors.Trace(err) - } - r = rs[0] - //nolint: errcheck - defer r.Close() - req = r.NewChunk(nil) - err = r.Next(ctx, req) - if err != nil { - return err - } - if req.NumRows() == 0 { - return nil - } - - row := req.GetRow(0) - cp.Status = CheckpointStatus(row.GetUint64(0)) - cp.AllocBase = row.GetInt64(1) - cp.TableID = row.GetInt64(2) - rawTableInfo := row.GetBytes(3) - if err := json.Unmarshal(rawTableInfo, &cp.TableInfo); err != nil { - return errors.Trace(err) - } - return nil - }) - - if err != nil { - return nil, errors.Trace(err) - } - - return cp, nil -} - -// Close implements CheckpointsDB.Close. -func (GlueCheckpointsDB) Close() error { - return nil -} - -// InsertEngineCheckpoints implements CheckpointsDB.InsertEngineCheckpoints. -func (g GlueCheckpointsDB) InsertEngineCheckpoints(ctx context.Context, tableName string, checkpointMap map[int32]*EngineCheckpoint) error { - logger := log.FromContext(ctx).With(zap.String("table", tableName)) - se, err := g.getSessionFunc() - if err != nil { - return errors.Trace(err) - } - defer se.Close() - - err = Transact(ctx, "update engine checkpoints", se, logger, func(c context.Context, s Session) error { - engineStmt, _, _, err := s.PrepareStmt(fmt.Sprintf(ReplaceEngineTemplate, g.schema, CheckpointTableNameEngine)) - if err != nil { - return errors.Trace(err) - } - defer dropPreparedStmt(ctx, s, engineStmt) - - chunkStmt, _, _, err := s.PrepareStmt(fmt.Sprintf(ReplaceChunkTemplate, g.schema, CheckpointTableNameChunk)) - if err != nil { - return errors.Trace(err) - } - defer dropPreparedStmt(ctx, s, chunkStmt) - - for engineID, engine := range checkpointMap { - _, err := s.ExecutePreparedStmt(c, engineStmt, []types.Datum{ - types.NewStringDatum(tableName), - types.NewIntDatum(int64(engineID)), - types.NewUintDatum(uint64(engine.Status)), - }) - if err != nil { - return errors.Trace(err) - } - for _, value := range engine.Chunks { - columnPerm, err := json.Marshal(value.ColumnPermutation) - if err != nil { - return errors.Trace(err) - } - _, err = s.ExecutePreparedStmt(c, chunkStmt, []types.Datum{ - types.NewStringDatum(tableName), - types.NewIntDatum(int64(engineID)), - types.NewStringDatum(value.Key.Path), - types.NewIntDatum(value.Key.Offset), - types.NewIntDatum(int64(value.FileMeta.Type)), - types.NewIntDatum(int64(value.FileMeta.Compression)), - types.NewStringDatum(value.FileMeta.SortKey), - types.NewIntDatum(value.FileMeta.FileSize), - types.NewBytesDatum(columnPerm), - types.NewIntDatum(value.Chunk.Offset), - types.NewIntDatum(value.Chunk.EndOffset), - types.NewIntDatum(value.Chunk.PrevRowIDMax), - types.NewIntDatum(value.Chunk.RowIDMax), - types.NewIntDatum(value.Timestamp), - }) - if err != nil { - return errors.Trace(err) - } - } - } - return nil - }) - return errors.Trace(err) -} - -// Update implements CheckpointsDB.Update. -func (g GlueCheckpointsDB) Update(ctx context.Context, checkpointDiffs map[string]*TableCheckpointDiff) error { - logger := log.FromContext(ctx) - se, err := g.getSessionFunc() - if err != nil { - log.FromContext(ctx).Error("can't get a session to update GlueCheckpointsDB", zap.Error(errors.Trace(err))) - return err - } - defer se.Close() - - chunkQuery := fmt.Sprintf(UpdateChunkTemplate, g.schema, CheckpointTableNameChunk) - rebaseQuery := fmt.Sprintf(UpdateTableRebaseTemplate, g.schema, CheckpointTableNameTable) - tableStatusQuery := fmt.Sprintf(UpdateTableStatusTemplate, g.schema, CheckpointTableNameTable) - engineStatusQuery := fmt.Sprintf(UpdateEngineTemplate, g.schema, CheckpointTableNameEngine) - return Transact(context.Background(), "update checkpoints", se, logger, func(c context.Context, s Session) error { - chunkStmt, _, _, err := s.PrepareStmt(chunkQuery) - if err != nil { - return errors.Trace(err) - } - defer dropPreparedStmt(ctx, s, chunkStmt) - rebaseStmt, _, _, err := s.PrepareStmt(rebaseQuery) - if err != nil { - return errors.Trace(err) - } - defer dropPreparedStmt(ctx, s, rebaseStmt) - tableStatusStmt, _, _, err := s.PrepareStmt(tableStatusQuery) - if err != nil { - return errors.Trace(err) - } - defer dropPreparedStmt(ctx, s, tableStatusStmt) - engineStatusStmt, _, _, err := s.PrepareStmt(engineStatusQuery) - if err != nil { - return errors.Trace(err) - } - defer dropPreparedStmt(ctx, s, engineStatusStmt) - - for tableName, cpd := range checkpointDiffs { - if cpd.hasStatus { - _, err := s.ExecutePreparedStmt(c, tableStatusStmt, []types.Datum{ - types.NewUintDatum(uint64(cpd.status)), - types.NewStringDatum(tableName), - }) - if err != nil { - return errors.Trace(err) - } - } - if cpd.hasRebase { - _, err := s.ExecutePreparedStmt(c, rebaseStmt, []types.Datum{ - types.NewIntDatum(cpd.allocBase), - types.NewStringDatum(tableName), - }) - if err != nil { - return errors.Trace(err) - } - } - for engineID, engineDiff := range cpd.engines { - if engineDiff.hasStatus { - _, err := s.ExecutePreparedStmt(c, engineStatusStmt, []types.Datum{ - types.NewUintDatum(uint64(engineDiff.status)), - types.NewStringDatum(tableName), - types.NewIntDatum(int64(engineID)), - }) - if err != nil { - return errors.Trace(err) - } - } - for key, diff := range engineDiff.chunks { - columnPerm, err := json.Marshal(diff.columnPermutation) - if err != nil { - return errors.Trace(err) - } - _, err = s.ExecutePreparedStmt(c, chunkStmt, []types.Datum{ - types.NewIntDatum(diff.pos), - types.NewIntDatum(diff.rowID), - types.NewUintDatum(diff.checksum.SumSize()), - types.NewUintDatum(diff.checksum.SumKVS()), - types.NewUintDatum(diff.checksum.Sum()), - types.NewBytesDatum(columnPerm), - types.NewStringDatum(tableName), - types.NewIntDatum(int64(engineID)), - types.NewStringDatum(key.Path), - types.NewIntDatum(key.Offset), - }) - if err != nil { - return errors.Trace(err) - } - } - } - } - return nil - }) -} - -// RemoveCheckpoint implements CheckpointsDB.RemoveCheckpoint. -func (g GlueCheckpointsDB) RemoveCheckpoint(ctx context.Context, tableName string) error { - logger := log.FromContext(ctx).With(zap.String("table", tableName)) - se, err := g.getSessionFunc() - if err != nil { - return errors.Trace(err) - } - defer se.Close() - - if tableName == allTables { - return common.Retry("remove all checkpoints", logger, func() error { - _, err := se.Execute(ctx, "DROP SCHEMA "+g.schema) - return err - }) - } - tableName = common.InterpolateMySQLString(tableName) - deleteChunkQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, g.schema, CheckpointTableNameChunk) - deleteChunkQuery = strings.ReplaceAll(deleteChunkQuery, "?", tableName) - deleteEngineQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, g.schema, CheckpointTableNameEngine) - deleteEngineQuery = strings.ReplaceAll(deleteEngineQuery, "?", tableName) - deleteTableQuery := fmt.Sprintf(DeleteCheckpointRecordTemplate, g.schema, CheckpointTableNameTable) - deleteTableQuery = strings.ReplaceAll(deleteTableQuery, "?", tableName) - - return errors.Trace(Transact(ctx, "remove checkpoints", se, logger, func(c context.Context, s Session) error { - if _, e := s.Execute(c, deleteChunkQuery); e != nil { - return e - } - if _, e := s.Execute(c, deleteEngineQuery); e != nil { - return e - } - if _, e := s.Execute(c, deleteTableQuery); e != nil { - return e - } - return nil - })) -} - -// MoveCheckpoints implements CheckpointsDB.MoveCheckpoints. -func (g GlueCheckpointsDB) MoveCheckpoints(ctx context.Context, taskID int64) error { - newSchema := fmt.Sprintf("`%s.%d.bak`", g.schema[1:len(g.schema)-1], taskID) - logger := log.FromContext(ctx).With(zap.Int64("taskID", taskID)) - se, err := g.getSessionFunc() - if err != nil { - return errors.Trace(err) - } - defer se.Close() - - err = common.Retry("create backup checkpoints schema", logger, func() error { - _, err := se.Execute(ctx, "CREATE SCHEMA IF NOT EXISTS "+newSchema) - return err - }) - if err != nil { - return errors.Trace(err) - } - for _, tbl := range []string{ - CheckpointTableNameChunk, CheckpointTableNameEngine, - CheckpointTableNameTable, CheckpointTableNameTask, - } { - query := fmt.Sprintf("RENAME TABLE %[1]s.%[3]s TO %[2]s.%[3]s", g.schema, newSchema, tbl) - err := common.Retry(fmt.Sprintf("move %s checkpoints table", tbl), logger, func() error { - _, err := se.Execute(ctx, query) - return err - }) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -// GetLocalStoringTables implements CheckpointsDB.GetLocalStoringTables. -func (g GlueCheckpointsDB) GetLocalStoringTables(ctx context.Context) (map[string][]int32, error) { - se, err := g.getSessionFunc() - if err != nil { - return nil, errors.Trace(err) - } - defer se.Close() - - var targetTables map[string][]int32 - - // lightning didn't check CheckpointStatusMaxInvalid before this function is called, so we skip invalid ones - // engines should exist if - // 1. table status is earlier than CheckpointStatusIndexImported, and - // 2. engine status is earlier than CheckpointStatusImported, and - // 3. chunk has been read - query := fmt.Sprintf(` - SELECT DISTINCT t.table_name, c.engine_id - FROM %s.%s t, %s.%s c, %s.%s e - WHERE t.table_name = c.table_name AND t.table_name = e.table_name AND c.engine_id = e.engine_id - AND %d < t.status AND t.status < %d - AND %d < e.status AND e.status < %d - AND c.pos > c.offset;`, - g.schema, CheckpointTableNameTable, g.schema, CheckpointTableNameChunk, g.schema, CheckpointTableNameEngine, - CheckpointStatusMaxInvalid, CheckpointStatusIndexImported, - CheckpointStatusMaxInvalid, CheckpointStatusImported) - - err = common.Retry("get local storing tables", log.FromContext(ctx), func() error { - targetTables = make(map[string][]int32) - rs, err := se.Execute(ctx, query) - if err != nil { - return errors.Trace(err) - } - rows, err := drainFirstRecordSet(ctx, rs) - if err != nil { - return errors.Trace(err) - } - - for _, row := range rows { - tableName := row.GetString(0) - engineID := int32(row.GetInt64(1)) - targetTables[tableName] = append(targetTables[tableName], engineID) - } - return nil - }) - if err != nil { - return nil, errors.Trace(err) - } - - return targetTables, err -} - -// IgnoreErrorCheckpoint implements CheckpointsDB.IgnoreErrorCheckpoint. -func (g GlueCheckpointsDB) IgnoreErrorCheckpoint(ctx context.Context, tableName string) error { - logger := log.FromContext(ctx).With(zap.String("table", tableName)) - se, err := g.getSessionFunc() - if err != nil { - return errors.Trace(err) - } - defer se.Close() - - var colName string - if tableName == allTables { - // This will expand to `WHERE 'all' = 'all'` and effectively allowing - // all tables to be included. - colName = stringLitAll - } else { - colName = columnTableName - } - - tableName = common.InterpolateMySQLString(tableName) - - engineQuery := fmt.Sprintf(` - UPDATE %s.%s SET status = %d WHERE %s = %s AND status <= %d; - `, g.schema, CheckpointTableNameEngine, CheckpointStatusLoaded, colName, tableName, CheckpointStatusMaxInvalid) - tableQuery := fmt.Sprintf(` - UPDATE %s.%s SET status = %d WHERE %s = %s AND status <= %d; - `, g.schema, CheckpointTableNameTable, CheckpointStatusLoaded, colName, tableName, CheckpointStatusMaxInvalid) - return errors.Trace(Transact(ctx, "ignore error checkpoints", se, logger, func(c context.Context, s Session) error { - if _, e := s.Execute(c, engineQuery); e != nil { - return e - } - if _, e := s.Execute(c, tableQuery); e != nil { - return e - } - return nil - })) -} - -// DestroyErrorCheckpoint implements CheckpointsDB.DestroyErrorCheckpoint. -func (g GlueCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tableName string) ([]DestroyedTableCheckpoint, error) { - logger := log.FromContext(ctx).With(zap.String("table", tableName)) - se, err := g.getSessionFunc() - if err != nil { - return nil, errors.Trace(err) - } - defer se.Close() - - var colName, aliasedColName string - - if tableName == allTables { - // These will expand to `WHERE 'all' = 'all'` and effectively allowing - // all tables to be included. - colName = stringLitAll - aliasedColName = stringLitAll - } else { - colName = columnTableName - aliasedColName = "t.table_name" - } - - tableName = common.InterpolateMySQLString(tableName) - - selectQuery := fmt.Sprintf(` - SELECT - t.table_name, - COALESCE(MIN(e.engine_id), 0), - COALESCE(MAX(e.engine_id), -1) - FROM %[1]s.%[4]s t - LEFT JOIN %[1]s.%[5]s e ON t.table_name = e.table_name - WHERE %[2]s = %[6]s AND t.status <= %[3]d - GROUP BY t.table_name; - `, g.schema, aliasedColName, CheckpointStatusMaxInvalid, CheckpointTableNameTable, CheckpointTableNameEngine, tableName) - deleteChunkQuery := fmt.Sprintf(` - DELETE FROM %[1]s.%[4]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[5]s WHERE %[2]s = %[6]s AND status <= %[3]d) - `, g.schema, colName, CheckpointStatusMaxInvalid, CheckpointTableNameChunk, CheckpointTableNameTable, tableName) - deleteEngineQuery := fmt.Sprintf(` - DELETE FROM %[1]s.%[4]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[5]s WHERE %[2]s = %[6]s AND status <= %[3]d) - `, g.schema, colName, CheckpointStatusMaxInvalid, CheckpointTableNameEngine, CheckpointTableNameTable, tableName) - deleteTableQuery := fmt.Sprintf(` - DELETE FROM %s.%s WHERE %s = %s AND status <= %d - `, g.schema, CheckpointTableNameTable, colName, tableName, CheckpointStatusMaxInvalid) - - var targetTables []DestroyedTableCheckpoint - err = Transact(ctx, "destroy error checkpoints", se, logger, func(c context.Context, s Session) error { - // clean because it's in a retry - targetTables = nil - rs, err := s.Execute(c, selectQuery) - if err != nil { - return errors.Trace(err) - } - r := rs[0] - req := r.NewChunk(nil) - it := chunk.NewIterator4Chunk(req) - for { - err = r.Next(ctx, req) - if err != nil { - _ = r.Close() - return err - } - if req.NumRows() == 0 { - break - } - - for row := it.Begin(); row != it.End(); row = it.Next() { - var dtc DestroyedTableCheckpoint - dtc.TableName = row.GetString(0) - dtc.MinEngineID = int32(row.GetInt64(1)) - dtc.MaxEngineID = int32(row.GetInt64(2)) - targetTables = append(targetTables, dtc) - } - } - _ = r.Close() - - if _, e := s.Execute(c, deleteChunkQuery); e != nil { - return errors.Trace(e) - } - if _, e := s.Execute(c, deleteEngineQuery); e != nil { - return errors.Trace(e) - } - if _, e := s.Execute(c, deleteTableQuery); e != nil { - return errors.Trace(e) - } - return nil - }) - - if err != nil { - return nil, errors.Trace(err) - } - - return targetTables, nil -} - -// DumpTables implements CheckpointsDB.DumpTables. -func (GlueCheckpointsDB) DumpTables(_ context.Context, _ io.Writer) error { - return errors.Errorf("dumping glue checkpoint into CSV not unsupported") -} - -// DumpEngines implements CheckpointsDB.DumpEngines. -func (GlueCheckpointsDB) DumpEngines(_ context.Context, _ io.Writer) error { - return errors.Errorf("dumping glue checkpoint into CSV not unsupported") -} - -// DumpChunks implements CheckpointsDB.DumpChunks. -func (GlueCheckpointsDB) DumpChunks(_ context.Context, _ io.Writer) error { - return errors.Errorf("dumping glue checkpoint into CSV not unsupported") -} - -// Transact is a helper function to execute a transaction. -func Transact(ctx context.Context, purpose string, s Session, logger log.Logger, action func(context.Context, Session) error) error { - return common.Retry(purpose, logger, func() error { - _, err := s.Execute(ctx, "BEGIN") - if err != nil { - return errors.Annotate(err, "begin transaction failed") - } - err = action(ctx, s) - if err != nil { - s.RollbackTxn(ctx) - return err - } - err = s.CommitTxn(ctx) - if err != nil { - return errors.Annotate(err, "commit transaction failed") - } - return nil - }) -} - -// TODO: will use drainFirstRecordSet to reduce repeat in GlueCheckpointsDB later -func drainFirstRecordSet(ctx context.Context, rss []sqlexec.RecordSet) ([]chunk.Row, error) { - if len(rss) != 1 { - return nil, errors.New("given result set doesn't have length 1") - } - rs := rss[0] - var rows []chunk.Row - req := rs.NewChunk(nil) - for { - err := rs.Next(ctx, req) - if err != nil || req.NumRows() == 0 { - _ = rs.Close() - return rows, err - } - iter := chunk.NewIterator4Chunk(req) - for r := iter.Begin(); r != iter.End(); r = iter.Next() { - rows = append(rows, r) - } - req = chunk.Renew(req, 1024) - } -} diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index 39a0c8f9fa363..45c4ed9330551 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -313,6 +313,26 @@ func UniqueTable(schema string, table string) string { return builder.String() } +func escapeIdentifiers(identifier []string) []any { + escaped := make([]any, len(identifier)) + for i, id := range identifier { + escaped[i] = EscapeIdentifier(id) + } + return escaped +} + +// SprintfWithIdentifiers escapes the identifiers and sprintf them. The input +// identifiers must not be escaped. +func SprintfWithIdentifiers(format string, identifiers ...string) string { + return fmt.Sprintf(format, escapeIdentifiers(identifiers)...) +} + +// FprintfWithIdentifiers escapes the identifiers and fprintf them. The input +// identifiers must not be escaped. +func FprintfWithIdentifiers(w io.Writer, format string, identifiers ...string) (int, error) { + return fmt.Fprintf(w, format, escapeIdentifiers(identifiers)...) +} + // EscapeIdentifier quote and escape an sql identifier func EscapeIdentifier(identifier string) string { var builder strings.Builder @@ -525,11 +545,11 @@ loop: } // BuildDropIndexSQL builds the SQL statement to drop index. -func BuildDropIndexSQL(tableName string, idxInfo *model.IndexInfo) string { +func BuildDropIndexSQL(dbName, tableName string, idxInfo *model.IndexInfo) string { if idxInfo.Primary { - return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", tableName) + return SprintfWithIdentifiers("ALTER TABLE %s.%s DROP PRIMARY KEY", dbName, tableName) } - return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", tableName, EscapeIdentifier(idxInfo.Name.O)) + return SprintfWithIdentifiers("ALTER TABLE %s.%s DROP INDEX %s", dbName, tableName, idxInfo.Name.O) } // BuildAddIndexSQL builds the SQL statement to create missing indexes. diff --git a/br/pkg/lightning/errormanager/errormanager.go b/br/pkg/lightning/errormanager/errormanager.go index 80d1ab32df5a6..b92785762b4d8 100644 --- a/br/pkg/lightning/errormanager/errormanager.go +++ b/br/pkg/lightning/errormanager/errormanager.go @@ -170,7 +170,7 @@ const ( type ErrorManager struct { db *sql.DB taskID int64 - schemaEscaped string + schema string configError *config.MaxError remainingError config.MaxError @@ -229,7 +229,7 @@ func New(db *sql.DB, cfg *config.Config, logger log.Logger) *ErrorManager { } if len(cfg.App.TaskInfoSchemaName) != 0 { em.db = db - em.schemaEscaped = common.EscapeIdentifier(cfg.App.TaskInfoSchemaName) + em.schema = cfg.App.TaskInfoSchemaName } return em } @@ -267,7 +267,7 @@ func (em *ErrorManager) Init(ctx context.Context) error { for _, sql := range sqls { // trim spaces for unit test pattern matching - err := exec.Exec(ctx, sql[0], strings.TrimSpace(fmt.Sprintf(sql[1], em.schemaEscaped))) + err := exec.Exec(ctx, sql[0], strings.TrimSpace(common.SprintfWithIdentifiers(sql[1], em.schema))) if err != nil { return err } @@ -312,7 +312,7 @@ func (em *ErrorManager) RecordTypeError( HideQueryLog: redact.NeedRedact(), } if err := exec.Exec(ctx, "insert type error record", - fmt.Sprintf(insertIntoTypeError, em.schemaEscaped), + common.SprintfWithIdentifiers(insertIntoTypeError, em.schema), em.taskID, tableName, path, @@ -366,7 +366,10 @@ func (em *ErrorManager) RecordDataConflictError( } if err := exec.Transact(ctx, "insert data conflict error record", func(c context.Context, txn *sql.Tx) error { sb := &strings.Builder{} - fmt.Fprintf(sb, insertIntoConflictErrorData, em.schemaEscaped) + _, err := common.FprintfWithIdentifiers(sb, insertIntoConflictErrorData, em.schema) + if err != nil { + return err + } var sqlArgs []interface{} for i, conflictInfo := range conflictInfos { if i > 0 { @@ -383,7 +386,7 @@ func (em *ErrorManager) RecordDataConflictError( tablecodec.IsRecordKey(conflictInfo.RawKey), ) } - _, err := txn.ExecContext(c, sb.String(), sqlArgs...) + _, err = txn.ExecContext(c, sb.String(), sqlArgs...) return err }); err != nil { gerr = err @@ -425,7 +428,10 @@ func (em *ErrorManager) RecordIndexConflictError( } if err := exec.Transact(ctx, "insert index conflict error record", func(c context.Context, txn *sql.Tx) error { sb := &strings.Builder{} - fmt.Fprintf(sb, insertIntoConflictErrorIndex, em.schemaEscaped) + _, err := common.FprintfWithIdentifiers(sb, insertIntoConflictErrorIndex, em.schema) + if err != nil { + return err + } var sqlArgs []interface{} for i, conflictInfo := range conflictInfos { if i > 0 { @@ -445,7 +451,7 @@ func (em *ErrorManager) RecordIndexConflictError( tablecodec.IsRecordKey(conflictInfo.RawKey), ) } - _, err := txn.ExecContext(c, sb.String(), sqlArgs...) + _, err = txn.ExecContext(c, sb.String(), sqlArgs...) return err }); err != nil { gerr = err @@ -487,7 +493,7 @@ func (em *ErrorManager) RemoveAllConflictKeys( var handleRows [][2][]byte for start < end { rows, err := em.db.QueryContext( - gCtx, fmt.Sprintf(selectConflictKeysRemove, em.schemaEscaped), + gCtx, common.SprintfWithIdentifiers(selectConflictKeysRemove, em.schema), tableName, start, end, rowLimit) if err != nil { return errors.Trace(err) @@ -567,7 +573,7 @@ func (em *ErrorManager) ReplaceConflictKeys( // demo for "replace" algorithm: https://github.com/lyzx2001/tidb-conflict-replace // check index KV indexKvRows, err := em.db.QueryContext( - gCtx, fmt.Sprintf(selectIndexConflictKeysReplace, em.schemaEscaped), + gCtx, common.SprintfWithIdentifiers(selectIndexConflictKeysReplace, em.schema), tableName) if err != nil { return errors.Trace(err) @@ -666,7 +672,10 @@ func (em *ErrorManager) ReplaceConflictKeys( if err := exec.Transact(ctx, "insert data conflict error record for conflict detection 'replace' mode", func(c context.Context, txn *sql.Tx) error { sb := &strings.Builder{} - fmt.Fprintf(sb, insertIntoConflictErrorData, em.schemaEscaped) + _, err2 := common.FprintfWithIdentifiers(sb, insertIntoConflictErrorData, em.schema) + if err2 != nil { + return err2 + } var sqlArgs []interface{} sb.WriteString(sqlValuesConflictErrorData) sqlArgs = append(sqlArgs, @@ -696,7 +705,7 @@ func (em *ErrorManager) ReplaceConflictKeys( // check data KV dataKvRows, err := em.db.QueryContext( - gCtx, fmt.Sprintf(selectDataConflictKeysReplace, em.schemaEscaped), + gCtx, common.SprintfWithIdentifiers(selectDataConflictKeysReplace, em.schema), tableName) if err != nil { return errors.Trace(err) @@ -872,7 +881,7 @@ func (em *ErrorManager) recordDuplicate( HideQueryLog: redact.NeedRedact(), } return exec.Exec(ctx, "insert duplicate record", - fmt.Sprintf(insertIntoDupRecord, em.schemaEscaped), + common.SprintfWithIdentifiers(insertIntoDupRecord, em.schema), em.taskID, tableName, path, @@ -974,7 +983,7 @@ func (em *ErrorManager) LogErrorDetails() { } func (em *ErrorManager) fmtTableName(t string) string { - return fmt.Sprintf("%s.`%s`", em.schemaEscaped, t) + return common.UniqueTable(em.schema, t) } // Output renders a table which contains error summery for each error type. diff --git a/br/pkg/lightning/errormanager/errormanager_test.go b/br/pkg/lightning/errormanager/errormanager_test.go index 0cfa3345de1bc..5745cbfb5de70 100644 --- a/br/pkg/lightning/errormanager/errormanager_test.go +++ b/br/pkg/lightning/errormanager/errormanager_test.go @@ -623,7 +623,7 @@ func TestErrorMgrErrorOutput(t *testing.T) { remainingError: cfg.App.MaxError, configConflict: &cfg.Conflict, conflictErrRemain: atomic.NewInt64(100), - schemaEscaped: "`error_info`", + schema: "error_info", conflictV1Enabled: true, } diff --git a/br/pkg/lightning/importer/get_pre_info.go b/br/pkg/lightning/importer/get_pre_info.go index ed02a28ee14a0..114bd642b36a3 100644 --- a/br/pkg/lightning/importer/get_pre_info.go +++ b/br/pkg/lightning/importer/get_pre_info.go @@ -196,7 +196,7 @@ func (g *TargetInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName stri // the data is partially imported, but the index data has not been imported. // In this situation, if no hint is added, the SQL executor might fetch the record from index, // which is empty. This will result in missing check. - fmt.Sprintf("SELECT 1 FROM %s USE INDEX() LIMIT 1", common.UniqueTable(schemaName, tableName)), + common.SprintfWithIdentifiers("SELECT 1 FROM %s.%s USE INDEX() LIMIT 1", schemaName, tableName), &dump, ) diff --git a/br/pkg/lightning/importer/import.go b/br/pkg/lightning/importer/import.go index caea78d2a04ef..57b80ae1f16d8 100644 --- a/br/pkg/lightning/importer/import.go +++ b/br/pkg/lightning/importer/import.go @@ -92,7 +92,7 @@ const ( TaskMetaTableName = "task_meta_v2" TableMetaTableName = "table_meta" // CreateTableMetadataTable stores the per-table sub jobs information used by TiDB Lightning - CreateTableMetadataTable = `CREATE TABLE IF NOT EXISTS %s ( + CreateTableMetadataTable = `CREATE TABLE IF NOT EXISTS %s.%s ( task_id BIGINT(20) UNSIGNED, table_id BIGINT(64) NOT NULL, table_name VARCHAR(64) NOT NULL, @@ -109,7 +109,7 @@ const ( PRIMARY KEY (table_id, task_id) );` // CreateTaskMetaTable stores the pre-lightning metadata used by TiDB Lightning - CreateTaskMetaTable = `CREATE TABLE IF NOT EXISTS %s ( + CreateTaskMetaTable = `CREATE TABLE IF NOT EXISTS %s.%s ( task_id BIGINT(20) UNSIGNED NOT NULL, pd_cfgs VARCHAR(2048) NOT NULL DEFAULT '', status VARCHAR(32) NOT NULL, diff --git a/br/pkg/lightning/importer/meta_manager.go b/br/pkg/lightning/importer/meta_manager.go index 98b68e3090777..61c34dadc93d3 100644 --- a/br/pkg/lightning/importer/meta_manager.go +++ b/br/pkg/lightning/importer/meta_manager.go @@ -43,15 +43,15 @@ func (b *dbMetaMgrBuilder) Init(ctx context.Context) error { Logger: log.FromContext(ctx), HideQueryLog: redact.NeedRedact(), } - metaDBSQL := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", common.EscapeIdentifier(b.schema)) + metaDBSQL := common.SprintfWithIdentifiers("CREATE DATABASE IF NOT EXISTS %s", b.schema) if err := exec.Exec(ctx, "create meta schema", metaDBSQL); err != nil { return errors.Annotate(err, "create meta schema failed") } - taskMetaSQL := fmt.Sprintf(CreateTaskMetaTable, common.UniqueTable(b.schema, TaskMetaTableName)) + taskMetaSQL := common.SprintfWithIdentifiers(CreateTaskMetaTable, b.schema, TaskMetaTableName) if err := exec.Exec(ctx, "create meta table", taskMetaSQL); err != nil { return errors.Annotate(err, "create task meta table failed") } - tableMetaSQL := fmt.Sprintf(CreateTableMetadataTable, common.UniqueTable(b.schema, TableMetaTableName)) + tableMetaSQL := common.SprintfWithIdentifiers(CreateTableMetadataTable, b.schema, TableMetaTableName) if err := exec.Exec(ctx, "create meta table", tableMetaSQL); err != nil { return errors.Annotate(err, "create table meta table failed") } @@ -63,7 +63,7 @@ func (b *dbMetaMgrBuilder) TaskMetaMgr(pd *pdutil.PdController) taskMetaMgr { session: b.db, taskID: b.taskID, pd: pd, - tableName: common.UniqueTable(b.schema, TaskMetaTableName), + tableName: TaskMetaTableName, schemaName: b.schema, } } @@ -73,7 +73,8 @@ func (b *dbMetaMgrBuilder) TableMetaMgr(tr *TableImporter) tableMetaMgr { session: b.db, taskID: b.taskID, tr: tr, - tableName: common.UniqueTable(b.schema, TableMetaTableName), + schemaName: b.schema, + tableName: TableMetaTableName, needChecksum: b.needChecksum, } } @@ -92,6 +93,7 @@ type dbTableMetaMgr struct { session *sql.DB taskID int64 tr *TableImporter + schemaName string tableName string needChecksum bool } @@ -102,7 +104,7 @@ func (m *dbTableMetaMgr) InitTableMeta(ctx context.Context) error { Logger: m.tr.logger, } // avoid override existing metadata if the meta is already inserted. - stmt := fmt.Sprintf(`INSERT IGNORE INTO %s (task_id, table_id, table_name, status) values (?, ?, ?, ?)`, m.tableName) + stmt := common.SprintfWithIdentifiers(`INSERT IGNORE INTO %s.%s (task_id, table_id, table_name, status) VALUES (?, ?, ?, ?)`, m.schemaName, m.tableName) task := m.tr.logger.Begin(zap.DebugLevel, "init table meta") err := exec.Exec(ctx, "init table meta", stmt, m.taskID, m.tr.tableInfo.ID, m.tr.tableName, metaStatusInitial.String()) task.End(zap.ErrorLevel, err) @@ -189,7 +191,7 @@ func (m *dbTableMetaMgr) AllocTableRowIDs(ctx context.Context, rawRowIDMax int64 return exec.Transact(ctx, "init table allocator base", func(ctx context.Context, tx *sql.Tx) error { rows, err := tx.QueryContext( ctx, - fmt.Sprintf("SELECT task_id, row_id_base, row_id_max, total_kvs_base, total_bytes_base, checksum_base, status from %s WHERE table_id = ? FOR UPDATE", m.tableName), + common.SprintfWithIdentifiers("SELECT task_id, row_id_base, row_id_max, total_kvs_base, total_bytes_base, checksum_base, status FROM %s.%s WHERE table_id = ? FOR UPDATE", m.schemaName, m.tableName), m.tr.tableInfo.ID, ) if err != nil { @@ -270,7 +272,7 @@ func (m *dbTableMetaMgr) AllocTableRowIDs(ctx context.Context, rawRowIDMax int64 newStatus = metaStatusRestoreStarted } - query := fmt.Sprintf("update %s set row_id_base = ?, row_id_max = ?, status = ? where table_id = ? and task_id = ?", m.tableName) + query := common.SprintfWithIdentifiers("UPDATE %s.%s SET row_id_base = ?, row_id_max = ?, status = ? WHERE table_id = ? AND task_id = ?", m.schemaName, m.tableName) _, err := tx.ExecContext(ctx, query, newRowIDBase, newRowIDMax, newStatus.String(), m.tr.tableInfo.ID, m.taskID) if err != nil { return errors.Trace(err) @@ -350,7 +352,7 @@ func (m *dbTableMetaMgr) UpdateTableBaseChecksum(ctx context.Context, checksum * DB: m.session, Logger: m.tr.logger, } - query := fmt.Sprintf("update %s set total_kvs_base = ?, total_bytes_base = ?, checksum_base = ?, status = ? where table_id = ? and task_id = ?", m.tableName) + query := common.SprintfWithIdentifiers("UPDATE %s.%s SET total_kvs_base = ?, total_bytes_base = ?, checksum_base = ?, status = ? WHERE table_id = ? AND task_id = ?", m.schemaName, m.tableName) return exec.Exec(ctx, "update base checksum", query, checksum.SumKVS(), checksum.SumSize(), checksum.Sum(), metaStatusRestoreStarted.String(), m.tr.tableInfo.ID, m.taskID) @@ -361,7 +363,7 @@ func (m *dbTableMetaMgr) UpdateTableStatus(ctx context.Context, status metaStatu DB: m.session, Logger: m.tr.logger, } - query := fmt.Sprintf("update %s set status = ? where table_id = ? and task_id = ?", m.tableName) + query := common.SprintfWithIdentifiers("UPDATE %s.%s SET status = ? WHERE table_id = ? AND task_id = ?", m.schemaName, m.tableName) return exec.Exec(ctx, "update meta status", query, status.String(), m.tr.tableInfo.ID, m.taskID) } @@ -394,7 +396,7 @@ func (m *dbTableMetaMgr) CheckAndUpdateLocalChecksum(ctx context.Context, checks err = exec.Transact(ctx, "checksum pre-check", func(ctx context.Context, tx *sql.Tx) error { rows, err := tx.QueryContext( ctx, - fmt.Sprintf("SELECT task_id, total_kvs_base, total_bytes_base, checksum_base, total_kvs, total_bytes, checksum, status, has_duplicates from %s WHERE table_id = ? FOR UPDATE", m.tableName), + common.SprintfWithIdentifiers("SELECT task_id, total_kvs_base, total_bytes_base, checksum_base, total_kvs, total_bytes, checksum, status, has_duplicates from %s.%s WHERE table_id = ? FOR UPDATE", m.schemaName, m.tableName), m.tr.tableInfo.ID, ) if err != nil { @@ -441,7 +443,7 @@ func (m *dbTableMetaMgr) CheckAndUpdateLocalChecksum(ctx context.Context, checks needRemoteDupe = false break } else if status == metaStatusChecksuming { - return common.ErrTableIsChecksuming.GenWithStackByArgs(m.tableName) + return common.ErrTableIsChecksuming.GenWithStackByArgs(common.UniqueTable(m.schemaName, m.tableName)) } totalBytes += baseTotalBytes @@ -458,7 +460,7 @@ func (m *dbTableMetaMgr) CheckAndUpdateLocalChecksum(ctx context.Context, checks return errors.Trace(err) } - query := fmt.Sprintf("update %s set total_kvs = ?, total_bytes = ?, checksum = ?, status = ?, has_duplicates = ? where table_id = ? and task_id = ?", m.tableName) + query := common.SprintfWithIdentifiers("UPDATE %s.%s SET total_kvs = ?, total_bytes = ?, checksum = ?, status = ?, has_duplicates = ? WHERE table_id = ? AND task_id = ?", m.schemaName, m.tableName) _, err = tx.ExecContext(ctx, query, checksum.SumKVS(), checksum.SumSize(), checksum.Sum(), newStatus.String(), hasLocalDupes, m.tr.tableInfo.ID, m.taskID) return errors.Annotate(err, "update local checksum failed") }) @@ -481,7 +483,7 @@ func (m *dbTableMetaMgr) FinishTable(ctx context.Context) error { DB: m.session, Logger: m.tr.logger, } - query := fmt.Sprintf("DELETE FROM %s where table_id = ? and (status = 'checksuming' or status = 'checksum_skipped')", m.tableName) + query := common.SprintfWithIdentifiers("DELETE FROM %s.%s where table_id = ? and (status = 'checksuming' or status = 'checksum_skipped')", m.schemaName, m.tableName) return exec.Exec(ctx, "clean up metas", query, m.tr.tableInfo.ID) } @@ -522,10 +524,9 @@ type taskMetaMgr interface { } type dbTaskMetaMgr struct { - session *sql.DB - taskID int64 - pd *pdutil.PdController - // unique name of task meta table + session *sql.DB + taskID int64 + pd *pdutil.PdController tableName string schemaName string } @@ -596,7 +597,10 @@ func (m *dbTaskMetaMgr) InitTask(ctx context.Context, tikvSourceSize, tiflashSou Logger: log.FromContext(ctx), } // avoid override existing metadata if the meta is already inserted. - stmt := fmt.Sprintf(`INSERT INTO %s (task_id, status, tikv_source_bytes, tiflash_source_bytes) values (?, ?, ?, ?) ON DUPLICATE KEY UPDATE state = ?`, m.tableName) + stmt := common.SprintfWithIdentifiers(` + INSERT INTO %s.%s (task_id, status, tikv_source_bytes, tiflash_source_bytes) + VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE state = ?`, + m.schemaName, m.tableName) err := exec.Exec(ctx, "init task meta", stmt, m.taskID, taskMetaStatusInitial.String(), tikvSourceSize, tiflashSourceSize, taskStateNormal) return errors.Trace(err) } @@ -610,7 +614,7 @@ func (m *dbTaskMetaMgr) CheckTaskExist(ctx context.Context) (bool, error) { exist := false err := exec.Transact(ctx, "check whether this task has started before", func(ctx context.Context, tx *sql.Tx) error { rows, err := tx.QueryContext(ctx, - fmt.Sprintf("SELECT task_id from %s WHERE task_id = ?", m.tableName), + common.SprintfWithIdentifiers("SELECT task_id from %s.%s WHERE task_id = ?", m.schemaName, m.tableName), m.taskID, ) if err != nil { @@ -656,7 +660,17 @@ func (m *dbTaskMetaMgr) CheckTasksExclusively(ctx context.Context, action func(t return exec.Transact(ctx, "check tasks exclusively", func(ctx context.Context, tx *sql.Tx) error { rows, err := tx.QueryContext( ctx, - fmt.Sprintf("SELECT task_id, pd_cfgs, status, state, tikv_source_bytes, tiflash_source_bytes, tikv_avail, tiflash_avail from %s FOR UPDATE", m.tableName), + common.SprintfWithIdentifiers(` + SELECT + task_id, + pd_cfgs, + status, + state, + tikv_source_bytes, + tiflash_source_bytes, + tikv_avail, + tiflash_avail + FROM %s.%s FOR UPDATE`, m.schemaName, m.tableName), ) if err != nil { return errors.Annotate(err, "fetch task metas failed") @@ -685,7 +699,10 @@ func (m *dbTaskMetaMgr) CheckTasksExclusively(ctx context.Context, action func(t return errors.Trace(err) } for _, task := range newTasks { - query := fmt.Sprintf("REPLACE INTO %s (task_id, pd_cfgs, status, state, tikv_source_bytes, tiflash_source_bytes, tikv_avail, tiflash_avail) VALUES(?, ?, ?, ?, ?, ?, ?, ?)", m.tableName) + query := common.SprintfWithIdentifiers(` + REPLACE INTO %s.%s (task_id, pd_cfgs, status, state, tikv_source_bytes, tiflash_source_bytes, tikv_avail, tiflash_avail) + VALUES(?, ?, ?, ?, ?, ?, ?, ?)`, + m.schemaName, m.tableName) if _, err = tx.ExecContext(ctx, query, task.taskID, task.pdCfgs, task.status.String(), task.state, task.tikvSourceBytes, task.tiflashSourceBytes, task.tikvAvail, task.tiflashAvail); err != nil { return errors.Trace(err) } @@ -719,7 +736,10 @@ func (m *dbTaskMetaMgr) CheckAndPausePdSchedulers(ctx context.Context) (pdutil.U err = exec.Transact(ctx, "check and pause schedulers", func(ctx context.Context, tx *sql.Tx) error { rows, err := tx.QueryContext( ctx, - fmt.Sprintf("SELECT task_id, pd_cfgs, status, state from %s FOR UPDATE", m.tableName), + common.SprintfWithIdentifiers(` + SELECT task_id, pd_cfgs, status, state + FROM %s.%s FOR UPDATE`, + m.schemaName, m.tableName), ) if err != nil { return errors.Annotate(err, "fetch task meta failed") @@ -793,7 +813,9 @@ func (m *dbTaskMetaMgr) CheckAndPausePdSchedulers(ctx context.Context) (pdutil.U return errors.Trace(err) } - query := fmt.Sprintf("update %s set pd_cfgs = ?, status = ? where task_id = ?", m.tableName) + query := common.SprintfWithIdentifiers(` + UPDATE %s.%s SET pd_cfgs = ?, status = ? WHERE task_id = ?`, + m.schemaName, m.tableName) _, err = tx.ExecContext(ctx, query, string(jsonByts), taskMetaStatusScheduleSet.String(), m.taskID) return errors.Annotate(err, "update task pd configs failed") @@ -850,7 +872,10 @@ func (m *dbTaskMetaMgr) CheckAndFinishRestore(ctx context.Context, finished bool switchBack = true allFinished = finished err = exec.Transact(ctx, "check and finish schedulers", func(ctx context.Context, tx *sql.Tx) error { - rows, err := tx.QueryContext(ctx, fmt.Sprintf("SELECT task_id, status, state from %s FOR UPDATE", m.tableName)) + rows, err := tx.QueryContext( + ctx, + common.SprintfWithIdentifiers("SELECT task_id, status, state FROM %s.%s FOR UPDATE", m.schemaName, m.tableName), + ) if err != nil { return errors.Annotate(err, "fetch task meta failed") } @@ -910,7 +935,7 @@ func (m *dbTaskMetaMgr) CheckAndFinishRestore(ctx context.Context, finished bool newStatus = taskMetaStatusSwitchSkipped } - query := fmt.Sprintf("update %s set status = ?, state = ? where task_id = ?", m.tableName) + query := common.SprintfWithIdentifiers("UPDATE %s.%s SET status = ?, state = ? WHERE task_id = ?", m.schemaName, m.tableName) if _, err = tx.ExecContext(ctx, query, newStatus.String(), newState, m.taskID); err != nil { return errors.Trace(err) } @@ -930,7 +955,7 @@ func (m *dbTaskMetaMgr) Cleanup(ctx context.Context) error { Logger: log.FromContext(ctx), } // avoid override existing metadata if the meta is already inserted. - stmt := fmt.Sprintf("DROP TABLE %s;", m.tableName) + stmt := common.SprintfWithIdentifiers("DROP TABLE %s.%s;", m.schemaName, m.tableName) if err := exec.Exec(ctx, "cleanup task meta tables", stmt); err != nil { return errors.Trace(err) } @@ -942,8 +967,8 @@ func (m *dbTaskMetaMgr) CleanupTask(ctx context.Context) error { DB: m.session, Logger: log.FromContext(ctx), } - stmt := fmt.Sprintf("DELETE FROM %s WHERE task_id = %d;", m.tableName, m.taskID) - err := exec.Exec(ctx, "clean up task", stmt) + stmt := common.SprintfWithIdentifiers("DELETE FROM %s.%s WHERE task_id = ?;", m.schemaName, m.tableName) + err := exec.Exec(ctx, "clean up task", stmt, m.taskID) return errors.Trace(err) } @@ -970,7 +995,7 @@ func MaybeCleanupAllMetas( // check if all tables are finished if tableMetaExist { - query := fmt.Sprintf("SELECT COUNT(*) from %s", common.UniqueTable(schemaName, TableMetaTableName)) + query := common.SprintfWithIdentifiers("SELECT COUNT(*) from %s.%s", schemaName, TableMetaTableName) var cnt int if err := exec.QueryRow(ctx, "fetch table meta row count", query, &cnt); err != nil { return errors.Trace(err) @@ -982,7 +1007,7 @@ func MaybeCleanupAllMetas( } // avoid override existing metadata if the meta is already inserted. - stmt := fmt.Sprintf("DROP DATABASE %s;", common.EscapeIdentifier(schemaName)) + stmt := common.SprintfWithIdentifiers("DROP DATABASE %s;", schemaName) if err := exec.Exec(ctx, "cleanup task meta tables", stmt); err != nil { return errors.Trace(err) } diff --git a/br/pkg/lightning/importer/meta_manager_test.go b/br/pkg/lightning/importer/meta_manager_test.go index b7cd5e104643e..847da5452402b 100644 --- a/br/pkg/lightning/importer/meta_manager_test.go +++ b/br/pkg/lightning/importer/meta_manager_test.go @@ -99,7 +99,8 @@ func newMetaMgrSuite(t *testing.T) *metaMgrSuite { taskID: 1, tr: newTableRestore(t, "test", "t1", 1, 1, "CREATE TABLE `t1` (`c1` varchar(5) NOT NULL)", kvStore), - tableName: common.UniqueTable("test", TableMetaTableName), + schemaName: "test", + tableName: TableMetaTableName, needChecksum: true, } s.mockDB = m @@ -279,7 +280,7 @@ func TestAllocTableRowIDsRetryOnTableInChecksum(t *testing.T) { s.mockDB.ExpectExec("SET SESSION tidb_txn_mode = 'pessimistic';"). WillReturnResult(sqlmock.NewResult(int64(0), int64(0))) s.mockDB.ExpectBegin() - s.mockDB.ExpectQuery("\\QSELECT task_id, row_id_base, row_id_max, total_kvs_base, total_bytes_base, checksum_base, status from `test`.`table_meta` WHERE table_id = ? FOR UPDATE\\E"). + s.mockDB.ExpectQuery("\\QSELECT task_id, row_id_base, row_id_max, total_kvs_base, total_bytes_base, checksum_base, status FROM `test`.`table_meta` WHERE table_id = ? FOR UPDATE\\E"). WithArgs(int64(1)). WillReturnError(errors.New("mock err")) s.mockDB.ExpectRollback() @@ -321,7 +322,7 @@ func (s *metaMgrSuite) prepareMockInner(rowsVal [][]driver.Value, nextRowID *int for _, r := range rowsVal { rows = rows.AddRow(r...) } - s.mockDB.ExpectQuery("\\QSELECT task_id, row_id_base, row_id_max, total_kvs_base, total_bytes_base, checksum_base, status from `test`.`table_meta` WHERE table_id = ? FOR UPDATE\\E"). + s.mockDB.ExpectQuery("\\QSELECT task_id, row_id_base, row_id_max, total_kvs_base, total_bytes_base, checksum_base, status FROM `test`.`table_meta` WHERE table_id = ? FOR UPDATE\\E"). WithArgs(int64(1)). WillReturnRows(rows) @@ -332,7 +333,7 @@ func (s *metaMgrSuite) prepareMockInner(rowsVal [][]driver.Value, nextRowID *int } if len(updateArgs) > 0 { - s.mockDB.ExpectExec("\\Qupdate `test`.`table_meta` set row_id_base = ?, row_id_max = ?, status = ? where table_id = ? and task_id = ?\\E"). + s.mockDB.ExpectExec("\\QUPDATE `test`.`table_meta` SET row_id_base = ?, row_id_max = ?, status = ? WHERE table_id = ? AND task_id = ?\\E"). WithArgs(updateArgs...). WillReturnResult(sqlmock.NewResult(int64(0), int64(1))) } @@ -345,7 +346,7 @@ func (s *metaMgrSuite) prepareMockInner(rowsVal [][]driver.Value, nextRowID *int s.mockDB.ExpectCommit() if checksum != nil { - s.mockDB.ExpectExec("\\Qupdate `test`.`table_meta` set total_kvs_base = ?, total_bytes_base = ?, checksum_base = ?, status = ? where table_id = ? and task_id = ?\\E"). + s.mockDB.ExpectExec("\\QUPDATE `test`.`table_meta` SET total_kvs_base = ?, total_bytes_base = ?, checksum_base = ?, status = ? WHERE table_id = ? AND task_id = ?\\E"). WithArgs(checksum.SumKVS(), checksum.SumSize(), checksum.Sum(), metaStatusRestoreStarted.String(), int64(1), int64(1)). WillReturnResult(sqlmock.NewResult(int64(0), int64(1))) s.checksumMgr.checksum = local.RemoteChecksum{ @@ -356,7 +357,7 @@ func (s *metaMgrSuite) prepareMockInner(rowsVal [][]driver.Value, nextRowID *int } if updateStatus != nil { - s.mockDB.ExpectExec("\\Qupdate `test`.`table_meta` set status = ? where table_id = ? and task_id = ?\\E"). + s.mockDB.ExpectExec("\\QUPDATE `test`.`table_meta` SET status = ? WHERE table_id = ? AND task_id = ?\\E"). WithArgs(*updateStatus, int64(1), int64(1)). WillReturnResult(sqlmock.NewResult(int64(0), int64(1))) } @@ -373,9 +374,10 @@ func newTaskMetaMgrSuite(t *testing.T) *taskMetaMgrSuite { var s taskMetaMgrSuite s.mgr = &dbTaskMetaMgr{ - session: db, - taskID: 1, - tableName: common.UniqueTable("test", "t1"), + session: db, + taskID: 1, + tableName: "t1", + schemaName: "test", } s.mockDB = m return &s @@ -386,7 +388,7 @@ func TestCheckTasksExclusively(t *testing.T) { s.mockDB.ExpectExec("SET SESSION tidb_txn_mode = 'pessimistic';"). WillReturnResult(sqlmock.NewResult(int64(0), int64(0))) s.mockDB.ExpectBegin() - s.mockDB.ExpectQuery("SELECT task_id, pd_cfgs, status, state, tikv_source_bytes, tiflash_source_bytes, tikv_avail, tiflash_avail from `test`.`t1` FOR UPDATE"). + s.mockDB.ExpectQuery("SELECT task_id, pd_cfgs, status, state, tikv_source_bytes, tiflash_source_bytes, tikv_avail, tiflash_avail FROM `test`.`t1` FOR UPDATE"). WillReturnRows(sqlmock.NewRows([]string{"task_id", "pd_cfgs", "status", "state", "tikv_source_bytes", "tiflash_source_bytes", "tiflash_avail", "tiflash_avail"}). AddRow("0", "", taskMetaStatusInitial.String(), "0", "0", "0", "0", "0"). AddRow("1", "", taskMetaStatusInitial.String(), "0", "0", "0", "0", "0"). diff --git a/br/pkg/lightning/importer/table_import.go b/br/pkg/lightning/importer/table_import.go index 469f1360cec06..7f13c9d1cc62b 100644 --- a/br/pkg/lightning/importer/table_import.go +++ b/br/pkg/lightning/importer/table_import.go @@ -1393,10 +1393,9 @@ func (tr *TableImporter) dropIndexes(ctx context.Context, db *sql.DB) error { logger := log.FromContext(ctx).With(zap.String("table", tr.tableName)) tblInfo := tr.tableInfo - tableName := common.UniqueTable(tblInfo.DB, tblInfo.Name) remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo.Core) for _, idxInfo := range dropIndexes { - sqlStr := common.BuildDropIndexSQL(tableName, idxInfo) + sqlStr := common.BuildDropIndexSQL(tblInfo.DB, tblInfo.Name, idxInfo) logger.Info("drop index", zap.String("sql", sqlStr)) diff --git a/br/pkg/lightning/mydump/loader.go b/br/pkg/lightning/mydump/loader.go index 630a40e015a24..3b3b95eadc8f4 100644 --- a/br/pkg/lightning/mydump/loader.go +++ b/br/pkg/lightning/mydump/loader.go @@ -69,7 +69,7 @@ func (m *MDDatabaseMeta) GetSchema(ctx context.Context, store storage.ExternalSt } } // set default if schema sql is empty or failed to extract. - return "CREATE DATABASE IF NOT EXISTS " + common.EscapeIdentifier(m.Name) + return common.SprintfWithIdentifiers("CREATE DATABASE IF NOT EXISTS %s", m.Name) } // MDTableMeta contains some parsed metadata for a table in the source by MyDumper Loader. diff --git a/pkg/disttask/importinto/scheduler.go b/pkg/disttask/importinto/scheduler.go index bed181d4d924c..5f91eaeb8f940 100644 --- a/pkg/disttask/importinto/scheduler.go +++ b/pkg/disttask/importinto/scheduler.go @@ -440,11 +440,10 @@ func (sch *importScheduler) Close() { // nolint:deadcode func dropTableIndexes(ctx context.Context, handle storage.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error { tblInfo := taskMeta.Plan.TableInfo - tableName := common.UniqueTable(taskMeta.Plan.DBName, tblInfo.Name.L) remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo) for _, idxInfo := range dropIndexes { - sqlStr := common.BuildDropIndexSQL(tableName, idxInfo) + sqlStr := common.BuildDropIndexSQL(taskMeta.Plan.DBName, tblInfo.Name.L, idxInfo) if err := executeSQL(ctx, handle, logger, sqlStr); err != nil { if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { switch merr.Number { diff --git a/pkg/executor/importer/precheck.go b/pkg/executor/importer/precheck.go index 1cb1c6f75a445..4570186bf7dd5 100644 --- a/pkg/executor/importer/precheck.go +++ b/pkg/executor/importer/precheck.go @@ -75,7 +75,7 @@ func (e *LoadDataController) checkTotalFileSize() error { } func (e *LoadDataController) checkTableEmpty(ctx context.Context, conn sqlexec.SQLExecutor) error { - sql := fmt.Sprintf("SELECT 1 FROM %s USE INDEX() LIMIT 1", common.UniqueTable(e.DBName, e.Table.Meta().Name.L)) + sql := common.SprintfWithIdentifiers("SELECT 1 FROM %s.%s USE INDEX() LIMIT 1", e.DBName, e.Table.Meta().Name.L) rs, err := conn.ExecuteInternal(ctx, sql) if err != nil { return err