Skip to content

Commit

Permalink
br: Solve SQL Injection Risk - Format String (pingcap#49666)
Browse files Browse the repository at this point in the history
  • Loading branch information
lyzx2001 authored and AilinKid committed Jan 17, 2024
1 parent 5dc5145 commit 3d01baf
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 49 deletions.
51 changes: 24 additions & 27 deletions br/pkg/lightning/checkpoints/checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -1615,12 +1615,12 @@ func (cpdb *MySQLCheckpointsDB) IgnoreErrorCheckpoint(ctx context.Context, table
Logger: log.FromContext(ctx).With(zap.String("table", tableName)),
}
err := s.Transact(ctx, "ignore error checkpoints", func(c context.Context, tx *sql.Tx) error {
query := fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE ? = ? AND status <= ?", cpdb.schema, CheckpointTableNameEngine)
if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, colName, tableName, CheckpointStatusMaxInvalid); e != nil {
query := fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE %s = ? AND status <= ?", cpdb.schema, CheckpointTableNameEngine, colName)
if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, tableName, CheckpointStatusMaxInvalid); e != nil {
return errors.Trace(e)
}
query = fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE ? = ? AND status <= ?", cpdb.schema, CheckpointTableNameTable)
if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, colName, tableName, CheckpointStatusMaxInvalid); e != nil {
query = fmt.Sprintf("UPDATE %s.%s SET status = ? WHERE %s = ? AND status <= ?", cpdb.schema, CheckpointTableNameTable, colName)
if _, e := tx.ExecContext(c, query, CheckpointStatusLoaded, tableName, CheckpointStatusMaxInvalid); e != nil {
return errors.Trace(e)
}
return nil
Expand All @@ -1643,30 +1643,27 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl
}

selectQuery := fmt.Sprintf(`
SELECT
t.table_name,
COALESCE(MIN(e.engine_id), 0),
COALESCE(MAX(e.engine_id), -1)
FROM %[1]s.%[4]s t
LEFT JOIN %[1]s.%[5]s e ON t.table_name = e.table_name
WHERE %[2]s = ? AND t.status <= %[3]d
GROUP BY t.table_name;
`, cpdb.schema, aliasedColName, CheckpointStatusMaxInvalid, CheckpointTableNameTable, CheckpointTableNameEngine)

// nolint:gosec
SELECT
t.table_name,
COALESCE(MIN(e.engine_id), 0),
COALESCE(MAX(e.engine_id), -1)
FROM %[1]s.%[3]s t
LEFT JOIN %[1]s.%[4]s e ON t.table_name = e.table_name
WHERE %[2]s = ? AND t.status <= ?
GROUP BY t.table_name;
`, cpdb.schema, aliasedColName, CheckpointTableNameTable, CheckpointTableNameEngine)

deleteChunkQuery := fmt.Sprintf(`
DELETE FROM %[1]s.%[4]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[5]s WHERE %[2]s = ? AND status <= %[3]d)
`, cpdb.schema, colName, CheckpointStatusMaxInvalid, CheckpointTableNameChunk, CheckpointTableNameTable)
DELETE FROM %[1]s.%[3]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[4]s WHERE %[2]s = ? AND status <= ?)
`, cpdb.schema, colName, CheckpointTableNameChunk, CheckpointTableNameTable)

// nolint:gosec
deleteEngineQuery := fmt.Sprintf(`
DELETE FROM %[1]s.%[4]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[5]s WHERE %[2]s = ? AND status <= %[3]d)
`, cpdb.schema, colName, CheckpointStatusMaxInvalid, CheckpointTableNameEngine, CheckpointTableNameTable)
DELETE FROM %[1]s.%[3]s WHERE table_name IN (SELECT table_name FROM %[1]s.%[4]s WHERE %[2]s = ? AND status <= ?)
`, cpdb.schema, colName, CheckpointTableNameEngine, CheckpointTableNameTable)

// nolint:gosec
deleteTableQuery := fmt.Sprintf(`
DELETE FROM %s.%s WHERE %s = ? AND status <= %d
`, cpdb.schema, CheckpointTableNameTable, colName, CheckpointStatusMaxInvalid)
DELETE FROM %s.%s WHERE %s = ? AND status <= ?
`, cpdb.schema, CheckpointTableNameTable, colName)

var targetTables []DestroyedTableCheckpoint

Expand All @@ -1677,7 +1674,7 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl
err := s.Transact(ctx, "destroy error checkpoints", func(c context.Context, tx *sql.Tx) error {
// Obtain the list of tables
targetTables = nil
rows, e := tx.QueryContext(c, selectQuery, tableName) // #nosec G201
rows, e := tx.QueryContext(c, selectQuery, tableName, CheckpointStatusMaxInvalid)
if e != nil {
return errors.Trace(e)
}
Expand All @@ -1695,13 +1692,13 @@ func (cpdb *MySQLCheckpointsDB) DestroyErrorCheckpoint(ctx context.Context, tabl
}

// Delete the checkpoints
if _, e := tx.ExecContext(c, deleteChunkQuery, tableName); e != nil {
if _, e := tx.ExecContext(c, deleteChunkQuery, tableName, CheckpointStatusMaxInvalid); e != nil {
return errors.Trace(e)
}
if _, e := tx.ExecContext(c, deleteEngineQuery, tableName); e != nil {
if _, e := tx.ExecContext(c, deleteEngineQuery, tableName, CheckpointStatusMaxInvalid); e != nil {
return errors.Trace(e)
}
if _, e := tx.ExecContext(c, deleteTableQuery, tableName); e != nil {
if _, e := tx.ExecContext(c, deleteTableQuery, tableName, CheckpointStatusMaxInvalid); e != nil {
return errors.Trace(e)
}
return nil
Expand Down
32 changes: 16 additions & 16 deletions br/pkg/lightning/checkpoints/checkpoints_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,12 @@ func TestIgnoreAllErrorCheckpoints_SQL(t *testing.T) {

s.mock.ExpectBegin()
s.mock.
ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE \\? = \\? AND status <= \\?").
WithArgs(checkpoints.CheckpointStatusLoaded, "'all'", sqlmock.AnyArg(), 25).
ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE 'all' = \\? AND status <= \\?").
WithArgs(checkpoints.CheckpointStatusLoaded, sqlmock.AnyArg(), 25).
WillReturnResult(sqlmock.NewResult(5, 3))
s.mock.
ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE \\? = \\? AND status <= \\?").
WithArgs(checkpoints.CheckpointStatusLoaded, "'all'", sqlmock.AnyArg(), 25).
ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE 'all' = \\? AND status <= \\?").
WithArgs(checkpoints.CheckpointStatusLoaded, sqlmock.AnyArg(), 25).
WillReturnResult(sqlmock.NewResult(6, 2))
s.mock.ExpectCommit()

Expand All @@ -413,12 +413,12 @@ func TestIgnoreOneErrorCheckpoint(t *testing.T) {

s.mock.ExpectBegin()
s.mock.
ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE \\? = \\? AND status <= \\?").
WithArgs(checkpoints.CheckpointStatusLoaded, "table_name", "`db1`.`t2`", 25).
ExpectExec("UPDATE `mock-schema`\\.engine_v\\d+ SET status = \\? WHERE table_name = \\? AND status <= \\?").
WithArgs(checkpoints.CheckpointStatusLoaded, "`db1`.`t2`", 25).
WillReturnResult(sqlmock.NewResult(5, 2))
s.mock.
ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE \\? = \\? AND status <= \\?").
WithArgs(checkpoints.CheckpointStatusLoaded, "table_name", "`db1`.`t2`", 25).
ExpectExec("UPDATE `mock-schema`\\.table_v\\d+ SET status = \\? WHERE table_name = \\? AND status <= \\?").
WithArgs(checkpoints.CheckpointStatusLoaded, "`db1`.`t2`", 25).
WillReturnResult(sqlmock.NewResult(6, 1))
s.mock.ExpectCommit()

Expand All @@ -432,22 +432,22 @@ func TestDestroyAllErrorCheckpoints_SQL(t *testing.T) {
s.mock.ExpectBegin()
s.mock.
ExpectQuery("SELECT (?s:.+)'all' = \\?").
WithArgs(sqlmock.AnyArg()).
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnRows(
sqlmock.NewRows([]string{"table_name", "__min__", "__max__"}).
AddRow("`db1`.`t2`", -1, 0),
)
s.mock.
ExpectExec("DELETE FROM `mock-schema`\\.chunk_v\\d+ WHERE table_name IN .+ 'all' = \\?").
WithArgs(sqlmock.AnyArg()).
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 5))
s.mock.
ExpectExec("DELETE FROM `mock-schema`\\.engine_v\\d+ WHERE table_name IN .+ 'all' = \\?").
WithArgs(sqlmock.AnyArg()).
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 3))
s.mock.
ExpectExec("DELETE FROM `mock-schema`\\.table_v\\d+ WHERE 'all' = \\?").
WithArgs(sqlmock.AnyArg()).
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 2))
s.mock.ExpectCommit()

Expand All @@ -466,22 +466,22 @@ func TestDestroyOneErrorCheckpoints(t *testing.T) {
s.mock.ExpectBegin()
s.mock.
ExpectQuery("SELECT (?s:.+)table_name = \\?").
WithArgs("`db1`.`t2`").
WithArgs("`db1`.`t2`", sqlmock.AnyArg()).
WillReturnRows(
sqlmock.NewRows([]string{"table_name", "__min__", "__max__"}).
AddRow("`db1`.`t2`", -1, 0),
)
s.mock.
ExpectExec("DELETE FROM `mock-schema`\\.chunk_v\\d+ WHERE .+table_name = \\?").
WithArgs("`db1`.`t2`").
WithArgs("`db1`.`t2`", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 4))
s.mock.
ExpectExec("DELETE FROM `mock-schema`\\.engine_v\\d+ WHERE .+table_name = \\?").
WithArgs("`db1`.`t2`").
WithArgs("`db1`.`t2`", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 2))
s.mock.
ExpectExec("DELETE FROM `mock-schema`\\.table_v\\d+ WHERE table_name = \\?").
WithArgs("`db1`.`t2`").
WithArgs("`db1`.`t2`", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
s.mock.ExpectCommit()

Expand Down
5 changes: 0 additions & 5 deletions br/pkg/lightning/importer/meta_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ func (m *dbTableMetaMgr) AllocTableRowIDs(ctx context.Context, rawRowIDMax int64
newStatus = metaStatusRestoreStarted
}

// nolint:gosec
query := fmt.Sprintf("update %s set row_id_base = ?, row_id_max = ?, status = ? where table_id = ? and task_id = ?", m.tableName)
_, err := tx.ExecContext(ctx, query, newRowIDBase, newRowIDMax, newStatus.String(), m.tr.tableInfo.ID, m.taskID)
if err != nil {
Expand Down Expand Up @@ -459,7 +458,6 @@ func (m *dbTableMetaMgr) CheckAndUpdateLocalChecksum(ctx context.Context, checks
return errors.Trace(err)
}

// nolint:gosec
query := fmt.Sprintf("update %s set total_kvs = ?, total_bytes = ?, checksum = ?, status = ?, has_duplicates = ? where table_id = ? and task_id = ?", m.tableName)
_, err = tx.ExecContext(ctx, query, checksum.SumKVS(), checksum.SumSize(), checksum.Sum(), newStatus.String(), hasLocalDupes, m.tr.tableInfo.ID, m.taskID)
return errors.Annotate(err, "update local checksum failed")
Expand Down Expand Up @@ -687,7 +685,6 @@ func (m *dbTaskMetaMgr) CheckTasksExclusively(ctx context.Context, action func(t
return errors.Trace(err)
}
for _, task := range newTasks {
// nolint:gosec
query := fmt.Sprintf("REPLACE INTO %s (task_id, pd_cfgs, status, state, tikv_source_bytes, tiflash_source_bytes, tikv_avail, tiflash_avail) VALUES(?, ?, ?, ?, ?, ?, ?, ?)", m.tableName)
if _, err = tx.ExecContext(ctx, query, task.taskID, task.pdCfgs, task.status.String(), task.state, task.tikvSourceBytes, task.tiflashSourceBytes, task.tikvAvail, task.tiflashAvail); err != nil {
return errors.Trace(err)
Expand Down Expand Up @@ -796,7 +793,6 @@ func (m *dbTaskMetaMgr) CheckAndPausePdSchedulers(ctx context.Context) (pdutil.U
return errors.Trace(err)
}

// nolint:gosec
query := fmt.Sprintf("update %s set pd_cfgs = ?, status = ? where task_id = ?", m.tableName)
_, err = tx.ExecContext(ctx, query, string(jsonByts), taskMetaStatusScheduleSet.String(), m.taskID)

Expand Down Expand Up @@ -914,7 +910,6 @@ func (m *dbTaskMetaMgr) CheckAndFinishRestore(ctx context.Context, finished bool
newStatus = taskMetaStatusSwitchSkipped
}

// nolint:gosec
query := fmt.Sprintf("update %s set status = ?, state = ? where task_id = ?", m.tableName)
if _, err = tx.ExecContext(ctx, query, newStatus.String(), newState, m.taskID); err != nil {
return errors.Trace(err)
Expand Down
1 change: 1 addition & 0 deletions br/pkg/lightning/importer/meta_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ func (s *metaMgrSuite) prepareMock(rowsVal [][]driver.Value, nextRowID *int64, u
WillReturnResult(sqlmock.NewResult(int64(0), int64(0)))
s.prepareMockInner(rowsVal, nextRowID, updateArgs, checksum, updateStatus, rollback)
}

func (s *metaMgrSuite) prepareMockInner(rowsVal [][]driver.Value, nextRowID *int64, updateArgs []driver.Value, checksum *verification.KVChecksum, updateStatus *string, rollback bool) {
s.mockDB.ExpectBegin()

Expand Down
2 changes: 1 addition & 1 deletion dumpling/tests/s3/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func main() {
return errors.Trace(err)
}

query := fmt.Sprintf("insert into %s values('aaaaaaaaaa')", table) // nolint:gosec
query := fmt.Sprintf("insert into %s values('aaaaaaaaaa')", table)
for i := 1; i < 10000; i++ {
query += ",('aaaaaaaaaa')"
}
Expand Down

0 comments on commit 3d01baf

Please sign in to comment.