27bb03bbc0
* adding copyright header * fix fmt and a test
583 lines
15 KiB
Go
583 lines
15 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package command
|
|
|
|
import (
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/posener/complete"
|
|
)
|
|
|
|
type Predict struct {
|
|
client *api.Client
|
|
clientOnce sync.Once
|
|
}
|
|
|
|
func NewPredict() *Predict {
|
|
return &Predict{}
|
|
}
|
|
|
|
func (p *Predict) Client() *api.Client {
|
|
p.clientOnce.Do(func() {
|
|
if p.client == nil { // For tests
|
|
client, _ := api.NewClient(nil)
|
|
|
|
if client.Token() == "" {
|
|
helper, err := DefaultTokenHelper()
|
|
if err != nil {
|
|
return
|
|
}
|
|
token, err := helper.Get()
|
|
if err != nil {
|
|
return
|
|
}
|
|
client.SetToken(token)
|
|
}
|
|
|
|
// Turn off retries for prediction
|
|
if os.Getenv(api.EnvVaultMaxRetries) == "" {
|
|
client.SetMaxRetries(0)
|
|
}
|
|
|
|
p.client = client
|
|
}
|
|
})
|
|
return p.client
|
|
}
|
|
|
|
// defaultPredictVaultMounts is the default list of mounts to return to the
|
|
// user. This is a best-guess, given we haven't communicated with the Vault
|
|
// server. If the user has no token or if the token does not have the default
|
|
// policy attached, it won't be able to read cubbyhole/, but it's a better UX
|
|
// that returning nothing.
|
|
var defaultPredictVaultMounts = []string{"cubbyhole/"}
|
|
|
|
// predictClient is the API client to use for prediction. We create this at the
|
|
// beginning once, because completions are generated for each command (and this
|
|
// doesn't change), and the only way to configure the predict/autocomplete
|
|
// client is via environment variables. Even if the user specifies a flag, we
|
|
// can't parse that flag until after the command is submitted.
|
|
var (
|
|
predictClient *api.Client
|
|
predictClientOnce sync.Once
|
|
)
|
|
|
|
// PredictClient returns the cached API client for the predictor.
|
|
func PredictClient() *api.Client {
|
|
predictClientOnce.Do(func() {
|
|
if predictClient == nil { // For tests
|
|
predictClient, _ = api.NewClient(nil)
|
|
}
|
|
})
|
|
return predictClient
|
|
}
|
|
|
|
// PredictVaultAvailableMounts returns a predictor for the available mounts in
|
|
// Vault. For now, there is no way to programmatically get this list. If, in the
|
|
// future, such a list exists, we can adapt it here. Until then, it's
|
|
// hard-coded.
|
|
func (b *BaseCommand) PredictVaultAvailableMounts() complete.Predictor {
|
|
// This list does not contain deprecated backends. At present, there is no
|
|
// API that lists all available secret backends, so this is hard-coded :(.
|
|
return complete.PredictSet(
|
|
"aws",
|
|
"consul",
|
|
"database",
|
|
"generic",
|
|
"pki",
|
|
"plugin",
|
|
"rabbitmq",
|
|
"ssh",
|
|
"totp",
|
|
"transit",
|
|
)
|
|
}
|
|
|
|
// PredictVaultAvailableAuths returns a predictor for the available auths in
|
|
// Vault. For now, there is no way to programmatically get this list. If, in the
|
|
// future, such a list exists, we can adapt it here. Until then, it's
|
|
// hard-coded.
|
|
func (b *BaseCommand) PredictVaultAvailableAuths() complete.Predictor {
|
|
return complete.PredictSet(
|
|
"app-id",
|
|
"approle",
|
|
"aws",
|
|
"cert",
|
|
"gcp",
|
|
"github",
|
|
"ldap",
|
|
"okta",
|
|
"plugin",
|
|
"radius",
|
|
"userpass",
|
|
)
|
|
}
|
|
|
|
// PredictVaultFiles returns a predictor for Vault mounts and paths based on the
|
|
// configured client for the base command. Unfortunately this happens pre-flag
|
|
// parsing, so users must rely on environment variables for autocomplete if they
|
|
// are not using Vault at the default endpoints.
|
|
func (b *BaseCommand) PredictVaultFiles() complete.Predictor {
|
|
return NewPredict().VaultFiles()
|
|
}
|
|
|
|
// PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles
|
|
// for more information and restrictions.
|
|
func (b *BaseCommand) PredictVaultFolders() complete.Predictor {
|
|
return NewPredict().VaultFolders()
|
|
}
|
|
|
|
// PredictVaultNamespaces returns a predictor for "namespaces". See PredictVaultFiles
|
|
// for more information an restrictions.
|
|
func (b *BaseCommand) PredictVaultNamespaces() complete.Predictor {
|
|
return NewPredict().VaultNamespaces()
|
|
}
|
|
|
|
// PredictVaultMounts returns a predictor for "folders". See PredictVaultFiles
|
|
// for more information and restrictions.
|
|
func (b *BaseCommand) PredictVaultMounts() complete.Predictor {
|
|
return NewPredict().VaultMounts()
|
|
}
|
|
|
|
// PredictVaultAudits returns a predictor for "folders". See PredictVaultFiles
|
|
// for more information and restrictions.
|
|
func (b *BaseCommand) PredictVaultAudits() complete.Predictor {
|
|
return NewPredict().VaultAudits()
|
|
}
|
|
|
|
// PredictVaultAuths returns a predictor for "folders". See PredictVaultFiles
|
|
// for more information and restrictions.
|
|
func (b *BaseCommand) PredictVaultAuths() complete.Predictor {
|
|
return NewPredict().VaultAuths()
|
|
}
|
|
|
|
// PredictVaultPlugins returns a predictor for installed plugins.
|
|
func (b *BaseCommand) PredictVaultPlugins(pluginTypes ...api.PluginType) complete.Predictor {
|
|
return NewPredict().VaultPlugins(pluginTypes...)
|
|
}
|
|
|
|
// PredictVaultPolicies returns a predictor for "folders". See PredictVaultFiles
|
|
// for more information and restrictions.
|
|
func (b *BaseCommand) PredictVaultPolicies() complete.Predictor {
|
|
return NewPredict().VaultPolicies()
|
|
}
|
|
|
|
func (b *BaseCommand) PredictVaultDebugTargets() complete.Predictor {
|
|
return complete.PredictSet(
|
|
"config",
|
|
"host",
|
|
"metrics",
|
|
"pprof",
|
|
"replication-status",
|
|
"server-status",
|
|
)
|
|
}
|
|
|
|
// VaultFiles returns a predictor for Vault "files". This is a public API for
|
|
// consumers, but you probably want BaseCommand.PredictVaultFiles instead.
|
|
func (p *Predict) VaultFiles() complete.Predictor {
|
|
return p.vaultPaths(true)
|
|
}
|
|
|
|
// VaultFolders returns a predictor for Vault "folders". This is a public
|
|
// API for consumers, but you probably want BaseCommand.PredictVaultFolders
|
|
// instead.
|
|
func (p *Predict) VaultFolders() complete.Predictor {
|
|
return p.vaultPaths(false)
|
|
}
|
|
|
|
// VaultNamespaces returns a predictor for Vault "namespaces". This is a public
|
|
// API for consumers, but you probably want BaseCommand.PredictVaultNamespaces
|
|
// instead.
|
|
func (p *Predict) VaultNamespaces() complete.Predictor {
|
|
return p.filterFunc(p.namespaces)
|
|
}
|
|
|
|
// VaultMounts returns a predictor for Vault "folders". This is a public
|
|
// API for consumers, but you probably want BaseCommand.PredictVaultMounts
|
|
// instead.
|
|
func (p *Predict) VaultMounts() complete.Predictor {
|
|
return p.filterFunc(p.mounts)
|
|
}
|
|
|
|
// VaultAudits returns a predictor for Vault "folders". This is a public API for
|
|
// consumers, but you probably want BaseCommand.PredictVaultAudits instead.
|
|
func (p *Predict) VaultAudits() complete.Predictor {
|
|
return p.filterFunc(p.audits)
|
|
}
|
|
|
|
// VaultAuths returns a predictor for Vault "folders". This is a public API for
|
|
// consumers, but you probably want BaseCommand.PredictVaultAuths instead.
|
|
func (p *Predict) VaultAuths() complete.Predictor {
|
|
return p.filterFunc(p.auths)
|
|
}
|
|
|
|
// VaultPlugins returns a predictor for Vault's plugin catalog. This is a public
|
|
// API for consumers, but you probably want BaseCommand.PredictVaultPlugins
|
|
// instead.
|
|
func (p *Predict) VaultPlugins(pluginTypes ...api.PluginType) complete.Predictor {
|
|
filterFunc := func() []string {
|
|
return p.plugins(pluginTypes...)
|
|
}
|
|
return p.filterFunc(filterFunc)
|
|
}
|
|
|
|
// VaultPolicies returns a predictor for Vault "folders". This is a public API for
|
|
// consumers, but you probably want BaseCommand.PredictVaultPolicies instead.
|
|
func (p *Predict) VaultPolicies() complete.Predictor {
|
|
return p.filterFunc(p.policies)
|
|
}
|
|
|
|
// vaultPaths parses the CLI options and returns the "best" list of possible
|
|
// paths. If there are any errors, this function returns an empty result. All
|
|
// errors are suppressed since this is a prediction function.
|
|
func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
|
|
return func(args complete.Args) []string {
|
|
// Do not predict more than one paths
|
|
if p.hasPathArg(args.All) {
|
|
return nil
|
|
}
|
|
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
path := args.Last
|
|
|
|
// Trim path with potential mount
|
|
var relativePath string
|
|
mountInfos, err := p.mountInfos()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
var mountType, mountVersion string
|
|
for mount, mountInfo := range mountInfos {
|
|
if strings.HasPrefix(path, mount) {
|
|
relativePath = strings.TrimPrefix(path, mount+"/")
|
|
mountType = mountInfo.Type
|
|
if mountInfo.Options != nil {
|
|
mountVersion = mountInfo.Options["version"]
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
// Predict path or mount depending on path separator
|
|
var predictions []string
|
|
if strings.Contains(relativePath, "/") {
|
|
predictions = p.paths(mountType, mountVersion, path, includeFiles)
|
|
} else {
|
|
predictions = p.filter(p.mounts(), path)
|
|
}
|
|
|
|
// Either no results or many results, so return.
|
|
if len(predictions) != 1 {
|
|
return predictions
|
|
}
|
|
|
|
// If this is not a "folder", do not try to recurse.
|
|
if !strings.HasSuffix(predictions[0], "/") {
|
|
return predictions
|
|
}
|
|
|
|
// If the prediction is the same as the last guess, return it (we have no
|
|
// new information and we won't get anymore).
|
|
if predictions[0] == args.Last {
|
|
return predictions
|
|
}
|
|
|
|
// Re-predict with the remaining path
|
|
args.Last = predictions[0]
|
|
return p.vaultPaths(includeFiles).Predict(args)
|
|
}
|
|
}
|
|
|
|
// paths predicts all paths which start with the given path.
|
|
func (p *Predict) paths(mountType, mountVersion, path string, includeFiles bool) []string {
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
// Vault does not support listing based on a sub-key, so we have to back-pedal
|
|
// to the last "/" and return all paths on that "folder". Then we perform
|
|
// client-side filtering.
|
|
root := path
|
|
idx := strings.LastIndex(root, "/")
|
|
if idx > 0 && idx < len(root) {
|
|
root = root[:idx+1]
|
|
}
|
|
|
|
paths := p.listPaths(buildAPIListPath(root, mountType, mountVersion))
|
|
|
|
var predictions []string
|
|
for _, p := range paths {
|
|
// Calculate the absolute "path" for matching.
|
|
p = root + p
|
|
|
|
if strings.HasPrefix(p, path) {
|
|
// Ensure this is a directory or we've asked to include files.
|
|
if includeFiles || strings.HasSuffix(p, "/") {
|
|
predictions = append(predictions, p)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add root to the path
|
|
if len(predictions) == 0 {
|
|
predictions = append(predictions, path)
|
|
}
|
|
|
|
return predictions
|
|
}
|
|
|
|
func buildAPIListPath(path, mountType, mountVersion string) string {
|
|
if mountType == "kv" && mountVersion == "2" {
|
|
return toKVv2ListPath(path)
|
|
}
|
|
return path
|
|
}
|
|
|
|
func toKVv2ListPath(path string) string {
|
|
firstSlashIdx := strings.Index(path, "/")
|
|
if firstSlashIdx < 0 {
|
|
return path
|
|
}
|
|
|
|
return path[:firstSlashIdx] + "/metadata" + path[firstSlashIdx:]
|
|
}
|
|
|
|
// audits returns a sorted list of the audit backends for Vault server for
|
|
// which the client is configured to communicate with.
|
|
func (p *Predict) audits() []string {
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
audits, err := client.Sys().ListAudit()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
list := make([]string, 0, len(audits))
|
|
for m := range audits {
|
|
list = append(list, m)
|
|
}
|
|
sort.Strings(list)
|
|
return list
|
|
}
|
|
|
|
// auths returns a sorted list of the enabled auth provides for Vault server for
|
|
// which the client is configured to communicate with.
|
|
func (p *Predict) auths() []string {
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
auths, err := client.Sys().ListAuth()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
list := make([]string, 0, len(auths))
|
|
for m := range auths {
|
|
list = append(list, m)
|
|
}
|
|
sort.Strings(list)
|
|
return list
|
|
}
|
|
|
|
// plugins returns a sorted list of the plugins in the catalog.
|
|
func (p *Predict) plugins(pluginTypes ...api.PluginType) []string {
|
|
// This method's signature doesn't enforce that a pluginType must be passed in.
|
|
// If it's not, it's likely the caller's intent is go get a list of all of them,
|
|
// so let's help them out.
|
|
if len(pluginTypes) == 0 {
|
|
pluginTypes = append(pluginTypes, api.PluginTypeUnknown)
|
|
}
|
|
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
var plugins []string
|
|
pluginsAdded := make(map[string]bool)
|
|
for _, pluginType := range pluginTypes {
|
|
result, err := client.Sys().ListPlugins(&api.ListPluginsInput{Type: api.PluginType(pluginType)})
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
if result == nil {
|
|
return nil
|
|
}
|
|
for _, names := range result.PluginsByType {
|
|
for _, name := range names {
|
|
if _, ok := pluginsAdded[name]; !ok {
|
|
plugins = append(plugins, name)
|
|
pluginsAdded[name] = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
sort.Strings(plugins)
|
|
return plugins
|
|
}
|
|
|
|
// policies returns a sorted list of the policies stored in this Vault
|
|
// server.
|
|
func (p *Predict) policies() []string {
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
policies, err := client.Sys().ListPolicies()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
sort.Strings(policies)
|
|
return policies
|
|
}
|
|
|
|
// mountInfos returns a map with mount paths as keys and MountOutputs as values
|
|
// for the Vault server which the client is configured to communicate with.
|
|
// Returns error if server communication fails.
|
|
func (p *Predict) mountInfos() (map[string]*api.MountOutput, error) {
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
mounts, err := client.Sys().ListMounts()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mounts, nil
|
|
}
|
|
|
|
// mounts returns a sorted list of the mount paths for Vault server for
|
|
// which the client is configured to communicate with. This function returns the
|
|
// default list of mounts if an error occurs.
|
|
func (p *Predict) mounts() []string {
|
|
mounts, err := p.mountInfos()
|
|
if err != nil {
|
|
return defaultPredictVaultMounts
|
|
}
|
|
|
|
list := make([]string, 0, len(mounts))
|
|
for m := range mounts {
|
|
list = append(list, m)
|
|
}
|
|
sort.Strings(list)
|
|
return list
|
|
}
|
|
|
|
// namespaces returns a sorted list of the namespace paths for Vault server for
|
|
// which the client is configured to communicate with. This function returns
|
|
// an empty list in any error occurs.
|
|
func (p *Predict) namespaces() []string {
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
secret, err := client.Logical().List("sys/namespaces")
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
namespaces, ok := extractListData(secret)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
list := make([]string, 0, len(namespaces))
|
|
for _, n := range namespaces {
|
|
s, ok := n.(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
list = append(list, s)
|
|
}
|
|
sort.Strings(list)
|
|
return list
|
|
}
|
|
|
|
// listPaths returns a list of paths (HTTP LIST) for the given path. This
|
|
// function returns an empty list of any errors occur.
|
|
func (p *Predict) listPaths(path string) []string {
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
secret, err := client.Logical().List(path)
|
|
if err != nil || secret == nil || secret.Data == nil {
|
|
return nil
|
|
}
|
|
|
|
paths, ok := secret.Data["keys"].([]interface{})
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
list := make([]string, 0, len(paths))
|
|
for _, p := range paths {
|
|
if str, ok := p.(string); ok {
|
|
list = append(list, str)
|
|
}
|
|
}
|
|
sort.Strings(list)
|
|
return list
|
|
}
|
|
|
|
// hasPathArg determines if the args have already accepted a path.
|
|
func (p *Predict) hasPathArg(args []string) bool {
|
|
var nonFlags []string
|
|
for _, a := range args {
|
|
if !strings.HasPrefix(a, "-") {
|
|
nonFlags = append(nonFlags, a)
|
|
}
|
|
}
|
|
|
|
return len(nonFlags) > 2
|
|
}
|
|
|
|
// filterFunc is used to compose a complete predictor that filters an array
|
|
// of strings as per the filter function.
|
|
func (p *Predict) filterFunc(f func() []string) complete.Predictor {
|
|
return complete.PredictFunc(func(args complete.Args) []string {
|
|
if p.hasPathArg(args.All) {
|
|
return nil
|
|
}
|
|
|
|
client := p.Client()
|
|
if client == nil {
|
|
return nil
|
|
}
|
|
|
|
return p.filter(f(), args.Last)
|
|
})
|
|
}
|
|
|
|
// filter filters the given list for items that start with the prefix.
|
|
func (p *Predict) filter(list []string, prefix string) []string {
|
|
var predictions []string
|
|
for _, item := range list {
|
|
if strings.HasPrefix(item, prefix) {
|
|
predictions = append(predictions, item)
|
|
}
|
|
}
|
|
return predictions
|
|
}
|