417 lines
11 KiB
Go
417 lines
11 KiB
Go
package command
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/command/token"
|
|
"github.com/kr/text"
|
|
"github.com/mitchellh/cli"
|
|
"github.com/pkg/errors"
|
|
"github.com/posener/complete"
|
|
)
|
|
|
|
// maxLineLength is the maximum width of any line.
|
|
const maxLineLength int = 78
|
|
|
|
// reRemoveWhitespace is a regular expression for stripping whitespace from
|
|
// a string.
|
|
var reRemoveWhitespace = regexp.MustCompile(`[\s]+`)
|
|
|
|
type BaseCommand struct {
|
|
UI cli.Ui
|
|
|
|
flags *FlagSets
|
|
flagsOnce sync.Once
|
|
|
|
flagAddress string
|
|
flagCACert string
|
|
flagCAPath string
|
|
flagClientCert string
|
|
flagClientKey string
|
|
flagTLSServerName string
|
|
flagTLSSkipVerify bool
|
|
flagWrapTTL time.Duration
|
|
|
|
flagFormat string
|
|
flagField string
|
|
|
|
tokenHelper token.TokenHelper
|
|
|
|
// For testing
|
|
client *api.Client
|
|
}
|
|
|
|
// Client returns the HTTP API client. The client is cached on the command to
|
|
// save performance on future calls.
|
|
func (c *BaseCommand) Client() (*api.Client, error) {
|
|
// Read the test client if present
|
|
if c.client != nil {
|
|
return c.client, nil
|
|
}
|
|
|
|
config := api.DefaultConfig()
|
|
|
|
if err := config.ReadEnvironment(); err != nil {
|
|
return nil, errors.Wrap(err, "failed to read environment")
|
|
}
|
|
|
|
if c.flagAddress != "" {
|
|
config.Address = c.flagAddress
|
|
}
|
|
|
|
// If we need custom TLS configuration, then set it
|
|
if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" ||
|
|
c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify {
|
|
t := &api.TLSConfig{
|
|
CACert: c.flagCACert,
|
|
CAPath: c.flagCAPath,
|
|
ClientCert: c.flagClientCert,
|
|
ClientKey: c.flagClientKey,
|
|
TLSServerName: c.flagTLSServerName,
|
|
Insecure: c.flagTLSSkipVerify,
|
|
}
|
|
config.ConfigureTLS(t)
|
|
}
|
|
|
|
// Build the client
|
|
client, err := api.NewClient(config)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to create client")
|
|
}
|
|
|
|
// Set the wrapping function
|
|
client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc)
|
|
|
|
// Get the token if it came in from the environment
|
|
token := client.Token()
|
|
|
|
// If we don't have a token, check the token helper
|
|
if token == "" {
|
|
helper, err := c.TokenHelper()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get token helper")
|
|
}
|
|
token, err = helper.Get()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get token from token helper")
|
|
}
|
|
}
|
|
|
|
// Set the token
|
|
if token != "" {
|
|
client.SetToken(token)
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
// TokenHelper returns the token helper attached to the command.
|
|
func (c *BaseCommand) TokenHelper() (token.TokenHelper, error) {
|
|
if c.tokenHelper != nil {
|
|
return c.tokenHelper, nil
|
|
}
|
|
|
|
helper, err := DefaultTokenHelper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return helper, nil
|
|
}
|
|
|
|
// DefaultWrappingLookupFunc is the default wrapping function based on the
|
|
// CLI flag.
|
|
func (c *BaseCommand) DefaultWrappingLookupFunc(operation, path string) string {
|
|
if c.flagWrapTTL != 0 {
|
|
return c.flagWrapTTL.String()
|
|
}
|
|
|
|
return api.DefaultWrappingLookupFunc(operation, path)
|
|
}
|
|
|
|
type FlagSetBit uint
|
|
|
|
const (
|
|
FlagSetNone FlagSetBit = 1 << iota
|
|
FlagSetHTTP
|
|
FlagSetOutputField
|
|
FlagSetOutputFormat
|
|
)
|
|
|
|
// flagSet creates the flags for this command. The result is cached on the
|
|
// command to save performance on future calls.
|
|
func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
|
|
c.flagsOnce.Do(func() {
|
|
set := NewFlagSets(c.UI)
|
|
|
|
if bit&FlagSetHTTP != 0 {
|
|
f := set.NewFlagSet("HTTP Options")
|
|
|
|
f.StringVar(&StringVar{
|
|
Name: "address",
|
|
Target: &c.flagAddress,
|
|
Default: "https://127.0.0.1:8200",
|
|
EnvVar: "VAULT_ADDR",
|
|
Completion: complete.PredictAnything,
|
|
Usage: "Address of the Vault server.",
|
|
})
|
|
|
|
f.StringVar(&StringVar{
|
|
Name: "ca-cert",
|
|
Target: &c.flagCACert,
|
|
Default: "",
|
|
EnvVar: "VAULT_CACERT",
|
|
Completion: complete.PredictFiles("*"),
|
|
Usage: "Path on the local disk to a single PEM-encoded CA " +
|
|
"certificate to verify the Vault server's SSL certificate. This " +
|
|
"takes precendence over -ca-path.",
|
|
})
|
|
|
|
f.StringVar(&StringVar{
|
|
Name: "ca-path",
|
|
Target: &c.flagCAPath,
|
|
Default: "",
|
|
EnvVar: "VAULT_CAPATH",
|
|
Completion: complete.PredictDirs("*"),
|
|
Usage: "Path on the local disk to a directory of PEM-encoded CA " +
|
|
"certificates to verify the Vault server's SSL certificate.",
|
|
})
|
|
|
|
f.StringVar(&StringVar{
|
|
Name: "client-cert",
|
|
Target: &c.flagClientCert,
|
|
Default: "",
|
|
EnvVar: "VAULT_CLIENT_CERT",
|
|
Completion: complete.PredictFiles("*"),
|
|
Usage: "Path on the local disk to a single PEM-encoded CA " +
|
|
"certificate to use for TLS authentication to the Vault server. If " +
|
|
"this flag is specified, -client-key is also required.",
|
|
})
|
|
|
|
f.StringVar(&StringVar{
|
|
Name: "client-key",
|
|
Target: &c.flagClientKey,
|
|
Default: "",
|
|
EnvVar: "VAULT_CLIENT_KEY",
|
|
Completion: complete.PredictFiles("*"),
|
|
Usage: "Path on the local disk to a single PEM-encoded private key " +
|
|
"matching the client certificate from -client-cert.",
|
|
})
|
|
|
|
f.StringVar(&StringVar{
|
|
Name: "tls-server-name",
|
|
Target: &c.flagTLSServerName,
|
|
Default: "",
|
|
EnvVar: "VAULT_TLS_SERVER_NAME",
|
|
Completion: complete.PredictAnything,
|
|
Usage: "Name to use as the SNI host when connecting to the Vault " +
|
|
"server via TLS.",
|
|
})
|
|
|
|
f.BoolVar(&BoolVar{
|
|
Name: "tls-skip-verify",
|
|
Target: &c.flagTLSSkipVerify,
|
|
Default: false,
|
|
EnvVar: "VAULT_SKIP_VERIFY",
|
|
Usage: "Disable verification of TLS certificates. Using this option " +
|
|
"is highly discouraged and decreases the security of data " +
|
|
"transmissions to and from the Vault server.",
|
|
})
|
|
|
|
f.DurationVar(&DurationVar{
|
|
Name: "wrap-ttl",
|
|
Target: &c.flagWrapTTL,
|
|
Default: 0,
|
|
EnvVar: "VAULT_WRAP_TTL",
|
|
Completion: complete.PredictAnything,
|
|
Usage: "Wraps the response in a cubbyhole token with the requested " +
|
|
"TTL. The response is available via the \"vault unwrap\" command. " +
|
|
"The TTL is specified as a numeric string with suffix like \"30s\" " +
|
|
"or \"5m\".",
|
|
})
|
|
}
|
|
|
|
if bit&(FlagSetOutputField|FlagSetOutputFormat) != 0 {
|
|
f := set.NewFlagSet("Output Options")
|
|
|
|
if bit&FlagSetOutputField != 0 {
|
|
f.StringVar(&StringVar{
|
|
Name: "field",
|
|
Target: &c.flagField,
|
|
Default: "",
|
|
Completion: complete.PredictAnything,
|
|
Usage: "Print only the field with the given name. Specifying " +
|
|
"this option will take precedence over other formatting " +
|
|
"directives. The result will not have a trailing newline " +
|
|
"making it idea for piping to other processes.",
|
|
})
|
|
}
|
|
|
|
if bit&FlagSetOutputFormat != 0 {
|
|
f.StringVar(&StringVar{
|
|
Name: "format",
|
|
Target: &c.flagFormat,
|
|
Default: "table",
|
|
EnvVar: "VAULT_FORMAT",
|
|
Completion: complete.PredictSet("table", "json", "yaml"),
|
|
Usage: "Print the output in the given format. Valid formats " +
|
|
"are \"table\", \"json\", or \"yaml\".",
|
|
})
|
|
}
|
|
}
|
|
|
|
c.flags = set
|
|
})
|
|
|
|
return c.flags
|
|
}
|
|
|
|
// wrapAtLengthWithPadding wraps the given text at the maxLineLength, taking
|
|
// into account any provided left padding.
|
|
func wrapAtLengthWithPadding(s string, pad int) string {
|
|
wrapped := text.Wrap(s, maxLineLength-pad)
|
|
lines := strings.Split(wrapped, "\n")
|
|
for i, line := range lines {
|
|
lines[i] = strings.Repeat(" ", pad) + line
|
|
}
|
|
return strings.Join(lines, "\n")
|
|
}
|
|
|
|
// wrapAtLength wraps the given text to maxLineLength.
|
|
func wrapAtLength(s string) string {
|
|
return wrapAtLengthWithPadding(s, 0)
|
|
}
|
|
|
|
// FlagSets is a group of flag sets.
|
|
type FlagSets struct {
|
|
flagSets []*FlagSet
|
|
mainSet *flag.FlagSet
|
|
hiddens map[string]struct{}
|
|
completions complete.Flags
|
|
}
|
|
|
|
// NewFlagSets creates a new flag sets.
|
|
func NewFlagSets(ui cli.Ui) *FlagSets {
|
|
mainSet := flag.NewFlagSet("", flag.ContinueOnError)
|
|
|
|
// Errors and usage are controlled by the CLI.
|
|
mainSet.Usage = func() {}
|
|
mainSet.SetOutput(ioutil.Discard)
|
|
|
|
return &FlagSets{
|
|
flagSets: make([]*FlagSet, 0, 6),
|
|
mainSet: mainSet,
|
|
hiddens: make(map[string]struct{}),
|
|
completions: complete.Flags{},
|
|
}
|
|
}
|
|
|
|
// NewFlagSet creates a new flag set from the given flag sets.
|
|
func (f *FlagSets) NewFlagSet(name string) *FlagSet {
|
|
flagSet := NewFlagSet(name)
|
|
flagSet.mainSet = f.mainSet
|
|
flagSet.completions = f.completions
|
|
f.flagSets = append(f.flagSets, flagSet)
|
|
return flagSet
|
|
}
|
|
|
|
// Completions returns the completions for this flag set.
|
|
func (f *FlagSets) Completions() complete.Flags {
|
|
return f.completions
|
|
}
|
|
|
|
// Parse parses the given flags, returning any errors.
|
|
func (f *FlagSets) Parse(args []string) error {
|
|
return f.mainSet.Parse(args)
|
|
}
|
|
|
|
// Args returns the remaining args after parsing.
|
|
func (f *FlagSets) Args() []string {
|
|
return f.mainSet.Args()
|
|
}
|
|
|
|
// Help builds custom help for this command, grouping by flag set.
|
|
func (fs *FlagSets) Help() string {
|
|
var out bytes.Buffer
|
|
|
|
for _, set := range fs.flagSets {
|
|
printFlagTitle(&out, set.name+":")
|
|
set.VisitAll(func(f *flag.Flag) {
|
|
// Skip any hidden flags
|
|
if v, ok := f.Value.(FlagVisibility); ok && v.Hidden() {
|
|
return
|
|
}
|
|
printFlagDetail(&out, f)
|
|
})
|
|
}
|
|
|
|
return strings.TrimRight(out.String(), "\n")
|
|
}
|
|
|
|
// FlagSet is a grouped wrapper around a real flag set and a grouped flag set.
|
|
type FlagSet struct {
|
|
name string
|
|
flagSet *flag.FlagSet
|
|
mainSet *flag.FlagSet
|
|
completions complete.Flags
|
|
}
|
|
|
|
// NewFlagSet creates a new flag set.
|
|
func NewFlagSet(name string) *FlagSet {
|
|
return &FlagSet{
|
|
name: name,
|
|
flagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
|
}
|
|
}
|
|
|
|
// Name returns the name of this flag set.
|
|
func (f *FlagSet) Name() string {
|
|
return f.name
|
|
}
|
|
|
|
func (f *FlagSet) Visit(fn func(*flag.Flag)) {
|
|
f.flagSet.Visit(fn)
|
|
}
|
|
|
|
func (f *FlagSet) VisitAll(fn func(*flag.Flag)) {
|
|
f.flagSet.VisitAll(fn)
|
|
}
|
|
|
|
// printFlagTitle prints a consistently-formatted title to the given writer.
|
|
func printFlagTitle(w io.Writer, s string) {
|
|
fmt.Fprintf(w, "%s\n\n", s)
|
|
}
|
|
|
|
// printFlagDetail prints a single flag to the given writer.
|
|
func printFlagDetail(w io.Writer, f *flag.Flag) {
|
|
// Check if the flag is hidden - do not print any flag detail or help output
|
|
// if it is hidden.
|
|
if h, ok := f.Value.(FlagVisibility); ok && h.Hidden() {
|
|
return
|
|
}
|
|
|
|
// Check for a detailed example
|
|
example := ""
|
|
if t, ok := f.Value.(FlagExample); ok {
|
|
example = t.Example()
|
|
}
|
|
|
|
if example != "" {
|
|
fmt.Fprintf(w, " -%s=<%s>\n", f.Name, example)
|
|
} else {
|
|
fmt.Fprintf(w, " -%s\n", f.Name)
|
|
}
|
|
|
|
usage := reRemoveWhitespace.ReplaceAllString(f.Usage, " ")
|
|
indented := wrapAtLengthWithPadding(usage, 6)
|
|
fmt.Fprintf(w, "%s\n\n", indented)
|
|
}
|