diff --git a/changelog/16553.txt b/changelog/16553.txt new file mode 100644 index 000000000..7031f04a1 --- /dev/null +++ b/changelog/16553.txt @@ -0,0 +1,3 @@ +```release-note:improvement +command: Fix shell completion for KV v2 mounts +``` diff --git a/command/base_predict.go b/command/base_predict.go index 13959bb5b..61cbe092d 100644 --- a/command/base_predict.go +++ b/command/base_predict.go @@ -250,9 +250,19 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc { // Trim path with potential mount var relativePath string - for _, mount := range p.mounts() { + mountInfos, err := p.mountInfos() + if err != nil { + return nil + } + + var mountType, mountVersion string + for mount, mountInfo := range mountInfos { if strings.HasPrefix(path, mount) { relativePath = strings.TrimPrefix(path, mount+"/") + mountType = mountInfo.Type + if mountInfo.Options != nil { + mountVersion = mountInfo.Options["version"] + } break } } @@ -260,7 +270,7 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc { // Predict path or mount depending on path separator var predictions []string if strings.Contains(relativePath, "/") { - predictions = p.paths(path, includeFiles) + predictions = p.paths(mountType, mountVersion, path, includeFiles) } else { predictions = p.filter(p.mounts(), path) } @@ -288,7 +298,7 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc { } // paths predicts all paths which start with the given path. -func (p *Predict) paths(path string, includeFiles bool) []string { +func (p *Predict) paths(mountType, mountVersion, path string, includeFiles bool) []string { client := p.Client() if client == nil { return nil @@ -303,7 +313,7 @@ func (p *Predict) paths(path string, includeFiles bool) []string { root = root[:idx+1] } - paths := p.listPaths(root) + paths := p.listPaths(buildAPIListPath(root, mountType, mountVersion)) var predictions []string for _, p := range paths { @@ -326,6 +336,22 @@ func (p *Predict) paths(path string, includeFiles bool) []string { return predictions } +func buildAPIListPath(path, mountType, mountVersion string) string { + if mountType == "kv" && mountVersion == "2" { + return toKVv2ListPath(path) + } + return path +} + +func toKVv2ListPath(path string) string { + firstSlashIdx := strings.Index(path, "/") + if firstSlashIdx < 0 { + return path + } + + return path[:firstSlashIdx] + "/metadata" + path[firstSlashIdx:] +} + // audits returns a sorted list of the audit backends for Vault server for // which the client is configured to communicate with. func (p *Predict) audits() []string { @@ -421,16 +447,28 @@ func (p *Predict) policies() []string { return policies } +// mountInfos returns a map with mount paths as keys and MountOutputs as values +// for the Vault server which the client is configured to communicate with. +// Returns error if server communication fails. +func (p *Predict) mountInfos() (map[string]*api.MountOutput, error) { + client := p.Client() + if client == nil { + return nil, nil + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + return nil, err + } + + return mounts, nil +} + // mounts returns a sorted list of the mount paths for Vault server for // which the client is configured to communicate with. This function returns the // default list of mounts if an error occurs. func (p *Predict) mounts() []string { - client := p.Client() - if client == nil { - return nil - } - - mounts, err := client.Sys().ListMounts() + mounts, err := p.mountInfos() if err != nil { return defaultPredictVaultMounts } diff --git a/command/base_predict_test.go b/command/base_predict_test.go index 1cd245765..644a36667 100644 --- a/command/base_predict_test.go +++ b/command/base_predict_test.go @@ -554,7 +554,80 @@ func TestPredict_Paths(t *testing.T) { p := NewPredict() p.client = client - act := p.paths(tc.path, tc.includeFiles) + act := p.paths("kv", "1", tc.path, tc.includeFiles) + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + }) +} + +func TestPredict_PathsKVv2(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerWithKVVersion(t, "2") + defer closer() + + data := map[string]interface{}{"data": map[string]interface{}{"a": "b"}} + if _, err := client.Logical().Write("secret/data/bar", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/data/foo", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/data/zip/zap", data); err != nil { + t.Fatal(err) + } + + cases := []struct { + name string + path string + includeFiles bool + exp []string + }{ + { + "bad_path", + "nope/not/a/real/path/ever", + true, + []string{"nope/not/a/real/path/ever"}, + }, + { + "good_path", + "secret/", + true, + []string{"secret/bar", "secret/foo", "secret/zip/"}, + }, + { + "good_path_no_files", + "secret/", + false, + []string{"secret/zip/"}, + }, + { + "partial_match", + "secret/z", + true, + []string{"secret/zip/"}, + }, + { + "partial_match_no_files", + "secret/z", + false, + []string{"secret/zip/"}, + }, + } + + t.Run("group", func(t *testing.T) { + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewPredict() + p.client = client + + act := p.paths("kv", "2", tc.path, tc.includeFiles) if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) } diff --git a/command/command_test.go b/command/command_test.go index f3249f2b5..f4b8c7fd2 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -67,6 +67,13 @@ func testVaultServer(tb testing.TB) (*api.Client, func()) { return client, closer } +func testVaultServerWithKVVersion(tb testing.TB, kvVersion string) (*api.Client, func()) { + tb.Helper() + + client, _, closer := testVaultServerUnsealWithKVVersion(tb, kvVersion) + return client, closer +} + func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) { tb.Helper() @@ -85,6 +92,10 @@ func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) { // testVaultServerUnseal creates a test vault cluster and returns a configured // API client, list of unseal keys (as strings), and a closer function. func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) { + return testVaultServerUnsealWithKVVersion(tb, "1") +} + +func testVaultServerUnsealWithKVVersion(tb testing.TB, kvVersion string) (*api.Client, []string, func()) { tb.Helper() logger := log.NewInterceptLogger(&log.LoggerOptions{ Output: log.DefaultOutput, @@ -92,7 +103,7 @@ func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) { JSONFormat: logging.ParseEnvLogFormat() == logging.JSONFormat, }) - return testVaultServerCoreConfig(tb, &vault.CoreConfig{ + return testVaultServerCoreConfigWithOpts(tb, &vault.CoreConfig{ DisableMlock: true, DisableCache: true, Logger: logger, @@ -100,6 +111,10 @@ func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) { AuditBackends: defaultVaultAuditBackends, LogicalBackends: defaultVaultLogicalBackends, BuiltinRegistry: builtinplugins.Registry, + }, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + NumCores: 1, + KVVersion: kvVersion, }) } @@ -121,15 +136,19 @@ func testVaultServerPluginDir(tb testing.TB, pluginDir string) (*api.Client, []s }) } -// testVaultServerCoreConfig creates a new vault cluster with the given core -// configuration. This is a lower-level test helper. func testVaultServerCoreConfig(tb testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) { - tb.Helper() - - cluster := vault.NewTestCluster(benchhelpers.TBtoT(tb), coreConfig, &vault.TestClusterOptions{ + return testVaultServerCoreConfigWithOpts(tb, coreConfig, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, NumCores: 1, // Default is 3, but we don't need that many }) +} + +// testVaultServerCoreConfig creates a new vault cluster with the given core +// configuration. This is a lower-level test helper. +func testVaultServerCoreConfigWithOpts(tb testing.TB, coreConfig *vault.CoreConfig, opts *vault.TestClusterOptions) (*api.Client, []string, func()) { + tb.Helper() + + cluster := vault.NewTestCluster(benchhelpers.TBtoT(tb), coreConfig, opts) cluster.Start() // Make it easy to get access to the active diff --git a/command/kv_get.go b/command/kv_get.go index 3aca29d20..391efaf88 100644 --- a/command/kv_get.go +++ b/command/kv_get.go @@ -80,7 +80,7 @@ func (c *KVGetCommand) Flags() *FlagSets { } func (c *KVGetCommand) AutocompleteArgs() complete.Predictor { - return nil + return c.PredictVaultFiles() } func (c *KVGetCommand) AutocompleteFlags() complete.Flags { diff --git a/command/kv_patch.go b/command/kv_patch.go index 1134248f5..5368d9e0d 100644 --- a/command/kv_patch.go +++ b/command/kv_patch.go @@ -120,7 +120,7 @@ func (c *KVPatchCommand) Flags() *FlagSets { } func (c *KVPatchCommand) AutocompleteArgs() complete.Predictor { - return nil + return c.PredictVaultFiles() } func (c *KVPatchCommand) AutocompleteFlags() complete.Flags { diff --git a/command/kv_put.go b/command/kv_put.go index f380e2d64..fe991711a 100644 --- a/command/kv_put.go +++ b/command/kv_put.go @@ -96,7 +96,7 @@ func (c *KVPutCommand) Flags() *FlagSets { } func (c *KVPutCommand) AutocompleteArgs() complete.Predictor { - return nil + return c.PredictVaultFolders() } func (c *KVPutCommand) AutocompleteFlags() complete.Flags { diff --git a/vault/testing.go b/vault/testing.go index 9a0c17c69..665b5c021 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1214,6 +1214,7 @@ type TestClusterOptions struct { // this stores the vault version that should be used for each core config VersionMap map[int]string RedundancyZoneMap map[int]string + KVVersion string } var DefaultNumCores = 3 @@ -2090,6 +2091,11 @@ func (tc *TestCluster) initCores(t testing.T, opts *TestClusterOptions, addAudit TestWaitActive(t, leader.Core) + kvVersion := "1" + if opts != nil { + kvVersion = opts.KVVersion + } + // Existing tests rely on this; we can make a toggle to disable it // later if we want kvReq := &logical.Request{ @@ -2101,7 +2107,7 @@ func (tc *TestCluster) initCores(t testing.T, opts *TestClusterOptions, addAudit "path": "secret/", "description": "key/value secret storage", "options": map[string]string{ - "version": "1", + "version": kvVersion, }, }, }