diff --git a/br/pkg/backup/prepare_snap/BUILD.bazel b/br/pkg/backup/prepare_snap/BUILD.bazel index f99a723e0a1ad..ec2f8bbaf1f56 100644 --- a/br/pkg/backup/prepare_snap/BUILD.bazel +++ b/br/pkg/backup/prepare_snap/BUILD.bazel @@ -35,7 +35,7 @@ go_test( timeout = "short", srcs = ["prepare_test.go"], flaky = True, - shard_count = 9, + shard_count = 10, deps = [ ":prepare_snap", "//br/pkg/utils", diff --git a/br/pkg/backup/prepare_snap/prepare.go b/br/pkg/backup/prepare_snap/prepare.go index dfb44300e5bf1..405fc3a1502c2 100644 --- a/br/pkg/backup/prepare_snap/prepare.go +++ b/br/pkg/backup/prepare_snap/prepare.go @@ -192,19 +192,36 @@ func (p *Preparer) Finalize(ctx context.Context) error { return nil }) } - if err := eg.Wait(); err != nil { - logutil.CL(ctx).Warn("failed to finalize some prepare streams.", logutil.ShortError(err)) - return err - } - logutil.CL(ctx).Info("all connections to store have shuted down.") + errCh := make(chan error, 1) + go func() { + if err := eg.Wait(); err != nil { + logutil.CL(ctx).Warn("failed to finalize some prepare streams.", logutil.ShortError(err)) + errCh <- err + return + } + logutil.CL(ctx).Info("all connections to store have shuted down.") + errCh <- nil + }() for { select { - case event := <-p.eventChan: + case event, ok := <-p.eventChan: + if !ok { + return nil + } if err := p.onEvent(ctx, event); err != nil { return err } - default: - return nil + case err, ok := <-errCh: + if !ok { + panic("unreachable.") + } + if err != nil { + return err + } + // All streams are finialized, they shouldn't send more events to event chan. + close(p.eventChan) + case <-ctx.Done(): + return ctx.Err() } } } @@ -407,6 +424,10 @@ func (p *Preparer) streamOf(ctx context.Context, storeID uint64) (*prepareStream } func (p *Preparer) createAndCacheStream(ctx context.Context, cli PrepareClient, storeID uint64) error { + if _, ok := p.clients[storeID]; ok { + return nil + } + s := new(prepareStream) s.storeID = storeID s.output = p.eventChan diff --git a/br/pkg/backup/prepare_snap/prepare_test.go b/br/pkg/backup/prepare_snap/prepare_test.go index 5ce51e8448315..d6a5a7c16ae31 100644 --- a/br/pkg/backup/prepare_snap/prepare_test.go +++ b/br/pkg/backup/prepare_snap/prepare_test.go @@ -47,7 +47,12 @@ type mockStore struct { successRegions []metapb.Region onWaitApply func(*metapb.Region) error - now func() time.Time + + waitApplyDelay func() + delaiedWaitApplies sync.WaitGroup + + injectConnErr <-chan error + now func() time.Time } func (s *mockStore) Send(req *brpb.PrepareSnapshotBackupRequest) error { @@ -67,7 +72,16 @@ func (s *mockStore) Send(req *brpb.PrepareSnapshotBackupRequest) error { } } } - s.sendResp(resp) + if s.waitApplyDelay != nil { + s.delaiedWaitApplies.Add(1) + go func() { + defer s.delaiedWaitApplies.Done() + s.waitApplyDelay() + s.sendResp(resp) + }() + } else { + s.sendResp(resp) + } if resp.Error == nil { s.successRegions = append(s.successRegions, *region) } @@ -100,11 +114,21 @@ func (s *mockStore) sendResp(resp brpb.PrepareSnapshotBackupResponse) { } func (s *mockStore) Recv() (*brpb.PrepareSnapshotBackupResponse, error) { - out, ok := <-s.output - if !ok { - return nil, io.EOF + for { + select { + case out, ok := <-s.output: + if !ok { + return nil, io.EOF + } + return &out, nil + case err, ok := <-s.injectConnErr: + if ok { + return nil, err + } else { + s.injectConnErr = nil + } + } } - return &out, nil } type mockStores struct { @@ -167,7 +191,7 @@ func (m *mockStores) ConnectToStore(ctx context.Context, storeID uint64) (Prepar s, ok := m.stores[storeID] if !ok || s == nil { m.stores[storeID] = &mockStore{ - output: make(chan brpb.PrepareSnapshotBackupResponse, 16), + output: make(chan brpb.PrepareSnapshotBackupResponse, 20480), successRegions: []metapb.Region{}, onWaitApply: func(r *metapb.Region) error { return nil @@ -538,3 +562,37 @@ func TestHooks(t *testing.T) { req.NoError(adv.Finalize(context.Background())) ms.AssertIsNormalMode(t) } + +func TestManyMessagesWhenFinalizing(t *testing.T) { + req := require.New(t) + pdc := fakeCluster(t, 3, dummyRegions(10240)...) + ms := newTestEnv(pdc) + blockCh := make(chan struct{}) + injectErr := make(chan error) + ms.onCreateStore = func(ms *mockStore) { + ms.waitApplyDelay = func() { + <-blockCh + } + ms.injectConnErr = injectErr + } + prep := New(ms) + ctx := context.Background() + req.NoError(prep.PrepareConnections(ctx)) + errC := async(func() error { return prep.DriveLoopAndWaitPrepare(ctx) }) + injectErr <- errors.NewNoStackError("whoa!") + req.Error(<-errC) + close(blockCh) + for _, s := range ms.stores { + s.delaiedWaitApplies.Wait() + } + // Closing the stream should be error. + req.Error(prep.Finalize(ctx)) +} + +func async[T any](f func() T) <-chan T { + ch := make(chan T) + go func() { + ch <- f() + }() + return ch +} diff --git a/br/pkg/backup/prepare_snap/stream.go b/br/pkg/backup/prepare_snap/stream.go index 1108731fa5002..f963899b1d826 100644 --- a/br/pkg/backup/prepare_snap/stream.go +++ b/br/pkg/backup/prepare_snap/stream.go @@ -74,6 +74,9 @@ func (p *prepareStream) InitConn(ctx context.Context, cli PrepareClient) error { return p.GoLeaseLoop(ctx, p.leaseDuration) } +// Finalize cuts down this connection and remove the lease. +// This will block until all messages has been flushed to `output` channel. +// After this return, no more messages should be appended to the `output` channel. func (p *prepareStream) Finalize(ctx context.Context) error { log.Info("shutting down", zap.Uint64("store", p.storeID)) return p.stopClientLoop(ctx) @@ -151,7 +154,8 @@ func (p *prepareStream) clientLoop(ctx context.Context, dur time.Duration) error return nil case res := <-p.serverStream: if err := p.onResponse(ctx, res); err != nil { - p.sendErr(errors.Annotate(err, "failed to recv from the stream")) + err = errors.Annotate(err, "failed to recv from the stream") + p.sendErr(err) return err } case <-ticker.C: @@ -186,6 +190,10 @@ func (p *prepareStream) sendErr(err error) { } func (p *prepareStream) convertToEvent(resp *brpb.PrepareSnapshotBackupResponse) (event, bool) { + if resp == nil { + log.Warn("Received nil message, that shouldn't happen in a normal cluster.", zap.Uint64("store", p.storeID)) + return event{}, false + } switch resp.Ty { case brpb.PrepareSnapshotBackupEventType_WaitApplyDone: return event{