open-nomad/helper/snapshot/snapshot_test.go
hc-github-team-nomad-core 1b2237d6a8
backport of commit 776a26bce7cf3a320fc7e7f4a6bf9da2b30f3da7 (#18375)
Co-authored-by: James Rasell <jrasell@users.noreply.github.com>
2023-09-01 10:25:08 +01:00

353 lines
8.4 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package snapshot
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/require"
)
// MockFSM is a simple FSM for testing that simply stores its logs in a slice of
// byte slices.
type MockFSM struct {
sync.Mutex
logs [][]byte
}
// MockSnapshot is a snapshot sink for testing that encodes the contents of a
// MockFSM using msgpack.
type MockSnapshot struct {
logs [][]byte
maxIndex int
}
// See raft.FSM.
func (m *MockFSM) Apply(log *raft.Log) interface{} {
m.Lock()
defer m.Unlock()
m.logs = append(m.logs, log.Data)
return len(m.logs)
}
// See raft.FSM.
func (m *MockFSM) Snapshot() (raft.FSMSnapshot, error) {
m.Lock()
defer m.Unlock()
return &MockSnapshot{m.logs, len(m.logs)}, nil
}
// See raft.FSM.
func (m *MockFSM) Restore(in io.ReadCloser) error {
m.Lock()
defer m.Unlock()
defer in.Close()
dec := codec.NewDecoder(in, structs.MsgpackHandle)
m.logs = nil
return dec.Decode(&m.logs)
}
// See raft.SnapshotSink.
func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error {
enc := codec.NewEncoder(sink, structs.MsgpackHandle)
if err := enc.Encode(m.logs[:m.maxIndex]); err != nil {
sink.Cancel()
return err
}
sink.Close()
return nil
}
// See raft.SnapshotSink.
func (m *MockSnapshot) Release() {
}
// makeRaft returns a Raft and its FSM, with snapshots based in the given dir.
func makeRaft(t *testing.T, dir string) (*raft.Raft, *MockFSM) {
snaps, err := raft.NewFileSnapshotStore(dir, 5, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
fsm := &MockFSM{}
store := raft.NewInmemStore()
addr, trans := raft.NewInmemTransport("")
config := raft.DefaultConfig()
config.LocalID = raft.ServerID(fmt.Sprintf("server-%s", addr))
var members raft.Configuration
members.Servers = append(members.Servers, raft.Server{
Suffrage: raft.Voter,
ID: config.LocalID,
Address: addr,
})
err = raft.BootstrapCluster(config, store, store, snaps, trans, members)
if err != nil {
t.Fatalf("err: %v", err)
}
raft, err := raft.NewRaft(config, fsm, store, store, snaps, trans)
if err != nil {
t.Fatalf("err: %v", err)
}
timeout := time.After(10 * time.Second)
for {
if leaderAddr, _ := raft.LeaderWithID(); leaderAddr != "" {
break
}
select {
case <-raft.LeaderCh():
case <-time.After(1 * time.Second):
// Need to poll because we might have missed the first
// go with the leader channel.
case <-timeout:
t.Fatalf("timed out waiting for leader")
}
}
return raft, fsm
}
func TestSnapshot(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data. We tee everything we
// apply off to a buffer for checking post-snapshot.
var expected []bytes.Buffer
entries := 64 * 1024
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < entries; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := before.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
expected = append(expected, copy)
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
if err != nil {
t.Fatalf("err: %v", err)
}
defer snap.Close()
// Verify the snapshot. We have to rewind it after for the restore.
metadata, err := Verify(snap)
if err != nil {
t.Fatalf("err: %v", err)
}
if _, err := snap.file.Seek(0, 0); err != nil {
t.Fatalf("err: %v", err)
}
if int(metadata.Index) != entries+2 {
t.Fatalf("bad: %d", metadata.Index)
}
if metadata.Term != 2 {
t.Fatalf("bad: %d", metadata.Index)
}
if metadata.Version != raft.SnapshotVersionMax {
t.Fatalf("bad: %d", metadata.Version)
}
// Make a new, independent Raft.
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
defer after.Shutdown()
// Put some initial data in there that the snapshot should overwrite.
for i := 0; i < 16; i++ {
var log bytes.Buffer
if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := after.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
}
// Restore the snapshot.
if err := Restore(logger, snap, after); err != nil {
t.Fatalf("err: %v", err)
}
// Compare the contents.
fsm.Lock()
defer fsm.Unlock()
if len(fsm.logs) != len(expected) {
t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
}
for i := range fsm.logs {
if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
t.Fatalf("bad: log %d doesn't match", i)
}
}
}
func TestSnapshot_Nil(t *testing.T) {
var snap *Snapshot
if idx := snap.Index(); idx != 0 {
t.Fatalf("bad: %d", idx)
}
n, err := snap.Read(make([]byte, 16))
if n != 0 || err != io.EOF {
t.Fatalf("bad: %d %v", n, err)
}
if err := snap.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestSnapshot_BadVerify(t *testing.T) {
buf := bytes.NewBuffer([]byte("nope"))
_, err := Verify(buf)
if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
t.Fatalf("err: %v", err)
}
}
func TestSnapshot_TruncatedVerify(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data. We tee everything we
// apply off to a buffer for checking post-snapshot.
var expected []bytes.Buffer
entries := 64 * 1024
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < entries; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
_, err := io.CopyN(both, rand.Reader, 256)
require.NoError(t, err)
future := before.Apply(log.Bytes(), time.Second)
require.NoError(t, future.Error())
expected = append(expected, copy)
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
require.NoError(t, err)
defer snap.Close()
var data []byte
{
var buf bytes.Buffer
_, err = io.Copy(&buf, snap)
require.NoError(t, err)
data = buf.Bytes()
}
for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
buf := bytes.NewReader(data[0 : len(data)-removeBytes])
_, err = Verify(buf)
require.Error(t, err)
})
}
}
func TestSnapshot_BadRestore(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data.
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < 16*1024; i++ {
var log bytes.Buffer
if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := before.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
if err != nil {
t.Fatalf("err: %v", err)
}
// Make a new, independent Raft.
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
defer after.Shutdown()
// Put some initial data in there that should not be harmed by the
// failed restore attempt.
var expected []bytes.Buffer
for i := 0; i < 16; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := after.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
expected = append(expected, copy)
}
// Attempt to restore a truncated version of the snapshot. This is
// expected to fail.
err = Restore(logger, io.LimitReader(snap, 512), after)
if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
t.Fatalf("err: %v", err)
}
// Compare the contents to make sure the aborted restore didn't harm
// anything.
fsm.Lock()
defer fsm.Unlock()
if len(fsm.logs) != len(expected) {
t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
}
for i := range fsm.logs {
if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
t.Fatalf("bad: log %d doesn't match", i)
}
}
}