diff --git a/api/sys_raft.go b/api/sys_raft.go index 5677cf454..043a69801 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -1,21 +1,25 @@ package api import ( + "archive/tar" + "compress/gzip" "context" "encoding/json" "errors" "fmt" "io" + "io/ioutil" "net/http" + "sync" "time" "github.com/hashicorp/go-secure-stdlib/parseutil" - - "github.com/mitchellh/mapstructure" - "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/mitchellh/mapstructure" ) +var ErrIncompleteSnapshot = errors.New("incomplete snapshot, unable to read SHA256SUMS.sealed file") + // RaftJoinResponse represents the response of the raft join API type RaftJoinResponse struct { Joined bool `json:"joined"` @@ -210,11 +214,60 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { return err } - _, err = io.Copy(snapWriter, resp.Body) + // Make sure that the last file in the archive, SHA256SUMS.sealed, is present + // and non-empty. This is to catch cases where the snapshot failed midstream, + // e.g. due to a problem with the seal that prevented encryption of that file. + var wg sync.WaitGroup + wg.Add(1) + var verified bool + + rPipe, wPipe := io.Pipe() + dup := io.TeeReader(resp.Body, wPipe) + go func() { + defer func() { + io.Copy(ioutil.Discard, rPipe) + rPipe.Close() + wg.Done() + }() + + uncompressed, err := gzip.NewReader(rPipe) + if err != nil { + return + } + + t := tar.NewReader(uncompressed) + var h *tar.Header + for { + h, err = t.Next() + if err != nil { + return + } + if h.Name != "SHA256SUMS.sealed" { + continue + } + var b []byte + b, err = ioutil.ReadAll(t) + if err != nil || len(b) == 0 { + return + } + verified = true + return + } + }() + + // Copy bytes from dup to snapWriter. This will have a side effect that + // everything read from dup will be written to wPipe. + _, err = io.Copy(snapWriter, dup) + wPipe.Close() if err != nil { + rPipe.CloseWithError(err) return err } + wg.Wait() + if !verified { + return ErrIncompleteSnapshot + } return nil } diff --git a/changelog/12388.txt b/changelog/12388.txt new file mode 100644 index 000000000..f384c90bd --- /dev/null +++ b/changelog/12388.txt @@ -0,0 +1,3 @@ +```release-note:bug +storage/raft: Detect incomplete raft snapshots in api.RaftSnapshot(), and thereby in `vault operator raft snapshot save`. +``` diff --git a/vault/external_tests/raft/raft_test.go b/vault/external_tests/raft/raft_test.go index f98b57558..0ab2a3032 100644 --- a/vault/external_tests/raft/raft_test.go +++ b/vault/external_tests/raft/raft_test.go @@ -4,14 +4,19 @@ import ( "bytes" "context" "crypto/md5" + "errors" "fmt" + "io" "io/ioutil" "net/http" "strings" + "sync" "sync/atomic" "testing" "time" + vaultseal "github.com/hashicorp/vault/vault/seal" + "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" @@ -36,6 +41,8 @@ type RaftClusterOpts struct { PhysicalFactoryConfig map[string]interface{} DisablePerfStandby bool EnableResponseHeaderRaftNodeID bool + NumCores int + Seal vault.Seal } func raftCluster(t testing.TB, ropts *RaftClusterOpts) *vault.TestCluster { @@ -49,6 +56,7 @@ func raftCluster(t testing.TB, ropts *RaftClusterOpts) *vault.TestCluster { }, DisableAutopilot: !ropts.EnableAutopilot, EnableResponseHeaderRaftNodeID: ropts.EnableResponseHeaderRaftNodeID, + Seal: ropts.Seal, } opts := vault.TestClusterOptions{ @@ -57,6 +65,7 @@ func raftCluster(t testing.TB, ropts *RaftClusterOpts) *vault.TestCluster { opts.InmemClusterLayers = ropts.InmemCluster opts.PhysicalFactoryConfig = ropts.PhysicalFactoryConfig conf.DisablePerformanceStandby = ropts.DisablePerfStandby + opts.NumCores = ropts.NumCores teststorage.RaftBackendSetup(conf, &opts) @@ -542,6 +551,58 @@ func TestRaft_SnapshotAPI(t *testing.T) { } } +func TestRaft_SnapshotAPI_MidstreamFailure(t *testing.T) { + // defer goleak.VerifyNone(t) + t.Parallel() + + seal, errptr := vaultseal.NewToggleableTestSeal(nil) + cluster := raftCluster(t, &RaftClusterOpts{ + NumCores: 1, + Seal: vault.NewAutoSeal(seal), + }) + defer cluster.Cleanup() + + leaderClient := cluster.Cores[0].Client + + // Write a bunch of keys; if too few, the detection code in api.RaftSnapshot + // will never make it into the tar part, it'll fail merely when trying to + // decompress the stream. + for i := 0; i < 1000; i++ { + _, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{ + "test": "data", + }) + if err != nil { + t.Fatal(err) + } + } + + r, w := io.Pipe() + var snap []byte + var wg sync.WaitGroup + wg.Add(1) + + var readErr error + go func() { + snap, readErr = ioutil.ReadAll(r) + wg.Done() + }() + + *errptr = errors.New("seal failure") + // Take a snapshot + err := leaderClient.Sys().RaftSnapshot(w) + w.Close() + if err == nil || err != api.ErrIncompleteSnapshot { + t.Fatalf("expected err=%v, got: %v", api.ErrIncompleteSnapshot, err) + } + wg.Wait() + if len(snap) == 0 && readErr == nil { + readErr = errors.New("no bytes read") + } + if readErr != nil { + t.Fatal(readErr) + } +} + func TestRaft_SnapshotAPI_RekeyRotate_Backward(t *testing.T) { type testCase struct { Name string diff --git a/vault/seal/seal_testing.go b/vault/seal/seal_testing.go index e4f4db3ce..1a7130b2d 100644 --- a/vault/seal/seal_testing.go +++ b/vault/seal/seal_testing.go @@ -1,6 +1,8 @@ package seal import ( + "context" + "github.com/hashicorp/go-hclog" wrapping "github.com/hashicorp/go-kms-wrapping" ) @@ -22,3 +24,36 @@ func NewTestSeal(opts *TestSealOpts) *Access { OverriddenType: opts.Name, } } + +func NewToggleableTestSeal(opts *TestSealOpts) (*Access, *error) { + if opts == nil { + opts = new(TestSealOpts) + } + + w := &ToggleableWrapper{Wrapper: wrapping.NewTestWrapper(opts.Secret)} + return &Access{ + Wrapper: w, + OverriddenType: opts.Name, + }, &w.Error +} + +type ToggleableWrapper struct { + wrapping.Wrapper + Error error +} + +func (t ToggleableWrapper) Encrypt(ctx context.Context, bytes []byte, bytes2 []byte) (*wrapping.EncryptedBlobInfo, error) { + if t.Error != nil { + return nil, t.Error + } + return t.Wrapper.Encrypt(ctx, bytes, bytes2) +} + +func (t ToggleableWrapper) Decrypt(ctx context.Context, info *wrapping.EncryptedBlobInfo, bytes []byte) ([]byte, error) { + if t.Error != nil { + return nil, t.Error + } + return t.Wrapper.Decrypt(ctx, info, bytes) +} + +var _ wrapping.Wrapper = &ToggleableWrapper{}