// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package command import ( "os" "sort" "strings" "sync" "github.com/hashicorp/vault/api" "github.com/posener/complete" ) type Predict struct { client *api.Client clientOnce sync.Once } func NewPredict() *Predict { return &Predict{} } func (p *Predict) Client() *api.Client { p.clientOnce.Do(func() { if p.client == nil { // For tests client, _ := api.NewClient(nil) if client.Token() == "" { helper, err := DefaultTokenHelper() if err != nil { return } token, err := helper.Get() if err != nil { return } client.SetToken(token) } // Turn off retries for prediction if os.Getenv(api.EnvVaultMaxRetries) == "" { client.SetMaxRetries(0) } p.client = client } }) return p.client } // defaultPredictVaultMounts is the default list of mounts to return to the // user. This is a best-guess, given we haven't communicated with the Vault // server. If the user has no token or if the token does not have the default // policy attached, it won't be able to read cubbyhole/, but it's a better UX // that returning nothing. var defaultPredictVaultMounts = []string{"cubbyhole/"} // predictClient is the API client to use for prediction. We create this at the // beginning once, because completions are generated for each command (and this // doesn't change), and the only way to configure the predict/autocomplete // client is via environment variables. Even if the user specifies a flag, we // can't parse that flag until after the command is submitted. var ( predictClient *api.Client predictClientOnce sync.Once ) // PredictClient returns the cached API client for the predictor. func PredictClient() *api.Client { predictClientOnce.Do(func() { if predictClient == nil { // For tests predictClient, _ = api.NewClient(nil) } }) return predictClient } // PredictVaultAvailableMounts returns a predictor for the available mounts in // Vault. For now, there is no way to programmatically get this list. If, in the // future, such a list exists, we can adapt it here. Until then, it's // hard-coded. func (b *BaseCommand) PredictVaultAvailableMounts() complete.Predictor { // This list does not contain deprecated backends. At present, there is no // API that lists all available secret backends, so this is hard-coded :(. return complete.PredictSet( "aws", "consul", "database", "generic", "pki", "plugin", "rabbitmq", "ssh", "totp", "transit", ) } // PredictVaultAvailableAuths returns a predictor for the available auths in // Vault. For now, there is no way to programmatically get this list. If, in the // future, such a list exists, we can adapt it here. Until then, it's // hard-coded. func (b *BaseCommand) PredictVaultAvailableAuths() complete.Predictor { return complete.PredictSet( "app-id", "approle", "aws", "cert", "gcp", "github", "ldap", "okta", "plugin", "radius", "userpass", ) } // PredictVaultFiles returns a predictor for Vault mounts and paths based on the // configured client for the base command. Unfortunately this happens pre-flag // parsing, so users must rely on environment variables for autocomplete if they // are not using Vault at the default endpoints. func (b *BaseCommand) PredictVaultFiles() complete.Predictor { return NewPredict().VaultFiles() } // PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles // for more information and restrictions. func (b *BaseCommand) PredictVaultFolders() complete.Predictor { return NewPredict().VaultFolders() } // PredictVaultNamespaces returns a predictor for "namespaces". See PredictVaultFiles // for more information an restrictions. func (b *BaseCommand) PredictVaultNamespaces() complete.Predictor { return NewPredict().VaultNamespaces() } // PredictVaultMounts returns a predictor for "folders". See PredictVaultFiles // for more information and restrictions. func (b *BaseCommand) PredictVaultMounts() complete.Predictor { return NewPredict().VaultMounts() } // PredictVaultAudits returns a predictor for "folders". See PredictVaultFiles // for more information and restrictions. func (b *BaseCommand) PredictVaultAudits() complete.Predictor { return NewPredict().VaultAudits() } // PredictVaultAuths returns a predictor for "folders". See PredictVaultFiles // for more information and restrictions. func (b *BaseCommand) PredictVaultAuths() complete.Predictor { return NewPredict().VaultAuths() } // PredictVaultPlugins returns a predictor for installed plugins. func (b *BaseCommand) PredictVaultPlugins(pluginTypes ...api.PluginType) complete.Predictor { return NewPredict().VaultPlugins(pluginTypes...) } // PredictVaultPolicies returns a predictor for "folders". See PredictVaultFiles // for more information and restrictions. func (b *BaseCommand) PredictVaultPolicies() complete.Predictor { return NewPredict().VaultPolicies() } func (b *BaseCommand) PredictVaultDebugTargets() complete.Predictor { return complete.PredictSet( "config", "host", "metrics", "pprof", "replication-status", "server-status", ) } // VaultFiles returns a predictor for Vault "files". This is a public API for // consumers, but you probably want BaseCommand.PredictVaultFiles instead. func (p *Predict) VaultFiles() complete.Predictor { return p.vaultPaths(true) } // VaultFolders returns a predictor for Vault "folders". This is a public // API for consumers, but you probably want BaseCommand.PredictVaultFolders // instead. func (p *Predict) VaultFolders() complete.Predictor { return p.vaultPaths(false) } // VaultNamespaces returns a predictor for Vault "namespaces". This is a public // API for consumers, but you probably want BaseCommand.PredictVaultNamespaces // instead. func (p *Predict) VaultNamespaces() complete.Predictor { return p.filterFunc(p.namespaces) } // VaultMounts returns a predictor for Vault "folders". This is a public // API for consumers, but you probably want BaseCommand.PredictVaultMounts // instead. func (p *Predict) VaultMounts() complete.Predictor { return p.filterFunc(p.mounts) } // VaultAudits returns a predictor for Vault "folders". This is a public API for // consumers, but you probably want BaseCommand.PredictVaultAudits instead. func (p *Predict) VaultAudits() complete.Predictor { return p.filterFunc(p.audits) } // VaultAuths returns a predictor for Vault "folders". This is a public API for // consumers, but you probably want BaseCommand.PredictVaultAuths instead. func (p *Predict) VaultAuths() complete.Predictor { return p.filterFunc(p.auths) } // VaultPlugins returns a predictor for Vault's plugin catalog. This is a public // API for consumers, but you probably want BaseCommand.PredictVaultPlugins // instead. func (p *Predict) VaultPlugins(pluginTypes ...api.PluginType) complete.Predictor { filterFunc := func() []string { return p.plugins(pluginTypes...) } return p.filterFunc(filterFunc) } // VaultPolicies returns a predictor for Vault "folders". This is a public API for // consumers, but you probably want BaseCommand.PredictVaultPolicies instead. func (p *Predict) VaultPolicies() complete.Predictor { return p.filterFunc(p.policies) } // vaultPaths parses the CLI options and returns the "best" list of possible // paths. If there are any errors, this function returns an empty result. All // errors are suppressed since this is a prediction function. func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc { return func(args complete.Args) []string { // Do not predict more than one paths if p.hasPathArg(args.All) { return nil } client := p.Client() if client == nil { return nil } path := args.Last // Trim path with potential mount var relativePath string mountInfos, err := p.mountInfos() if err != nil { return nil } var mountType, mountVersion string for mount, mountInfo := range mountInfos { if strings.HasPrefix(path, mount) { relativePath = strings.TrimPrefix(path, mount+"/") mountType = mountInfo.Type if mountInfo.Options != nil { mountVersion = mountInfo.Options["version"] } break } } // Predict path or mount depending on path separator var predictions []string if strings.Contains(relativePath, "/") { predictions = p.paths(mountType, mountVersion, path, includeFiles) } else { predictions = p.filter(p.mounts(), path) } // Either no results or many results, so return. if len(predictions) != 1 { return predictions } // If this is not a "folder", do not try to recurse. if !strings.HasSuffix(predictions[0], "/") { return predictions } // If the prediction is the same as the last guess, return it (we have no // new information and we won't get anymore). if predictions[0] == args.Last { return predictions } // Re-predict with the remaining path args.Last = predictions[0] return p.vaultPaths(includeFiles).Predict(args) } } // paths predicts all paths which start with the given path. func (p *Predict) paths(mountType, mountVersion, path string, includeFiles bool) []string { client := p.Client() if client == nil { return nil } // Vault does not support listing based on a sub-key, so we have to back-pedal // to the last "/" and return all paths on that "folder". Then we perform // client-side filtering. root := path idx := strings.LastIndex(root, "/") if idx > 0 && idx < len(root) { root = root[:idx+1] } paths := p.listPaths(buildAPIListPath(root, mountType, mountVersion)) var predictions []string for _, p := range paths { // Calculate the absolute "path" for matching. p = root + p if strings.HasPrefix(p, path) { // Ensure this is a directory or we've asked to include files. if includeFiles || strings.HasSuffix(p, "/") { predictions = append(predictions, p) } } } // Add root to the path if len(predictions) == 0 { predictions = append(predictions, path) } return predictions } func buildAPIListPath(path, mountType, mountVersion string) string { if mountType == "kv" && mountVersion == "2" { return toKVv2ListPath(path) } return path } func toKVv2ListPath(path string) string { firstSlashIdx := strings.Index(path, "/") if firstSlashIdx < 0 { return path } return path[:firstSlashIdx] + "/metadata" + path[firstSlashIdx:] } // audits returns a sorted list of the audit backends for Vault server for // which the client is configured to communicate with. func (p *Predict) audits() []string { client := p.Client() if client == nil { return nil } audits, err := client.Sys().ListAudit() if err != nil { return nil } list := make([]string, 0, len(audits)) for m := range audits { list = append(list, m) } sort.Strings(list) return list } // auths returns a sorted list of the enabled auth provides for Vault server for // which the client is configured to communicate with. func (p *Predict) auths() []string { client := p.Client() if client == nil { return nil } auths, err := client.Sys().ListAuth() if err != nil { return nil } list := make([]string, 0, len(auths)) for m := range auths { list = append(list, m) } sort.Strings(list) return list } // plugins returns a sorted list of the plugins in the catalog. func (p *Predict) plugins(pluginTypes ...api.PluginType) []string { // This method's signature doesn't enforce that a pluginType must be passed in. // If it's not, it's likely the caller's intent is go get a list of all of them, // so let's help them out. if len(pluginTypes) == 0 { pluginTypes = append(pluginTypes, api.PluginTypeUnknown) } client := p.Client() if client == nil { return nil } var plugins []string pluginsAdded := make(map[string]bool) for _, pluginType := range pluginTypes { result, err := client.Sys().ListPlugins(&api.ListPluginsInput{Type: api.PluginType(pluginType)}) if err != nil { return nil } if result == nil { return nil } for _, names := range result.PluginsByType { for _, name := range names { if _, ok := pluginsAdded[name]; !ok { plugins = append(plugins, name) pluginsAdded[name] = true } } } } sort.Strings(plugins) return plugins } // policies returns a sorted list of the policies stored in this Vault // server. func (p *Predict) policies() []string { client := p.Client() if client == nil { return nil } policies, err := client.Sys().ListPolicies() if err != nil { return nil } sort.Strings(policies) return policies } // mountInfos returns a map with mount paths as keys and MountOutputs as values // for the Vault server which the client is configured to communicate with. // Returns error if server communication fails. func (p *Predict) mountInfos() (map[string]*api.MountOutput, error) { client := p.Client() if client == nil { return nil, nil } mounts, err := client.Sys().ListMounts() if err != nil { return nil, err } return mounts, nil } // mounts returns a sorted list of the mount paths for Vault server for // which the client is configured to communicate with. This function returns the // default list of mounts if an error occurs. func (p *Predict) mounts() []string { mounts, err := p.mountInfos() if err != nil { return defaultPredictVaultMounts } list := make([]string, 0, len(mounts)) for m := range mounts { list = append(list, m) } sort.Strings(list) return list } // namespaces returns a sorted list of the namespace paths for Vault server for // which the client is configured to communicate with. This function returns // an empty list in any error occurs. func (p *Predict) namespaces() []string { client := p.Client() if client == nil { return nil } secret, err := client.Logical().List("sys/namespaces") if err != nil { return nil } namespaces, ok := extractListData(secret) if !ok { return nil } list := make([]string, 0, len(namespaces)) for _, n := range namespaces { s, ok := n.(string) if !ok { continue } list = append(list, s) } sort.Strings(list) return list } // listPaths returns a list of paths (HTTP LIST) for the given path. This // function returns an empty list of any errors occur. func (p *Predict) listPaths(path string) []string { client := p.Client() if client == nil { return nil } secret, err := client.Logical().List(path) if err != nil || secret == nil || secret.Data == nil { return nil } paths, ok := secret.Data["keys"].([]interface{}) if !ok { return nil } list := make([]string, 0, len(paths)) for _, p := range paths { if str, ok := p.(string); ok { list = append(list, str) } } sort.Strings(list) return list } // hasPathArg determines if the args have already accepted a path. func (p *Predict) hasPathArg(args []string) bool { var nonFlags []string for _, a := range args { if !strings.HasPrefix(a, "-") { nonFlags = append(nonFlags, a) } } return len(nonFlags) > 2 } // filterFunc is used to compose a complete predictor that filters an array // of strings as per the filter function. func (p *Predict) filterFunc(f func() []string) complete.Predictor { return complete.PredictFunc(func(args complete.Args) []string { if p.hasPathArg(args.All) { return nil } client := p.Client() if client == nil { return nil } return p.filter(f(), args.Last) }) } // filter filters the given list for items that start with the prefix. func (p *Predict) filter(list []string, prefix string) []string { var predictions []string for _, item := range list { if strings.HasPrefix(item, prefix) { predictions = append(predictions, item) } } return predictions }