open-consul/snapshot/snapshot_test.go
Daniel Nephin ea6c2b2adc ci: Add staticcheck and fix most errors
Three of the checks are temporarily disabled to limit the size of the
diff, and allow us to enable all the other checks in CI.

In a follow up we can fix the issues reported by the other checks one
at a time, and enable them.
2020-05-28 11:59:58 -04:00

348 lines
8.2 KiB
Go

package snapshot
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/require"
)
// MockFSM is a simple FSM for testing that simply stores its logs in a slice of
// byte slices.
type MockFSM struct {
sync.Mutex
logs [][]byte
}
// MockSnapshot is a snapshot sink for testing that encodes the contents of a
// MockFSM using msgpack.
type MockSnapshot struct {
logs [][]byte
maxIndex int
}
// See raft.FSM.
func (m *MockFSM) Apply(log *raft.Log) interface{} {
m.Lock()
defer m.Unlock()
m.logs = append(m.logs, log.Data)
return len(m.logs)
}
// See raft.FSM.
func (m *MockFSM) Snapshot() (raft.FSMSnapshot, error) {
m.Lock()
defer m.Unlock()
return &MockSnapshot{m.logs, len(m.logs)}, nil
}
// See raft.FSM.
func (m *MockFSM) Restore(in io.ReadCloser) error {
m.Lock()
defer m.Unlock()
defer in.Close()
dec := codec.NewDecoder(in, structs.MsgpackHandle)
m.logs = nil
return dec.Decode(&m.logs)
}
// See raft.SnapshotSink.
func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error {
enc := codec.NewEncoder(sink, structs.MsgpackHandle)
if err := enc.Encode(m.logs[:m.maxIndex]); err != nil {
sink.Cancel()
return err
}
sink.Close()
return nil
}
// See raft.SnapshotSink.
func (m *MockSnapshot) Release() {
}
// makeRaft returns a Raft and its FSM, with snapshots based in the given dir.
func makeRaft(t *testing.T, dir string) (*raft.Raft, *MockFSM) {
snaps, err := raft.NewFileSnapshotStore(dir, 5, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
fsm := &MockFSM{}
store := raft.NewInmemStore()
addr, trans := raft.NewInmemTransport("")
config := raft.DefaultConfig()
config.LocalID = raft.ServerID(fmt.Sprintf("server-%s", addr))
var members raft.Configuration
members.Servers = append(members.Servers, raft.Server{
Suffrage: raft.Voter,
ID: config.LocalID,
Address: addr,
})
err = raft.BootstrapCluster(config, store, store, snaps, trans, members)
if err != nil {
t.Fatalf("err: %v", err)
}
raft, err := raft.NewRaft(config, fsm, store, store, snaps, trans)
if err != nil {
t.Fatalf("err: %v", err)
}
timeout := time.After(10 * time.Second)
for {
if raft.Leader() != "" {
break
}
select {
case <-raft.LeaderCh():
case <-time.After(1 * time.Second):
// Need to poll because we might have missed the first
// go with the leader channel.
case <-timeout:
t.Fatalf("timed out waiting for leader")
}
}
return raft, fsm
}
func TestSnapshot(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data. We tee everything we
// apply off to a buffer for checking post-snapshot.
var expected []bytes.Buffer
entries := 64 * 1024
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < entries; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := before.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
expected = append(expected, copy)
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
if err != nil {
t.Fatalf("err: %v", err)
}
defer snap.Close()
// Verify the snapshot. We have to rewind it after for the restore.
metadata, err := Verify(snap)
if err != nil {
t.Fatalf("err: %v", err)
}
if _, err := snap.file.Seek(0, 0); err != nil {
t.Fatalf("err: %v", err)
}
if int(metadata.Index) != entries+2 {
t.Fatalf("bad: %d", metadata.Index)
}
if metadata.Term != 2 {
t.Fatalf("bad: %d", metadata.Index)
}
if metadata.Version != raft.SnapshotVersionMax {
t.Fatalf("bad: %d", metadata.Version)
}
// Make a new, independent Raft.
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
defer after.Shutdown()
// Put some initial data in there that the snapshot should overwrite.
for i := 0; i < 16; i++ {
var log bytes.Buffer
if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := after.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
}
// Restore the snapshot.
if err := Restore(logger, snap, after); err != nil {
t.Fatalf("err: %v", err)
}
// Compare the contents.
fsm.Lock()
defer fsm.Unlock()
if len(fsm.logs) != len(expected) {
t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
}
for i := range fsm.logs {
if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
t.Fatalf("bad: log %d doesn't match", i)
}
}
}
func TestSnapshot_Nil(t *testing.T) {
var snap *Snapshot
if idx := snap.Index(); idx != 0 {
t.Fatalf("bad: %d", idx)
}
n, err := snap.Read(make([]byte, 16))
if n != 0 || err != io.EOF {
t.Fatalf("bad: %d %v", n, err)
}
if err := snap.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestSnapshot_BadVerify(t *testing.T) {
buf := bytes.NewBuffer([]byte("nope"))
_, err := Verify(buf)
if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
t.Fatalf("err: %v", err)
}
}
func TestSnapshot_TruncatedVerify(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data. We tee everything we
// apply off to a buffer for checking post-snapshot.
entries := 64 * 1024
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < entries; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
_, err := io.CopyN(both, rand.Reader, 256)
require.NoError(t, err)
future := before.Apply(log.Bytes(), time.Second)
require.NoError(t, future.Error())
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
require.NoError(t, err)
defer snap.Close()
var data []byte
{
var buf bytes.Buffer
_, err = io.Copy(&buf, snap)
require.NoError(t, err)
data = buf.Bytes()
}
for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
buf := bytes.NewReader(data[0 : len(data)-removeBytes])
_, err = Verify(buf)
require.Error(t, err)
})
}
}
func TestSnapshot_BadRestore(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data.
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < 16*1024; i++ {
var log bytes.Buffer
if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := before.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
if err != nil {
t.Fatalf("err: %v", err)
}
// Make a new, independent Raft.
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
defer after.Shutdown()
// Put some initial data in there that should not be harmed by the
// failed restore attempt.
var expected []bytes.Buffer
for i := 0; i < 16; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := after.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
expected = append(expected, copy)
}
// Attempt to restore a truncated version of the snapshot. This is
// expected to fail.
err = Restore(logger, io.LimitReader(snap, 512), after)
if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
t.Fatalf("err: %v", err)
}
// Compare the contents to make sure the aborted restore didn't harm
// anything.
fsm.Lock()
defer fsm.Unlock()
if len(fsm.logs) != len(expected) {
t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
}
for i := range fsm.logs {
if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
t.Fatalf("bad: log %d doesn't match", i)
}
}
}