From 8e343caeda4358ee309f7dfe9702b448dc81ba15 Mon Sep 17 00:00:00 2001 From: Seth Vargo Date: Mon, 4 Sep 2017 23:59:03 -0400 Subject: [PATCH] Update audit-enable command --- command/audit_enable.go | 203 +++++++++++++++++++---------------- command/audit_enable_test.go | 178 +++++++++++++++++++++++------- 2 files changed, 249 insertions(+), 132 deletions(-) diff --git a/command/audit_enable.go b/command/audit_enable.go index 680a94ed1..61fed2048 100644 --- a/command/audit_enable.go +++ b/command/audit_enable.go @@ -7,97 +7,34 @@ import ( "strings" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/kv-builder" - "github.com/hashicorp/vault/meta" - "github.com/mitchellh/mapstructure" + "github.com/mitchellh/cli" "github.com/posener/complete" ) +// Ensure we are implementing the right interfaces. +var _ cli.Command = (*AuditEnableCommand)(nil) +var _ cli.CommandAutocomplete = (*AuditEnableCommand)(nil) + // AuditEnableCommand is a Command that mounts a new mount. type AuditEnableCommand struct { - meta.Meta + *BaseCommand - // A test stdin that can be used for tests - testStdin io.Reader -} + flagDescription string + flagPath string + flagLocal bool -func (c *AuditEnableCommand) Run(args []string) int { - var desc, path string - var local bool - flags := c.Meta.FlagSet("audit-enable", meta.FlagSetDefault) - flags.StringVar(&desc, "description", "", "") - flags.StringVar(&path, "path", "", "") - flags.BoolVar(&local, "local", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) < 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\naudit-enable expects at least one argument: the type to enable")) - return 1 - } - - auditType := args[0] - if path == "" { - path = auditType - } - - // Build the options - var stdin io.Reader = os.Stdin - if c.testStdin != nil { - stdin = c.testStdin - } - builder := &kvbuilder.Builder{Stdin: stdin} - if err := builder.Add(args[1:]...); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error parsing options: %s", err)) - return 1 - } - - var opts map[string]string - if err := mapstructure.WeakDecode(builder.Map(), &opts); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error parsing options: %s", err)) - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 1 - } - - err = client.Sys().EnableAuditWithOptions(path, &api.EnableAuditOptions{ - Type: auditType, - Description: desc, - Options: opts, - Local: local, - }) - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error enabling audit backend: %s", err)) - return 1 - } - - c.Ui.Output(fmt.Sprintf( - "Successfully enabled audit backend '%s' with path '%s'!", auditType, path)) - return 0 + testStdin io.Reader // For tests } func (c *AuditEnableCommand) Synopsis() string { - return "Enable an audit backend" + return "Enables an audit backend" } func (c *AuditEnableCommand) Help() string { helpText := ` -Usage: vault audit-enable [options] type [config...] +Usage: vault audit-enable [options] TYPE [CONFIG K=V...] - Enable an audit backend. + Enables an audit backend at a given path. This command enables an audit backend of type "type". Additional options for configuring the audit backend can be specified after the @@ -111,24 +48,49 @@ Usage: vault audit-enable [options] type [config...] For information on available configuration options, please see the documentation. -General Options: -` + meta.GeneralOptionsUsage() + ` -Audit Enable Options: +` + c.Flags().Help() - -description= A human-friendly description for the backend. This - shows up only when querying the enabled backends. - - -path= Specify a unique path for this audit backend. This - is purely for referencing this audit backend. By - default this will be the backend type. - - -local Mark the mount as a local mount. Local mounts - are not replicated nor (if a secondary) - removed by replication. -` return strings.TrimSpace(helpText) } +func (c *AuditEnableCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.StringVar(&StringVar{ + Name: "description", + Target: &c.flagDescription, + Default: "", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Human-friendly description for the purpose of this audit " + + "backend.", + }) + + f.StringVar(&StringVar{ + Name: "path", + Target: &c.flagPath, + Default: "", // The default is complex, so we have to manually document + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Place where the audit backend will be accessible. This must be " + + "unique across all audit backends. This defaults to the \"type\" of the " + + "audit backend.", + }) + + f.BoolVar(&BoolVar{ + Name: "local", + Target: &c.flagLocal, + Default: false, + EnvVar: "", + Usage: "Mark the audit backend as a local-only backned. Local backends " + + "are not replicated nor removed by replication.", + }) + + return set +} + func (c *AuditEnableCommand) AutocompleteArgs() complete.Predictor { return complete.PredictSet( "file", @@ -138,9 +100,60 @@ func (c *AuditEnableCommand) AutocompleteArgs() complete.Predictor { } func (c *AuditEnableCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-description": complete.PredictNothing, - "-path": complete.PredictNothing, - "-local": complete.PredictNothing, - } + return c.Flags().Completions() +} + +func (c *AuditEnableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) < 1 { + c.UI.Error("Missing TYPE!") + return 1 + } + + // Grab the type + auditType := strings.TrimSpace(args[0]) + + auditPath := c.flagPath + if auditPath == "" { + auditPath = auditType + } + auditPath = ensureTrailingSlash(auditPath) + + // Pull our fake stdin if needed + stdin := (io.Reader)(os.Stdin) + if c.testStdin != nil { + stdin = c.testStdin + } + + options, err := parseArgsDataString(stdin, args[1:]) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to parse K=V data: %s", err)) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().EnableAuditWithOptions(auditPath, &api.EnableAuditOptions{ + Type: auditType, + Description: c.flagDescription, + Options: options, + Local: c.flagLocal, + }); err != nil { + c.UI.Error(fmt.Sprintf("Error enabling audit backend: %s", err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Enabled the %s audit backend at: %s", auditType, auditPath)) + return 0 } diff --git a/command/audit_enable_test.go b/command/audit_enable_test.go index 118f103d3..6be5c5c68 100644 --- a/command/audit_enable_test.go +++ b/command/audit_enable_test.go @@ -1,56 +1,160 @@ package command import ( - "reflect" + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuditEnable(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuditEnableCommand(tb testing.TB) (*cli.MockUi, *AuditEnableCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuditEnableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuditEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestAuditEnableCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "empty", + nil, + "Missing TYPE!", + 1, + }, + { + "not_a_valid_type", + []string{"nope_definitely_not_a_valid_type_like_ever"}, + "", + 2, + }, + { + "enable", + []string{"file", "file_path=discard"}, + "Success! Enabled the file audit backend at: file/", + 0, + }, + { + "enable_path", + []string{ + "-path", "audit_path", + "file", + "file_path=discard", + }, + "Success! Enabled the file audit backend at: audit_path/", + 0, }, } - args := []string{ - "-address", addr, - "noop", - "foo=bar", + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testAuditEnableCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("integration", func(t *testing.T) { + t.Parallel() - // Get the client - client, err := c.Client() - if err != nil { - t.Fatalf("err: %#v", err) - } + client, closer := testVaultServer(t) + defer closer() - audits, err := client.Sys().ListAudit() - if err != nil { - t.Fatalf("err: %#v", err) - } + ui, cmd := testAuditEnableCommand(t) + cmd.client = client - audit, ok := audits["noop/"] - if !ok { - t.Fatalf("err: %#v", audits) - } + code := cmd.Run([]string{ + "-path", "audit_enable_integration/", + "-description", "The best kind of test", + "file", + "file_path=discard", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } - expected := map[string]string{"foo": "bar"} - if !reflect.DeepEqual(audit.Options, expected) { - t.Fatalf("err: %#v", audit) - } + expected := "Success! Enabled the file audit backend at: audit_enable_integration/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + audits, err := client.Sys().ListAudit() + if err != nil { + t.Fatal(err) + } + + auditInfo, ok := audits["audit_enable_integration/"] + if !ok { + t.Fatalf("expected audit to exist") + } + if exp := "file"; auditInfo.Type != exp { + t.Errorf("expected %q to be %q", auditInfo.Type, exp) + } + if exp := "The best kind of test"; auditInfo.Description != exp { + t.Errorf("expected %q to be %q", auditInfo.Description, exp) + } + + filePath, ok := auditInfo.Options["file_path"] + if !ok || filePath != "discard" { + t.Errorf("missing some options: %#v", auditInfo) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuditEnableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "pki", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error enabling audit backend: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuditEnableCommand(t) + assertNoTabs(t, cmd) + }) }