cli: fix usage of gzip.Reader to better detect corrupt snapshots during save/restore (#7697)

This commit is contained in:
R.B. Boyer 2020-04-24 17:18:56 -05:00 committed by GitHub
parent 82b0fbd975
commit 032e0ae901
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 327 additions and 144 deletions

View File

@ -3,7 +3,7 @@ package inspect
import (
"io"
"os"
"path"
"path/filepath"
"strings"
"testing"
@ -68,7 +68,7 @@ func TestSnapshotInspectCommand(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
file := path.Join(dir, "backup.tgz")
file := filepath.Join(dir, "backup.tgz")
// Save a snapshot of the current Consul state
f, err := os.Create(file)

View File

@ -1,15 +1,20 @@
package restore
import (
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
"testing"
"github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
)
func TestSnapshotRestoreCommand_noTabs(t *testing.T) {
@ -71,7 +76,7 @@ func TestSnapshotRestoreCommand(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
file := path.Join(dir, "backup.tgz")
file := filepath.Join(dir, "backup.tgz")
args := []string{
"-http-addr=" + a.HTTPAddr(),
file,
@ -100,3 +105,58 @@ func TestSnapshotRestoreCommand(t *testing.T) {
t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
}
}
func TestSnapshotRestoreCommand_TruncatedSnapshot(t *testing.T) {
t.Parallel()
a := agent.NewTestAgent(t, ``)
defer a.Shutdown()
client := a.Client()
// Seed it with 64K of random data just so we have something to work with.
{
blob := make([]byte, 64*1024)
_, err := rand.Read(blob)
require.NoError(t, err)
_, err = client.KV().Put(&api.KVPair{Key: "blob", Value: blob}, nil)
require.NoError(t, err)
}
// Do a manual snapshot so we can send back roughly reasonable data.
var inputData []byte
{
rc, _, err := client.Snapshot().Save(nil)
require.NoError(t, err)
defer rc.Close()
inputData, err = ioutil.ReadAll(rc)
require.NoError(t, err)
}
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
data := inputData[0 : len(inputData)-removeBytes]
ui := cli.NewMockUi()
c := New(ui)
file := filepath.Join(dir, "backup.tgz")
require.NoError(t, ioutil.WriteFile(file, data, 0644))
args := []string{
"-http-addr=" + a.HTTPAddr(),
file,
}
code := c.Run(args)
require.Equal(t, 1, code, "expected non-zero exit")
output := ui.ErrorWriter.String()
require.Contains(t, output, "Error restoring snapshot")
require.Contains(t, output, "EOF")
})
}
}

View File

@ -1,14 +1,22 @@
package save
import (
"crypto/rand"
"fmt"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
)
func TestSnapshotSaveCommand_noTabs(t *testing.T) {
@ -17,6 +25,7 @@ func TestSnapshotSaveCommand_noTabs(t *testing.T) {
t.Fatal("help has tabs")
}
}
func TestSnapshotSaveCommand_Validation(t *testing.T) {
t.Parallel()
@ -70,7 +79,7 @@ func TestSnapshotSaveCommand(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
file := path.Join(dir, "backup.tgz")
file := filepath.Join(dir, "backup.tgz")
args := []string{
"-http-addr=" + a.HTTPAddr(),
file,
@ -91,3 +100,82 @@ func TestSnapshotSaveCommand(t *testing.T) {
t.Fatalf("err: %v", err)
}
}
func TestSnapshotSaveCommand_TruncatedStream(t *testing.T) {
t.Parallel()
a := agent.NewTestAgent(t, ``)
defer a.Shutdown()
client := a.Client()
// Seed it with 64K of random data just so we have something to work with.
{
blob := make([]byte, 64*1024)
_, err := rand.Read(blob)
require.NoError(t, err)
_, err = client.KV().Put(&api.KVPair{Key: "blob", Value: blob}, nil)
require.NoError(t, err)
}
// Do a manual snapshot so we can send back roughly reasonable data.
var inputData []byte
{
rc, _, err := client.Snapshot().Save(nil)
require.NoError(t, err)
defer rc.Close()
inputData, err = ioutil.ReadAll(rc)
require.NoError(t, err)
}
var fakeResult atomic.Value
// Run a fake webserver to pretend to be the snapshot API.
fakeAddr := lib.StartTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path != "/v1/snapshot" {
w.WriteHeader(http.StatusNotFound)
return
}
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
raw := fakeResult.Load()
if raw == nil {
w.WriteHeader(http.StatusNotFound)
return
}
data := raw.([]byte)
_, _ = w.Write(data)
}))
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
data := inputData[0 : len(inputData)-removeBytes]
fakeResult.Store(data)
ui := cli.NewMockUi()
c := New(ui)
file := filepath.Join(dir, "backup.tgz")
args := []string{
"-http-addr=" + fakeAddr, // point to the fake
file,
}
code := c.Run(args)
require.Equal(t, 1, code, "expected non-zero exit")
output := ui.ErrorWriter.String()
require.Contains(t, output, "Error verifying snapshot file")
require.Contains(t, output, "EOF")
})
}
}

2
go.mod
View File

@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.26
github.com/mitchellh/cli v1.1.0
github.com/mitchellh/copystructure v1.0.0
github.com/mitchellh/go-testing-interface v1.0.0
github.com/mitchellh/go-testing-interface v1.14.0
github.com/mitchellh/hashstructure v0.0.0-20170609045927-2bca23e0e452
github.com/mitchellh/mapstructure v1.1.2
github.com/mitchellh/reflectwalk v1.0.1

2
go.sum
View File

@ -328,6 +328,8 @@ github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrk
github.com/mitchellh/go-testing-interface v0.0.0-20171004221916-a61a99592b77/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdIPrefOvVG1VZ96U0=
github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
github.com/mitchellh/go-testing-interface v1.14.0 h1:/x0XQ6h+3U3nAyk1yx+bHPURrKa9sVVvYbuqZ7pIAtI=
github.com/mitchellh/go-testing-interface v1.14.0/go.mod h1:gfgS7OtZj6MA4U1UrDRp04twqAjfvlZyCfX3sDjEym8=
github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo=
github.com/mitchellh/hashstructure v0.0.0-20170609045927-2bca23e0e452 h1:hOY53G+kBFhbYFpRVxHl5eS7laP6B1+Cq+Z9Dry1iMU=
github.com/mitchellh/hashstructure v0.0.0-20170609045927-2bca23e0e452/go.mod h1:QjSHrPWS+BGUVBYkbTZWEnOh3G1DutKwClXU/ABz6AQ=

36
lib/testing_httpserver.go Normal file
View File

@ -0,0 +1,36 @@
package lib
import (
"net/http"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/mitchellh/go-testing-interface"
)
// StartTestServer fires up a web server on a random unused port to serve the
// given handler body. The address it is listening on is returned. When the
// test case terminates the server will be stopped via cleanup functions.
//
// We can't directly use httptest.Server here because that only thinks a port
// is free if it's not bound. Consul tests frequently reserve ports via
// `sdk/freeport` so you can have one part of the test try to use a port and
// _know_ nothing is listening. If you simply assumed unbound ports were free
// you'd end up with test cross-talk and weirdness.
func StartTestServer(t testing.T, handler http.Handler) string {
ports := freeport.MustTake(1)
t.Cleanup(func() {
freeport.Return(ports)
})
addr := ipaddr.FormatAddressPort("127.0.0.1", ports[0])
server := &http.Server{Addr: addr, Handler: handler}
t.Cleanup(func() {
server.Close()
})
go server.ListenAndServe()
return addr
}

View File

@ -197,7 +197,7 @@ func read(in io.Reader, metadata *raft.SnapshotMeta, snap io.Writer) error {
// Previously we used json.Decode to decode the archive stream. There are
// edgecases in which it doesn't read all the bytes from the stream, even
// though the json object is still being parsed properly. Since we
// simutaniously feeded everything to metaHash, our hash ended up being
// simultaneously feeded everything to metaHash, our hash ended up being
// different than what we calculated when creating the snapshot. Which in
// turn made the snapshot verification fail. By explicitly reading the
// whole thing first we ensure that we calculate the correct hash
@ -223,7 +223,6 @@ func read(in io.Reader, metadata *raft.SnapshotMeta, snap io.Writer) error {
default:
return fmt.Errorf("unexpected file %q in snapshot", hdr.Name)
}
}
// Verify all the hashes.

View File

@ -137,9 +137,29 @@ func Verify(in io.Reader) (*raft.SnapshotMeta, error) {
if err := read(decomp, &metadata, ioutil.Discard); err != nil {
return nil, fmt.Errorf("failed to read snapshot file: %v", err)
}
if err := concludeGzipRead(decomp); err != nil {
return nil, err
}
return &metadata, nil
}
// concludeGzipRead should be invoked after you think you've consumed all of
// the data from the gzip stream. It will error if the stream was corrupt.
//
// The docs for gzip.Reader say: "Clients should treat data returned by Read as
// tentative until they receive the io.EOF marking the end of the data."
func concludeGzipRead(decomp *gzip.Reader) error {
extra, err := ioutil.ReadAll(decomp) // ReadAll consumes the EOF
if err != nil {
return err
} else if len(extra) != 0 {
return fmt.Errorf("%d unread uncompressed bytes remain", len(extra))
}
return nil
}
// Restore takes the snapshot from the reader and attempts to apply it to the
// given Raft instance.
func Restore(logger hclog.Logger, in io.Reader, r *raft.Raft) error {
@ -175,6 +195,10 @@ func Restore(logger hclog.Logger, in io.Reader, r *raft.Raft) error {
return fmt.Errorf("failed to read snapshot file: %v", err)
}
if err := concludeGzipRead(decomp); err != nil {
return err
}
// Sync and rewind the file so it's ready to be read again.
if err := snap.Sync(); err != nil {
return fmt.Errorf("failed to sync temp snapshot: %v", err)

View File

@ -6,7 +6,7 @@ import (
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
"sync"
"testing"
@ -16,6 +16,7 @@ import (
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/require"
)
// MockFSM is a simple FSM for testing that simply stores its logs in a slice of
@ -131,7 +132,7 @@ func TestSnapshot(t *testing.T) {
// apply off to a buffer for checking post-snapshot.
var expected []bytes.Buffer
entries := 64 * 1024
before, _ := makeRaft(t, path.Join(dir, "before"))
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < entries; i++ {
var log bytes.Buffer
@ -174,7 +175,7 @@ func TestSnapshot(t *testing.T) {
}
// Make a new, independent Raft.
after, fsm := makeRaft(t, path.Join(dir, "after"))
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
defer after.Shutdown()
// Put some initial data in there that the snapshot should overwrite.
@ -232,12 +233,60 @@ func TestSnapshot_BadVerify(t *testing.T) {
}
}
func TestSnapshot_TruncatedVerify(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data. We tee everything we
// apply off to a buffer for checking post-snapshot.
var expected []bytes.Buffer
entries := 64 * 1024
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < entries; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
_, err := io.CopyN(both, rand.Reader, 256)
require.NoError(t, err)
future := before.Apply(log.Bytes(), time.Second)
require.NoError(t, future.Error())
expected = append(expected, copy)
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
require.NoError(t, err)
defer snap.Close()
var data []byte
{
var buf bytes.Buffer
_, err = io.Copy(&buf, snap)
require.NoError(t, err)
data = buf.Bytes()
}
for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
buf := bytes.NewReader(data[0 : len(data)-removeBytes])
_, err = Verify(buf)
require.Error(t, err)
})
}
}
func TestSnapshot_BadRestore(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data.
before, _ := makeRaft(t, path.Join(dir, "before"))
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < 16*1024; i++ {
var log bytes.Buffer
@ -258,7 +307,7 @@ func TestSnapshot_BadRestore(t *testing.T) {
}
// Make a new, independent Raft.
after, fsm := makeRaft(t, path.Join(dir, "after"))
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
defer after.Shutdown()
// Put some initial data in there that should not be harmed by the

View File

@ -38,6 +38,14 @@ You can also call the test helper at runtime if needed:
TestHelper(&testing.RuntimeT{})
}
## Versioning
The tagged version matches the version of Go that the interface is
compatible with. For example, the version "1.14.0" is for Go 1.14 and
introduced the `Cleanup` function. The patch version (the ".0" in the
prior example) is used to fix any bugs found in this library and has no
correlation to the supported Go version.
## Why?!
**Why would I call a test helper that takes a *testing.T at runtime?**

View File

@ -1 +1,3 @@
module github.com/mitchellh/go-testing-interface
go 1.14

View File

@ -1,5 +1,3 @@
// +build !go1.9
package testing
import (
@ -12,6 +10,7 @@ import (
// In unit tests you can just pass a *testing.T struct. At runtime, outside
// of tests, you can pass in a RuntimeT struct from this package.
type T interface {
Cleanup(func())
Error(args ...interface{})
Errorf(format string, args ...interface{})
Fail()
@ -19,6 +18,7 @@ type T interface {
Failed() bool
Fatal(args ...interface{})
Fatalf(format string, args ...interface{})
Helper()
Log(args ...interface{})
Logf(format string, args ...interface{})
Name() string
@ -31,10 +31,13 @@ type T interface {
// RuntimeT implements T and can be instantiated and run at runtime to
// mimic *testing.T behavior. Unlike *testing.T, this will simply panic
// for calls to Fatal. For calls to Error, you'll have to check the errors
// list to determine whether to exit yourself. Name and Skip methods are
// unimplemented noops.
// list to determine whether to exit yourself.
//
// Cleanup does NOT work, so if you're using a helper that uses Cleanup,
// there may be dangling resources.
type RuntimeT struct {
failed bool
skipped bool
failed bool
}
func (t *RuntimeT) Error(args ...interface{}) {
@ -43,20 +46,10 @@ func (t *RuntimeT) Error(args ...interface{}) {
}
func (t *RuntimeT) Errorf(format string, args ...interface{}) {
log.Println(fmt.Sprintf(format, args...))
log.Printf(format, args...)
t.Fail()
}
func (t *RuntimeT) Fatal(args ...interface{}) {
log.Println(fmt.Sprintln(args...))
t.FailNow()
}
func (t *RuntimeT) Fatalf(format string, args ...interface{}) {
log.Println(fmt.Sprintf(format, args...))
t.FailNow()
}
func (t *RuntimeT) Fail() {
t.failed = true
}
@ -69,6 +62,16 @@ func (t *RuntimeT) Failed() bool {
return t.failed
}
func (t *RuntimeT) Fatal(args ...interface{}) {
log.Print(args...)
t.FailNow()
}
func (t *RuntimeT) Fatalf(format string, args ...interface{}) {
log.Printf(format, args...)
t.FailNow()
}
func (t *RuntimeT) Log(args ...interface{}) {
log.Println(fmt.Sprintln(args...))
}
@ -77,8 +80,28 @@ func (t *RuntimeT) Logf(format string, args ...interface{}) {
log.Println(fmt.Sprintf(format, args...))
}
func (t *RuntimeT) Name() string { return "" }
func (t *RuntimeT) Skip(args ...interface{}) {}
func (t *RuntimeT) SkipNow() {}
func (t *RuntimeT) Skipf(format string, args ...interface{}) {}
func (t *RuntimeT) Skipped() bool { return false }
func (t *RuntimeT) Name() string {
return ""
}
func (t *RuntimeT) Skip(args ...interface{}) {
log.Print(args...)
t.SkipNow()
}
func (t *RuntimeT) SkipNow() {
t.skipped = true
}
func (t *RuntimeT) Skipf(format string, args ...interface{}) {
log.Printf(format, args...)
t.SkipNow()
}
func (t *RuntimeT) Skipped() bool {
return t.skipped
}
func (t *RuntimeT) Helper() {}
func (t *RuntimeT) Cleanup(func()) {}

View File

@ -1,108 +0,0 @@
// +build go1.9
// NOTE: This is a temporary copy of testing.go for Go 1.9 with the addition
// of "Helper" to the T interface. Go 1.9 at the time of typing is in RC
// and is set for release shortly. We'll support this on master as the default
// as soon as 1.9 is released.
package testing
import (
"fmt"
"log"
)
// T is the interface that mimics the standard library *testing.T.
//
// In unit tests you can just pass a *testing.T struct. At runtime, outside
// of tests, you can pass in a RuntimeT struct from this package.
type T interface {
Error(args ...interface{})
Errorf(format string, args ...interface{})
Fail()
FailNow()
Failed() bool
Fatal(args ...interface{})
Fatalf(format string, args ...interface{})
Log(args ...interface{})
Logf(format string, args ...interface{})
Name() string
Skip(args ...interface{})
SkipNow()
Skipf(format string, args ...interface{})
Skipped() bool
Helper()
}
// RuntimeT implements T and can be instantiated and run at runtime to
// mimic *testing.T behavior. Unlike *testing.T, this will simply panic
// for calls to Fatal. For calls to Error, you'll have to check the errors
// list to determine whether to exit yourself.
type RuntimeT struct {
skipped bool
failed bool
}
func (t *RuntimeT) Error(args ...interface{}) {
log.Println(fmt.Sprintln(args...))
t.Fail()
}
func (t *RuntimeT) Errorf(format string, args ...interface{}) {
log.Printf(format, args...)
t.Fail()
}
func (t *RuntimeT) Fail() {
t.failed = true
}
func (t *RuntimeT) FailNow() {
panic("testing.T failed, see logs for output (if any)")
}
func (t *RuntimeT) Failed() bool {
return t.failed
}
func (t *RuntimeT) Fatal(args ...interface{}) {
log.Print(args...)
t.FailNow()
}
func (t *RuntimeT) Fatalf(format string, args ...interface{}) {
log.Printf(format, args...)
t.FailNow()
}
func (t *RuntimeT) Log(args ...interface{}) {
log.Println(fmt.Sprintln(args...))
}
func (t *RuntimeT) Logf(format string, args ...interface{}) {
log.Println(fmt.Sprintf(format, args...))
}
func (t *RuntimeT) Name() string {
return ""
}
func (t *RuntimeT) Skip(args ...interface{}) {
log.Print(args...)
t.SkipNow()
}
func (t *RuntimeT) SkipNow() {
t.skipped = true
}
func (t *RuntimeT) Skipf(format string, args ...interface{}) {
log.Printf(format, args...)
t.SkipNow()
}
func (t *RuntimeT) Skipped() bool {
return t.skipped
}
func (t *RuntimeT) Helper() {}

2
vendor/modules.txt vendored
View File

@ -299,7 +299,7 @@ github.com/mitchellh/cli
github.com/mitchellh/copystructure
# github.com/mitchellh/go-homedir v1.1.0
github.com/mitchellh/go-homedir
# github.com/mitchellh/go-testing-interface v1.0.0
# github.com/mitchellh/go-testing-interface v1.14.0
github.com/mitchellh/go-testing-interface
# github.com/mitchellh/hashstructure v0.0.0-20170609045927-2bca23e0e452
github.com/mitchellh/hashstructure