Make key completion work for both kv-v1 and kv-v2 (#16553)

Co-authored-by: Kieron Browne <kbrowne@vmware.com>
Co-authored-by: Georgi Sabev <georgethebeatle@gmail.com>
Co-authored-by: Danail Branekov <danailster@gmail.com>
This commit is contained in:
georgethebeatle 2022-09-13 19:11:00 +03:00 committed by GitHub
parent fcf6467cbf
commit f9439a9c41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 160 additions and 21 deletions

3
changelog/16553.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
command: Fix shell completion for KV v2 mounts
```

View File

@ -250,9 +250,19 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
// Trim path with potential mount // Trim path with potential mount
var relativePath string var relativePath string
for _, mount := range p.mounts() { mountInfos, err := p.mountInfos()
if err != nil {
return nil
}
var mountType, mountVersion string
for mount, mountInfo := range mountInfos {
if strings.HasPrefix(path, mount) { if strings.HasPrefix(path, mount) {
relativePath = strings.TrimPrefix(path, mount+"/") relativePath = strings.TrimPrefix(path, mount+"/")
mountType = mountInfo.Type
if mountInfo.Options != nil {
mountVersion = mountInfo.Options["version"]
}
break break
} }
} }
@ -260,7 +270,7 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
// Predict path or mount depending on path separator // Predict path or mount depending on path separator
var predictions []string var predictions []string
if strings.Contains(relativePath, "/") { if strings.Contains(relativePath, "/") {
predictions = p.paths(path, includeFiles) predictions = p.paths(mountType, mountVersion, path, includeFiles)
} else { } else {
predictions = p.filter(p.mounts(), path) predictions = p.filter(p.mounts(), path)
} }
@ -288,7 +298,7 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
} }
// paths predicts all paths which start with the given path. // paths predicts all paths which start with the given path.
func (p *Predict) paths(path string, includeFiles bool) []string { func (p *Predict) paths(mountType, mountVersion, path string, includeFiles bool) []string {
client := p.Client() client := p.Client()
if client == nil { if client == nil {
return nil return nil
@ -303,7 +313,7 @@ func (p *Predict) paths(path string, includeFiles bool) []string {
root = root[:idx+1] root = root[:idx+1]
} }
paths := p.listPaths(root) paths := p.listPaths(buildAPIListPath(root, mountType, mountVersion))
var predictions []string var predictions []string
for _, p := range paths { for _, p := range paths {
@ -326,6 +336,22 @@ func (p *Predict) paths(path string, includeFiles bool) []string {
return predictions return predictions
} }
func buildAPIListPath(path, mountType, mountVersion string) string {
if mountType == "kv" && mountVersion == "2" {
return toKVv2ListPath(path)
}
return path
}
func toKVv2ListPath(path string) string {
firstSlashIdx := strings.Index(path, "/")
if firstSlashIdx < 0 {
return path
}
return path[:firstSlashIdx] + "/metadata" + path[firstSlashIdx:]
}
// audits returns a sorted list of the audit backends for Vault server for // audits returns a sorted list of the audit backends for Vault server for
// which the client is configured to communicate with. // which the client is configured to communicate with.
func (p *Predict) audits() []string { func (p *Predict) audits() []string {
@ -421,16 +447,28 @@ func (p *Predict) policies() []string {
return policies return policies
} }
// mountInfos returns a map with mount paths as keys and MountOutputs as values
// for the Vault server which the client is configured to communicate with.
// Returns error if server communication fails.
func (p *Predict) mountInfos() (map[string]*api.MountOutput, error) {
client := p.Client()
if client == nil {
return nil, nil
}
mounts, err := client.Sys().ListMounts()
if err != nil {
return nil, err
}
return mounts, nil
}
// mounts returns a sorted list of the mount paths for Vault server for // mounts returns a sorted list of the mount paths for Vault server for
// which the client is configured to communicate with. This function returns the // which the client is configured to communicate with. This function returns the
// default list of mounts if an error occurs. // default list of mounts if an error occurs.
func (p *Predict) mounts() []string { func (p *Predict) mounts() []string {
client := p.Client() mounts, err := p.mountInfos()
if client == nil {
return nil
}
mounts, err := client.Sys().ListMounts()
if err != nil { if err != nil {
return defaultPredictVaultMounts return defaultPredictVaultMounts
} }

View File

@ -554,7 +554,80 @@ func TestPredict_Paths(t *testing.T) {
p := NewPredict() p := NewPredict()
p.client = client p.client = client
act := p.paths(tc.path, tc.includeFiles) act := p.paths("kv", "1", tc.path, tc.includeFiles)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
})
}
})
}
func TestPredict_PathsKVv2(t *testing.T) {
t.Parallel()
client, closer := testVaultServerWithKVVersion(t, "2")
defer closer()
data := map[string]interface{}{"data": map[string]interface{}{"a": "b"}}
if _, err := client.Logical().Write("secret/data/bar", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/data/foo", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/data/zip/zap", data); err != nil {
t.Fatal(err)
}
cases := []struct {
name string
path string
includeFiles bool
exp []string
}{
{
"bad_path",
"nope/not/a/real/path/ever",
true,
[]string{"nope/not/a/real/path/ever"},
},
{
"good_path",
"secret/",
true,
[]string{"secret/bar", "secret/foo", "secret/zip/"},
},
{
"good_path_no_files",
"secret/",
false,
[]string{"secret/zip/"},
},
{
"partial_match",
"secret/z",
true,
[]string{"secret/zip/"},
},
{
"partial_match_no_files",
"secret/z",
false,
[]string{"secret/zip/"},
},
}
t.Run("group", func(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p := NewPredict()
p.client = client
act := p.paths("kv", "2", tc.path, tc.includeFiles)
if !reflect.DeepEqual(act, tc.exp) { if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp) t.Errorf("expected %q to be %q", act, tc.exp)
} }

View File

@ -67,6 +67,13 @@ func testVaultServer(tb testing.TB) (*api.Client, func()) {
return client, closer return client, closer
} }
func testVaultServerWithKVVersion(tb testing.TB, kvVersion string) (*api.Client, func()) {
tb.Helper()
client, _, closer := testVaultServerUnsealWithKVVersion(tb, kvVersion)
return client, closer
}
func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) { func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) {
tb.Helper() tb.Helper()
@ -85,6 +92,10 @@ func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) {
// testVaultServerUnseal creates a test vault cluster and returns a configured // testVaultServerUnseal creates a test vault cluster and returns a configured
// API client, list of unseal keys (as strings), and a closer function. // API client, list of unseal keys (as strings), and a closer function.
func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) { func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) {
return testVaultServerUnsealWithKVVersion(tb, "1")
}
func testVaultServerUnsealWithKVVersion(tb testing.TB, kvVersion string) (*api.Client, []string, func()) {
tb.Helper() tb.Helper()
logger := log.NewInterceptLogger(&log.LoggerOptions{ logger := log.NewInterceptLogger(&log.LoggerOptions{
Output: log.DefaultOutput, Output: log.DefaultOutput,
@ -92,7 +103,7 @@ func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) {
JSONFormat: logging.ParseEnvLogFormat() == logging.JSONFormat, JSONFormat: logging.ParseEnvLogFormat() == logging.JSONFormat,
}) })
return testVaultServerCoreConfig(tb, &vault.CoreConfig{ return testVaultServerCoreConfigWithOpts(tb, &vault.CoreConfig{
DisableMlock: true, DisableMlock: true,
DisableCache: true, DisableCache: true,
Logger: logger, Logger: logger,
@ -100,6 +111,10 @@ func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) {
AuditBackends: defaultVaultAuditBackends, AuditBackends: defaultVaultAuditBackends,
LogicalBackends: defaultVaultLogicalBackends, LogicalBackends: defaultVaultLogicalBackends,
BuiltinRegistry: builtinplugins.Registry, BuiltinRegistry: builtinplugins.Registry,
}, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
NumCores: 1,
KVVersion: kvVersion,
}) })
} }
@ -121,15 +136,19 @@ func testVaultServerPluginDir(tb testing.TB, pluginDir string) (*api.Client, []s
}) })
} }
// testVaultServerCoreConfig creates a new vault cluster with the given core
// configuration. This is a lower-level test helper.
func testVaultServerCoreConfig(tb testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) { func testVaultServerCoreConfig(tb testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) {
tb.Helper() return testVaultServerCoreConfigWithOpts(tb, coreConfig, &vault.TestClusterOptions{
cluster := vault.NewTestCluster(benchhelpers.TBtoT(tb), coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler, HandlerFunc: vaulthttp.Handler,
NumCores: 1, // Default is 3, but we don't need that many NumCores: 1, // Default is 3, but we don't need that many
}) })
}
// testVaultServerCoreConfig creates a new vault cluster with the given core
// configuration. This is a lower-level test helper.
func testVaultServerCoreConfigWithOpts(tb testing.TB, coreConfig *vault.CoreConfig, opts *vault.TestClusterOptions) (*api.Client, []string, func()) {
tb.Helper()
cluster := vault.NewTestCluster(benchhelpers.TBtoT(tb), coreConfig, opts)
cluster.Start() cluster.Start()
// Make it easy to get access to the active // Make it easy to get access to the active

View File

@ -80,7 +80,7 @@ func (c *KVGetCommand) Flags() *FlagSets {
} }
func (c *KVGetCommand) AutocompleteArgs() complete.Predictor { func (c *KVGetCommand) AutocompleteArgs() complete.Predictor {
return nil return c.PredictVaultFiles()
} }
func (c *KVGetCommand) AutocompleteFlags() complete.Flags { func (c *KVGetCommand) AutocompleteFlags() complete.Flags {

View File

@ -120,7 +120,7 @@ func (c *KVPatchCommand) Flags() *FlagSets {
} }
func (c *KVPatchCommand) AutocompleteArgs() complete.Predictor { func (c *KVPatchCommand) AutocompleteArgs() complete.Predictor {
return nil return c.PredictVaultFiles()
} }
func (c *KVPatchCommand) AutocompleteFlags() complete.Flags { func (c *KVPatchCommand) AutocompleteFlags() complete.Flags {

View File

@ -96,7 +96,7 @@ func (c *KVPutCommand) Flags() *FlagSets {
} }
func (c *KVPutCommand) AutocompleteArgs() complete.Predictor { func (c *KVPutCommand) AutocompleteArgs() complete.Predictor {
return nil return c.PredictVaultFolders()
} }
func (c *KVPutCommand) AutocompleteFlags() complete.Flags { func (c *KVPutCommand) AutocompleteFlags() complete.Flags {

View File

@ -1214,6 +1214,7 @@ type TestClusterOptions struct {
// this stores the vault version that should be used for each core config // this stores the vault version that should be used for each core config
VersionMap map[int]string VersionMap map[int]string
RedundancyZoneMap map[int]string RedundancyZoneMap map[int]string
KVVersion string
} }
var DefaultNumCores = 3 var DefaultNumCores = 3
@ -2090,6 +2091,11 @@ func (tc *TestCluster) initCores(t testing.T, opts *TestClusterOptions, addAudit
TestWaitActive(t, leader.Core) TestWaitActive(t, leader.Core)
kvVersion := "1"
if opts != nil {
kvVersion = opts.KVVersion
}
// Existing tests rely on this; we can make a toggle to disable it // Existing tests rely on this; we can make a toggle to disable it
// later if we want // later if we want
kvReq := &logical.Request{ kvReq := &logical.Request{
@ -2101,7 +2107,7 @@ func (tc *TestCluster) initCores(t testing.T, opts *TestClusterOptions, addAudit
"path": "secret/", "path": "secret/",
"description": "key/value secret storage", "description": "key/value secret storage",
"options": map[string]string{ "options": map[string]string{
"version": "1", "version": kvVersion,
}, },
}, },
} }