diff --git a/server/mock/store.go b/server/mock/store.go index bceb73a92c3..460c2a677f4 100644 --- a/server/mock/store.go +++ b/server/mock/store.go @@ -103,7 +103,9 @@ func (ms multiStore) Snapshot(height uint64, format uint32) (<-chan io.ReadClose panic("not implemented") } -func (ms multiStore) Restore(height uint64, format uint32, chunks <-chan io.ReadCloser) error { +func (ms multiStore) Restore( + height uint64, format uint32, chunks <-chan io.ReadCloser, ready chan<- struct{}, +) error { panic("not implemented") } diff --git a/snapshots/helpers_test.go b/snapshots/helpers_test.go index eede8b2d585..bda3c315545 100644 --- a/snapshots/helpers_test.go +++ b/snapshots/helpers_test.go @@ -63,13 +63,18 @@ type mockSnapshotter struct { chunks [][]byte } -func (m *mockSnapshotter) Restore(height uint64, format uint32, chunks <-chan io.ReadCloser) error { +func (m *mockSnapshotter) Restore( + height uint64, format uint32, chunks <-chan io.ReadCloser, ready chan<- struct{}, +) error { if format == 0 { return types.ErrUnknownFormat } if m.chunks != nil { return errors.New("already has contents") } + if ready != nil { + close(ready) + } m.chunks = [][]byte{} for reader := range chunks { @@ -140,6 +145,8 @@ func (m *hungSnapshotter) Snapshot(height uint64, format uint32) (<-chan io.Read return ch, nil } -func (m *hungSnapshotter) Restore(height uint64, format uint32, chunks <-chan io.ReadCloser) error { +func (m *hungSnapshotter) Restore( + height uint64, format uint32, chunks <-chan io.ReadCloser, ready chan<- struct{}, +) error { panic("not implemented") } diff --git a/snapshots/manager.go b/snapshots/manager.go index 3ccf502b3b2..1341bcf958a 100644 --- a/snapshots/manager.go +++ b/snapshots/manager.go @@ -8,7 +8,6 @@ import ( "io" "io/ioutil" "sync" - "time" "github.com/cosmos/cosmos-sdk/snapshots/types" ) @@ -169,9 +168,10 @@ func (m *Manager) Restore(snapshot types.Snapshot) error { // Start an asynchronous snapshot restoration, passing chunks and completion status via channels. chChunks := make(chan io.ReadCloser, chunkBufferSize) + chReady := make(chan struct{}, 1) chDone := make(chan restoreDone, 1) go func() { - err := m.target.Restore(snapshot.Height, snapshot.Format, chChunks) + err := m.target.Restore(snapshot.Height, snapshot.Format, chChunks, chReady) chDone <- restoreDone{ complete: err == nil, err: err, @@ -187,7 +187,7 @@ func (m *Manager) Restore(snapshot types.Snapshot) error { return done.err } return errors.New("restore ended unexpectedly") - case <-time.After(20 * time.Millisecond): + case <-chReady: } m.chRestore = chChunks diff --git a/snapshots/types/snapshotter.go b/snapshots/types/snapshotter.go index d59a2a0874b..1ebd763b5d7 100644 --- a/snapshots/types/snapshotter.go +++ b/snapshots/types/snapshotter.go @@ -10,5 +10,7 @@ type Snapshotter interface { Snapshot(height uint64, format uint32) (<-chan io.ReadCloser, error) // Restore restores a state snapshot, taking snapshot chunk readers as input. - Restore(height uint64, format uint32, chunks <-chan io.ReadCloser) error + // If the ready channel is non-nil, it returns a ready signal (by being closed) once the + // restorer is ready to accept chunks. + Restore(height uint64, format uint32, chunks <-chan io.ReadCloser, ready chan<- struct{}) error } diff --git a/store/rootmulti/store.go b/store/rootmulti/store.go index 9cc4b3f3d63..4e55d531362 100644 --- a/store/rootmulti/store.go +++ b/store/rootmulti/store.go @@ -658,7 +658,9 @@ func (rs *Store) Snapshot(height uint64, format uint32) (<-chan io.ReadCloser, e } // Restore implements snapshottypes.Snapshotter. -func (rs *Store) Restore(height uint64, format uint32, chunks <-chan io.ReadCloser) error { +func (rs *Store) Restore( + height uint64, format uint32, chunks <-chan io.ReadCloser, ready chan<- struct{}, +) error { if format != snapshottypes.CurrentFormat { return fmt.Errorf("%w %v", snapshottypes.ErrUnknownFormat, format) } @@ -670,6 +672,12 @@ func (rs *Store) Restore(height uint64, format uint32, chunks <-chan io.ReadClos height, math.MaxInt64) } + // Signal readiness. Must be done before the readers below are set up, since the zlib + // reader reads from the stream on initialization, potentially causing deadlocks. + if ready != nil { + close(ready) + } + // Set up a restore stream pipeline // chan io.ReadCloser -> chunkReader -> zlib -> delimited Protobuf -> ExportNode chunkReader := snapshots.NewChunkReader(chunks) diff --git a/store/rootmulti/store_test.go b/store/rootmulti/store_test.go index 55af45cc3ae..06509b0caca 100644 --- a/store/rootmulti/store_test.go +++ b/store/rootmulti/store_test.go @@ -593,7 +593,7 @@ func TestMultistoreRestore_Errors(t *testing.T) { for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { - err := store.Restore(tc.height, tc.format, nil) + err := store.Restore(tc.height, tc.format, nil, nil) require.Error(t, err) if tc.expectType != nil { assert.True(t, errors.Is(err, tc.expectType)) @@ -610,8 +610,10 @@ func TestMultistoreSnapshotRestore(t *testing.T) { chunks, err := source.Snapshot(version, snapshottypes.CurrentFormat) require.NoError(t, err) - err = target.Restore(version, snapshottypes.CurrentFormat, chunks) + ready := make(chan struct{}) + err = target.Restore(version, snapshottypes.CurrentFormat, chunks, ready) require.NoError(t, err) + assert.EqualValues(t, struct{}{}, <-ready) assert.Equal(t, source.LastCommitID(), target.LastCommitID()) for key, sourceStore := range source.stores { @@ -687,7 +689,7 @@ func benchmarkMultistoreSnapshotRestore(b *testing.B, stores uint8, storeKeys ui chunks, err := source.Snapshot(version, snapshottypes.CurrentFormat) require.NoError(b, err) - err = target.Restore(version, snapshottypes.CurrentFormat, chunks) + err = target.Restore(version, snapshottypes.CurrentFormat, chunks, nil) require.NoError(b, err) require.Equal(b, source.LastCommitID(), target.LastCommitID()) }