diff --git a/br/cmd/br/backup.go b/br/cmd/br/backup.go index 925cf58c132b4..6525c85d535b9 100644 --- a/br/cmd/br/backup.go +++ b/br/cmd/br/backup.go @@ -22,10 +22,11 @@ import ( func runBackupCommand(command *cobra.Command, cmdName string) error { cfg := task.BackupConfig{Config: task.Config{LogProgress: HasLogFile()}} - if err := cfg.ParseFromFlags(command.Flags()); err != nil { + if err := cfg.ParseFromFlags(command.Flags(), false); err != nil { command.SilenceUsage = false return errors.Trace(err) } + overrideDefaultBackupConfigIfNeeded(&cfg, command) if err := metricsutil.RegisterMetricsForBR(cfg.PD, cfg.KeyspaceName); err != nil { return errors.Trace(err) @@ -211,3 +212,10 @@ func newTxnBackupCommand() *cobra.Command { task.DefineTxnBackupFlags(command) return command } + +func overrideDefaultBackupConfigIfNeeded(config *task.BackupConfig, cmd *cobra.Command) { + // override only if flag not set by user + if !cmd.Flags().Changed(task.FlagChecksum) { + config.Checksum = false + } +} diff --git a/br/cmd/br/cmd.go b/br/cmd/br/cmd.go index df0395fa1d719..afda0a6c26473 100644 --- a/br/cmd/br/cmd.go +++ b/br/cmd/br/cmd.go @@ -81,8 +81,8 @@ func timestampLogFileName() string { return filepath.Join(os.TempDir(), time.Now().Format("br.log.2006-01-02T15.04.05Z0700")) } -// AddFlags adds flags to the given cmd. -func AddFlags(cmd *cobra.Command) { +// DefineCommonFlags defines the common flags for all BR cmd operation. +func DefineCommonFlags(cmd *cobra.Command) { cmd.Version = build.Info() cmd.Flags().BoolP(flagVersion, flagVersionShort, false, "Display version information about BR") cmd.SetVersionTemplate("{{printf \"%s\" .Version}}\n") @@ -99,6 +99,8 @@ func AddFlags(cmd *cobra.Command) { "Set whether to redact sensitive info in log") cmd.PersistentFlags().String(FlagStatusAddr, "", "Set the HTTP listening address for the status report service. Set to empty string to disable") + + // defines BR task common flags, this is shared by cmd and sql(brie) task.DefineCommonFlags(cmd.PersistentFlags()) cmd.PersistentFlags().StringP(FlagSlowLogFile, "", "", diff --git a/br/cmd/br/main.go b/br/cmd/br/main.go index f745920f5bfba..cad081606a0ea 100644 --- a/br/cmd/br/main.go +++ b/br/cmd/br/main.go @@ -20,7 +20,7 @@ func main() { TraverseChildren: true, SilenceUsage: true, } - AddFlags(rootCmd) + DefineCommonFlags(rootCmd) SetDefaultContext(ctx) rootCmd.AddCommand( NewDebugCommand(), diff --git a/br/cmd/br/restore.go b/br/cmd/br/restore.go index 916ed3b703933..820bf1abf505d 100644 --- a/br/cmd/br/restore.go +++ b/br/cmd/br/restore.go @@ -25,7 +25,7 @@ import ( func runRestoreCommand(command *cobra.Command, cmdName string) error { cfg := task.RestoreConfig{Config: task.Config{LogProgress: HasLogFile()}} - if err := cfg.ParseFromFlags(command.Flags()); err != nil { + if err := cfg.ParseFromFlags(command.Flags(), false); err != nil { command.SilenceUsage = false return errors.Trace(err) } diff --git a/br/pkg/backup/schema.go b/br/pkg/backup/schema.go index bd33b29d70240..7b1640b0d30e6 100644 --- a/br/pkg/backup/schema.go +++ b/br/pkg/backup/schema.go @@ -106,7 +106,7 @@ func (ss *Schemas) BackupSchemas( } var checksum *checkpoint.ChecksumItem - var exists bool = false + var exists = false if ss.checkpointChecksum != nil && schema.tableInfo != nil { checksum, exists = ss.checkpointChecksum[schema.tableInfo.ID] } @@ -145,7 +145,7 @@ func (ss *Schemas) BackupSchemas( zap.Uint64("Crc64Xor", schema.crc64xor), zap.Uint64("TotalKvs", schema.totalKvs), zap.Uint64("TotalBytes", schema.totalBytes), - zap.Duration("calculate-take", calculateCost)) + zap.Duration("TimeTaken", calculateCost)) } } if statsHandle != nil { diff --git a/br/pkg/metautil/metafile.go b/br/pkg/metautil/metafile.go index 03cc95ca1b5de..814b5d75d194b 100644 --- a/br/pkg/metautil/metafile.go +++ b/br/pkg/metautil/metafile.go @@ -171,11 +171,6 @@ type Table struct { StatsFileIndexes []*backuppb.StatsFileIndex } -// NoChecksum checks whether the table has a calculated checksum. -func (tbl *Table) NoChecksum() bool { - return tbl.Crc64Xor == 0 && tbl.TotalKvs == 0 && tbl.TotalBytes == 0 -} - // MetaReader wraps a reader to read both old and new version of backupmeta. type MetaReader struct { storage storage.ExternalStorage @@ -240,7 +235,7 @@ func (reader *MetaReader) readDataFiles(ctx context.Context, output func(*backup } // ArchiveSize return the size of Archive data -func (*MetaReader) ArchiveSize(_ context.Context, files []*backuppb.File) uint64 { +func ArchiveSize(files []*backuppb.File) uint64 { total := uint64(0) for _, file := range files { total += file.Size_ @@ -248,6 +243,30 @@ func (*MetaReader) ArchiveSize(_ context.Context, files []*backuppb.File) uint64 return total } +type ChecksumStats struct { + Crc64Xor uint64 + TotalKvs uint64 + TotalBytes uint64 +} + +func (stats ChecksumStats) ChecksumExists() bool { + if stats.Crc64Xor == 0 && stats.TotalKvs == 0 && stats.TotalBytes == 0 { + return false + } + return true +} + +// CalculateChecksumStatsOnFiles returns the ChecksumStats for the given files +func CalculateChecksumStatsOnFiles(files []*backuppb.File) ChecksumStats { + var stats ChecksumStats + for _, file := range files { + stats.Crc64Xor ^= file.Crc64Xor + stats.TotalKvs += file.TotalKvs + stats.TotalBytes += file.TotalBytes + } + return stats +} + // ReadDDLs reads the ddls from the backupmeta. // This function is compatible with the old backupmeta. func (reader *MetaReader) ReadDDLs(ctx context.Context) ([]byte, error) { diff --git a/br/pkg/restore/snap_client/client.go b/br/pkg/restore/snap_client/client.go new file mode 100644 index 0000000000000..957ec300cff94 --- /dev/null +++ b/br/pkg/restore/snap_client/client.go @@ -0,0 +1,1121 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapclient + +import ( + "bytes" + "cmp" + "context" + "crypto/tls" + "encoding/json" + "slices" + "strings" + "sync" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/checkpoint" + "github.com/pingcap/tidb/br/pkg/checksum" + "github.com/pingcap/tidb/br/pkg/conn" + "github.com/pingcap/tidb/br/pkg/conn/util" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/pdutil" + "github.com/pingcap/tidb/br/pkg/restore" + importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" + tidallocdb "github.com/pingcap/tidb/br/pkg/restore/internal/prealloc_db" + tidalloc "github.com/pingcap/tidb/br/pkg/restore/internal/prealloc_table_id" + "github.com/pingcap/tidb/br/pkg/restore/split" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/model" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/redact" + kvutil "github.com/tikv/client-go/v2/util" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/keepalive" +) + +const ( + strictPlacementPolicyMode = "STRICT" + ignorePlacementPolicyMode = "IGNORE" + + defaultDDLConcurrency = 100 + maxSplitKeysOnce = 10240 +) + +const minBatchDdlSize = 1 + +type SnapClient struct { + // Tool clients used by SnapClient + fileImporter *SnapFileImporter + pdClient pd.Client + pdHTTPClient pdhttp.Client + + // User configurable parameters + cipher *backuppb.CipherInfo + concurrencyPerStore uint + keepaliveConf keepalive.ClientParameters + rateLimit uint64 + tlsConf *tls.Config + + switchCh chan struct{} + + storeCount int + supportPolicy bool + workerPool *tidbutil.WorkerPool + + noSchema bool + hasSpeedLimited bool + + databases map[string]*metautil.Database + ddlJobs []*model.Job + + // store tables need to rebase info like auto id and random id and so on after create table + rebasedTablesMap map[restore.UniqueTableName]bool + + backupMeta *backuppb.BackupMeta + + // TODO Remove this field or replace it with a []*DB, + // since https://github.com/pingcap/br/pull/377 needs more DBs to speed up DDL execution. + // And for now, we must inject a pool of DBs to `Client.GoCreateTables`, otherwise there would be a race condition. + // This is dirty: why we need DBs from different sources? + // By replace it with a []*DB, we can remove the dirty parameter of `Client.GoCreateTable`, + // along with them in some private functions. + // Before you do it, you can firstly read discussions at + // https://github.com/pingcap/br/pull/377#discussion_r446594501, + // this probably isn't as easy as it seems like (however, not hard, too :D) + db *tidallocdb.DB + + // use db pool to speed up restoration in BR binary mode. + dbPool []*tidallocdb.DB + + dom *domain.Domain + + // correspond to --tidb-placement-mode config. + // STRICT(default) means policy related SQL can be executed in tidb. + // IGNORE means policy related SQL will be ignored. + policyMode string + + // policy name -> policy info + policyMap *sync.Map + + batchDdlSize uint + + // if fullClusterRestore = true: + // - if there's system tables in the backup(backup data since br 5.1.0), the cluster should be a fresh cluster + // without user database or table. and system tables about privileges is restored together with user data. + // - if there no system tables in the backup(backup data from br < 5.1.0), restore all user data just like + // previous version did. + // if fullClusterRestore = false, restore all user data just like previous version did. + // fullClusterRestore = true when there is no explicit filter setting, and it's full restore or point command + // with a full backup data. + // todo: maybe change to an enum + // this feature is controlled by flag with-sys-table + fullClusterRestore bool + + // see RestoreCommonConfig.WithSysTable + withSysTable bool + + // the rewrite mode of the downloaded SST files in TiKV. + rewriteMode RewriteMode + + // checkpoint information for snapshot restore + checkpointRunner *checkpoint.CheckpointRunner[checkpoint.RestoreKeyType, checkpoint.RestoreValueType] + checkpointChecksum map[int64]*checkpoint.ChecksumItem +} + +// NewRestoreClient returns a new RestoreClient. +func NewRestoreClient( + pdClient pd.Client, + pdHTTPCli pdhttp.Client, + tlsConf *tls.Config, + keepaliveConf keepalive.ClientParameters, +) *SnapClient { + return &SnapClient{ + pdClient: pdClient, + pdHTTPClient: pdHTTPCli, + tlsConf: tlsConf, + keepaliveConf: keepaliveConf, + switchCh: make(chan struct{}), + } +} + +func (rc *SnapClient) closeConn() { + // rc.db can be nil in raw kv mode. + if rc.db != nil { + rc.db.Close() + } + for _, db := range rc.dbPool { + db.Close() + } +} + +// Close a client. +func (rc *SnapClient) Close() { + // close the connection, and it must be succeed when in SQL mode. + rc.closeConn() + + if err := rc.fileImporter.Close(); err != nil { + log.Warn("failed to close file importer") + } + + log.Info("Restore client closed") +} + +func (rc *SnapClient) SetRateLimit(rateLimit uint64) { + rc.rateLimit = rateLimit +} + +func (rc *SnapClient) SetCrypter(crypter *backuppb.CipherInfo) { + rc.cipher = crypter +} + +// GetClusterID gets the cluster id from down-stream cluster. +func (rc *SnapClient) GetClusterID(ctx context.Context) uint64 { + return rc.pdClient.GetClusterID(ctx) +} + +func (rc *SnapClient) GetDomain() *domain.Domain { + return rc.dom +} + +// GetTLSConfig returns the tls config. +func (rc *SnapClient) GetTLSConfig() *tls.Config { + return rc.tlsConf +} + +// GetSupportPolicy tells whether target tidb support placement policy. +func (rc *SnapClient) GetSupportPolicy() bool { + return rc.supportPolicy +} + +func (rc *SnapClient) updateConcurrency() { + // we believe 32 is large enough for download worker pool. + // it won't reach the limit if sst files distribute evenly. + // when restore memory usage is still too high, we should reduce concurrencyPerStore + // to sarifice some speed to reduce memory usage. + count := uint(rc.storeCount) * rc.concurrencyPerStore * 32 + log.Info("download coarse worker pool", zap.Uint("size", count)) + rc.workerPool = tidbutil.NewWorkerPool(count, "file") +} + +// SetConcurrencyPerStore sets the concurrency of download files for each store. +func (rc *SnapClient) SetConcurrencyPerStore(c uint) { + log.Info("per-store download worker pool", zap.Uint("size", c)) + rc.concurrencyPerStore = c +} + +func (rc *SnapClient) SetBatchDdlSize(batchDdlsize uint) { + rc.batchDdlSize = batchDdlsize +} + +func (rc *SnapClient) GetBatchDdlSize() uint { + return rc.batchDdlSize +} + +func (rc *SnapClient) SetWithSysTable(withSysTable bool) { + rc.withSysTable = withSysTable +} + +// TODO: remove this check and return RewriteModeKeyspace +func (rc *SnapClient) SetRewriteMode(ctx context.Context) { + if err := version.CheckClusterVersion(ctx, rc.pdClient, version.CheckVersionForKeyspaceBR); err != nil { + log.Warn("Keyspace BR is not supported in this cluster, fallback to legacy restore", zap.Error(err)) + rc.rewriteMode = RewriteModeLegacy + } else { + rc.rewriteMode = RewriteModeKeyspace + } +} + +func (rc *SnapClient) GetRewriteMode() RewriteMode { + return rc.rewriteMode +} + +// SetPlacementPolicyMode to policy mode. +func (rc *SnapClient) SetPlacementPolicyMode(withPlacementPolicy string) { + switch strings.ToUpper(withPlacementPolicy) { + case strictPlacementPolicyMode: + rc.policyMode = strictPlacementPolicyMode + case ignorePlacementPolicyMode: + rc.policyMode = ignorePlacementPolicyMode + default: + rc.policyMode = strictPlacementPolicyMode + } + log.Info("set placement policy mode", zap.String("mode", rc.policyMode)) +} + +// AllocTableIDs would pre-allocate the table's origin ID if exists, so that the TiKV doesn't need to rewrite the key in +// the download stage. +func (rc *SnapClient) AllocTableIDs(ctx context.Context, tables []*metautil.Table) error { + preallocedTableIDs := tidalloc.New(tables) + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) + err := kv.RunInNewTxn(ctx, rc.GetDomain().Store(), true, func(_ context.Context, txn kv.Transaction) error { + return preallocedTableIDs.Alloc(meta.NewMutator(txn)) + }) + if err != nil { + return err + } + + log.Info("registering the table IDs", zap.Stringer("ids", preallocedTableIDs)) + for i := range rc.dbPool { + rc.dbPool[i].RegisterPreallocatedIDs(preallocedTableIDs) + } + if rc.db != nil { + rc.db.RegisterPreallocatedIDs(preallocedTableIDs) + } + return nil +} + +// InitCheckpoint initialize the checkpoint status for the cluster. If the cluster is +// restored for the first time, it will initialize the checkpoint metadata. Otherwrise, +// it will load checkpoint metadata and checkpoint ranges/checksum from the external +// storage. +func (rc *SnapClient) InitCheckpoint( + ctx context.Context, + g glue.Glue, store kv.Storage, + config *pdutil.ClusterConfig, + checkpointFirstRun bool, +) (checkpointSetWithTableID map[int64]map[string]struct{}, checkpointClusterConfig *pdutil.ClusterConfig, err error) { + // checkpoint sets distinguished by range key + checkpointSetWithTableID = make(map[int64]map[string]struct{}) + + if !checkpointFirstRun { + execCtx := rc.db.Session().GetSessionCtx().GetRestrictedSQLExecutor() + // load the checkpoint since this is not the first time to restore + meta, err := checkpoint.LoadCheckpointMetadataForSnapshotRestore(ctx, execCtx) + if err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + + if meta.UpstreamClusterID != rc.backupMeta.ClusterId { + return checkpointSetWithTableID, nil, errors.Errorf( + "The upstream cluster id[%d] of the current snapshot restore does not match that[%d] recorded in checkpoint. "+ + "Perhaps you should specify the last full backup storage instead, "+ + "or just clean the checkpoint database[%s] if the cluster has been cleaned up.", + rc.backupMeta.ClusterId, meta.UpstreamClusterID, checkpoint.SnapshotRestoreCheckpointDatabaseName) + } + + if meta.RestoredTS != rc.backupMeta.EndVersion { + return checkpointSetWithTableID, nil, errors.Errorf( + "The current snapshot restore want to restore cluster to the BackupTS[%d], which is different from that[%d] recorded in checkpoint. "+ + "Perhaps you should specify the last full backup storage instead, "+ + "or just clean the checkpoint database[%s] if the cluster has been cleaned up.", + rc.backupMeta.EndVersion, meta.RestoredTS, checkpoint.SnapshotRestoreCheckpointDatabaseName, + ) + } + + // The schedulers config is nil, so the restore-schedulers operation is just nil. + // Then the undo function would use the result undo of `remove schedulers` operation, + // instead of that in checkpoint meta. + if meta.SchedulersConfig != nil { + checkpointClusterConfig = meta.SchedulersConfig + } + + // t1 is the latest time the checkpoint ranges persisted to the external storage. + t1, err := checkpoint.LoadCheckpointDataForSnapshotRestore(ctx, execCtx, func(tableID int64, rangeKey checkpoint.RestoreValueType) { + checkpointSet, exists := checkpointSetWithTableID[tableID] + if !exists { + checkpointSet = make(map[string]struct{}) + checkpointSetWithTableID[tableID] = checkpointSet + } + checkpointSet[rangeKey.RangeKey] = struct{}{} + }) + if err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + // t2 is the latest time the checkpoint checksum persisted to the external storage. + checkpointChecksum, t2, err := checkpoint.LoadCheckpointChecksumForRestore(ctx, execCtx) + if err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + rc.checkpointChecksum = checkpointChecksum + // use the later time to adjust the summary elapsed time. + if t1 > t2 { + summary.AdjustStartTimeToEarlierTime(t1) + } else { + summary.AdjustStartTimeToEarlierTime(t2) + } + } else { + // initialize the checkpoint metadata since it is the first time to restore. + meta := &checkpoint.CheckpointMetadataForSnapshotRestore{ + UpstreamClusterID: rc.backupMeta.ClusterId, + RestoredTS: rc.backupMeta.EndVersion, + } + // a nil config means undo function + if config != nil { + meta.SchedulersConfig = &pdutil.ClusterConfig{Schedulers: config.Schedulers, ScheduleCfg: config.ScheduleCfg} + } + if err := checkpoint.SaveCheckpointMetadataForSnapshotRestore(ctx, rc.db.Session(), meta); err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + } + + se, err := g.CreateSession(store) + if err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + rc.checkpointRunner, err = checkpoint.StartCheckpointRunnerForRestore(ctx, se) + return checkpointSetWithTableID, checkpointClusterConfig, errors.Trace(err) +} + +func (rc *SnapClient) WaitForFinishCheckpoint(ctx context.Context, flush bool) { + if rc.checkpointRunner != nil { + rc.checkpointRunner.WaitForFinish(ctx, flush) + } +} + +// makeDBPool makes a session pool with specficated size by sessionFactory. +func makeDBPool(size uint, dbFactory func() (*tidallocdb.DB, error)) ([]*tidallocdb.DB, error) { + dbPool := make([]*tidallocdb.DB, 0, size) + for i := uint(0); i < size; i++ { + db, e := dbFactory() + if e != nil { + return dbPool, e + } + if db != nil { + dbPool = append(dbPool, db) + } + } + return dbPool, nil +} + +// Init create db connection and domain for storage. +func (rc *SnapClient) Init(g glue.Glue, store kv.Storage) error { + // setDB must happen after set PolicyMode. + // we will use policyMode to set session variables. + var err error + rc.db, rc.supportPolicy, err = tidallocdb.NewDB(g, store, rc.policyMode) + if err != nil { + return errors.Trace(err) + } + rc.dom, err = g.GetDomain(store) + if err != nil { + return errors.Trace(err) + } + + // init backupMeta only for passing unit test + if rc.backupMeta == nil { + rc.backupMeta = new(backuppb.BackupMeta) + } + + // There are different ways to create session between in binary and in SQL. + // + // Maybe allow user modify the DDL concurrency isn't necessary, + // because executing DDL is really I/O bound (or, algorithm bound?), + // and we cost most of time at waiting DDL jobs be enqueued. + // So these jobs won't be faster or slower when machine become faster or slower, + // hence make it a fixed value would be fine. + rc.dbPool, err = makeDBPool(defaultDDLConcurrency, func() (*tidallocdb.DB, error) { + db, _, err := tidallocdb.NewDB(g, store, rc.policyMode) + return db, err + }) + if err != nil { + log.Warn("create session pool failed, we will send DDLs only by created sessions", + zap.Error(err), + zap.Int("sessionCount", len(rc.dbPool)), + ) + } + return errors.Trace(err) +} + +func (rc *SnapClient) initClients(ctx context.Context, backend *backuppb.StorageBackend, isRawKvMode bool, isTxnKvMode bool) error { + stores, err := conn.GetAllTiKVStoresWithRetry(ctx, rc.pdClient, util.SkipTiFlash) + if err != nil { + return errors.Annotate(err, "failed to get stores") + } + rc.storeCount = len(stores) + rc.updateConcurrency() + + var splitClientOpts []split.ClientOptionalParameter + if isRawKvMode { + splitClientOpts = append(splitClientOpts, split.WithRawKV()) + } + + metaClient := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, rc.storeCount+1, splitClientOpts...) + importCli := importclient.NewImportClient(metaClient, rc.tlsConf, rc.keepaliveConf) + rc.fileImporter, err = NewSnapFileImporter(ctx, metaClient, importCli, backend, isRawKvMode, isTxnKvMode, stores, rc.rewriteMode, rc.concurrencyPerStore) + return errors.Trace(err) +} + +func (rc *SnapClient) needLoadSchemas(backupMeta *backuppb.BackupMeta) bool { + return !(backupMeta.IsRawKv || backupMeta.IsTxnKv) +} + +// LoadSchemaIfNeededAndInitClient loads schemas from BackupMeta to initialize RestoreClient. +func (rc *SnapClient) LoadSchemaIfNeededAndInitClient( + c context.Context, + backupMeta *backuppb.BackupMeta, + backend *backuppb.StorageBackend, + reader *metautil.MetaReader, + loadStats bool) error { + if rc.needLoadSchemas(backupMeta) { + databases, err := metautil.LoadBackupTables(c, reader, loadStats) + if err != nil { + return errors.Trace(err) + } + rc.databases = databases + + var ddlJobs []*model.Job + // ddls is the bytes of json.Marshal + ddls, err := reader.ReadDDLs(c) + if err != nil { + return errors.Trace(err) + } + if len(ddls) != 0 { + err = json.Unmarshal(ddls, &ddlJobs) + if err != nil { + return errors.Trace(err) + } + } + rc.ddlJobs = ddlJobs + } + rc.backupMeta = backupMeta + log.Info("load backupmeta", zap.Int("databases", len(rc.databases)), zap.Int("jobs", len(rc.ddlJobs))) + + return rc.initClients(c, backend, backupMeta.IsRawKv, backupMeta.IsTxnKv) +} + +// IsRawKvMode checks whether the backup data is in raw kv format, in which case transactional recover is forbidden. +func (rc *SnapClient) IsRawKvMode() bool { + return rc.backupMeta.IsRawKv +} + +// GetFilesInRawRange gets all files that are in the given range or intersects with the given range. +func (rc *SnapClient) GetFilesInRawRange(startKey []byte, endKey []byte, cf string) ([]*backuppb.File, error) { + if !rc.IsRawKvMode() { + return nil, errors.Annotate(berrors.ErrRestoreModeMismatch, "the backup data is not in raw kv mode") + } + + for _, rawRange := range rc.backupMeta.RawRanges { + // First check whether the given range is backup-ed. If not, we cannot perform the restore. + if rawRange.Cf != cf { + continue + } + + if (len(rawRange.EndKey) > 0 && bytes.Compare(startKey, rawRange.EndKey) >= 0) || + (len(endKey) > 0 && bytes.Compare(rawRange.StartKey, endKey) >= 0) { + // The restoring range is totally out of the current range. Skip it. + continue + } + + if bytes.Compare(startKey, rawRange.StartKey) < 0 || + utils.CompareEndKey(endKey, rawRange.EndKey) > 0 { + // Only partial of the restoring range is in the current backup-ed range. So the given range can't be fully + // restored. + return nil, errors.Annotatef(berrors.ErrRestoreRangeMismatch, + "the given range to restore [%s, %s) is not fully covered by the range that was backed up [%s, %s)", + redact.Key(startKey), redact.Key(endKey), redact.Key(rawRange.StartKey), redact.Key(rawRange.EndKey), + ) + } + + // We have found the range that contains the given range. Find all necessary files. + files := make([]*backuppb.File, 0) + + for _, file := range rc.backupMeta.Files { + if file.Cf != cf { + continue + } + + if len(file.EndKey) > 0 && bytes.Compare(file.EndKey, startKey) < 0 { + // The file is before the range to be restored. + continue + } + if len(endKey) > 0 && bytes.Compare(endKey, file.StartKey) <= 0 { + // The file is after the range to be restored. + // The specified endKey is exclusive, so when it equals to a file's startKey, the file is still skipped. + continue + } + + files = append(files, file) + } + + // There should be at most one backed up range that covers the restoring range. + return files, nil + } + + return nil, errors.Annotate(berrors.ErrRestoreRangeMismatch, "no backup data in the range") +} + +// ResetTS resets the timestamp of PD to a bigger value. +func (rc *SnapClient) ResetTS(ctx context.Context, pdCtrl *pdutil.PdController) error { + restoreTS := rc.backupMeta.GetEndVersion() + log.Info("reset pd timestamp", zap.Uint64("ts", restoreTS)) + return utils.WithRetry(ctx, func() error { + return pdCtrl.ResetTS(ctx, restoreTS) + }, utils.NewPDReqBackoffer()) +} + +// GetDatabases returns all databases. +func (rc *SnapClient) GetDatabases() []*metautil.Database { + dbs := make([]*metautil.Database, 0, len(rc.databases)) + for _, db := range rc.databases { + dbs = append(dbs, db) + } + return dbs +} + +// HasBackedUpSysDB whether we have backed up system tables +// br backs system tables up since 5.1.0 +func (rc *SnapClient) HasBackedUpSysDB() bool { + sysDBs := []string{"mysql", "sys"} + for _, db := range sysDBs { + temporaryDB := utils.TemporaryDBName(db) + _, backedUp := rc.databases[temporaryDB.O] + if backedUp { + return true + } + } + return false +} + +// GetPlacementPolicies returns policies. +func (rc *SnapClient) GetPlacementPolicies() (*sync.Map, error) { + policies := &sync.Map{} + for _, p := range rc.backupMeta.Policies { + policyInfo := &model.PolicyInfo{} + err := json.Unmarshal(p.Info, policyInfo) + if err != nil { + return nil, errors.Trace(err) + } + policies.Store(policyInfo.Name.L, policyInfo) + } + return policies, nil +} + +// GetDDLJobs returns ddl jobs. +func (rc *SnapClient) GetDDLJobs() []*model.Job { + return rc.ddlJobs +} + +// SetPolicyMap set policyMap. +func (rc *SnapClient) SetPolicyMap(p *sync.Map) { + rc.policyMap = p +} + +// CreatePolicies creates all policies in full restore. +func (rc *SnapClient) CreatePolicies(ctx context.Context, policyMap *sync.Map) error { + var err error + policyMap.Range(func(key, value any) bool { + e := rc.db.CreatePlacementPolicy(ctx, value.(*model.PolicyInfo)) + if e != nil { + err = e + return false + } + return true + }) + return err +} + +// CreateDatabases creates databases. If the client has the db pool, it would create it. +func (rc *SnapClient) CreateDatabases(ctx context.Context, dbs []*metautil.Database) error { + if rc.IsSkipCreateSQL() { + log.Info("skip create database") + return nil + } + + if len(rc.dbPool) == 0 { + log.Info("create databases sequentially") + for _, db := range dbs { + err := rc.db.CreateDatabase(ctx, db.Info, rc.supportPolicy, rc.policyMap) + if err != nil { + return errors.Trace(err) + } + } + return nil + } + + log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool)), zap.Int("number of db", len(dbs))) + eg, ectx := errgroup.WithContext(ctx) + workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DB DDL workers") + for _, db_ := range dbs { + db := db_ + workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { + conn := rc.dbPool[id%uint64(len(rc.dbPool))] + return conn.CreateDatabase(ectx, db.Info, rc.supportPolicy, rc.policyMap) + }) + } + return eg.Wait() +} + +// generateRebasedTables generate a map[UniqueTableName]bool to represent tables that haven't updated table info. +// there are two situations: +// 1. tables that already exists in the restored cluster. +// 2. tables that are created by executing ddl jobs. +// so, only tables in incremental restoration will be added to the map +func (rc *SnapClient) generateRebasedTables(tables []*metautil.Table) { + if !rc.IsIncremental() { + // in full restoration, all tables are created by Session.CreateTable, and all tables' info is updated. + rc.rebasedTablesMap = make(map[restore.UniqueTableName]bool) + return + } + + rc.rebasedTablesMap = make(map[restore.UniqueTableName]bool, len(tables)) + for _, table := range tables { + rc.rebasedTablesMap[restore.UniqueTableName{DB: table.DB.Name.String(), Table: table.Info.Name.String()}] = true + } +} + +// getRebasedTables returns tables that may need to be rebase auto increment id or auto random id +func (rc *SnapClient) getRebasedTables() map[restore.UniqueTableName]bool { + return rc.rebasedTablesMap +} + +// CreateTables create tables, and generate their information. +// this function will use workers as the same number of sessionPool, +// leave sessionPool nil to send DDLs sequential. +func (rc *SnapClient) CreateTables( + ctx context.Context, + tables []*metautil.Table, + newTS uint64, +) ([]*CreatedTable, error) { + log.Info("start create tables", zap.Int("total count", len(tables))) + rc.generateRebasedTables(tables) + + // try to restore tables in batch + if rc.batchDdlSize > minBatchDdlSize && len(rc.dbPool) > 0 { + tables, err := rc.createTablesBatch(ctx, tables, newTS) + if err == nil { + return tables, nil + } else if !utils.FallBack2CreateTable(err) { + return nil, errors.Trace(err) + } + // fall back to old create table (sequential create table) + log.Info("fall back to the sequential create table") + } + + // restore tables in db pool + if len(rc.dbPool) > 0 { + return rc.createTablesSingle(ctx, rc.dbPool, tables, newTS) + } + // restore tables in one db + return rc.createTablesSingle(ctx, []*tidallocdb.DB{rc.db}, tables, newTS) +} + +func (rc *SnapClient) createTables( + ctx context.Context, + db *tidallocdb.DB, + tables []*metautil.Table, + newTS uint64, +) ([]*CreatedTable, error) { + log.Info("client to create tables") + if rc.IsSkipCreateSQL() { + log.Info("skip create table and alter autoIncID") + } else { + err := db.CreateTables(ctx, tables, rc.getRebasedTables(), rc.supportPolicy, rc.policyMap) + if err != nil { + return nil, errors.Trace(err) + } + } + cts := make([]*CreatedTable, 0, len(tables)) + for _, table := range tables { + newTableInfo, err := restore.GetTableSchema(rc.dom, table.DB.Name, table.Info.Name) + if err != nil { + return nil, errors.Trace(err) + } + if newTableInfo.IsCommonHandle != table.Info.IsCommonHandle { + return nil, errors.Annotatef(berrors.ErrRestoreModeMismatch, + "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", + restore.TransferBoolToValue(table.Info.IsCommonHandle), + table.Info.IsCommonHandle, + newTableInfo.IsCommonHandle) + } + rules := restoreutils.GetRewriteRules(newTableInfo, table.Info, newTS, true) + ct := &CreatedTable{ + RewriteRule: rules, + Table: newTableInfo, + OldTable: table, + } + log.Debug("new created tables", zap.Any("table", ct)) + cts = append(cts, ct) + } + return cts, nil +} + +func (rc *SnapClient) createTablesBatch(ctx context.Context, tables []*metautil.Table, newTS uint64) ([]*CreatedTable, error) { + eg, ectx := errgroup.WithContext(ctx) + rater := logutil.TraceRateOver(logutil.MetricTableCreatedCounter) + workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "Create Tables Worker") + numOfTables := len(tables) + createdTables := struct { + sync.Mutex + tables []*CreatedTable + }{ + tables: make([]*CreatedTable, 0, len(tables)), + } + + for lastSent := 0; lastSent < numOfTables; lastSent += int(rc.batchDdlSize) { + end := min(lastSent+int(rc.batchDdlSize), len(tables)) + log.Info("create tables", zap.Int("table start", lastSent), zap.Int("table end", end)) + + tableSlice := tables[lastSent:end] + workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { + db := rc.dbPool[id%uint64(len(rc.dbPool))] + cts, err := rc.createTables(ectx, db, tableSlice, newTS) // ddl job for [lastSent:i) + failpoint.Inject("restore-createtables-error", func(val failpoint.Value) { + if val.(bool) { + err = errors.New("sample error without extra message") + } + }) + if err != nil { + log.Error("create tables fail", zap.Error(err)) + return err + } + rater.Add(float64(len(cts))) + rater.L().Info("tables created", zap.Int("num", len(cts))) + createdTables.Lock() + createdTables.tables = append(createdTables.tables, cts...) + createdTables.Unlock() + return err + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + + return createdTables.tables, nil +} + +func (rc *SnapClient) createTable( + ctx context.Context, + db *tidallocdb.DB, + table *metautil.Table, + newTS uint64, +) (*CreatedTable, error) { + if rc.IsSkipCreateSQL() { + log.Info("skip create table and alter autoIncID", zap.Stringer("table", table.Info.Name)) + } else { + err := db.CreateTable(ctx, table, rc.getRebasedTables(), rc.supportPolicy, rc.policyMap) + if err != nil { + return nil, errors.Trace(err) + } + } + newTableInfo, err := restore.GetTableSchema(rc.dom, table.DB.Name, table.Info.Name) + if err != nil { + return nil, errors.Trace(err) + } + if newTableInfo.IsCommonHandle != table.Info.IsCommonHandle { + return nil, errors.Annotatef(berrors.ErrRestoreModeMismatch, + "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", + restore.TransferBoolToValue(table.Info.IsCommonHandle), + table.Info.IsCommonHandle, + newTableInfo.IsCommonHandle) + } + rules := restoreutils.GetRewriteRules(newTableInfo, table.Info, newTS, true) + et := &CreatedTable{ + RewriteRule: rules, + Table: newTableInfo, + OldTable: table, + } + return et, nil +} + +func (rc *SnapClient) createTablesSingle( + ctx context.Context, + dbPool []*tidallocdb.DB, + tables []*metautil.Table, + newTS uint64, +) ([]*CreatedTable, error) { + eg, ectx := errgroup.WithContext(ctx) + workers := tidbutil.NewWorkerPool(uint(len(dbPool)), "DDL workers") + rater := logutil.TraceRateOver(logutil.MetricTableCreatedCounter) + createdTables := struct { + sync.Mutex + tables []*CreatedTable + }{ + tables: make([]*CreatedTable, 0, len(tables)), + } + for _, tbl := range tables { + table := tbl + workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { + db := dbPool[id%uint64(len(dbPool))] + rt, err := rc.createTable(ectx, db, table, newTS) + if err != nil { + log.Error("create table failed", + zap.Error(err), + zap.Stringer("db", table.DB.Name), + zap.Stringer("table", table.Info.Name)) + return errors.Trace(err) + } + rater.Inc() + rater.L().Info("table created", + zap.Stringer("table", table.Info.Name), + zap.Stringer("database", table.DB.Name)) + + createdTables.Lock() + createdTables.tables = append(createdTables.tables, rt) + createdTables.Unlock() + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + + return createdTables.tables, nil +} + +// InitFullClusterRestore init fullClusterRestore and set SkipGrantTable as needed +func (rc *SnapClient) InitFullClusterRestore(explicitFilter bool) { + rc.fullClusterRestore = !explicitFilter && rc.IsFull() + + log.Info("full cluster restore", zap.Bool("value", rc.fullClusterRestore)) +} + +func (rc *SnapClient) IsFullClusterRestore() bool { + return rc.fullClusterRestore +} + +// IsFull returns whether this backup is full. +func (rc *SnapClient) IsFull() bool { + failpoint.Inject("mock-incr-backup-data", func() { + failpoint.Return(false) + }) + return !rc.IsIncremental() +} + +// IsIncremental returns whether this backup is incremental. +func (rc *SnapClient) IsIncremental() bool { + return !(rc.backupMeta.StartVersion == rc.backupMeta.EndVersion || + rc.backupMeta.StartVersion == 0) +} + +// NeedCheckFreshCluster is every time. except restore from a checkpoint or user has not set filter argument. +func (rc *SnapClient) NeedCheckFreshCluster(ExplicitFilter bool, firstRun bool) bool { + return rc.IsFull() && !ExplicitFilter && firstRun +} + +// EnableSkipCreateSQL sets switch of skip create schema and tables. +func (rc *SnapClient) EnableSkipCreateSQL() { + rc.noSchema = true +} + +// IsSkipCreateSQL returns whether we need skip create schema and tables in restore. +func (rc *SnapClient) IsSkipCreateSQL() bool { + return rc.noSchema +} + +// CheckTargetClusterFresh check whether the target cluster is fresh or not +// if there's no user dbs or tables, we take it as a fresh cluster, although +// user may have created some users or made other changes. +func (rc *SnapClient) CheckTargetClusterFresh(ctx context.Context) error { + log.Info("checking whether target cluster is fresh") + return restore.AssertUserDBsEmpty(rc.dom) +} + +// ExecDDLs executes the queries of the ddl jobs. +func (rc *SnapClient) ExecDDLs(ctx context.Context, ddlJobs []*model.Job) error { + // Sort the ddl jobs by schema version in ascending order. + slices.SortFunc(ddlJobs, func(i, j *model.Job) int { + return cmp.Compare(i.BinlogInfo.SchemaVersion, j.BinlogInfo.SchemaVersion) + }) + + for _, job := range ddlJobs { + err := rc.db.ExecDDL(ctx, job) + if err != nil { + return errors.Trace(err) + } + log.Info("execute ddl query", + zap.String("db", job.SchemaName), + zap.String("query", job.Query), + zap.Int64("historySchemaVersion", job.BinlogInfo.SchemaVersion)) + } + return nil +} + +func (rc *SnapClient) ResetSpeedLimit(ctx context.Context) error { + rc.hasSpeedLimited = false + err := rc.setSpeedLimit(ctx, 0) + if err != nil { + return errors.Trace(err) + } + return nil +} + +func (rc *SnapClient) setSpeedLimit(ctx context.Context, rateLimit uint64) error { + if !rc.hasSpeedLimited { + stores, err := util.GetAllTiKVStores(ctx, rc.pdClient, util.SkipTiFlash) + if err != nil { + return errors.Trace(err) + } + + eg, ectx := errgroup.WithContext(ctx) + for _, store := range stores { + if err := ectx.Err(); err != nil { + return errors.Trace(err) + } + + finalStore := store + rc.workerPool.ApplyOnErrorGroup(eg, + func() error { + err := rc.fileImporter.SetDownloadSpeedLimit(ectx, finalStore.GetId(), rateLimit) + if err != nil { + return errors.Trace(err) + } + return nil + }) + } + + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + rc.hasSpeedLimited = true + } + return nil +} + +func (rc *SnapClient) execAndValidateChecksum( + ctx context.Context, + tbl *CreatedTable, + kvClient kv.Client, + concurrency uint, +) error { + logger := log.L().With( + zap.String("db", tbl.OldTable.DB.Name.O), + zap.String("table", tbl.OldTable.Info.Name.O), + ) + + expectedChecksumStats := metautil.CalculateChecksumStatsOnFiles(tbl.OldTable.Files) + if !expectedChecksumStats.ChecksumExists() { + logger.Warn("table has no checksum, skipping checksum") + return nil + } + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("Client.execAndValidateChecksum", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + item, exists := rc.checkpointChecksum[tbl.Table.ID] + if !exists { + startTS, err := restore.GetTSWithRetry(ctx, rc.pdClient) + if err != nil { + return errors.Trace(err) + } + exe, err := checksum.NewExecutorBuilder(tbl.Table, startTS). + SetOldTable(tbl.OldTable). + SetConcurrency(concurrency). + SetOldKeyspace(tbl.RewriteRule.OldKeyspace). + SetNewKeyspace(tbl.RewriteRule.NewKeyspace). + SetExplicitRequestSourceType(kvutil.ExplicitTypeBR). + Build() + if err != nil { + return errors.Trace(err) + } + checksumResp, err := exe.Execute(ctx, kvClient, func() { + // TODO: update progress here. + }) + if err != nil { + return errors.Trace(err) + } + item = &checkpoint.ChecksumItem{ + TableID: tbl.Table.ID, + Crc64xor: checksumResp.Checksum, + TotalKvs: checksumResp.TotalKvs, + TotalBytes: checksumResp.TotalBytes, + } + if rc.checkpointRunner != nil { + err = rc.checkpointRunner.FlushChecksumItem(ctx, item) + if err != nil { + return errors.Trace(err) + } + } + } + checksumMatch := item.Crc64xor == expectedChecksumStats.Crc64Xor && + item.TotalKvs == expectedChecksumStats.TotalKvs && + item.TotalBytes == expectedChecksumStats.TotalBytes + failpoint.Inject("full-restore-validate-checksum", func(_ failpoint.Value) { + checksumMatch = false + }) + if !checksumMatch { + logger.Error("failed in validate checksum", + zap.Uint64("expected tidb crc64", expectedChecksumStats.Crc64Xor), + zap.Uint64("calculated crc64", item.Crc64xor), + zap.Uint64("expected tidb total kvs", expectedChecksumStats.TotalKvs), + zap.Uint64("calculated total kvs", item.TotalKvs), + zap.Uint64("expected tidb total bytes", expectedChecksumStats.TotalBytes), + zap.Uint64("calculated total bytes", item.TotalBytes), + ) + return errors.Annotate(berrors.ErrRestoreChecksumMismatch, "failed to validate checksum") + } + logger.Info("success in validating checksum") + return nil +} + +func (rc *SnapClient) WaitForFilesRestored(ctx context.Context, files []*backuppb.File, updateCh glue.Progress) error { + errCh := make(chan error, len(files)) + eg, ectx := errgroup.WithContext(ctx) + defer close(errCh) + + for _, file := range files { + fileReplica := file + rc.workerPool.ApplyOnErrorGroup(eg, + func() error { + defer func() { + log.Info("import sst files done", logutil.Files(files)) + updateCh.Inc() + }() + return rc.fileImporter.ImportSSTFiles(ectx, []TableIDWithFiles{{Files: []*backuppb.File{fileReplica}, RewriteRules: restoreutils.EmptyRewriteRule()}}, rc.cipher, rc.backupMeta.ApiVersion) + }) + } + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + return nil +} + +// RestoreRaw tries to restore raw keys in the specified range. +func (rc *SnapClient) RestoreRaw( + ctx context.Context, startKey []byte, endKey []byte, files []*backuppb.File, updateCh glue.Progress, +) error { + start := time.Now() + defer func() { + elapsed := time.Since(start) + log.Info("Restore Raw", + logutil.Key("startKey", startKey), + logutil.Key("endKey", endKey), + zap.Duration("take", elapsed)) + }() + err := rc.fileImporter.SetRawRange(startKey, endKey) + if err != nil { + return errors.Trace(err) + } + + err = rc.WaitForFilesRestored(ctx, files, updateCh) + if err != nil { + return errors.Trace(err) + } + log.Info( + "finish to restore raw range", + logutil.Key("startKey", startKey), + logutil.Key("endKey", endKey), + ) + return nil +} diff --git a/br/pkg/restore/snap_client/pipeline_items.go b/br/pkg/restore/snap_client/pipeline_items.go new file mode 100644 index 0000000000000..3f74434e72f02 --- /dev/null +++ b/br/pkg/restore/snap_client/pipeline_items.go @@ -0,0 +1,320 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapclient + +import ( + "context" + "time" + + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/metautil" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/model" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/engine" + pdhttp "github.com/tikv/pd/client/http" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/sync/errgroup" +) + +const defaultChannelSize = 1024 + +// defaultChecksumConcurrency is the default number of the concurrent +// checksum tasks. +const defaultChecksumConcurrency = 64 + +// CreatedTable is a table created on restore process, +// but not yet filled with data. +type CreatedTable struct { + RewriteRule *restoreutils.RewriteRules + Table *model.TableInfo + OldTable *metautil.Table +} + +type PhysicalTable struct { + NewPhysicalID int64 + OldPhysicalID int64 + RewriteRules *restoreutils.RewriteRules +} + +type TableIDWithFiles struct { + TableID int64 + + Files []*backuppb.File + // RewriteRules is the rewrite rules for the specify table. + // because these rules belongs to the *one table*. + // we can hold them here. + RewriteRules *restoreutils.RewriteRules +} + +type zapFilesGroupMarshaler []TableIDWithFiles + +// MarshalLogObjectForFiles is an internal util function to zap something having `Files` field. +func MarshalLogObjectForFiles(files []TableIDWithFiles, encoder zapcore.ObjectEncoder) error { + return zapFilesGroupMarshaler(files).MarshalLogObject(encoder) +} + +func (fgs zapFilesGroupMarshaler) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + elements := make([]string, 0) + total := 0 + totalKVs := uint64(0) + totalBytes := uint64(0) + totalSize := uint64(0) + for _, fg := range fgs { + for _, f := range fg.Files { + total += 1 + elements = append(elements, f.GetName()) + totalKVs += f.GetTotalKvs() + totalBytes += f.GetTotalBytes() + totalSize += f.GetSize_() + } + } + encoder.AddInt("total", total) + _ = encoder.AddArray("files", logutil.AbbreviatedArrayMarshaler(elements)) + encoder.AddUint64("totalKVs", totalKVs) + encoder.AddUint64("totalBytes", totalBytes) + encoder.AddUint64("totalSize", totalSize) + return nil +} + +func zapFilesGroup(filesGroup []TableIDWithFiles) zap.Field { + return zap.Object("files", zapFilesGroupMarshaler(filesGroup)) +} + +func defaultOutputTableChan() chan *CreatedTable { + return make(chan *CreatedTable, defaultChannelSize) +} + +func concurrentHandleTablesCh( + ctx context.Context, + inCh <-chan *CreatedTable, + outCh chan<- *CreatedTable, + errCh chan<- error, + workers *tidbutil.WorkerPool, + processFun func(context.Context, *CreatedTable) error, + deferFun func()) { + eg, ectx := errgroup.WithContext(ctx) + defer func() { + if err := eg.Wait(); err != nil { + errCh <- err + } + close(outCh) + deferFun() + }() + + for { + select { + // if we use ectx here, maybe canceled will mask real error. + case <-ctx.Done(): + errCh <- ctx.Err() + case tbl, ok := <-inCh: + if !ok { + return + } + cloneTable := tbl + worker := workers.ApplyWorker() + eg.Go(func() error { + defer workers.RecycleWorker(worker) + err := processFun(ectx, cloneTable) + if err != nil { + return err + } + outCh <- cloneTable + return nil + }) + } + } +} + +// GoValidateChecksum forks a goroutine to validate checksum after restore. +// it returns a channel fires a struct{} when all things get done. +func (rc *SnapClient) GoValidateChecksum( + ctx context.Context, + inCh <-chan *CreatedTable, + kvClient kv.Client, + errCh chan<- error, + updateCh glue.Progress, + concurrency uint, +) chan *CreatedTable { + log.Info("Start to validate checksum") + outCh := defaultOutputTableChan() + workers := tidbutil.NewWorkerPool(defaultChecksumConcurrency, "RestoreChecksum") + go concurrentHandleTablesCh(ctx, inCh, outCh, errCh, workers, func(c context.Context, tbl *CreatedTable) error { + start := time.Now() + defer func() { + elapsed := time.Since(start) + summary.CollectSuccessUnit("table checksum", 1, elapsed) + }() + err := rc.execAndValidateChecksum(c, tbl, kvClient, concurrency) + if err != nil { + return errors.Trace(err) + } + updateCh.Inc() + return nil + }, func() { + log.Info("all checksum ended") + }) + return outCh +} + +func (rc *SnapClient) GoUpdateMetaAndLoadStats( + ctx context.Context, + s storage.ExternalStorage, + inCh <-chan *CreatedTable, + errCh chan<- error, + statsConcurrency uint, + loadStats bool, +) chan *CreatedTable { + log.Info("Start to update meta then load stats") + outCh := defaultOutputTableChan() + workers := tidbutil.NewWorkerPool(statsConcurrency, "UpdateStats") + statsHandler := rc.dom.StatsHandle() + + go concurrentHandleTablesCh(ctx, inCh, outCh, errCh, workers, func(c context.Context, tbl *CreatedTable) error { + oldTable := tbl.OldTable + var statsErr error = nil + if loadStats && oldTable.Stats != nil { + log.Info("start loads analyze after validate checksum", + zap.Int64("old id", oldTable.Info.ID), + zap.Int64("new id", tbl.Table.ID), + ) + start := time.Now() + // NOTICE: skip updating cache after load stats from json + if statsErr = statsHandler.LoadStatsFromJSONNoUpdate(ctx, rc.dom.InfoSchema(), oldTable.Stats, 0); statsErr != nil { + log.Error("analyze table failed", zap.Any("table", oldTable.Stats), zap.Error(statsErr)) + } + log.Info("restore stat done", + zap.Stringer("table", oldTable.Info.Name), + zap.Stringer("db", oldTable.DB.Name), + zap.Duration("cost", time.Since(start))) + } else if loadStats && len(oldTable.StatsFileIndexes) > 0 { + log.Info("start to load statistic data for each partition", + zap.Int64("old id", oldTable.Info.ID), + zap.Int64("new id", tbl.Table.ID), + ) + start := time.Now() + rewriteIDMap := restoreutils.GetTableIDMap(tbl.Table, tbl.OldTable.Info) + if statsErr = metautil.RestoreStats(ctx, s, rc.cipher, statsHandler, tbl.Table, oldTable.StatsFileIndexes, rewriteIDMap); statsErr != nil { + log.Error("analyze table failed", zap.Any("table", oldTable.StatsFileIndexes), zap.Error(statsErr)) + } + log.Info("restore statistic data done", + zap.Stringer("table", oldTable.Info.Name), + zap.Stringer("db", oldTable.DB.Name), + zap.Duration("cost", time.Since(start))) + } + + if statsErr != nil || !loadStats || (oldTable.Stats == nil && len(oldTable.StatsFileIndexes) == 0) { + // Not need to return err when failed because of update analysis-meta + log.Info("start update metas", zap.Stringer("table", oldTable.Info.Name), zap.Stringer("db", oldTable.DB.Name)) + // the total kvs contains the index kvs, but the stats meta needs the count of rows + count := int64(oldTable.TotalKvs / uint64(len(oldTable.Info.Indices)+1)) + if statsErr = statsHandler.SaveMetaToStorage(tbl.Table.ID, count, 0, "br restore"); statsErr != nil { + log.Error("update stats meta failed", zap.Any("table", tbl.Table), zap.Error(statsErr)) + } + } + return nil + }, func() { + log.Info("all stats updated") + }) + return outCh +} + +func (rc *SnapClient) GoWaitTiFlashReady( + ctx context.Context, + inCh <-chan *CreatedTable, + updateCh glue.Progress, + errCh chan<- error, +) chan *CreatedTable { + log.Info("Start to wait tiflash replica sync") + outCh := defaultOutputTableChan() + workers := tidbutil.NewWorkerPool(4, "WaitForTiflashReady") + // TODO support tiflash store changes + tikvStats, err := infosync.GetTiFlashStoresStat(context.Background()) + if err != nil { + errCh <- err + } + tiFlashStores := make(map[int64]pdhttp.StoreInfo) + for _, store := range tikvStats.Stores { + if engine.IsTiFlashHTTPResp(&store.Store) { + tiFlashStores[store.Store.ID] = store + } + } + go concurrentHandleTablesCh(ctx, inCh, outCh, errCh, workers, func(c context.Context, tbl *CreatedTable) error { + if tbl.Table != nil && tbl.Table.TiFlashReplica == nil { + log.Info("table has no tiflash replica", + zap.Stringer("table", tbl.OldTable.Info.Name), + zap.Stringer("db", tbl.OldTable.DB.Name)) + updateCh.Inc() + return nil + } + if rc.dom == nil { + // unreachable, current we have initial domain in mgr. + log.Fatal("unreachable, domain is nil") + } + log.Info("table has tiflash replica, start sync..", + zap.Stringer("table", tbl.OldTable.Info.Name), + zap.Stringer("db", tbl.OldTable.DB.Name)) + for { + var progress float64 + if pi := tbl.Table.GetPartitionInfo(); pi != nil && len(pi.Definitions) > 0 { + for _, p := range pi.Definitions { + progressOfPartition, err := infosync.MustGetTiFlashProgress(p.ID, tbl.Table.TiFlashReplica.Count, &tiFlashStores) + if err != nil { + log.Warn("failed to get progress for tiflash partition replica, retry it", + zap.Int64("tableID", tbl.Table.ID), zap.Int64("partitionID", p.ID), zap.Error(err)) + time.Sleep(time.Second) + continue + } + progress += progressOfPartition + } + progress = progress / float64(len(pi.Definitions)) + } else { + var err error + progress, err = infosync.MustGetTiFlashProgress(tbl.Table.ID, tbl.Table.TiFlashReplica.Count, &tiFlashStores) + if err != nil { + log.Warn("failed to get progress for tiflash replica, retry it", + zap.Int64("tableID", tbl.Table.ID), zap.Error(err)) + time.Sleep(time.Second) + continue + } + } + // check until progress is 1 + if progress == 1 { + log.Info("tiflash replica synced", + zap.Stringer("table", tbl.OldTable.Info.Name), + zap.Stringer("db", tbl.OldTable.DB.Name)) + break + } + // just wait for next check + // tiflash check the progress every 2s + // we can wait 2.5x times + time.Sleep(5 * time.Second) + } + updateCh.Inc() + return nil + }, func() { + log.Info("all tiflash replica synced") + }) + return outCh +} diff --git a/br/pkg/task/backup.go b/br/pkg/task/backup.go index 915bdb2092bd9..642a51017d6f7 100644 --- a/br/pkg/task/backup.go +++ b/br/pkg/task/backup.go @@ -41,7 +41,6 @@ import ( "github.com/spf13/pflag" "github.com/tikv/client-go/v2/oracle" kvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/multierr" "go.uber.org/zap" ) @@ -159,7 +158,7 @@ func DefineBackupFlags(flags *pflag.FlagSet) { } // ParseFromFlags parses the backup-related flags from the flag set. -func (cfg *BackupConfig) ParseFromFlags(flags *pflag.FlagSet) error { +func (cfg *BackupConfig) ParseFromFlags(flags *pflag.FlagSet, skipCommonConfig bool) error { timeAgo, err := flags.GetDuration(flagBackupTimeago) if err != nil { return errors.Trace(err) @@ -212,9 +211,13 @@ func (cfg *BackupConfig) ParseFromFlags(flags *pflag.FlagSet) error { } cfg.CompressionConfig = *compressionCfg - if err = cfg.Config.ParseFromFlags(flags); err != nil { - return errors.Trace(err) + // parse common flags if needed + if !skipCommonConfig { + if err = cfg.Config.ParseFromFlags(flags); err != nil { + return errors.Trace(err) + } } + cfg.RemoveSchedulers, err = flags.GetBool(flagRemoveSchedulers) if err != nil { return errors.Trace(err) @@ -789,18 +792,15 @@ func ParseTSString(ts string, tzCheck bool) (uint64, error) { return oracle.GoTimeToTS(t1), nil } -func DefaultBackupConfig() BackupConfig { +func DefaultBackupConfig(commonConfig Config) BackupConfig { fs := pflag.NewFlagSet("dummy", pflag.ContinueOnError) - DefineCommonFlags(fs) DefineBackupFlags(fs) cfg := BackupConfig{} - err := multierr.Combine( - cfg.ParseFromFlags(fs), - cfg.Config.ParseFromFlags(fs), - ) + err := cfg.ParseFromFlags(fs, true) if err != nil { - log.Panic("infallible operation failed.", zap.Error(err)) + log.Panic("failed to parse backup flags to config", zap.Error(err)) } + cfg.Config = commonConfig return cfg } diff --git a/br/pkg/task/common.go b/br/pkg/task/common.go index 343f1c0e84b16..56b1d53ba0cdd 100644 --- a/br/pkg/task/common.go +++ b/br/pkg/task/common.go @@ -64,7 +64,7 @@ const ( flagRateLimit = "ratelimit" flagRateLimitUnit = "ratelimit-unit" flagConcurrency = "concurrency" - flagChecksum = "checksum" + FlagChecksum = "checksum" flagFilter = "filter" flagCaseSensitive = "case-sensitive" flagRemoveTiFlash = "remove-tiflash" @@ -273,7 +273,7 @@ func DefineCommonFlags(flags *pflag.FlagSet) { flags.Uint(flagChecksumConcurrency, variable.DefChecksumTableConcurrency, "The concurrency of checksumming in one table") flags.Uint64(flagRateLimit, unlimited, "The rate limit of the task, MB/s per node") - flags.Bool(flagChecksum, true, "Run checksum at end of task") + flags.Bool(FlagChecksum, true, "Run checksum at end of task") flags.Bool(flagRemoveTiFlash, true, "Remove TiFlash replicas before backup or restore, for unsupported versions of TiFlash") @@ -318,7 +318,7 @@ func DefineCommonFlags(flags *pflag.FlagSet) { // HiddenFlagsForStream temporary hidden flags that stream cmd not support. func HiddenFlagsForStream(flags *pflag.FlagSet) { - _ = flags.MarkHidden(flagChecksum) + _ = flags.MarkHidden(FlagChecksum) _ = flags.MarkHidden(flagLoadStats) _ = flags.MarkHidden(flagChecksumConcurrency) _ = flags.MarkHidden(flagRateLimit) @@ -506,7 +506,7 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error { return errors.Trace(err) } - if cfg.Checksum, err = flags.GetBool(flagChecksum); err != nil { + if cfg.Checksum, err = flags.GetBool(FlagChecksum); err != nil { return errors.Trace(err) } if cfg.ChecksumConcurrency, err = flags.GetUint(flagChecksumConcurrency); err != nil { @@ -619,6 +619,59 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error { return cfg.normalizePDURLs() } +<<<<<<< HEAD +======= +func (cfg *Config) parseAndValidateMasterKeyInfo(hasPlaintextKey bool, flags *pflag.FlagSet) error { + masterKeyString, err := flags.GetString(flagMasterKeyConfig) + if err != nil { + return errors.Errorf("master key flag '%s' is not defined: %v", flagMasterKeyConfig, err) + } + + if masterKeyString == "" { + return nil + } + + if hasPlaintextKey { + return errors.Errorf("invalid argument: both plaintext data key encryption and master key based encryption are set at the same time") + } + + encryptionMethodString, err := flags.GetString(flagMasterKeyCipherType) + if err != nil { + return errors.Errorf("encryption method flag '%s' is not defined: %v", flagMasterKeyCipherType, err) + } + + encryptionMethod, err := parseCipherType(encryptionMethodString) + if err != nil { + return errors.Errorf("failed to parse encryption method: %v", err) + } + + if !utils.IsEffectiveEncryptionMethod(encryptionMethod) { + return errors.Errorf("invalid encryption method: %s", encryptionMethodString) + } + + masterKeyStrings := strings.Split(masterKeyString, masterKeysDelimiter) + cfg.MasterKeyConfig = backuppb.MasterKeyConfig{ + EncryptionType: encryptionMethod, + MasterKeys: make([]*encryptionpb.MasterKey, 0, len(masterKeyStrings)), + } + + for _, keyString := range masterKeyStrings { + masterKey, err := validateAndParseMasterKeyString(strings.TrimSpace(keyString)) + if err != nil { + return errors.Wrapf(err, "invalid master key configuration: %s", keyString) + } + cfg.MasterKeyConfig.MasterKeys = append(cfg.MasterKeyConfig.MasterKeys, &masterKey) + } + + return nil +} + +// OverrideDefaultForBackup override common config for backup tasks +func (cfg *Config) OverrideDefaultForBackup() { + cfg.Checksum = false +} + +>>>>>>> 4f047be191b (br: restore checksum shouldn't rely on backup checksum (#56712)) // NewMgr creates a new mgr at the given PD address. func NewMgr(ctx context.Context, g glue.Glue, pds []string, diff --git a/br/pkg/task/common_test.go b/br/pkg/task/common_test.go index c942b96bc531e..83d0bc7ae43e4 100644 --- a/br/pkg/task/common_test.go +++ b/br/pkg/task/common_test.go @@ -190,8 +190,10 @@ func expectedDefaultConfig() Config { } func expectedDefaultBackupConfig() BackupConfig { + defaultConfig := expectedDefaultConfig() + defaultConfig.Checksum = false return BackupConfig{ - Config: expectedDefaultConfig(), + Config: defaultConfig, GCTTL: utils.DefaultBRGCSafePointTTL, CompressionConfig: CompressionConfig{ CompressionType: backup.CompressionType_ZSTD, @@ -231,13 +233,16 @@ func TestDefault(t *testing.T) { } func TestDefaultBackup(t *testing.T) { - def := DefaultBackupConfig() + commonConfig := DefaultConfig() + commonConfig.OverrideDefaultForBackup() + def := DefaultBackupConfig(commonConfig) defaultConfig := expectedDefaultBackupConfig() require.Equal(t, defaultConfig, def) } func TestDefaultRestore(t *testing.T) { - def := DefaultRestoreConfig() + commonConfig := DefaultConfig() + def := DefaultRestoreConfig(commonConfig) defaultConfig := expectedDefaultRestoreConfig() require.Equal(t, defaultConfig, def) } diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 6fb33260f9c4c..c266fca260abc 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -327,7 +327,7 @@ func (cfg *RestoreConfig) ParseStreamRestoreFlags(flags *pflag.FlagSet) error { } // ParseFromFlags parses the restore-related flags from the flag set. -func (cfg *RestoreConfig) ParseFromFlags(flags *pflag.FlagSet) error { +func (cfg *RestoreConfig) ParseFromFlags(flags *pflag.FlagSet, skipCommonConfig bool) error { var err error cfg.NoSchema, err = flags.GetBool(flagNoSchema) if err != nil { @@ -337,10 +337,15 @@ func (cfg *RestoreConfig) ParseFromFlags(flags *pflag.FlagSet) error { if err != nil { return errors.Trace(err) } - err = cfg.Config.ParseFromFlags(flags) - if err != nil { - return errors.Trace(err) + + // parse common config if needed + if !skipCommonConfig { + err = cfg.Config.ParseFromFlags(flags) + if err != nil { + return errors.Trace(err) + } } + err = cfg.RestoreCommonConfig.ParseFromFlags(flags) if err != nil { return errors.Trace(err) @@ -604,6 +609,7 @@ func registerTaskToPD(ctx context.Context, etcdCLI *clientv3.Client) (closeF fun return register.Close, errors.Trace(err) } +<<<<<<< HEAD func removeCheckpointDataForSnapshotRestore(ctx context.Context, storageName string, taskName string, config *Config) error { _, s, err := GetStorage(ctx, storageName, config) if err != nil { @@ -621,19 +627,18 @@ func removeCheckpointDataForLogRestore(ctx context.Context, storageName string, } func DefaultRestoreConfig() RestoreConfig { +======= +func DefaultRestoreConfig(commonConfig Config) RestoreConfig { +>>>>>>> 4f047be191b (br: restore checksum shouldn't rely on backup checksum (#56712)) fs := pflag.NewFlagSet("dummy", pflag.ContinueOnError) - DefineCommonFlags(fs) DefineRestoreFlags(fs) cfg := RestoreConfig{} - err := multierr.Combine( - cfg.ParseFromFlags(fs), - cfg.RestoreCommonConfig.ParseFromFlags(fs), - cfg.Config.ParseFromFlags(fs), - ) + err := cfg.ParseFromFlags(fs, true) if err != nil { - log.Panic("infallible failed.", zap.Error(err)) + log.Panic("failed to parse restore flags to config", zap.Error(err)) } + cfg.Config = commonConfig return cfg } @@ -770,7 +775,7 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf } reader := metautil.NewMetaReader(backupMeta, s, &cfg.CipherInfo) - if err = client.InitBackupMeta(c, backupMeta, u, reader, cfg.LoadStats); err != nil { + if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, cfg.LoadStats); err != nil { return errors.Trace(err) } @@ -785,7 +790,17 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf return errors.Annotate(berrors.ErrRestoreInvalidBackup, "contain tables but no databases") } +<<<<<<< HEAD archiveSize := reader.ArchiveSize(ctx, files) +======= + if cfg.CheckRequirements { + if err := checkDiskSpace(ctx, mgr, files, tables); err != nil { + return errors.Trace(err) + } + } + + archiveSize := metautil.ArchiveSize(files) +>>>>>>> 4f047be191b (br: restore checksum shouldn't rely on backup checksum (#56712)) g.Record(summary.RestoreDataSize, archiveSize) //restore from tidb will fetch a general Size issue https://github.com/pingcap/tidb/issues/27247 g.Record("Size", archiveSize) @@ -1077,8 +1092,9 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf var finish <-chan struct{} postHandleCh := afterTableRestoredCh - // pipeline checksum - if cfg.Checksum { + // pipeline checksum only when enabled and is not incremental snapshot repair mode cuz incremental doesn't have + // enough information in backup meta to validate checksum + if cfg.Checksum && !client.IsIncremental() { postHandleCh = client.GoValidateChecksum( ctx, postHandleCh, mgr.GetStorage().GetClient(), errCh, updateCh, cfg.ChecksumConcurrency) } @@ -1093,7 +1109,7 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf finish = dropToBlackhole(ctx, postHandleCh, errCh) - // Reset speed limit. ResetSpeedLimit must be called after client.InitBackupMeta has been called. + // Reset speed limit. ResetSpeedLimit must be called after client.LoadSchemaIfNeededAndInitClient has been called. defer func() { var resetErr error // In future we may need a mechanism to set speed limit in ttl. like what we do in switchmode. TODO diff --git a/br/pkg/task/restore_raw.go b/br/pkg/task/restore_raw.go index 5b9b009853a02..c103367be2ffa 100644 --- a/br/pkg/task/restore_raw.go +++ b/br/pkg/task/restore_raw.go @@ -116,7 +116,7 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR return errors.Trace(err) } reader := metautil.NewMetaReader(backupMeta, s, &cfg.CipherInfo) - if err = client.InitBackupMeta(c, backupMeta, u, reader, true); err != nil { + if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, true); err != nil { return errors.Trace(err) } @@ -128,7 +128,7 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR if err != nil { return errors.Trace(err) } - archiveSize := reader.ArchiveSize(ctx, files) + archiveSize := metautil.ArchiveSize(files) g.Record(summary.RestoreDataSize, archiveSize) if len(files) == 0 { diff --git a/br/pkg/task/restore_txn.go b/br/pkg/task/restore_txn.go index 596b1d29d714e..a2038c19f33fa 100644 --- a/br/pkg/task/restore_txn.go +++ b/br/pkg/task/restore_txn.go @@ -60,7 +60,7 @@ func RunRestoreTxn(c context.Context, g glue.Glue, cmdName string, cfg *Config) return errors.Trace(err) } reader := metautil.NewMetaReader(backupMeta, s, &cfg.CipherInfo) - if err = client.InitBackupMeta(c, backupMeta, u, reader, true); err != nil { + if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, true); err != nil { return errors.Trace(err) } @@ -69,7 +69,7 @@ func RunRestoreTxn(c context.Context, g glue.Glue, cmdName string, cfg *Config) } files := backupMeta.Files - archiveSize := reader.ArchiveSize(ctx, files) + archiveSize := metautil.ArchiveSize(files) g.Record(summary.RestoreDataSize, archiveSize) if len(files) == 0 { diff --git a/br/tests/br_file_corruption/run.sh b/br/tests/br_file_corruption/run.sh new file mode 100644 index 0000000000000..60907ac2e7a4c --- /dev/null +++ b/br/tests/br_file_corruption/run.sh @@ -0,0 +1,83 @@ +#!/bin/sh +# +# Copyright 2024 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eux + +DB="$TEST_NAME" +TABLE="usertable" +CUR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) + +run_sql "CREATE DATABASE $DB;" +go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB +run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$DB" --checksum=false + +# Replace the single file manipulation with a loop over all .sst files +for filename in $(find $TEST_DIR/$DB -name "*.sst"); do + filename_temp="${filename}_temp" + filename_bak="${filename}_bak" + echo "corruption" > "$filename_temp" + cat "$filename" >> "$filename_temp" + mv "$filename" "$filename_bak" +done + +# need to drop db otherwise restore will fail because of cluster not fresh but not the expected issue +run_sql "DROP DATABASE IF EXISTS $DB;" + +# file lost +export GO_FAILPOINTS="github.com/pingcap/tidb/br/pkg/utils/set-import-attempt-to-one=return(true)" +restore_fail=0 +run_br --pd $PD_ADDR restore full -s "local://$TEST_DIR/$DB" || restore_fail=1 +export GO_FAILPOINTS="" +if [ $restore_fail -ne 1 ]; then + echo 'expect restore to fail on file lost but succeed' + exit 1 +fi +run_sql "DROP DATABASE IF EXISTS $DB;" + +# file corruption +for filename in $(find $TEST_DIR/$DB -name "*.sst_temp"); do + mv "$filename" "${filename%_temp}" + truncate -s -11 "${filename%_temp}" +done + +export GO_FAILPOINTS="github.com/pingcap/tidb/br/pkg/utils/set-import-attempt-to-one=return(true)" +restore_fail=0 +run_br --pd $PD_ADDR restore full -s "local://$TEST_DIR/$DB" || restore_fail=1 +export GO_FAILPOINTS="" +if [ $restore_fail -ne 1 ]; then + echo 'expect restore to fail on file corruption but succeed' + exit 1 +fi +run_sql "DROP DATABASE IF EXISTS $DB;" + +# verify validating checksum is still performed even backup didn't enable it +for filename in $(find $TEST_DIR/$DB -name "*.sst_bak"); do + mv "$filename" "${filename%_bak}" +done + +export GO_FAILPOINTS="github.com/pingcap/tidb/br/pkg/restore/snap_client/full-restore-validate-checksum=return(true)" +restore_fail=0 +run_br --pd $PD_ADDR restore full -s "local://$TEST_DIR/$DB" --checksum=true || restore_fail=1 +export GO_FAILPOINTS="" +if [ $restore_fail -ne 1 ]; then + echo 'expect restore to fail on checksum mismatch but succeed' + exit 1 +fi +run_sql "DROP DATABASE IF EXISTS $DB;" + +# sanity check restore can succeed +run_br --pd $PD_ADDR restore full -s "local://$TEST_DIR/$DB" --checksum=true +echo 'file corruption tests passed' diff --git a/br/tests/br_full_ddl/run.sh b/br/tests/br_full_ddl/run.sh index 370d77dca66dd..41f96e91def8a 100755 --- a/br/tests/br_full_ddl/run.sh +++ b/br/tests/br_full_ddl/run.sh @@ -107,7 +107,7 @@ echo "backup start with stats..." unset BR_LOG_TO_TERM cluster_index_before_backup=$(run_sql "show variables like '%cluster%';" | awk '{print $2}') -run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$DB" --log-file $LOG --ignore-stats=false || cat $LOG +run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$DB" --log-file $LOG --ignore-stats=false --checksum=true || cat $LOG checksum_count=$(cat $LOG | grep "checksum success" | wc -l | xargs) if [ "${checksum_count}" -lt "1" ];then diff --git a/br/tests/br_full_index/run.sh b/br/tests/br_full_index/run.sh index edcac1bfa2377..28f959c10b5f4 100755 --- a/br/tests/br_full_index/run.sh +++ b/br/tests/br_full_index/run.sh @@ -41,7 +41,7 @@ echo "backup start..." # Do not log to terminal unset BR_LOG_TO_TERM # do not backup stats to test whether we can restore without stats. -run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$DB" --ignore-stats=true --log-file $LOG || cat $LOG +run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$DB" --ignore-stats=true --log-file $LOG --checksum=true || cat $LOG BR_LOG_TO_TERM=1 checksum_count=$(cat $LOG | grep "checksum success" | wc -l | xargs) diff --git a/pkg/executor/brie.go b/pkg/executor/brie.go index 1e5316881819e..e0f35f433a2cc 100644 --- a/pkg/executor/brie.go +++ b/pkg/executor/brie.go @@ -282,7 +282,15 @@ func (b *executorBuilder) buildBRIE(s *ast.BRIEStmt, schema *expression.Schema) Key: tidbCfg.Security.ClusterSSLKey, } pds := strings.Split(tidbCfg.Path, ",") + + // build common config and override for specific task if needed cfg := task.DefaultConfig() + switch s.Kind { + case ast.BRIEKindBackup: + cfg.OverrideDefaultForBackup() + default: + } + cfg.PD = pds cfg.TLS = tlsCfg @@ -357,8 +365,7 @@ func (b *executorBuilder) buildBRIE(s *ast.BRIEStmt, schema *expression.Schema) switch s.Kind { case ast.BRIEKindBackup: - bcfg := task.DefaultBackupConfig() - bcfg.Config = cfg + bcfg := task.DefaultBackupConfig(cfg) e.backupCfg = &bcfg for _, opt := range s.Options { @@ -387,8 +394,7 @@ func (b *executorBuilder) buildBRIE(s *ast.BRIEStmt, schema *expression.Schema) } case ast.BRIEKindRestore: - rcfg := task.DefaultRestoreConfig() - rcfg.Config = cfg + rcfg := task.DefaultRestoreConfig(cfg) e.restoreCfg = &rcfg for _, opt := range s.Options { if opt.Tp == ast.BRIEOptionOnline { diff --git a/pkg/executor/brie_test.go b/pkg/executor/brie_test.go index 266144bcf2c10..2b5c5bca3f990 100644 --- a/pkg/executor/brie_test.go +++ b/pkg/executor/brie_test.go @@ -140,3 +140,90 @@ func TestFetchShowBRIE(t *testing.T) { globalBRIEQueue.clearTask(e.Ctx().GetSessionVars().StmtCtx) require.Equal(t, info2Res, fetchShowBRIEResult(t, e, brieColTypes)) } +<<<<<<< HEAD +======= + +func TestBRIEBuilderOptions(t *testing.T) { + sctx := mock.NewContext() + sctx.GetSessionVars().User = &auth.UserIdentity{Username: "test"} + is := infoschema.MockInfoSchema([]*model.TableInfo{core.MockSignedTable(), core.MockUnsignedTable()}) + ResetGlobalBRIEQueueForTest() + builder := NewMockExecutorBuilderForTest(sctx, is) + ctx := context.Background() + p := parser.New() + p.SetParserConfig(parser.ParserConfig{EnableWindowFunction: true, EnableStrictDoubleTypeCheck: true}) + err := failpoint.Enable("github.com/pingcap/tidb/pkg/executor/modifyStore", `return("tikv")`) + require.NoError(t, err) + defer failpoint.Disable("github.com/pingcap/tidb/pkg/executor/modifyStore") + err = os.WriteFile("/tmp/keyfile", []byte(strings.Repeat("A", 128)), 0644) + + require.NoError(t, err) + stmt, err := p.ParseOneStmt("BACKUP TABLE `a` TO 'noop://' CHECKSUM_CONCURRENCY = 4 IGNORE_STATS = 1 COMPRESSION_LEVEL = 4 COMPRESSION_TYPE = 'lz4' ENCRYPTION_METHOD = 'aes256-ctr' ENCRYPTION_KEYFILE = '/tmp/keyfile'", "", "") + require.NoError(t, err) + nodeW := resolve.NewNodeW(stmt) + plan, err := core.BuildLogicalPlanForTest(ctx, sctx, nodeW, infoschema.MockInfoSchema([]*model.TableInfo{core.MockSignedTable(), core.MockUnsignedTable(), core.MockView()})) + require.NoError(t, err) + s, ok := stmt.(*ast.BRIEStmt) + require.True(t, ok) + require.True(t, s.Kind == ast.BRIEKindBackup) + for _, opt := range s.Options { + switch opt.Tp { + case ast.BRIEOptionChecksumConcurrency: + require.Equal(t, uint64(4), opt.UintValue) + case ast.BRIEOptionCompressionLevel: + require.Equal(t, uint64(4), opt.UintValue) + case ast.BRIEOptionIgnoreStats: + require.Equal(t, uint64(1), opt.UintValue) + case ast.BRIEOptionCompression: + require.Equal(t, "lz4", opt.StrValue) + case ast.BRIEOptionEncryptionMethod: + require.Equal(t, "aes256-ctr", opt.StrValue) + case ast.BRIEOptionEncryptionKeyFile: + require.Equal(t, "/tmp/keyfile", opt.StrValue) + } + } + schema := plan.Schema() + exec := builder.buildBRIE(s, schema) + require.NoError(t, builder.err) + e, ok := exec.(*BRIEExec) + require.True(t, ok) + require.False(t, e.backupCfg.Checksum) + require.Equal(t, uint(4), e.backupCfg.ChecksumConcurrency) + require.Equal(t, int32(4), e.backupCfg.CompressionLevel) + require.Equal(t, true, e.backupCfg.IgnoreStats) + require.Equal(t, backuppb.CompressionType_LZ4, e.backupCfg.CompressionConfig.CompressionType) + require.Equal(t, encryptionpb.EncryptionMethod_AES256_CTR, e.backupCfg.CipherInfo.CipherType) + require.Greater(t, len(e.backupCfg.CipherInfo.CipherKey), 0) + + stmt, err = p.ParseOneStmt("RESTORE TABLE `a` FROM 'noop://' CHECKSUM_CONCURRENCY = 4 WAIT_TIFLASH_READY = 1 WITH_SYS_TABLE = 1 LOAD_STATS = 1", "", "") + require.NoError(t, err) + nodeW = resolve.NewNodeW(stmt) + plan, err = core.BuildLogicalPlanForTest(ctx, sctx, nodeW, infoschema.MockInfoSchema([]*model.TableInfo{core.MockSignedTable(), core.MockUnsignedTable(), core.MockView()})) + require.NoError(t, err) + s, ok = stmt.(*ast.BRIEStmt) + require.True(t, ok) + require.True(t, s.Kind == ast.BRIEKindRestore) + for _, opt := range s.Options { + switch opt.Tp { + case ast.BRIEOptionChecksumConcurrency: + require.Equal(t, uint64(4), opt.UintValue) + case ast.BRIEOptionWaitTiflashReady: + require.Equal(t, uint64(1), opt.UintValue) + case ast.BRIEOptionWithSysTable: + require.Equal(t, uint64(1), opt.UintValue) + case ast.BRIEOptionLoadStats: + require.Equal(t, uint64(1), opt.UintValue) + } + } + schema = plan.Schema() + exec = builder.buildBRIE(s, schema) + require.NoError(t, builder.err) + e, ok = exec.(*BRIEExec) + require.True(t, ok) + require.Equal(t, uint(4), e.restoreCfg.ChecksumConcurrency) + require.True(t, e.restoreCfg.Checksum) + require.True(t, e.restoreCfg.WaitTiflashReady) + require.True(t, e.restoreCfg.WithSysTable) + require.True(t, e.restoreCfg.LoadStats) +} +>>>>>>> 4f047be191b (br: restore checksum shouldn't rely on backup checksum (#56712))