Skip to content

Commit

Permalink
cli: fix usage of gzip.Reader to better detect corrupt snapshots duri…
Browse files Browse the repository at this point in the history
…ng save/restore (#7697)
  • Loading branch information
rboyer authored and dnephin committed May 5, 2020
1 parent 3cf29d4 commit 7a8034b
Show file tree
Hide file tree
Showing 14 changed files with 327 additions and 144 deletions.
4 changes: 2 additions & 2 deletions command/snapshot/inspect/snapshot_inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package inspect
import (
"io"
"os"
"path"
"path/filepath"
"strings"
"testing"

Expand Down Expand Up @@ -68,7 +68,7 @@ func TestSnapshotInspectCommand(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)

file := path.Join(dir, "backup.tgz")
file := filepath.Join(dir, "backup.tgz")

// Save a snapshot of the current Consul state
f, err := os.Create(file)
Expand Down
64 changes: 62 additions & 2 deletions command/snapshot/restore/snapshot_restore_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
package restore

import (
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
"testing"

"github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
)

func TestSnapshotRestoreCommand_noTabs(t *testing.T) {
Expand Down Expand Up @@ -71,7 +76,7 @@ func TestSnapshotRestoreCommand(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)

file := path.Join(dir, "backup.tgz")
file := filepath.Join(dir, "backup.tgz")
args := []string{
"-http-addr=" + a.HTTPAddr(),
file,
Expand Down Expand Up @@ -100,3 +105,58 @@ func TestSnapshotRestoreCommand(t *testing.T) {
t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
}
}

func TestSnapshotRestoreCommand_TruncatedSnapshot(t *testing.T) {
t.Parallel()
a := agent.NewTestAgent(t, ``)
defer a.Shutdown()
client := a.Client()

// Seed it with 64K of random data just so we have something to work with.
{
blob := make([]byte, 64*1024)
_, err := rand.Read(blob)
require.NoError(t, err)

_, err = client.KV().Put(&api.KVPair{Key: "blob", Value: blob}, nil)
require.NoError(t, err)
}

// Do a manual snapshot so we can send back roughly reasonable data.
var inputData []byte
{
rc, _, err := client.Snapshot().Save(nil)
require.NoError(t, err)
defer rc.Close()

inputData, err = ioutil.ReadAll(rc)
require.NoError(t, err)
}

dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)

for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
data := inputData[0 : len(inputData)-removeBytes]

ui := cli.NewMockUi()
c := New(ui)

file := filepath.Join(dir, "backup.tgz")
require.NoError(t, ioutil.WriteFile(file, data, 0644))
args := []string{
"-http-addr=" + a.HTTPAddr(),
file,
}

code := c.Run(args)
require.Equal(t, 1, code, "expected non-zero exit")

output := ui.ErrorWriter.String()
require.Contains(t, output, "Error restoring snapshot")
require.Contains(t, output, "EOF")
})
}
}
92 changes: 90 additions & 2 deletions command/snapshot/save/snapshot_save_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package save

import (
"crypto/rand"
"fmt"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"sync/atomic"
"testing"

"github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
)

func TestSnapshotSaveCommand_noTabs(t *testing.T) {
Expand All @@ -17,6 +25,7 @@ func TestSnapshotSaveCommand_noTabs(t *testing.T) {
t.Fatal("help has tabs")
}
}

func TestSnapshotSaveCommand_Validation(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -70,7 +79,7 @@ func TestSnapshotSaveCommand(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)

file := path.Join(dir, "backup.tgz")
file := filepath.Join(dir, "backup.tgz")
args := []string{
"-http-addr=" + a.HTTPAddr(),
file,
Expand All @@ -91,3 +100,82 @@ func TestSnapshotSaveCommand(t *testing.T) {
t.Fatalf("err: %v", err)
}
}

func TestSnapshotSaveCommand_TruncatedStream(t *testing.T) {
t.Parallel()
a := agent.NewTestAgent(t, ``)
defer a.Shutdown()
client := a.Client()

// Seed it with 64K of random data just so we have something to work with.
{
blob := make([]byte, 64*1024)
_, err := rand.Read(blob)
require.NoError(t, err)

_, err = client.KV().Put(&api.KVPair{Key: "blob", Value: blob}, nil)
require.NoError(t, err)
}

// Do a manual snapshot so we can send back roughly reasonable data.
var inputData []byte
{
rc, _, err := client.Snapshot().Save(nil)
require.NoError(t, err)
defer rc.Close()

inputData, err = ioutil.ReadAll(rc)
require.NoError(t, err)
}

var fakeResult atomic.Value

// Run a fake webserver to pretend to be the snapshot API.
fakeAddr := lib.StartTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path != "/v1/snapshot" {
w.WriteHeader(http.StatusNotFound)
return
}
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

raw := fakeResult.Load()
if raw == nil {
w.WriteHeader(http.StatusNotFound)
return
}

data := raw.([]byte)
_, _ = w.Write(data)
}))

dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)

for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
data := inputData[0 : len(inputData)-removeBytes]

fakeResult.Store(data)

ui := cli.NewMockUi()
c := New(ui)

file := filepath.Join(dir, "backup.tgz")
args := []string{
"-http-addr=" + fakeAddr, // point to the fake
file,
}

code := c.Run(args)
require.Equal(t, 1, code, "expected non-zero exit")

output := ui.ErrorWriter.String()
require.Contains(t, output, "Error verifying snapshot file")
require.Contains(t, output, "EOF")
})
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ require (
github.com/miekg/dns v1.1.26
github.com/mitchellh/cli v1.0.0
github.com/mitchellh/copystructure v1.0.0
github.com/mitchellh/go-testing-interface v1.0.0
github.com/mitchellh/go-testing-interface v1.14.0
github.com/mitchellh/hashstructure v0.0.0-20170609045927-2bca23e0e452
github.com/mitchellh/mapstructure v1.2.3
github.com/mitchellh/reflectwalk v1.0.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrk
github.com/mitchellh/go-testing-interface v0.0.0-20171004221916-a61a99592b77/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdIPrefOvVG1VZ96U0=
github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
github.com/mitchellh/go-testing-interface v1.14.0 h1:/x0XQ6h+3U3nAyk1yx+bHPURrKa9sVVvYbuqZ7pIAtI=
github.com/mitchellh/go-testing-interface v1.14.0/go.mod h1:gfgS7OtZj6MA4U1UrDRp04twqAjfvlZyCfX3sDjEym8=
github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo=
github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg=
github.com/mitchellh/hashstructure v0.0.0-20170609045927-2bca23e0e452 h1:hOY53G+kBFhbYFpRVxHl5eS7laP6B1+Cq+Z9Dry1iMU=
Expand Down
36 changes: 36 additions & 0 deletions lib/testing_httpserver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package lib

import (
"net/http"

"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/mitchellh/go-testing-interface"
)

// StartTestServer fires up a web server on a random unused port to serve the
// given handler body. The address it is listening on is returned. When the
// test case terminates the server will be stopped via cleanup functions.
//
// We can't directly use httptest.Server here because that only thinks a port
// is free if it's not bound. Consul tests frequently reserve ports via
// `sdk/freeport` so you can have one part of the test try to use a port and
// _know_ nothing is listening. If you simply assumed unbound ports were free
// you'd end up with test cross-talk and weirdness.
func StartTestServer(t testing.T, handler http.Handler) string {
ports := freeport.MustTake(1)
t.Cleanup(func() {
freeport.Return(ports)
})

addr := ipaddr.FormatAddressPort("127.0.0.1", ports[0])

server := &http.Server{Addr: addr, Handler: handler}
t.Cleanup(func() {
server.Close()
})

go server.ListenAndServe()

return addr
}
3 changes: 1 addition & 2 deletions snapshot/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func read(in io.Reader, metadata *raft.SnapshotMeta, snap io.Writer) error {
// Previously we used json.Decode to decode the archive stream. There are
// edgecases in which it doesn't read all the bytes from the stream, even
// though the json object is still being parsed properly. Since we
// simutaniously feeded everything to metaHash, our hash ended up being
// simultaneously feeded everything to metaHash, our hash ended up being
// different than what we calculated when creating the snapshot. Which in
// turn made the snapshot verification fail. By explicitly reading the
// whole thing first we ensure that we calculate the correct hash
Expand All @@ -223,7 +223,6 @@ func read(in io.Reader, metadata *raft.SnapshotMeta, snap io.Writer) error {
default:
return fmt.Errorf("unexpected file %q in snapshot", hdr.Name)
}

}

// Verify all the hashes.
Expand Down
24 changes: 24 additions & 0 deletions snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,29 @@ func Verify(in io.Reader) (*raft.SnapshotMeta, error) {
if err := read(decomp, &metadata, ioutil.Discard); err != nil {
return nil, fmt.Errorf("failed to read snapshot file: %v", err)
}

if err := concludeGzipRead(decomp); err != nil {
return nil, err
}

return &metadata, nil
}

// concludeGzipRead should be invoked after you think you've consumed all of
// the data from the gzip stream. It will error if the stream was corrupt.
//
// The docs for gzip.Reader say: "Clients should treat data returned by Read as
// tentative until they receive the io.EOF marking the end of the data."
func concludeGzipRead(decomp *gzip.Reader) error {
extra, err := ioutil.ReadAll(decomp) // ReadAll consumes the EOF
if err != nil {
return err
} else if len(extra) != 0 {
return fmt.Errorf("%d unread uncompressed bytes remain", len(extra))
}
return nil
}

// Restore takes the snapshot from the reader and attempts to apply it to the
// given Raft instance.
func Restore(logger hclog.Logger, in io.Reader, r *raft.Raft) error {
Expand Down Expand Up @@ -175,6 +195,10 @@ func Restore(logger hclog.Logger, in io.Reader, r *raft.Raft) error {
return fmt.Errorf("failed to read snapshot file: %v", err)
}

if err := concludeGzipRead(decomp); err != nil {
return err
}

// Sync and rewind the file so it's ready to be read again.
if err := snap.Sync(); err != nil {
return fmt.Errorf("failed to sync temp snapshot: %v", err)
Expand Down
Loading

0 comments on commit 7a8034b

Please sign in to comment.