From c3ac053218637e324c3a67c46c5043551a89de8d Mon Sep 17 00:00:00 2001 From: hc-github-team-secure-vault-core <82990506+hc-github-team-secure-vault-core@users.noreply.github.com> Date: Tue, 31 Oct 2023 16:18:21 -0400 Subject: [PATCH] backport of commit 63ab253cb429c6fd7d7d61be6f76b25c742de7d1 (#23929) Co-authored-by: Ellie --- changelog/23457.txt | 3 + command/commands.go | 5 + command/operator_raft_snapshot.go | 4 + command/operator_raft_snapshot_inspect.go | 568 ++++++++++++++++++ .../operator_raft_snapshot_inspect_test.go | 141 +++++ physical/raft/io.go | 1 + physical/raft/varint.go | 19 +- 7 files changed, 735 insertions(+), 6 deletions(-) create mode 100644 changelog/23457.txt create mode 100644 command/operator_raft_snapshot_inspect.go create mode 100644 command/operator_raft_snapshot_inspect_test.go diff --git a/changelog/23457.txt b/changelog/23457.txt new file mode 100644 index 000000000..adec8ca9b --- /dev/null +++ b/changelog/23457.txt @@ -0,0 +1,3 @@ +```release-note:feature +cli/snapshot: Add CLI tool to inspect Vault snapshots +``` \ No newline at end of file diff --git a/command/commands.go b/command/commands.go index 68e2542b0..9a27577c7 100644 --- a/command/commands.go +++ b/command/commands.go @@ -484,6 +484,11 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) map[string]cli.Co BaseCommand: getBaseCommand(), }, nil }, + "operator raft snapshot inspect": func() (cli.Command, error) { + return &OperatorRaftSnapshotInspectCommand{ + BaseCommand: getBaseCommand(), + }, nil + }, "operator raft snapshot restore": func() (cli.Command, error) { return &OperatorRaftSnapshotRestoreCommand{ BaseCommand: getBaseCommand(), diff --git a/command/operator_raft_snapshot.go b/command/operator_raft_snapshot.go index 036c6ebae..500b4428f 100644 --- a/command/operator_raft_snapshot.go +++ b/command/operator_raft_snapshot.go @@ -35,6 +35,10 @@ Usage: vault operator raft snapshot [options] [args] $ vault operator raft snapshot save raft.snap + Inspects a snapshot based on a file: + + $ vault operator raft snapshot inspect raft.snap + Please see the individual subcommand help for detailed usage information. ` diff --git a/command/operator_raft_snapshot_inspect.go b/command/operator_raft_snapshot_inspect.go new file mode 100644 index 000000000..a64c5ba34 --- /dev/null +++ b/command/operator_raft_snapshot_inspect.go @@ -0,0 +1,568 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package command + +import ( + "archive/tar" + "bufio" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/json" + "fmt" + "hash" + "io" + "math" + "os" + "sort" + "strconv" + "strings" + "text/tabwriter" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/raft" + protoio "github.com/hashicorp/vault/physical/raft" + "github.com/hashicorp/vault/sdk/plugin/pb" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var ( + _ cli.Command = (*OperatorRaftSnapshotInspectCommand)(nil) + _ cli.CommandAutocomplete = (*OperatorRaftSnapshotInspectCommand)(nil) +) + +type OperatorRaftSnapshotInspectCommand struct { + *BaseCommand + details bool + depth int + filter string +} + +func (c *OperatorRaftSnapshotInspectCommand) Synopsis() string { + return "Inspects raft snapshot" +} + +func (c *OperatorRaftSnapshotInspectCommand) Help() string { + helpText := ` + Usage: vault operator raft snapshot inspect + + Inspects a snapshot file. + + $ vault operator raft snapshot inspect raft.snap + + ` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *OperatorRaftSnapshotInspectCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat) + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "details", + Target: &c.details, + Default: true, + Usage: "Provides information about usage for data stored in the snapshot.", + }) + + f.IntVar(&IntVar{ + Name: "depth", + Target: &c.depth, + Default: 2, + Usage: "Can only be used with -details. The key prefix depth used to breakdown KV store data. If set to 0, all keys will be returned. Defaults to 2.", + }) + + f.StringVar(&StringVar{ + Name: "filter", + Target: &c.filter, + Default: "", + Usage: "Can only be used with -details. Limits the key breakdown using this prefix filter.", + }) + + return set +} + +func (c *OperatorRaftSnapshotInspectCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictAnything +} + +func (c *OperatorRaftSnapshotInspectCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +type OutputFormat struct { + Meta *MetadataInfo + StatsKV []typeStats + TotalCountKV int + TotalSizeKV int +} + +// SnapshotInfo is used for passing snapshot stat +// information between functions +type SnapshotInfo struct { + Meta MetadataInfo + StatsKV map[string]typeStats + TotalCountKV int + TotalSizeKV int +} + +type MetadataInfo struct { + ID string + Size int64 + Index uint64 + Term uint64 + Version raft.SnapshotVersion +} + +type typeStats struct { + Name string + Count int + Size int +} + +func (c *OperatorRaftSnapshotInspectCommand) Run(args []string) int { + flags := c.Flags() + + if err := flags.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + // Validate flags + if c.depth < 0 { + c.UI.Error("Depth must be equal to or greater than 0") + return 1 + } + + var file string + args = c.flags.Args() + + switch len(args) { + case 0: + c.UI.Error("Missing FILE argument") + return 1 + case 1: + file = args[0] + default: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + // Open the file. + f, err := os.Open(file) + if err != nil { + c.UI.Error(fmt.Sprintf("Error opening snapshot file: %s", err)) + return 1 + } + defer f.Close() + + // Extract metadata and snapshot info from snapshot file + var info *SnapshotInfo + var meta *raft.SnapshotMeta + info, meta, err = c.Read(hclog.New(nil), f) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading snapshot: %s", err)) + return 1 + } + + if info == nil { + c.UI.Error(fmt.Sprintf("Error calculating snapshot info: %s", err)) + return 1 + } + + // Generate structs for the formatter with information we read in + metaformat := &MetadataInfo{ + ID: meta.ID, + Size: meta.Size, + Index: meta.Index, + Term: meta.Term, + Version: meta.Version, + } + + formattedStatsKV := generateKVStats(*info) + + data := &OutputFormat{ + Meta: metaformat, + StatsKV: formattedStatsKV, + TotalCountKV: info.TotalCountKV, + TotalSizeKV: info.TotalSizeKV, + } + + if Format(c.UI) != "table" { + return OutputData(c.UI, data) + } + + tableData, err := formatTable(data) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + + c.UI.Output(tableData) + + return 0 +} + +func (c *OperatorRaftSnapshotInspectCommand) kvEnhance(val *pb.StorageEntry, info *SnapshotInfo, read int) { + if !c.details { + return + } + + if val.Key == "" { + return + } + + // check for whether a filter is specified. if it is, skip + // any keys that don't match. + if len(c.filter) > 0 && !strings.HasPrefix(val.Key, c.filter) { + return + } + + split := strings.Split(val.Key, "/") + + // handle the situation where the key is shorter than + // the specified depth. + actualDepth := c.depth + if c.depth == 0 || c.depth > len(split) { + actualDepth = len(split) + } + + prefix := strings.Join(split[0:actualDepth], "/") + kvs := info.StatsKV[prefix] + if kvs.Name == "" { + kvs.Name = prefix + } + + kvs.Count++ + kvs.Size += read + info.TotalCountKV++ + info.TotalSizeKV += read + info.StatsKV[prefix] = kvs +} + +// Read from snapshot's state.bin and update the SnapshotInfo struct +func (c *OperatorRaftSnapshotInspectCommand) parseState(r io.Reader) (SnapshotInfo, error) { + info := SnapshotInfo{ + StatsKV: make(map[string]typeStats), + } + + protoReader := protoio.NewDelimitedReader(r, math.MaxInt32) + + for { + s := new(pb.StorageEntry) + if err := protoReader.ReadMsg(s); err != nil { + if err == io.EOF { + break + } + return info, err + } + size := protoReader.GetLastReadSize() + c.kvEnhance(s, &info, size) + } + + return info, nil +} + +// Read contents of snapshot. Parse metadata and snapshot info +// Also, verify validity of snapshot +func (c *OperatorRaftSnapshotInspectCommand) Read(logger hclog.Logger, in io.Reader) (*SnapshotInfo, *raft.SnapshotMeta, error) { + // Wrap the reader in a gzip decompressor. + decomp, err := gzip.NewReader(in) + if err != nil { + return nil, nil, fmt.Errorf("failed to decompress snapshot: %v", err) + } + + defer func() { + if decomp == nil { + return + } + + if err := decomp.Close(); err != nil { + logger.Error("Failed to close snapshot decompressor", "error", err) + } + }() + + // Read the archive. + snapshotInfo, metadata, err := c.read(decomp) + if err != nil { + return nil, nil, fmt.Errorf("failed to read snapshot file: %v", err) + } + + if err := concludeGzipRead(decomp); err != nil { + return nil, nil, err + } + + if err := decomp.Close(); err != nil { + return nil, nil, err + } + decomp = nil + return snapshotInfo, metadata, nil +} + +func formatTable(info *OutputFormat) (string, error) { + var b bytes.Buffer + tw := tabwriter.NewWriter(&b, 8, 8, 6, ' ', 0) + + fmt.Fprintf(tw, " ID\t%s", info.Meta.ID) + fmt.Fprintf(tw, "\n Size\t%d", info.Meta.Size) + fmt.Fprintf(tw, "\n Index\t%d", info.Meta.Index) + fmt.Fprintf(tw, "\n Term\t%d", info.Meta.Term) + fmt.Fprintf(tw, "\n Version\t%d", info.Meta.Version) + fmt.Fprintf(tw, "\n") + + if info.StatsKV != nil { + fmt.Fprintf(tw, "\n") + fmt.Fprintln(tw, "\n Key Name\tCount\tSize") + fmt.Fprintf(tw, " %s\t%s\t%s", "----", "----", "----") + + for _, s := range info.StatsKV { + fmt.Fprintf(tw, "\n %s\t%d\t%s", s.Name, s.Count, ByteSize(uint64(s.Size))) + } + + fmt.Fprintf(tw, "\n %s\t%s", "----", "----") + fmt.Fprintf(tw, "\n Total Size\t\t%s", ByteSize(uint64(info.TotalSizeKV))) + } + + if err := tw.Flush(); err != nil { + return b.String(), err + } + + return b.String(), nil +} + +const ( + BYTE = 1 << (10 * iota) + KILOBYTE + MEGABYTE + GIGABYTE + TERABYTE +) + +func ByteSize(bytes uint64) string { + unit := "" + value := float64(bytes) + + switch { + case bytes >= TERABYTE: + unit = "TB" + value = value / TERABYTE + case bytes >= GIGABYTE: + unit = "GB" + value = value / GIGABYTE + case bytes >= MEGABYTE: + unit = "MB" + value = value / MEGABYTE + case bytes >= KILOBYTE: + unit = "KB" + value = value / KILOBYTE + case bytes >= BYTE: + unit = "B" + case bytes == 0: + return "0" + } + + result := strconv.FormatFloat(value, 'f', 1, 64) + result = strings.TrimSuffix(result, ".0") + return result + unit +} + +// sortTypeStats sorts the stat slice by count and then +// alphabetically in the case the counts are equal +func sortTypeStats(stats []typeStats) []typeStats { + // sort alphabetically if size is equal + sort.Slice(stats, func(i, j int) bool { + // Sort alphabetically if count is equal + if stats[i].Count == stats[j].Count { + return stats[i].Name < stats[j].Name + } + return stats[i].Count > stats[j].Count + }) + + return stats +} + +// generateKVStats reformats the KV stats to work with +// the output struct that's used to produce the printed +// output the user sees. +func generateKVStats(info SnapshotInfo) []typeStats { + kvLen := len(info.StatsKV) + if kvLen > 0 { + ks := make([]typeStats, 0, kvLen) + + for _, s := range info.StatsKV { + ks = append(ks, s) + } + + ks = sortTypeStats(ks) + + return ks + } + + return nil +} + +// hashList manages a list of filenames and their hashes. +type hashList struct { + hashes map[string]hash.Hash +} + +// newHashList returns a new hashList. +func newHashList() *hashList { + return &hashList{ + hashes: make(map[string]hash.Hash), + } +} + +// Add creates a new hash for the given file. +func (hl *hashList) Add(file string) hash.Hash { + if existing, ok := hl.hashes[file]; ok { + return existing + } + + h := sha256.New() + hl.hashes[file] = h + return h +} + +// Encode takes the current sum of all the hashes and saves the hash list as a +// SHA256SUMS-style text file. +func (hl *hashList) Encode(w io.Writer) error { + for file, h := range hl.hashes { + if _, err := fmt.Fprintf(w, "%x %s\n", h.Sum([]byte{}), file); err != nil { + return err + } + } + return nil +} + +// DecodeAndVerify reads a SHA256SUMS-style text file and checks the results +// against the current sums for all the hashes. +func (hl *hashList) DecodeAndVerify(r io.Reader) error { + // Read the file and make sure everything in there has a matching hash. + seen := make(map[string]struct{}) + s := bufio.NewScanner(r) + for s.Scan() { + sha := make([]byte, sha256.Size) + var file string + if _, err := fmt.Sscanf(s.Text(), "%x %s", &sha, &file); err != nil { + return err + } + + h, ok := hl.hashes[file] + if !ok { + return fmt.Errorf("list missing hash for %q", file) + } + if !bytes.Equal(sha, h.Sum([]byte{})) { + return fmt.Errorf("hash check failed for %q", file) + } + seen[file] = struct{}{} + } + if err := s.Err(); err != nil { + return err + } + + // Make sure everything we had a hash for was seen. + for file := range hl.hashes { + if _, ok := seen[file]; !ok { + return fmt.Errorf("file missing for %q", file) + } + } + + return nil +} + +// read takes a reader and extracts the snapshot metadata and snapshot +// info. It also checks the integrity of the snapshot data. +func (c *OperatorRaftSnapshotInspectCommand) read(in io.Reader) (*SnapshotInfo, *raft.SnapshotMeta, error) { + // Start a new tar reader. + archive := tar.NewReader(in) + + // Create a hash list that we will use to compare with the SHA256SUMS + // file in the archive. + hl := newHashList() + + // Populate the hashes for all the files we expect to see. The check at + // the end will make sure these are all present in the SHA256SUMS file + // and that the hashes match. + metaHash := hl.Add("meta.json") + snapHash := hl.Add("state.bin") + + // Look through the archive for the pieces we care about. + var shaBuffer bytes.Buffer + var snapshotInfo SnapshotInfo + var metadata raft.SnapshotMeta + for { + hdr, err := archive.Next() + if err == io.EOF { + break + } + if err != nil { + return nil, nil, fmt.Errorf("failed reading snapshot: %v", err) + } + + switch hdr.Name { + case "meta.json": + // Previously we used json.Decode to decode the archive stream. There are + // edgecases in which it doesn't read all the bytes from the stream, even + // though the json object is still being parsed properly. Since we + // simultaneously feeded everything to metaHash, our hash ended up being + // different than what we calculated when creating the snapshot. Which in + // turn made the snapshot verification fail. By explicitly reading the + // whole thing first we ensure that we calculate the correct hash + // independent of how json.Decode works internally. + buf, err := io.ReadAll(io.TeeReader(archive, metaHash)) + if err != nil { + return nil, nil, fmt.Errorf("failed to read snapshot metadata: %v", err) + } + if err := json.Unmarshal(buf, &metadata); err != nil { + return nil, nil, fmt.Errorf("failed to decode snapshot metadata: %v", err) + } + case "state.bin": + // create reader that writes to snapHash what it reads from archive + wrappedReader := io.TeeReader(archive, snapHash) + var err error + snapshotInfo, err = c.parseState(wrappedReader) + if err != nil { + return nil, nil, fmt.Errorf("error parsing snapshot state: %v", err) + } + + case "SHA256SUMS": + if _, err := io.CopyN(&shaBuffer, archive, 10000); err != nil && err != io.EOF { + return nil, nil, fmt.Errorf("failed to read snapshot hashes: %v", err) + } + + case "SHA256SUMS.sealed": + // Add verification of sealed sum in future + continue + + default: + return nil, nil, fmt.Errorf("unexpected file %q in snapshot", hdr.Name) + } + } + + // Verify all the hashes. + if err := hl.DecodeAndVerify(&shaBuffer); err != nil { + return nil, nil, fmt.Errorf("failed checking integrity of snapshot: %v", err) + } + + return &snapshotInfo, &metadata, nil +} + +// concludeGzipRead should be invoked after you think you've consumed all of +// the data from the gzip stream. It will error if the stream was corrupt. +// +// The docs for gzip.Reader say: "Clients should treat data returned by Read as +// tentative until they receive the io.EOF marking the end of the data." +func concludeGzipRead(decomp *gzip.Reader) error { + extra, err := io.ReadAll(decomp) // ReadAll consumes the EOF + if err != nil { + return err + } + if len(extra) != 0 { + return fmt.Errorf("%d unread uncompressed bytes remain", len(extra)) + } + return nil +} diff --git a/command/operator_raft_snapshot_inspect_test.go b/command/operator_raft_snapshot_inspect_test.go new file mode 100644 index 000000000..de306595e --- /dev/null +++ b/command/operator_raft_snapshot_inspect_test.go @@ -0,0 +1,141 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package command + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + "github.com/hashicorp/vault/physical/raft" + "github.com/hashicorp/vault/sdk/physical" + "github.com/mitchellh/cli" +) + +func testOperatorRaftSnapshotInspectCommand(tb testing.TB) (*cli.MockUi, *OperatorRaftSnapshotInspectCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &OperatorRaftSnapshotInspectCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func createSnapshot(tb testing.TB) (*os.File, func(), error) { + // Create new raft backend + r, raftDir := raft.GetRaft(tb, true, false) + defer os.RemoveAll(raftDir) + + // Write some data + for i := 0; i < 100; i++ { + err := r.Put(context.Background(), &physical.Entry{ + Key: fmt.Sprintf("key-%d", i), + Value: []byte(fmt.Sprintf("value-%d", i)), + }) + if err != nil { + return nil, nil, fmt.Errorf("Error adding data to snapshot %s", err) + } + } + + // Create temporary file to save snapshot to + snap, err := os.CreateTemp("", "temp_snapshot.snap") + if err != nil { + return nil, nil, fmt.Errorf("Error creating temporary file %s", err) + } + + cleanup := func() { + err := os.RemoveAll(snap.Name()) + if err != nil { + tb.Errorf("Error deleting temporary snapshot %s", err) + } + } + + // Save snapshot + err = r.Snapshot(snap, nil) + if err != nil { + return nil, nil, fmt.Errorf("Error saving raft snapshot %s", err) + } + + return snap, cleanup, nil +} + +func TestOperatorRaftSnapshotInspectCommand_Run(t *testing.T) { + t.Parallel() + + file1, cleanup1, err := createSnapshot(t) + if err != nil { + t.Fatalf("Error creating snapshot %s", err) + } + + file2, cleanup2, err := createSnapshot(t) + if err != nil { + t.Fatalf("Error creating snapshot %s", err) + } + + cases := []struct { + name string + args []string + out string + code int + cleanup func() + }{ + { + "too_many_args", + []string{"test.snap", "test"}, + "Too many arguments", + 1, + nil, + }, + { + "default", + []string{file1.Name()}, + "ID bolt-snapshot", + 0, + cleanup1, + }, + { + "all_flags", + []string{"-details", "-depth", "10", "-filter", "key", file2.Name()}, + "Key Name", + 0, + cleanup2, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorRaftSnapshotInspectCommand(t) + + cmd.client = client + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + + if tc.cleanup != nil { + tc.cleanup() + } + }) + } + }) +} diff --git a/physical/raft/io.go b/physical/raft/io.go index d3d3d4b4c..98f96bc97 100644 --- a/physical/raft/io.go +++ b/physical/raft/io.go @@ -45,6 +45,7 @@ type WriteCloser interface { type Reader interface { ReadMsg(msg proto.Message) error + GetLastReadSize() int } type ReadCloser interface { diff --git a/physical/raft/varint.go b/physical/raft/varint.go index b3b9bfaae..87f59eaa7 100644 --- a/physical/raft/varint.go +++ b/physical/raft/varint.go @@ -79,14 +79,19 @@ func NewDelimitedReader(r io.Reader, maxSize int) ReadCloser { if c, ok := r.(io.Closer); ok { closer = c } - return &varintReader{bufio.NewReader(r), nil, maxSize, closer} + return &varintReader{bufio.NewReader(r), nil, maxSize, closer, 0} } type varintReader struct { - r *bufio.Reader - buf []byte - maxSize int - closer io.Closer + r *bufio.Reader + buf []byte + maxSize int + closer io.Closer + lastReadSize int +} + +func (this *varintReader) GetLastReadSize() int { + return this.lastReadSize } func (this *varintReader) ReadMsg(msg proto.Message) error { @@ -102,9 +107,11 @@ func (this *varintReader) ReadMsg(msg proto.Message) error { this.buf = make([]byte, length) } buf := this.buf[:length] - if _, err := io.ReadFull(this.r, buf); err != nil { + size, err := io.ReadFull(this.r, buf) + if err != nil { return err } + this.lastReadSize = size return proto.Unmarshal(buf, msg) }