From 48e6460da5aae4cc31de69ce9277d71be7424327 Mon Sep 17 00:00:00 2001 From: Seth Vargo Date: Tue, 5 Sep 2017 00:04:12 -0400 Subject: [PATCH] Update rotate command --- command/rotate.go | 121 +++++++++++++++++++++++++--------------- command/rotate_test.go | 123 +++++++++++++++++++++++++++++++++++------ 2 files changed, 181 insertions(+), 63 deletions(-) diff --git a/command/rotate.go b/command/rotate.go index 9da387370..eed0bc40e 100644 --- a/command/rotate.go +++ b/command/rotate.go @@ -4,64 +4,95 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) +// Ensure we are implementing the right interfaces. +var _ cli.Command = (*RotateCommand)(nil) +var _ cli.CommandAutocomplete = (*RotateCommand)(nil) + // RotateCommand is a Command that rotates the encryption key being used type RotateCommand struct { - meta.Meta -} - -func (c *RotateCommand) Run(args []string) int { - flags := c.Meta.FlagSet("rotate", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - // Rotate the key - err = client.Sys().Rotate() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error with key rotation: %s", err)) - return 2 - } - - // Print the key status - status, err := client.Sys().KeyStatus() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading audits: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf("Key Term: %d", status.Term)) - c.Ui.Output(fmt.Sprintf("Installation Time: %v", status.InstallTime)) - return 0 + *BaseCommand } func (c *RotateCommand) Synopsis() string { - return "Rotates the backend encryption key used to persist data" + return "Rotates the underlying encryption key" } func (c *RotateCommand) Help() string { helpText := ` Usage: vault rotate [options] - Rotates the backend encryption key which is used to secure data - written to the storage backend. This is done by installing a new key - which encrypts new data, while old keys are still used to decrypt - secrets written previously. This is an online operation and is not - disruptive. + Rotates the underlying encryption key which is used to secure data written + to the storage backend. This installs a new key in the key ring. This new + key is used to encrypted new data, while older keys in the ring are used to + decrypt older data. + + This is an online operation and does not cause downtime. This command is run + per-cluser (not per-server), since Vault servers in HA mode share the same + storeage backend. + + Rotate Vault's encryption key: + + $ vault rotate + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *RotateCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *RotateCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *RotateCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *RotateCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + // Rotate the key + err = client.Sys().Rotate() + if err != nil { + c.UI.Error(fmt.Sprintf("Error rotating key: %s", err)) + return 2 + } + + // Print the key status + status, err := client.Sys().KeyStatus() + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading key status: %s", err)) + return 2 + } + + c.UI.Output("Success! Rotated key") + c.UI.Output("") + c.UI.Output(printKeyStatus(status)) + return 0 +} diff --git a/command/rotate_test.go b/command/rotate_test.go index 257f28007..30790691c 100644 --- a/command/rotate_test.go +++ b/command/rotate_test.go @@ -1,31 +1,118 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestRotate(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testRotateCommand(tb testing.TB) (*cli.MockUi, *RotateCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &RotateCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &RotateCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestRotateCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"abcd1234"}, + "Too many arguments", + 1, }, } - args := []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testRotateCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("default", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testRotateCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Rotated key" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + status, err := client.Sys().KeyStatus() + if err != nil { + t.Fatal(err) + } + if exp := 1; status.Term < exp { + t.Errorf("expected %d to be less than %d", status.Term, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testRotateCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error rotating key: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testRotateCommand(t) + assertNoTabs(t, cmd) + }) }