From 3d01baf836d65849a563fd7f4846cca958523264 Mon Sep 17 00:00:00 2001 From: Luo Yangzhixin Date: Fri, 5 Jan 2024 12:24:33 +0800 Subject: [PATCH] br: Solve SQL Injection Risk - Format String (#49666) ref pingcap/tidb#30699 --- br/pkg/lightning/checkpoints/checkpoints.go | 51 +++++++++---------- .../checkpoints/checkpoints_sql_test.go | 32 ++++++------ br/pkg/lightning/importer/meta_manager.go | 5 -- .../lightning/importer/meta_manager_test.go | 1 + dumpling/tests/s3/import.go | 2 +- 5 files changed, 42 insertions(+), 49 deletions(-) diff --git a/br/pkg/lightning/checkpoints/checkpoints.go b/br/pkg/lightning/checkpoints/checkpoints.go index 64e1b9f972d39..aebbadc275190 100644 --- a/br/pkg/lightning/checkpoints/checkpoints.go +++ b/br/pkg/lightning/checkpoints/checkpoints.go @@ -1615,12 +1615,12 @@ func (cpdb *MySQLCheckpointsDB) IgnoreErrorCheckpoint(ctx context.Context, table Logger: log.FromContext(ctx).With(zap.String("table", tableName)), } err := s.Transact(ctx, "ignore error checkpoints", func(c context.Context, tx *sql.Tx) error { - query := fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE ? = ? AND status <= ?", cpdb.schema, CheckpointTableNameEngine) - if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, colName, tableName, CheckpointStatusMaxInvalid); e != nil { + query := fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE %s = ? AND status <= ?", cpdb.schema, CheckpointTableNameEngine, colName) + if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, tableName, CheckpointStatusMaxInvalid); e != nil { return errors.Trace(e) } - query = fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE ? = ? AND status <= ?", cpdb.schema, CheckpointTableNameTable) - if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, colName, tableName, CheckpointStatusMaxInvalid); e != nil { + query = fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE %s = ? AND status <= ?", cpdb.schema, CheckpointTableNameTable, colName) + if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, tableName, CheckpointStatusMaxInvalid); e != nil { return errors.Trace(e) } return nil @@ -1643,30 +1643,27 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl } selectQuery := fmt.Sprintf(` - SELECT - t.table_name, - COALESCE(MIN(e.engine_id), 0), - COALESCE(MAX(e.engine_id), -1) - FROM %[1]s.%[4]s t - LEFT JOIN %[1]s.%[5]s e ON t.table_name = e.table_name - WHERE %[2]s = ? AND t.status <= %[3]d - GROUP BY t.table_name; - `, cpdb.schema, aliasedColName, CheckpointStatusMaxInvalid, CheckpointTableNameTable, CheckpointTableNameEngine) - - // nolint:gosec + SELECT + t.table_name, + COALESCE(MIN(e.engine_id), 0), + COALESCE(MAX(e.engine_id), -1) + FROM %[1]s.%[3]s t + LEFT JOIN %[1]s.%[4]s e ON t.table_name = e.table_name + WHERE %[2]s = ? AND t.status <= ? + GROUP BY t.table_name; + `, cpdb.schema, aliasedColName, CheckpointTableNameTable, CheckpointTableNameEngine) + deleteChunkQuery := fmt.Sprintf(` - DELETE FROM %[1]s.%[4]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[5]s WHERE %[2]s = ? AND status <= %[3]d) - `, cpdb.schema, colName, CheckpointStatusMaxInvalid, CheckpointTableNameChunk, CheckpointTableNameTable) + DELETE FROM %[1]s.%[3]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[4]s WHERE %[2]s = ? AND status <= ?) + `, cpdb.schema, colName, CheckpointTableNameChunk, CheckpointTableNameTable) - // nolint:gosec deleteEngineQuery := fmt.Sprintf(` - DELETE FROM %[1]s.%[4]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[5]s WHERE %[2]s = ? AND status <= %[3]d) - `, cpdb.schema, colName, CheckpointStatusMaxInvalid, CheckpointTableNameEngine, CheckpointTableNameTable) + DELETE FROM %[1]s.%[3]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[4]s WHERE %[2]s = ? AND status <= ?) + `, cpdb.schema, colName, CheckpointTableNameEngine, CheckpointTableNameTable) - // nolint:gosec deleteTableQuery := fmt.Sprintf(` - DELETE FROM %s.%s WHERE %s = ? AND status <= %d - `, cpdb.schema, CheckpointTableNameTable, colName, CheckpointStatusMaxInvalid) + DELETE FROM %s.%s WHERE %s = ? AND status <= ? + `, cpdb.schema, CheckpointTableNameTable, colName) var targetTables []DestroyedTableCheckpoint @@ -1677,7 +1674,7 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl err := s.Transact(ctx, "destroy error checkpoints", func(c context.Context, tx *sql.Tx) error { // Obtain the list of tables targetTables = nil - rows, e := tx.QueryContext(c, selectQuery, tableName) // #nosec G201 + rows, e := tx.QueryContext(c, selectQuery, tableName, CheckpointStatusMaxInvalid) if e != nil { return errors.Trace(e) } @@ -1695,13 +1692,13 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl } // Delete the checkpoints - if _, e := tx.ExecContext(c, deleteChunkQuery, tableName); e != nil { + if _, e := tx.ExecContext(c, deleteChunkQuery, tableName, CheckpointStatusMaxInvalid); e != nil { return errors.Trace(e) } - if _, e := tx.ExecContext(c, deleteEngineQuery, tableName); e != nil { + if _, e := tx.ExecContext(c, deleteEngineQuery, tableName, CheckpointStatusMaxInvalid); e != nil { return errors.Trace(e) } - if _, e := tx.ExecContext(c, deleteTableQuery, tableName); e != nil { + if _, e := tx.ExecContext(c, deleteTableQuery, tableName, CheckpointStatusMaxInvalid); e != nil { return errors.Trace(e) } return nil diff --git a/br/pkg/lightning/checkpoints/checkpoints_sql_test.go b/br/pkg/lightning/checkpoints/checkpoints_sql_test.go index 4b4b35269bda5..6c8aa823edcd1 100644 --- a/br/pkg/lightning/checkpoints/checkpoints_sql_test.go +++ b/br/pkg/lightning/checkpoints/checkpoints_sql_test.go @@ -395,12 +395,12 @@ func TestIgnoreAllErrorCheckpoints_SQL(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE \\? = \\? AND status <= \\?"). - WithArgs(checkpoints.CheckpointStatusLoaded, "'all'", sqlmock.AnyArg(), 25). + ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE 'all' = \\? AND status <= \\?"). + WithArgs(checkpoints.CheckpointStatusLoaded, sqlmock.AnyArg(), 25). WillReturnResult(sqlmock.NewResult(5, 3)) s.mock. - ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE \\? = \\? AND status <= \\?"). - WithArgs(checkpoints.CheckpointStatusLoaded, "'all'", sqlmock.AnyArg(), 25). + ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE 'all' = \\? AND status <= \\?"). + WithArgs(checkpoints.CheckpointStatusLoaded, sqlmock.AnyArg(), 25). WillReturnResult(sqlmock.NewResult(6, 2)) s.mock.ExpectCommit() @@ -413,12 +413,12 @@ func TestIgnoreOneErrorCheckpoint(t *testing.T) { s.mock.ExpectBegin() s.mock. - ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE \\? = \\? AND status <= \\?"). - WithArgs(checkpoints.CheckpointStatusLoaded, "table_name", "`db1`.`t2`", 25). + ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE table_name = \\? AND status <= \\?"). + WithArgs(checkpoints.CheckpointStatusLoaded, "`db1`.`t2`", 25). WillReturnResult(sqlmock.NewResult(5, 2)) s.mock. - ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE \\? = \\? AND status <= \\?"). - WithArgs(checkpoints.CheckpointStatusLoaded, "table_name", "`db1`.`t2`", 25). + ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE table_name = \\? AND status <= \\?"). + WithArgs(checkpoints.CheckpointStatusLoaded, "`db1`.`t2`", 25). WillReturnResult(sqlmock.NewResult(6, 1)) s.mock.ExpectCommit() @@ -432,22 +432,22 @@ func TestDestroyAllErrorCheckpoints_SQL(t *testing.T) { s.mock.ExpectBegin() s.mock. ExpectQuery("SELECT (?s:.+)'all' = \\?"). - WithArgs(sqlmock.AnyArg()). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnRows( sqlmock.NewRows([]string{"table_name", "__min__", "__max__"}). AddRow("`db1`.`t2`", -1, 0), ) s.mock. ExpectExec("DELETE FROM `mock-schema`\\.chunk_v\\d+ WHERE table_name IN .+ 'all' = \\?"). - WithArgs(sqlmock.AnyArg()). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 5)) s.mock. ExpectExec("DELETE FROM `mock-schema`\\.engine_v\\d+ WHERE table_name IN .+ 'all' = \\?"). - WithArgs(sqlmock.AnyArg()). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 3)) s.mock. ExpectExec("DELETE FROM `mock-schema`\\.table_v\\d+ WHERE 'all' = \\?"). - WithArgs(sqlmock.AnyArg()). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 2)) s.mock.ExpectCommit() @@ -466,22 +466,22 @@ func TestDestroyOneErrorCheckpoints(t *testing.T) { s.mock.ExpectBegin() s.mock. ExpectQuery("SELECT (?s:.+)table_name = \\?"). - WithArgs("`db1`.`t2`"). + WithArgs("`db1`.`t2`", sqlmock.AnyArg()). WillReturnRows( sqlmock.NewRows([]string{"table_name", "__min__", "__max__"}). AddRow("`db1`.`t2`", -1, 0), ) s.mock. ExpectExec("DELETE FROM `mock-schema`\\.chunk_v\\d+ WHERE .+table_name = \\?"). - WithArgs("`db1`.`t2`"). + WithArgs("`db1`.`t2`", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 4)) s.mock. ExpectExec("DELETE FROM `mock-schema`\\.engine_v\\d+ WHERE .+table_name = \\?"). - WithArgs("`db1`.`t2`"). + WithArgs("`db1`.`t2`", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 2)) s.mock. ExpectExec("DELETE FROM `mock-schema`\\.table_v\\d+ WHERE table_name = \\?"). - WithArgs("`db1`.`t2`"). + WithArgs("`db1`.`t2`", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(0, 1)) s.mock.ExpectCommit() diff --git a/br/pkg/lightning/importer/meta_manager.go b/br/pkg/lightning/importer/meta_manager.go index fff80ddb09718..98b68e3090777 100644 --- a/br/pkg/lightning/importer/meta_manager.go +++ b/br/pkg/lightning/importer/meta_manager.go @@ -270,7 +270,6 @@ func (m *dbTableMetaMgr) AllocTableRowIDs(ctx context.Context, rawRowIDMax int64 newStatus = metaStatusRestoreStarted } - // nolint:gosec query := fmt.Sprintf("update %s set row_id_base = ?, row_id_max = ?, status = ? where table_id = ? and task_id = ?", m.tableName) _, err := tx.ExecContext(ctx, query, newRowIDBase, newRowIDMax, newStatus.String(), m.tr.tableInfo.ID, m.taskID) if err != nil { @@ -459,7 +458,6 @@ func (m *dbTableMetaMgr) CheckAndUpdateLocalChecksum(ctx context.Context, checks return errors.Trace(err) } - // nolint:gosec query := fmt.Sprintf("update %s set total_kvs = ?, total_bytes = ?, checksum = ?, status = ?, has_duplicates = ? where table_id = ? and task_id = ?", m.tableName) _, err = tx.ExecContext(ctx, query, checksum.SumKVS(), checksum.SumSize(), checksum.Sum(), newStatus.String(), hasLocalDupes, m.tr.tableInfo.ID, m.taskID) return errors.Annotate(err, "update local checksum failed") @@ -687,7 +685,6 @@ func (m *dbTaskMetaMgr) CheckTasksExclusively(ctx context.Context, action func(t return errors.Trace(err) } for _, task := range newTasks { - // nolint:gosec query := fmt.Sprintf("REPLACE INTO %s (task_id, pd_cfgs, status, state, tikv_source_bytes, tiflash_source_bytes, tikv_avail, tiflash_avail) VALUES(?, ?, ?, ?, ?, ?, ?, ?)", m.tableName) if _, err = tx.ExecContext(ctx, query, task.taskID, task.pdCfgs, task.status.String(), task.state, task.tikvSourceBytes, task.tiflashSourceBytes, task.tikvAvail, task.tiflashAvail); err != nil { return errors.Trace(err) @@ -796,7 +793,6 @@ func (m *dbTaskMetaMgr) CheckAndPausePdSchedulers(ctx context.Context) (pdutil.U return errors.Trace(err) } - // nolint:gosec query := fmt.Sprintf("update %s set pd_cfgs = ?, status = ? where task_id = ?", m.tableName) _, err = tx.ExecContext(ctx, query, string(jsonByts), taskMetaStatusScheduleSet.String(), m.taskID) @@ -914,7 +910,6 @@ func (m *dbTaskMetaMgr) CheckAndFinishRestore(ctx context.Context, finished bool newStatus = taskMetaStatusSwitchSkipped } - // nolint:gosec query := fmt.Sprintf("update %s set status = ?, state = ? where task_id = ?", m.tableName) if _, err = tx.ExecContext(ctx, query, newStatus.String(), newState, m.taskID); err != nil { return errors.Trace(err) diff --git a/br/pkg/lightning/importer/meta_manager_test.go b/br/pkg/lightning/importer/meta_manager_test.go index d778c04645dc0..fe7713d21d7b6 100644 --- a/br/pkg/lightning/importer/meta_manager_test.go +++ b/br/pkg/lightning/importer/meta_manager_test.go @@ -313,6 +313,7 @@ func (s *metaMgrSuite) prepareMock(rowsVal [][]driver.Value, nextRowID *int64, u WillReturnResult(sqlmock.NewResult(int64(0), int64(0))) s.prepareMockInner(rowsVal, nextRowID, updateArgs, checksum, updateStatus, rollback) } + func (s *metaMgrSuite) prepareMockInner(rowsVal [][]driver.Value, nextRowID *int64, updateArgs []driver.Value, checksum *verification.KVChecksum, updateStatus *string, rollback bool) { s.mockDB.ExpectBegin() diff --git a/dumpling/tests/s3/import.go b/dumpling/tests/s3/import.go index 76b2c75932a68..d5ae518fd5608 100644 --- a/dumpling/tests/s3/import.go +++ b/dumpling/tests/s3/import.go @@ -66,7 +66,7 @@ func main() { return errors.Trace(err) } - query := fmt.Sprintf("insert into %s values('aaaaaaaaaa')", table) // nolint:gosec + query := fmt.Sprintf("insert into %s values('aaaaaaaaaa')", table) for i := 1; i < 10000; i++ { query += ",('aaaaaaaaaa')" }