Update init command

This commit is contained in:
Seth Vargo 2017-09-05 00:02:02 -04:00
parent 9500cb7fc7
commit 076703ebc1
No known key found for this signature in database
GPG Key ID: C921994F9C27E0FF
2 changed files with 870 additions and 664 deletions

View File

@ -1,406 +1,594 @@
package command
import (
"encoding/json"
"fmt"
"net/url"
"os"
"runtime"
"strings"
consulapi "github.com/hashicorp/consul/api"
"github.com/ghodss/yaml"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/physical/consul"
"github.com/mitchellh/cli"
"github.com/posener/complete"
consulapi "github.com/hashicorp/consul/api"
)
// Ensure we are implementing the right interfaces.
var _ cli.Command = (*InitCommand)(nil)
var _ cli.CommandAutocomplete = (*InitCommand)(nil)
// InitCommand is a Command that initializes a new Vault server.
type InitCommand struct {
meta.Meta
}
*BaseCommand
func (c *InitCommand) Run(args []string) int {
var threshold, shares, storedShares, recoveryThreshold, recoveryShares int
var pgpKeys, recoveryPgpKeys, rootTokenPgpKey pgpkeys.PubKeyFilesFlag
var auto, check bool
var consulServiceName string
flags := c.Meta.FlagSet("init", meta.FlagSetDefault)
flags.Usage = func() { c.Ui.Error(c.Help()) }
flags.IntVar(&shares, "key-shares", 5, "")
flags.IntVar(&threshold, "key-threshold", 3, "")
flags.IntVar(&storedShares, "stored-shares", 0, "")
flags.Var(&pgpKeys, "pgp-keys", "")
flags.Var(&rootTokenPgpKey, "root-token-pgp-key", "")
flags.IntVar(&recoveryShares, "recovery-shares", 5, "")
flags.IntVar(&recoveryThreshold, "recovery-threshold", 3, "")
flags.Var(&recoveryPgpKeys, "recovery-pgp-keys", "")
flags.BoolVar(&check, "check", false, "")
flags.BoolVar(&auto, "auto", false, "")
flags.StringVar(&consulServiceName, "consul-service", consul.DefaultServiceName, "")
if err := flags.Parse(args); err != nil {
return 1
}
flagStatus bool
flagKeyShares int
flagKeyThreshold int
flagPGPKeys []string
flagRootTokenPGPKey string
initRequest := &api.InitRequest{
SecretShares: shares,
SecretThreshold: threshold,
StoredShares: storedShares,
PGPKeys: pgpKeys,
RecoveryShares: recoveryShares,
RecoveryThreshold: recoveryThreshold,
RecoveryPGPKeys: recoveryPgpKeys,
}
// HSM
flagStoredShares int
flagRecoveryShares int
flagRecoveryThreshold int
flagRecoveryPGPKeys []string
switch len(rootTokenPgpKey) {
case 0:
case 1:
initRequest.RootTokenPGPKey = rootTokenPgpKey[0]
default:
c.Ui.Error("Only one PGP key can be specified for encrypting the root token")
return 1
}
// Consul
flagConsulAuto bool
flagConsulService string
// If running in 'auto' mode, run service discovery based on environment
// variables of Consul.
if auto {
// Create configuration for Consul
consulConfig := consulapi.DefaultConfig()
// Create a client to communicate with Consul
consulClient, err := consulapi.NewClient(consulConfig)
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to create Consul client:%v", err))
return 1
}
// Fetch Vault's protocol scheme from the client
vaultclient, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to fetch Vault client: %v", err))
return 1
}
if vaultclient.Address() == "" {
c.Ui.Error("Failed to fetch Vault client address")
return 1
}
clientURL, err := url.Parse(vaultclient.Address())
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to parse Vault address: %v", err))
return 1
}
if clientURL == nil {
c.Ui.Error("Failed to parse Vault client address")
return 1
}
var uninitializedVaults []string
var initializedVault string
// Query the nodes belonging to the cluster
if services, _, err := consulClient.Catalog().Service(consulServiceName, "", &consulapi.QueryOptions{AllowStale: true}); err == nil {
Loop:
for _, service := range services {
vaultAddress := &url.URL{
Scheme: clientURL.Scheme,
Host: fmt.Sprintf("%s:%d", service.ServiceAddress, service.ServicePort),
}
// Set VAULT_ADDR to the discovered node
os.Setenv(api.EnvVaultAddress, vaultAddress.String())
// Create a client to communicate with the discovered node
client, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf("Error initializing client: %v", err))
return 1
}
// Check the initialization status of the discovered node
inited, err := client.Sys().InitStatus()
switch {
case err != nil:
c.Ui.Error(fmt.Sprintf("Error checking initialization status of discovered node: %+q. Err: %v", vaultAddress.String(), err))
return 1
case inited:
// One of the nodes in the cluster is initialized. Break out.
initializedVault = vaultAddress.String()
break Loop
default:
// Vault is uninitialized.
uninitializedVaults = append(uninitializedVaults, vaultAddress.String())
}
}
}
export := "export"
quote := "'"
if runtime.GOOS == "windows" {
export = "set"
quote = ""
}
if initializedVault != "" {
vaultURL, err := url.Parse(initializedVault)
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to parse Vault address: %+q. Err: %v", initializedVault, err))
}
c.Ui.Output(fmt.Sprintf("Discovered an initialized Vault node at %+q, using Consul service name %+q", vaultURL.String(), consulServiceName))
c.Ui.Output("\nSet the following environment variable to operate on the discovered Vault:\n")
c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%s%s%s", export, quote, vaultURL.String(), quote))
return 0
}
switch len(uninitializedVaults) {
case 0:
c.Ui.Error(fmt.Sprintf("Failed to discover Vault nodes using Consul service name %+q", consulServiceName))
return 1
case 1:
// There was only one node found in the Vault cluster and it
// was uninitialized.
vaultURL, err := url.Parse(uninitializedVaults[0])
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to parse Vault address: %+q. Err: %v", uninitializedVaults[0], err))
}
// Set the VAULT_ADDR to the discovered node. This will ensure
// that the client created will operate on the discovered node.
os.Setenv(api.EnvVaultAddress, vaultURL.String())
// Let the client know that initialization is perfomed on the
// discovered node.
c.Ui.Output(fmt.Sprintf("Discovered Vault at %+q using Consul service name %+q\n", vaultURL.String(), consulServiceName))
// Attempt initializing it
ret := c.runInit(check, initRequest)
// Regardless of success or failure, instruct client to update VAULT_ADDR
c.Ui.Output("\nSet the following environment variable to operate on the discovered Vault:\n")
c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%s%s%s", export, quote, vaultURL.String(), quote))
return ret
default:
// If more than one Vault node were discovered, print out all of them,
// requiring the client to update VAULT_ADDR and to run init again.
c.Ui.Output(fmt.Sprintf("Discovered more than one uninitialized Vaults using Consul service name %+q\n", consulServiceName))
c.Ui.Output("To initialize these Vaults, set any *one* of the following environment variables and run 'vault init':")
// Print valid commands to make setting the variables easier
for _, vaultNode := range uninitializedVaults {
vaultURL, err := url.Parse(vaultNode)
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to parse Vault address: %+q. Err: %v", vaultNode, err))
}
c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%s%s%s", export, quote, vaultURL.String(), quote))
}
return 0
}
}
return c.runInit(check, initRequest)
}
func (c *InitCommand) runInit(check bool, initRequest *api.InitRequest) int {
client, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing client: %s", err))
return 1
}
if check {
return c.checkStatus(client)
}
resp, err := client.Sys().Init(initRequest)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing Vault: %s", err))
return 1
}
for i, key := range resp.Keys {
if resp.KeysB64 != nil && len(resp.KeysB64) == len(resp.Keys) {
c.Ui.Output(fmt.Sprintf("Unseal Key %d: %s", i+1, resp.KeysB64[i]))
} else {
c.Ui.Output(fmt.Sprintf("Unseal Key %d: %s", i+1, key))
}
}
for i, key := range resp.RecoveryKeys {
if resp.RecoveryKeysB64 != nil && len(resp.RecoveryKeysB64) == len(resp.RecoveryKeys) {
c.Ui.Output(fmt.Sprintf("Recovery Key %d: %s", i+1, resp.RecoveryKeysB64[i]))
} else {
c.Ui.Output(fmt.Sprintf("Recovery Key %d: %s", i+1, key))
}
}
c.Ui.Output(fmt.Sprintf("Initial Root Token: %s", resp.RootToken))
if initRequest.StoredShares < 1 {
c.Ui.Output(fmt.Sprintf(
"\n"+
"Vault initialized with %d keys and a key threshold of %d. Please\n"+
"securely distribute the above keys. When the vault is re-sealed,\n"+
"restarted, or stopped, you must provide at least %d of these keys\n"+
"to unseal it again.\n\n"+
"Vault does not store the master key. Without at least %d keys,\n"+
"your vault will remain permanently sealed.",
initRequest.SecretShares,
initRequest.SecretThreshold,
initRequest.SecretThreshold,
initRequest.SecretThreshold,
))
} else {
c.Ui.Output(
"\n" +
"Vault initialized successfully.",
)
}
if len(resp.RecoveryKeys) > 0 {
c.Ui.Output(fmt.Sprintf(
"\n"+
"Recovery key initialized with %d keys and a key threshold of %d. Please\n"+
"securely distribute the above keys.",
initRequest.RecoveryShares,
initRequest.RecoveryThreshold,
))
}
return 0
}
func (c *InitCommand) checkStatus(client *api.Client) int {
inited, err := client.Sys().InitStatus()
switch {
case err != nil:
c.Ui.Error(fmt.Sprintf(
"Error checking initialization status: %s", err))
return 1
case inited:
c.Ui.Output("Vault has been initialized")
return 0
default:
c.Ui.Output("Vault is not initialized")
return 2
}
// Deprecations
// TODO: remove in 0.9.0
flagAuto bool
flagCheck bool
}
func (c *InitCommand) Synopsis() string {
return "Initialize a new Vault server"
return "Initializes a server"
}
func (c *InitCommand) Help() string {
helpText := `
Usage: vault init [options]
Initialize a new Vault server.
Initializes a Vault server. Initialization is the process by which Vault's
storage backend is prepared to receive data. Since Vault server's share the
same storage backend in HA mode, you only need to initialize one Vault to
initialize the storage backend.
This command connects to a Vault server and initializes it for the
first time. This sets up the initial set of master keys and the
backend data store structure.
During initialization, Vault generates an in-memory master key and applies
Shamir's secret sharing algorithm to disassemble that master key into a
configuration number of key shares such that a configurable subset of those
key shares must come together to regenerate the master key. These keys are
often called "unseal keys" in Vault's documentation.
This command can't be called on an already-initialized Vault server.
This command cannot be run against already-initialized Vault cluster.
General Options:
` + meta.GeneralOptionsUsage() + `
Init Options:
Start initialization with the default options:
-check Don't actually initialize, just check if Vault is
already initialized. A return code of 0 means Vault
is initialized; a return code of 2 means Vault is not
initialized; a return code of 1 means an error was
encountered.
$ vault init
-key-shares=5 The number of key shares to split the master key
into.
Initialize, but encrypt the unseal keys with pgp keys:
-key-threshold=3 The number of key shares required to reconstruct
the master key.
$ vault init \
-key-shares=3 \
-key-threshold=2 \
-pgp-keys="keybase:hashicorp,keybase:jefferai,keybase:sethvargo"
-stored-shares=0 The number of unseal keys to store. Only used with
Vault HSM. Must currently be equivalent to the
number of shares.
Encrypt the initial root token using a pgp key:
-pgp-keys If provided, must be a comma-separated list of
files on disk containing binary- or base64-format
public PGP keys, or Keybase usernames specified as
"keybase:<username>". The output unseal keys will
be encrypted and base64-encoded, in order, with the
given public keys. If you want to use them with the
'vault unseal' command, you will need to base64-
decode and decrypt; this will be the plaintext
unseal key. When 'stored-shares' are not used, the
number of entries in this field must match 'key-shares'.
When 'stored-shares' are used, the number of entries
should match the difference between 'key-shares'
and 'stored-shares'.
$ vault init -root-token-pgp-key="keybase:hashicorp"
-root-token-pgp-key If provided, a file on disk with a binary- or
base64-format public PGP key, or a Keybase username
specified as "keybase:<username>". The output root
token will be encrypted and base64-encoded, in
order, with the given public key. You will need
to base64-decode and decrypt the result.
For a complete list of examples, please see the documentation.
-recovery-shares=5 The number of key shares to split the recovery key
into. Only used with Vault HSM.
-recovery-threshold=3 The number of key shares required to reconstruct
the recovery key. Only used with Vault HSM.
-recovery-pgp-keys If provided, behaves like "pgp-keys" but for the
recovery key shares. Only used with Vault HSM.
-auto If set, performs service discovery using Consul.
When all the nodes of a Vault cluster are
registered with Consul, setting this flag will
trigger service discovery using the service name
with which Vault nodes are registered. This option
works well when each Vault cluster is registered
under a unique service name. Note that, when Consul
is serving as Vault's HA backend, Vault nodes are
registered with Consul by default. The service name
can be changed using 'consul-service' flag. Ensure
that environment variables required to communicate
with Consul, like (CONSUL_HTTP_ADDR,
CONSUL_HTTP_TOKEN, CONSUL_HTTP_SSL, et al) are
properly set. When only one Vault node is
discovered, it will be initialized and when more
than one Vault node is discovered, they will be
output for easy selection.
-consul-service Service name under which all the nodes of a Vault
cluster are registered with Consul. Note that, when
Vault uses Consul as its HA backend, by default,
Vault will register itself as a service with Consul
with the service name "vault". This name can be
modified in Vault's configuration file, using the
"service" option for the Consul backend.
`
` + c.Flags().Help()
return strings.TrimSpace(helpText)
}
func (c *InitCommand) Flags() *FlagSets {
set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat)
// Common Options
f := set.NewFlagSet("Common Options")
f.BoolVar(&BoolVar{
Name: "status",
Target: &c.flagStatus,
Default: false,
Usage: "Print the current initialization status. An exit code of 0 means " +
"the Vault is already initialized. An exit code of 1 means an error " +
"occurred. An exit code of 2 means the mean is not initialized.",
})
f.IntVar(&IntVar{
Name: "key-shares",
Aliases: []string{"n"},
Target: &c.flagKeyShares,
Default: 5,
Completion: complete.PredictAnything,
Usage: "Number of key shares to split the generated master key into. " +
"This is the number of \"unseal keys\" to generate.",
})
f.IntVar(&IntVar{
Name: "key-threshold",
Aliases: []string{"t"},
Target: &c.flagKeyThreshold,
Default: 3,
Completion: complete.PredictAnything,
Usage: "Number of key shares required to reconstruct the master key. " +
"This must be less than or equal to -key-shares.",
})
f.VarFlag(&VarFlag{
Name: "pgp-keys",
Value: (*pgpkeys.PubKeyFilesFlag)(&c.flagPGPKeys),
Completion: complete.PredictAnything,
Usage: "Comma-separated list of paths to files on disk containing " +
"public GPG keys OR a comma-separated list of Keybase usernames using " +
"the format \"keybase:<username>\". When supplied, the generated " +
"unseal keys will be encrypted and base64-encoded in the order " +
"specified in this list. The number of entires must match -key-shares, " +
"unless -store-shares are used.",
})
f.VarFlag(&VarFlag{
Name: "root-token-pgp-key",
Value: (*pgpkeys.PubKeyFileFlag)(&c.flagRootTokenPGPKey),
Completion: complete.PredictAnything,
Usage: "Path to a file on disk containing a binary or base64-encoded " +
"public GPG key. This can also be specified as a Keybase username " +
"using the format \"keybase:<username>\". When supplied, the generated " +
"root token will be encrypted and base64-encoded with the given public " +
"key.",
})
// Consul Options
f = set.NewFlagSet("Consul Options")
f.BoolVar(&BoolVar{
Name: "consul-auto",
Target: &c.flagConsulAuto,
Default: false,
Usage: "Perform automatic service discovery using Consul in HA mode. " +
"When all nodes in a Vault HA cluster are registered with Consul, " +
"enabling this option will trigger automatic service discovery based " +
"on the provided -consul-service value. When Consul is Vault's HA " +
"backend, this functionality is automatically enabled. Ensure the " +
"proper Consul environment variables are set (CONSUL_HTTP_ADDR, etc). " +
"When only one Vault server is discovered, it will be initialized " +
"automatically. When more than one Vault server is discovered, they " +
"will each be output for selection.",
})
f.StringVar(&StringVar{
Name: "consul-service",
Target: &c.flagConsulService,
Default: "vault",
Completion: complete.PredictAnything,
Usage: "Name of the service in Consul under which the Vault servers are " +
"registered.",
})
// HSM Options
f = set.NewFlagSet("HSM Options")
f.IntVar(&IntVar{
Name: "recovery-shares",
Target: &c.flagRecoveryShares,
Default: 5,
Completion: complete.PredictAnything,
Usage: "Number of key shares to split the recovery key into. " +
"This is only used in HSM mode.",
})
f.IntVar(&IntVar{
Name: "recovery-threshold",
Target: &c.flagRecoveryThreshold,
Default: 3,
Completion: complete.PredictAnything,
Usage: "Number of key shares required to reconstruct the recovery key. " +
"This is only used in HSM mode.",
})
f.VarFlag(&VarFlag{
Name: "recovery-pgp-keys",
Value: (*pgpkeys.PubKeyFilesFlag)(&c.flagRecoveryPGPKeys),
Completion: complete.PredictAnything,
Usage: "Behaves like -pgp-keys, but for the recovery key shares. This " +
"is only used in HSM mode.",
})
f.IntVar(&IntVar{
Name: "stored-shares",
Target: &c.flagStoredShares,
Default: 0, // No default, because we need to check if was supplied
Completion: complete.PredictAnything,
Usage: "Number of unseal keys to store on an HSM. This must be equal to " +
"-key-shares. This is only used in HSM mode.",
})
// Deprecations
// TODO: remove in 0.9.0
f.BoolVar(&BoolVar{
Name: "check", // prefer -status
Target: &c.flagCheck,
Default: false,
Hidden: true,
Usage: "",
})
f.BoolVar(&BoolVar{
Name: "auto", // prefer -consul-auto
Target: &c.flagAuto,
Default: false,
Hidden: true,
Usage: "",
})
return set
}
func (c *InitCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictNothing
return nil
}
func (c *InitCommand) AutocompleteFlags() complete.Flags {
return complete.Flags{
"-check": complete.PredictNothing,
"-key-shares": complete.PredictNothing,
"-key-threshold": complete.PredictNothing,
"-pgp-keys": complete.PredictNothing,
"-root-token-pgp-key": complete.PredictNothing,
"-recovery-shares": complete.PredictNothing,
"-recovery-threshold": complete.PredictNothing,
"-recovery-pgp-keys": complete.PredictNothing,
"-auto": complete.PredictNothing,
"-consul-service": complete.PredictNothing,
return c.Flags().Completions()
}
func (c *InitCommand) Run(args []string) int {
f := c.Flags()
if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
// Deprecations
// TODO: remove in 0.9.0
if c.flagAuto {
c.UI.Warn(wrapAtLength("WARNING! -auto is deprecated. Please use " +
"-consul-auto instead. This will be removed the next major release " +
"of Vault."))
c.flagConsulAuto = true
}
if c.flagCheck {
c.UI.Warn(wrapAtLength("WARNING! -check is deprecated. Please use " +
"-status instead. This will be removed in the next major release " +
"of Vault."))
c.flagStatus = true
}
// Build the initial init request
initReq := &api.InitRequest{
SecretShares: c.flagKeyShares,
SecretThreshold: c.flagKeyThreshold,
PGPKeys: c.flagPGPKeys,
RootTokenPGPKey: c.flagRootTokenPGPKey,
StoredShares: c.flagStoredShares,
RecoveryShares: c.flagRecoveryShares,
RecoveryThreshold: c.flagRecoveryThreshold,
RecoveryPGPKeys: c.flagRecoveryPGPKeys,
}
client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())
return 2
}
// Check auto mode
switch {
case c.flagStatus:
return c.status(client)
case c.flagConsulAuto:
return c.consulAuto(client, initReq)
default:
return c.init(client, initReq)
}
}
// consulAuto enables auto-joining via Consul.
func (c *InitCommand) consulAuto(client *api.Client, req *api.InitRequest) int {
// Capture the client original address and reset it
originalAddr := client.Address()
defer client.SetAddress(originalAddr)
// Create a client to communicate with Consul
consulClient, err := consulapi.NewClient(consulapi.DefaultConfig())
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to create Consul client:%v", err))
return 1
}
// Pull the scheme from the Vault client to determine if the Consul agent
// should talk via HTTP or HTTPS.
addr := client.Address()
clientURL, err := url.Parse(addr)
if err != nil || clientURL == nil {
c.UI.Error(fmt.Sprintf("Failed to parse Vault address %s: %s", addr, err))
return 1
}
var uninitedVaults []string
var initedVault string
// Query the nodes belonging to the cluster
services, _, err := consulClient.Catalog().Service(c.flagConsulService, "", &consulapi.QueryOptions{
AllowStale: true,
})
if err == nil {
for _, service := range services {
// Set the address on the client temporarily
vaultAddr := (&url.URL{
Scheme: clientURL.Scheme,
Host: fmt.Sprintf("%s:%d", service.ServiceAddress, service.ServicePort),
}).String()
client.SetAddress(vaultAddr)
// Check the initialization status of the discovered node
inited, err := client.Sys().InitStatus()
if err != nil {
c.UI.Error(fmt.Sprintf("Error checking init status of %q: %s", vaultAddr, err))
}
if inited {
initedVault = vaultAddr
break
}
// If we got this far, we communicated successfully with Vault, but it
// was not initialized.
uninitedVaults = append(uninitedVaults, vaultAddr)
}
}
// Get the correct export keywords and quotes for *nix vs Windows
export := "export"
quote := "\""
if runtime.GOOS == "windows" {
export = "set"
quote = ""
}
if initedVault != "" {
vaultURL, err := url.Parse(initedVault)
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to parse Vault address %q: %s", initedVault, err))
return 2
}
vaultAddr := vaultURL.String()
c.UI.Output(wrapAtLength(fmt.Sprintf(
"Discovered an initialized Vault node at %q with Consul service name "+
"%q. Set the following environment variable to target the discovered "+
"Vault server:",
vaultURL.String(), c.flagConsulService)))
c.UI.Output("")
c.UI.Output(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s", export, quote, vaultAddr, quote))
c.UI.Output("")
return 0
}
switch len(uninitedVaults) {
case 0:
c.UI.Error(fmt.Sprintf("No Vault nodes registered as %q in Consul", c.flagConsulService))
return 2
case 1:
// There was only one node found in the Vault cluster and it was
// uninitialized.
vaultURL, err := url.Parse(uninitedVaults[0])
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to parse Vault address %q: %s", initedVault, err))
return 2
}
vaultAddr := vaultURL.String()
// Update the client to connect to this Vault server
client.SetAddress(vaultAddr)
// Let the client know that initialization is perfomed on the
// discovered node.
c.UI.Output(wrapAtLength(fmt.Sprintf(
"Discovered an initialized Vault node at %q with Consul service name "+
"%q. Set the following environment variable to target the discovered "+
"Vault server:",
vaultURL.String(), c.flagConsulService)))
c.UI.Output("")
c.UI.Output(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s", export, quote, vaultAddr, quote))
c.UI.Output("")
c.UI.Output("Attempting to initialize it...")
c.UI.Output("")
// Attempt to initialize it
return c.init(client, req)
default:
// If more than one Vault node were discovered, print out all of them,
// requiring the client to update VAULT_ADDR and to run init again.
c.UI.Output(wrapAtLength(fmt.Sprintf(
"Discovered %d uninitialized Vault servers with Consul service name "+
"%q. To initialize these Vatuls, set any one of the following "+
"environment variables and run \"vault init\":",
len(uninitedVaults), c.flagConsulService)))
c.UI.Output("")
// Print valid commands to make setting the variables easier
for _, node := range uninitedVaults {
vaultURL, err := url.Parse(node)
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to parse Vault address %q: %s", initedVault, err))
return 2
}
vaultAddr := vaultURL.String()
c.UI.Output(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s", export, quote, vaultAddr, quote))
}
c.UI.Output("")
return 0
}
}
func (c *InitCommand) init(client *api.Client, req *api.InitRequest) int {
resp, err := client.Sys().Init(req)
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing: %s", err))
return 2
}
switch c.flagFormat {
case "yaml", "yml":
return c.initOutputYAML(req, resp)
case "json":
return c.initOutputJSON(req, resp)
case "table":
default:
c.UI.Error(fmt.Sprintf("Unknown format: %s", c.flagFormat))
return 1
}
for i, key := range resp.Keys {
if resp.KeysB64 != nil && len(resp.KeysB64) == len(resp.Keys) {
c.UI.Output(fmt.Sprintf("Unseal Key %d: %s", i+1, resp.KeysB64[i]))
} else {
c.UI.Output(fmt.Sprintf("Unseal Key %d: %s", i+1, key))
}
}
for i, key := range resp.RecoveryKeys {
if resp.RecoveryKeysB64 != nil && len(resp.RecoveryKeysB64) == len(resp.RecoveryKeys) {
c.UI.Output(fmt.Sprintf("Recovery Key %d: %s", i+1, resp.RecoveryKeysB64[i]))
} else {
c.UI.Output(fmt.Sprintf("Recovery Key %d: %s", i+1, key))
}
}
c.UI.Output("")
c.UI.Output(fmt.Sprintf("Initial Root Token: %s", resp.RootToken))
if req.StoredShares < 1 {
c.UI.Output("")
c.UI.Output(wrapAtLength(fmt.Sprintf(
"Vault initialized with %d key shares an a key threshold of %d. Please "+
"securely distributed the key shares printed above. When the Vault is "+
"re-sealed, restarted, or stopped, you must supply at least %d of "+
"these keys to unseal it before it can start servicing requests.",
req.SecretShares,
req.SecretThreshold,
req.SecretThreshold)))
c.UI.Output("")
c.UI.Output(wrapAtLength(fmt.Sprintf(
"Vault does not store the generated master key. Without at least %d "+
"key to reconstruct the master key, Vault will remain permanently "+
"sealed!",
req.SecretThreshold)))
c.UI.Output("")
c.UI.Output(wrapAtLength(
"It is possible to generate new unseal keys, provided you have a quorum " +
"of existing unseal keys shares. See \"vault rekey\" for more " +
"information."))
} else {
c.UI.Output("")
c.UI.Output("Success! Vault is initialized")
}
if len(resp.RecoveryKeys) > 0 {
c.UI.Output("")
c.UI.Output(wrapAtLength(fmt.Sprintf(
"Recovery key initialized with %d key shares and a key threshold of %d. "+
"Please securely distribute the key shares printed above.",
req.RecoveryShares,
req.RecoveryThreshold)))
}
return 0
}
// initOutputYAML outputs the init output as YAML.
func (c *InitCommand) initOutputYAML(req *api.InitRequest, resp *api.InitResponse) int {
b, err := yaml.Marshal(newMachineInit(req, resp))
if err != nil {
c.UI.Error(fmt.Sprintf("Error marshaling YAML: %s", err))
return 2
}
return PrintRaw(c.UI, strings.TrimSpace(string(b)))
}
// initOutputJSON outputs the init output as JSON.
func (c *InitCommand) initOutputJSON(req *api.InitRequest, resp *api.InitResponse) int {
b, err := json.Marshal(newMachineInit(req, resp))
if err != nil {
c.UI.Error(fmt.Sprintf("Error marshaling JSON: %s", err))
return 2
}
return PrintRaw(c.UI, strings.TrimSpace(string(b)))
}
// status inspects the init status of vault and returns an appropriate error
// code and message.
func (c *InitCommand) status(client *api.Client) int {
inited, err := client.Sys().InitStatus()
if err != nil {
c.UI.Error(fmt.Sprintf("Error checking init status: %s", err))
return 1 // Normally we'd return 2, but 2 means something special here
}
if inited {
c.UI.Output("Vault is initialized")
return 0
}
c.UI.Output("Vault is not initialized")
return 2
}
// machineInit is used to output information about the init command.
type machineInit struct {
UnsealKeysB64 []string `json:"unseal_keys_b64"`
UnsealKeysHex []string `json:"unseal_keys_hex"`
UnsealShares int `json:"unseal_shares"`
UnsealThreshold int `json:"unseal_threshold"`
RecoveryKeysB64 []string `json:"recovery_keys_b64"`
RecoveryKeysHex []string `json:"recovery_keys_hex"`
RecoveryShares int `json:"recovery_keys_shares"`
RecoveryThreshold int `json:"recovery_keys_threshold"`
RootToken string `json:"root_token"`
}
func newMachineInit(req *api.InitRequest, resp *api.InitResponse) *machineInit {
init := &machineInit{}
init.UnsealKeysHex = make([]string, len(resp.Keys))
for i, v := range resp.Keys {
init.UnsealKeysHex[i] = v
}
init.UnsealKeysB64 = make([]string, len(resp.KeysB64))
for i, v := range resp.KeysB64 {
init.UnsealKeysB64[i] = v
}
init.UnsealShares = req.SecretShares
init.UnsealThreshold = req.SecretThreshold
init.RecoveryKeysHex = make([]string, len(resp.RecoveryKeys))
for i, v := range resp.RecoveryKeys {
init.RecoveryKeysHex[i] = v
}
init.RecoveryKeysB64 = make([]string, len(resp.RecoveryKeysB64))
for i, v := range resp.RecoveryKeysB64 {
init.RecoveryKeysB64[i] = v
}
init.RecoveryShares = req.RecoveryShares
init.RecoveryThreshold = req.RecoveryThreshold
init.RootToken = resp.RootToken
return init
}

View File

@ -1,343 +1,361 @@
package command
import (
"bytes"
"encoding/base64"
"fmt"
"os"
"reflect"
"regexp"
"strconv"
"strings"
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault"
"github.com/keybase/go-crypto/openpgp"
"github.com/keybase/go-crypto/openpgp/packet"
"github.com/mitchellh/cli"
)
func TestInit(t *testing.T) {
ui := new(cli.MockUi)
c := &InitCommand{
Meta: meta.Meta{
Ui: ui,
func testInitCommand(tb testing.TB) (*cli.MockUi, *InitCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &InitCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
}
core := vault.TestCore(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
init, err := core.Initialized()
if err != nil {
t.Fatalf("err: %s", err)
}
if init {
t.Fatal("should not be initialized")
}
args := []string{"-address", addr}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
init, err = core.Initialized()
if err != nil {
t.Fatalf("err: %s", err)
}
if !init {
t.Fatal("should be initialized")
}
sealConf, err := core.SealAccess().BarrierConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
expected := &vault.SealConfig{
Type: "shamir",
SecretShares: 5,
SecretThreshold: 3,
}
if !reflect.DeepEqual(expected, sealConf) {
t.Fatalf("expected:\n%#v\ngot:\n%#v\n", expected, sealConf)
}
}
func TestInit_Check(t *testing.T) {
ui := new(cli.MockUi)
c := &InitCommand{
Meta: meta.Meta{
Ui: ui,
func TestInitCommand_Run(t *testing.T) {
t.Parallel()
cases := []struct {
name string
args []string
out string
code int
}{
{
"pgp_keys_multi",
[]string{
"-pgp-keys", "keybase:hashicorp",
"-pgp-keys", "keybase:jefferai",
},
"can only be specified once",
1,
},
{
"root_token_pgp_key_multi",
[]string{
"-root-token-pgp-key", "keybase:hashicorp",
"-root-token-pgp-key", "keybase:jefferai",
},
"can only be specified once",
1,
},
{
"root_token_pgp_key_multi_inline",
[]string{
"-root-token-pgp-key", "keybase:hashicorp,keybase:jefferai",
},
"can only specify one pgp key",
1,
},
{
"recovery_pgp_keys_multi",
[]string{
"-recovery-pgp-keys", "keybase:hashicorp",
"-recovery-pgp-keys", "keybase:jefferai",
},
"can only be specified once",
1,
},
{
"key_shares_pgp_less",
[]string{
"-key-shares", "10",
"-pgp-keys", "keybase:jefferai,keybase:sethvargo",
},
"incorrect number",
2,
},
{
"key_shares_pgp_more",
[]string{
"-key-shares", "1",
"-pgp-keys", "keybase:jefferai,keybase:sethvargo",
},
"incorrect number",
2,
},
}
core := vault.TestCore(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
t.Run("validations", func(t *testing.T) {
t.Parallel()
// Should return 2, not initialized
args := []string{"-address", addr, "-check"}
if code := c.Run(args); code != 2 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
for _, tc := range cases {
tc := tc
// Now initialize it
args = []string{"-address", addr}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Should return 0, initialized
args = []string{"-address", addr, "-check"}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
client, closer := testVaultServer(t)
defer closer()
init, err := core.Initialized()
if err != nil {
t.Fatalf("err: %s", err)
}
if !init {
t.Fatal("should be initialized")
}
}
ui, cmd := testInitCommand(t)
cmd.client = client
func TestInit_custom(t *testing.T) {
ui := new(cli.MockUi)
c := &InitCommand{
Meta: meta.Meta{
Ui: ui,
},
}
code := cmd.Run(tc.args)
if code != tc.code {
t.Errorf("expected %d to be %d", code, tc.code)
}
core := vault.TestCore(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
init, err := core.Initialized()
if err != nil {
t.Fatalf("err: %s", err)
}
if init {
t.Fatal("should not be initialized")
}
args := []string{
"-address", addr,
"-key-shares", "7",
"-key-threshold", "3",
}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
init, err = core.Initialized()
if err != nil {
t.Fatalf("err: %s", err)
}
if !init {
t.Fatal("should be initialized")
}
sealConf, err := core.SealAccess().BarrierConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
expected := &vault.SealConfig{
Type: "shamir",
SecretShares: 7,
SecretThreshold: 3,
}
if !reflect.DeepEqual(expected, sealConf) {
t.Fatalf("expected:\n%#v\ngot:\n%#v\n", expected, sealConf)
}
re, err := regexp.Compile("\\s+Initial Root Token:\\s+(.*)")
if err != nil {
t.Fatalf("Error compiling regex: %s", err)
}
matches := re.FindAllStringSubmatch(ui.OutputWriter.String(), -1)
if len(matches) != 1 {
t.Fatalf("Unexpected number of tokens found, got %d", len(matches))
}
rootToken := matches[0][1]
client, err := c.Client()
if err != nil {
t.Fatalf("Error fetching client: %v", err)
}
client.SetToken(rootToken)
re, err = regexp.Compile("\\s*Unseal Key \\d+: (.*)")
if err != nil {
t.Fatalf("Error compiling regex: %s", err)
}
matches = re.FindAllStringSubmatch(ui.OutputWriter.String(), -1)
if len(matches) != 7 {
t.Fatalf("Unexpected number of keys returned, got %d, matches was \n\n%#v\n\n, input was \n\n%s\n\n", len(matches), matches, ui.OutputWriter.String())
}
var unsealed bool
for i := 0; i < 3; i++ {
decodedKey, err := base64.StdEncoding.DecodeString(strings.TrimSpace(matches[i][1]))
if err != nil {
t.Fatalf("err decoding key %v: %v", matches[i][1], err)
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, tc.out) {
t.Errorf("expected %q to contain %q", combined, tc.out)
}
})
}
unsealed, err = core.Unseal(decodedKey)
if err != nil {
t.Fatalf("err during unseal: %v; key was %v", err, matches[i][1])
})
t.Run("status", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServerUninit(t)
defer closer()
ui, cmd := testInitCommand(t)
cmd.client = client
// Verify the non-init response code
code := cmd.Run([]string{
"-status",
})
if exp := 2; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
}
if !unsealed {
t.Fatal("expected to be unsealed")
}
tokenInfo, err := client.Auth().Token().LookupSelf()
if err != nil {
t.Fatalf("Error looking up root token info: %v", err)
}
if tokenInfo.Data["policies"].([]interface{})[0].(string) != "root" {
t.Fatalf("expected root policy")
}
}
func TestInit_PGP(t *testing.T) {
ui := new(cli.MockUi)
c := &InitCommand{
Meta: meta.Meta{
Ui: ui,
},
}
core := vault.TestCore(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
init, err := core.Initialized()
if err != nil {
t.Fatalf("err: %s", err)
}
if init {
t.Fatal("should not be initialized")
}
tempDir, pubFiles, err := getPubKeyFiles(t)
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempDir)
args := []string{
"-address", addr,
"-key-shares", "2",
"-pgp-keys", pubFiles[0] + ",@" + pubFiles[1] + "," + pubFiles[2],
"-key-threshold", "2",
"-root-token-pgp-key", pubFiles[0],
}
// This should fail, as key-shares does not match pgp-keys size
if code := c.Run(args); code == 0 {
t.Fatalf("bad (command should have failed): %d\n\n%s", code, ui.ErrorWriter.String())
}
args = []string{
"-address", addr,
"-key-shares", "4",
"-pgp-keys", pubFiles[0] + ",@" + pubFiles[1] + "," + pubFiles[2] + "," + pubFiles[3],
"-key-threshold", "2",
"-root-token-pgp-key", pubFiles[0],
}
ui.OutputWriter.Reset()
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
init, err = core.Initialized()
if err != nil {
t.Fatalf("err: %s", err)
}
if !init {
t.Fatal("should be initialized")
}
sealConf, err := core.SealAccess().BarrierConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
pgpKeys := []string{}
for _, pubFile := range pubFiles {
pub, err := pgpkeys.ReadPGPFile(pubFile)
if err != nil {
t.Fatalf("bad: %v", err)
// Now init to verify the init response code
if _, err := client.Sys().Init(&api.InitRequest{
SecretShares: 1,
SecretThreshold: 1,
}); err != nil {
t.Fatal(err)
}
pgpKeys = append(pgpKeys, pub)
}
expected := &vault.SealConfig{
Type: "shamir",
SecretShares: 4,
SecretThreshold: 2,
PGPKeys: pgpKeys,
}
if !reflect.DeepEqual(expected, sealConf) {
t.Fatalf("expected:\n%#v\ngot:\n%#v\n", expected, sealConf)
}
// Verify the init response code
ui, cmd = testInitCommand(t)
cmd.client = client
code = cmd.Run([]string{
"-status",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
})
re, err := regexp.Compile("\\s+Initial Root Token:\\s+(.*)")
if err != nil {
t.Fatalf("Error compiling regex: %s", err)
}
matches := re.FindAllStringSubmatch(ui.OutputWriter.String(), -1)
if len(matches) != 1 {
t.Fatalf("Unexpected number of tokens found, got %d", len(matches))
}
t.Run("default", func(t *testing.T) {
t.Parallel()
encRootToken := matches[0][1]
privKeyBytes, err := base64.StdEncoding.DecodeString(pgpkeys.TestPrivKey1)
if err != nil {
t.Fatalf("error decoding private key: %v", err)
}
ptBuf := bytes.NewBuffer(nil)
entity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(privKeyBytes)))
if err != nil {
t.Fatalf("Error parsing private key: %s", err)
}
var rootBytes []byte
rootBytes, err = base64.StdEncoding.DecodeString(encRootToken)
if err != nil {
t.Fatalf("Error decoding root token: %s", err)
}
entityList := &openpgp.EntityList{entity}
md, err := openpgp.ReadMessage(bytes.NewBuffer(rootBytes), entityList, nil, nil)
if err != nil {
t.Fatalf("Error decrypting root token: %s", err)
}
ptBuf.ReadFrom(md.UnverifiedBody)
rootToken := ptBuf.String()
client, closer := testVaultServerUninit(t)
defer closer()
parseDecryptAndTestUnsealKeys(t, ui.OutputWriter.String(), rootToken, false, nil, nil, core)
ui, cmd := testInitCommand(t)
cmd.client = client
client, err := c.Client()
if err != nil {
t.Fatalf("Error fetching client: %v", err)
}
code := cmd.Run([]string{})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
client.SetToken(rootToken)
init, err := client.Sys().InitStatus()
if err != nil {
t.Fatal(err)
}
if !init {
t.Error("expected initialized")
}
tokenInfo, err := client.Auth().Token().LookupSelf()
if err != nil {
t.Fatalf("Error looking up root token info: %v", err)
}
re := regexp.MustCompile(`Unseal Key \d+: (.+)`)
output := ui.OutputWriter.String()
match := re.FindAllStringSubmatch(output, -1)
if len(match) < 5 || len(match[0]) < 2 {
t.Fatalf("no match: %#v", match)
}
if tokenInfo.Data["policies"].([]interface{})[0].(string) != "root" {
t.Fatalf("expected root policy")
}
keys := make([]string, len(match))
for i := range match {
keys[i] = match[i][1]
}
// Try unsealing with those keys - only use 3, which is the default
// threshold.
for i, key := range keys[:3] {
resp, err := client.Sys().Unseal(key)
if err != nil {
t.Fatal(err)
}
exp := (i + 1) % 3 // 1, 2, 0
if resp.Progress != exp {
t.Errorf("expected %d to be %d", resp.Progress, exp)
}
}
status, err := client.Sys().SealStatus()
if err != nil {
t.Fatal(err)
}
if status.Sealed {
t.Errorf("expected vault to be unsealed: %#v", status)
}
})
t.Run("custom_shares_threshold", func(t *testing.T) {
t.Parallel()
keyShares, keyThreshold := 20, 15
client, closer := testVaultServerUninit(t)
defer closer()
ui, cmd := testInitCommand(t)
cmd.client = client
code := cmd.Run([]string{
"-key-shares", strconv.Itoa(keyShares),
"-key-threshold", strconv.Itoa(keyThreshold),
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
init, err := client.Sys().InitStatus()
if err != nil {
t.Fatal(err)
}
if !init {
t.Error("expected initialized")
}
re := regexp.MustCompile(`Unseal Key \d+: (.+)`)
output := ui.OutputWriter.String()
match := re.FindAllStringSubmatch(output, -1)
if len(match) < keyShares || len(match[0]) < 2 {
t.Fatalf("no match: %#v", match)
}
keys := make([]string, len(match))
for i := range match {
keys[i] = match[i][1]
}
// Try unsealing with those keys - only use 3, which is the default
// threshold.
for i, key := range keys[:keyThreshold] {
resp, err := client.Sys().Unseal(key)
if err != nil {
t.Fatal(err)
}
exp := (i + 1) % keyThreshold
if resp.Progress != exp {
t.Errorf("expected %d to be %d", resp.Progress, exp)
}
}
status, err := client.Sys().SealStatus()
if err != nil {
t.Fatal(err)
}
if status.Sealed {
t.Errorf("expected vault to be unsealed: %#v", status)
}
})
t.Run("pgp", func(t *testing.T) {
t.Parallel()
tempDir, pubFiles, err := getPubKeyFiles(t)
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempDir)
client, closer := testVaultServerUninit(t)
defer closer()
ui, cmd := testInitCommand(t)
cmd.client = client
code := cmd.Run([]string{
"-key-shares", "4",
"-key-threshold", "2",
"-pgp-keys", fmt.Sprintf("%s,@%s, %s, %s ",
pubFiles[0], pubFiles[1], pubFiles[2], pubFiles[3]),
"-root-token-pgp-key", pubFiles[0],
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
re := regexp.MustCompile(`Unseal Key \d+: (.+)`)
output := ui.OutputWriter.String()
match := re.FindAllStringSubmatch(output, -1)
if len(match) < 4 || len(match[0]) < 2 {
t.Fatalf("no match: %#v", match)
}
keys := make([]string, len(match))
for i := range match {
keys[i] = match[i][1]
}
// Try unsealing with one key
decryptedKey := testPGPDecrypt(t, pgpkeys.TestPrivKey1, keys[0])
if _, err := client.Sys().Unseal(decryptedKey); err != nil {
t.Fatal(err)
}
// Decrypt the root token
reToken := regexp.MustCompile(`Root Token: (.+)`)
match = reToken.FindAllStringSubmatch(output, -1)
if len(match) < 1 || len(match[0]) < 2 {
t.Fatalf("no match")
}
root := match[0][1]
decryptedRoot := testPGPDecrypt(t, pgpkeys.TestPrivKey1, root)
if l, exp := len(decryptedRoot), 36; l != exp {
t.Errorf("expected %d to be %d", l, exp)
}
})
t.Run("communication_failure", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServerBad(t)
defer closer()
ui, cmd := testInitCommand(t)
cmd.client = client
code := cmd.Run([]string{
"secret/foo",
})
if exp := 2; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
expected := "Error initializing: "
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected)
}
})
t.Run("no_tabs", func(t *testing.T) {
t.Parallel()
_, cmd := testInitCommand(t)
assertNoTabs(t, cmd)
})
}